Spaces:
Sleeping
Sleeping
zhangziang
commited on
Commit
·
f783161
1
Parent(s):
706e128
initial commit track binary
Browse files- .gitattributes +6 -0
- .gitignore +210 -0
- LICENSE +395 -0
- app.py +199 -0
- app_utils.py +202 -0
- assets/axis_ref.png +3 -0
- assets/axis_render.blend +3 -0
- assets/axis_tgt.png +3 -0
- axis_renderer.py +136 -0
- inference.py +238 -0
- orianyV2_demo.ipynb +0 -0
- paths.py +16 -0
- requirements.txt +22 -0
- vggt/heads/camera_head.py +162 -0
- vggt/heads/dpt_head.py +497 -0
- vggt/heads/head_act.py +125 -0
- vggt/heads/track_head.py +108 -0
- vggt/heads/track_modules/__init__.py +5 -0
- vggt/heads/track_modules/base_track_predictor.py +209 -0
- vggt/heads/track_modules/blocks.py +246 -0
- vggt/heads/track_modules/modules.py +218 -0
- vggt/heads/track_modules/utils.py +226 -0
- vggt/heads/utils.py +108 -0
- vggt/layers/__init__.py +11 -0
- vggt/layers/attention.py +98 -0
- vggt/layers/block.py +259 -0
- vggt/layers/drop_path.py +34 -0
- vggt/layers/layer_scale.py +27 -0
- vggt/layers/mlp.py +40 -0
- vggt/layers/patch_embed.py +88 -0
- vggt/layers/rope.py +188 -0
- vggt/layers/swiglu_ffn.py +72 -0
- vggt/layers/vision_transformer.py +407 -0
- vggt/models/aggregator.py +331 -0
- vggt/models/vggt.py +96 -0
- vggt/utils/geometry.py +166 -0
- vggt/utils/load_fn.py +146 -0
- vggt/utils/pose_enc.py +130 -0
- vggt/utils/rotation.py +138 -0
- vggt/utils/visual_track.py +239 -0
- vision_tower.py +279 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
@@ -33,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.blend filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 36 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 37 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 38 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/axis_ref.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/axis_tgt.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/axis_render.blend filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_demo/
|
| 2 |
+
test_demo_output/
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[codz]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# UV
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
#uv.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
#poetry.lock
|
| 112 |
+
#poetry.toml
|
| 113 |
+
|
| 114 |
+
# pdm
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 116 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 117 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 118 |
+
#pdm.lock
|
| 119 |
+
#pdm.toml
|
| 120 |
+
.pdm-python
|
| 121 |
+
.pdm-build/
|
| 122 |
+
|
| 123 |
+
# pixi
|
| 124 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 125 |
+
#pixi.lock
|
| 126 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 127 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 128 |
+
.pixi
|
| 129 |
+
|
| 130 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 131 |
+
__pypackages__/
|
| 132 |
+
|
| 133 |
+
# Celery stuff
|
| 134 |
+
celerybeat-schedule
|
| 135 |
+
celerybeat.pid
|
| 136 |
+
|
| 137 |
+
# SageMath parsed files
|
| 138 |
+
*.sage.py
|
| 139 |
+
|
| 140 |
+
# Environments
|
| 141 |
+
.env
|
| 142 |
+
.envrc
|
| 143 |
+
.venv
|
| 144 |
+
env/
|
| 145 |
+
venv/
|
| 146 |
+
ENV/
|
| 147 |
+
env.bak/
|
| 148 |
+
venv.bak/
|
| 149 |
+
|
| 150 |
+
# Spyder project settings
|
| 151 |
+
.spyderproject
|
| 152 |
+
.spyproject
|
| 153 |
+
|
| 154 |
+
# Rope project settings
|
| 155 |
+
.ropeproject
|
| 156 |
+
|
| 157 |
+
# mkdocs documentation
|
| 158 |
+
/site
|
| 159 |
+
|
| 160 |
+
# mypy
|
| 161 |
+
.mypy_cache/
|
| 162 |
+
.dmypy.json
|
| 163 |
+
dmypy.json
|
| 164 |
+
|
| 165 |
+
# Pyre type checker
|
| 166 |
+
.pyre/
|
| 167 |
+
|
| 168 |
+
# pytype static type analyzer
|
| 169 |
+
.pytype/
|
| 170 |
+
|
| 171 |
+
# Cython debug symbols
|
| 172 |
+
cython_debug/
|
| 173 |
+
|
| 174 |
+
# PyCharm
|
| 175 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 176 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 177 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 178 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 179 |
+
#.idea/
|
| 180 |
+
|
| 181 |
+
# Abstra
|
| 182 |
+
# Abstra is an AI-powered process automation framework.
|
| 183 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 184 |
+
# Learn more at https://abstra.io/docs
|
| 185 |
+
.abstra/
|
| 186 |
+
|
| 187 |
+
# Visual Studio Code
|
| 188 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 189 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 190 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 191 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 192 |
+
# .vscode/
|
| 193 |
+
|
| 194 |
+
# Ruff stuff:
|
| 195 |
+
.ruff_cache/
|
| 196 |
+
|
| 197 |
+
# PyPI configuration file
|
| 198 |
+
.pypirc
|
| 199 |
+
|
| 200 |
+
# Cursor
|
| 201 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 202 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 203 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 204 |
+
.cursorignore
|
| 205 |
+
.cursorindexingignore
|
| 206 |
+
|
| 207 |
+
# Marimo
|
| 208 |
+
marimo/_static/
|
| 209 |
+
marimo/_lsp/
|
| 210 |
+
__marimo__/
|
LICENSE
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution 4.0 International Public License
|
| 58 |
+
|
| 59 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 60 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 61 |
+
Attribution 4.0 International Public License ("Public License"). To the
|
| 62 |
+
extent this Public License may be interpreted as a contract, You are
|
| 63 |
+
granted the Licensed Rights in consideration of Your acceptance of
|
| 64 |
+
these terms and conditions, and the Licensor grants You such rights in
|
| 65 |
+
consideration of benefits the Licensor receives from making the
|
| 66 |
+
Licensed Material available under these terms and conditions.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
Section 1 -- Definitions.
|
| 70 |
+
|
| 71 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 72 |
+
Rights that is derived from or based upon the Licensed Material
|
| 73 |
+
and in which the Licensed Material is translated, altered,
|
| 74 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 75 |
+
permission under the Copyright and Similar Rights held by the
|
| 76 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 77 |
+
Material is a musical work, performance, or sound recording,
|
| 78 |
+
Adapted Material is always produced where the Licensed Material is
|
| 79 |
+
synched in timed relation with a moving image.
|
| 80 |
+
|
| 81 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 82 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 83 |
+
accordance with the terms and conditions of this Public License.
|
| 84 |
+
|
| 85 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 86 |
+
closely related to copyright including, without limitation,
|
| 87 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 88 |
+
Rights, without regard to how the rights are labeled or
|
| 89 |
+
categorized. For purposes of this Public License, the rights
|
| 90 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 91 |
+
Rights.
|
| 92 |
+
|
| 93 |
+
d. Effective Technological Measures means those measures that, in the
|
| 94 |
+
absence of proper authority, may not be circumvented under laws
|
| 95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 97 |
+
agreements.
|
| 98 |
+
|
| 99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 100 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 101 |
+
that applies to Your use of the Licensed Material.
|
| 102 |
+
|
| 103 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 104 |
+
or other material to which the Licensor applied this Public
|
| 105 |
+
License.
|
| 106 |
+
|
| 107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 108 |
+
terms and conditions of this Public License, which are limited to
|
| 109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 110 |
+
Licensed Material and that the Licensor has authority to license.
|
| 111 |
+
|
| 112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 113 |
+
under this Public License.
|
| 114 |
+
|
| 115 |
+
i. Share means to provide material to the public by any means or
|
| 116 |
+
process that requires permission under the Licensed Rights, such
|
| 117 |
+
as reproduction, public display, public performance, distribution,
|
| 118 |
+
dissemination, communication, or importation, and to make material
|
| 119 |
+
available to the public including in ways that members of the
|
| 120 |
+
public may access the material from a place and at a time
|
| 121 |
+
individually chosen by them.
|
| 122 |
+
|
| 123 |
+
j. Sui Generis Database Rights means rights other than copyright
|
| 124 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 125 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 126 |
+
as amended and/or succeeded, as well as other essentially
|
| 127 |
+
equivalent rights anywhere in the world.
|
| 128 |
+
|
| 129 |
+
k. You means the individual or entity exercising the Licensed Rights
|
| 130 |
+
under this Public License. Your has a corresponding meaning.
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
Section 2 -- Scope.
|
| 134 |
+
|
| 135 |
+
a. License grant.
|
| 136 |
+
|
| 137 |
+
1. Subject to the terms and conditions of this Public License,
|
| 138 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 139 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 140 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 141 |
+
|
| 142 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 143 |
+
in part; and
|
| 144 |
+
|
| 145 |
+
b. produce, reproduce, and Share Adapted Material.
|
| 146 |
+
|
| 147 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 148 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 149 |
+
License does not apply, and You do not need to comply with
|
| 150 |
+
its terms and conditions.
|
| 151 |
+
|
| 152 |
+
3. Term. The term of this Public License is specified in Section
|
| 153 |
+
6(a).
|
| 154 |
+
|
| 155 |
+
4. Media and formats; technical modifications allowed. The
|
| 156 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 157 |
+
all media and formats whether now known or hereafter created,
|
| 158 |
+
and to make technical modifications necessary to do so. The
|
| 159 |
+
Licensor waives and/or agrees not to assert any right or
|
| 160 |
+
authority to forbid You from making technical modifications
|
| 161 |
+
necessary to exercise the Licensed Rights, including
|
| 162 |
+
technical modifications necessary to circumvent Effective
|
| 163 |
+
Technological Measures. For purposes of this Public License,
|
| 164 |
+
simply making modifications authorized by this Section 2(a)
|
| 165 |
+
(4) never produces Adapted Material.
|
| 166 |
+
|
| 167 |
+
5. Downstream recipients.
|
| 168 |
+
|
| 169 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 170 |
+
recipient of the Licensed Material automatically
|
| 171 |
+
receives an offer from the Licensor to exercise the
|
| 172 |
+
Licensed Rights under the terms and conditions of this
|
| 173 |
+
Public License.
|
| 174 |
+
|
| 175 |
+
b. No downstream restrictions. You may not offer or impose
|
| 176 |
+
any additional or different terms or conditions on, or
|
| 177 |
+
apply any Effective Technological Measures to, the
|
| 178 |
+
Licensed Material if doing so restricts exercise of the
|
| 179 |
+
Licensed Rights by any recipient of the Licensed
|
| 180 |
+
Material.
|
| 181 |
+
|
| 182 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 183 |
+
may be construed as permission to assert or imply that You
|
| 184 |
+
are, or that Your use of the Licensed Material is, connected
|
| 185 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 186 |
+
the Licensor or others designated to receive attribution as
|
| 187 |
+
provided in Section 3(a)(1)(A)(i).
|
| 188 |
+
|
| 189 |
+
b. Other rights.
|
| 190 |
+
|
| 191 |
+
1. Moral rights, such as the right of integrity, are not
|
| 192 |
+
licensed under this Public License, nor are publicity,
|
| 193 |
+
privacy, and/or other similar personality rights; however, to
|
| 194 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 195 |
+
assert any such rights held by the Licensor to the limited
|
| 196 |
+
extent necessary to allow You to exercise the Licensed
|
| 197 |
+
Rights, but not otherwise.
|
| 198 |
+
|
| 199 |
+
2. Patent and trademark rights are not licensed under this
|
| 200 |
+
Public License.
|
| 201 |
+
|
| 202 |
+
3. To the extent possible, the Licensor waives any right to
|
| 203 |
+
collect royalties from You for the exercise of the Licensed
|
| 204 |
+
Rights, whether directly or through a collecting society
|
| 205 |
+
under any voluntary or waivable statutory or compulsory
|
| 206 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 207 |
+
reserves any right to collect such royalties.
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
Section 3 -- License Conditions.
|
| 211 |
+
|
| 212 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 213 |
+
following conditions.
|
| 214 |
+
|
| 215 |
+
a. Attribution.
|
| 216 |
+
|
| 217 |
+
1. If You Share the Licensed Material (including in modified
|
| 218 |
+
form), You must:
|
| 219 |
+
|
| 220 |
+
a. retain the following if it is supplied by the Licensor
|
| 221 |
+
with the Licensed Material:
|
| 222 |
+
|
| 223 |
+
i. identification of the creator(s) of the Licensed
|
| 224 |
+
Material and any others designated to receive
|
| 225 |
+
attribution, in any reasonable manner requested by
|
| 226 |
+
the Licensor (including by pseudonym if
|
| 227 |
+
designated);
|
| 228 |
+
|
| 229 |
+
ii. a copyright notice;
|
| 230 |
+
|
| 231 |
+
iii. a notice that refers to this Public License;
|
| 232 |
+
|
| 233 |
+
iv. a notice that refers to the disclaimer of
|
| 234 |
+
warranties;
|
| 235 |
+
|
| 236 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 237 |
+
extent reasonably practicable;
|
| 238 |
+
|
| 239 |
+
b. indicate if You modified the Licensed Material and
|
| 240 |
+
retain an indication of any previous modifications; and
|
| 241 |
+
|
| 242 |
+
c. indicate the Licensed Material is licensed under this
|
| 243 |
+
Public License, and include the text of, or the URI or
|
| 244 |
+
hyperlink to, this Public License.
|
| 245 |
+
|
| 246 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 247 |
+
reasonable manner based on the medium, means, and context in
|
| 248 |
+
which You Share the Licensed Material. For example, it may be
|
| 249 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 250 |
+
hyperlink to a resource that includes the required
|
| 251 |
+
information.
|
| 252 |
+
|
| 253 |
+
3. If requested by the Licensor, You must remove any of the
|
| 254 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 255 |
+
reasonably practicable.
|
| 256 |
+
|
| 257 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 258 |
+
License You apply must not prevent recipients of the Adapted
|
| 259 |
+
Material from complying with this Public License.
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
Section 4 -- Sui Generis Database Rights.
|
| 263 |
+
|
| 264 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 265 |
+
apply to Your use of the Licensed Material:
|
| 266 |
+
|
| 267 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 268 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 269 |
+
portion of the contents of the database;
|
| 270 |
+
|
| 271 |
+
b. if You include all or a substantial portion of the database
|
| 272 |
+
contents in a database in which You have Sui Generis Database
|
| 273 |
+
Rights, then the database in which You have Sui Generis Database
|
| 274 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 275 |
+
|
| 276 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 277 |
+
all or a substantial portion of the contents of the database.
|
| 278 |
+
|
| 279 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 280 |
+
replace Your obligations under this Public License where the Licensed
|
| 281 |
+
Rights include other Copyright and Similar Rights.
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 285 |
+
|
| 286 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 287 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 288 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 289 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 290 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 291 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 292 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 293 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 294 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 295 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 296 |
+
|
| 297 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 298 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 299 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 300 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 301 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 302 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 303 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 304 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 305 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 306 |
+
|
| 307 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 308 |
+
above shall be interpreted in a manner that, to the extent
|
| 309 |
+
possible, most closely approximates an absolute disclaimer and
|
| 310 |
+
waiver of all liability.
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
Section 6 -- Term and Termination.
|
| 314 |
+
|
| 315 |
+
a. This Public License applies for the term of the Copyright and
|
| 316 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 317 |
+
this Public License, then Your rights under this Public License
|
| 318 |
+
terminate automatically.
|
| 319 |
+
|
| 320 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 321 |
+
Section 6(a), it reinstates:
|
| 322 |
+
|
| 323 |
+
1. automatically as of the date the violation is cured, provided
|
| 324 |
+
it is cured within 30 days of Your discovery of the
|
| 325 |
+
violation; or
|
| 326 |
+
|
| 327 |
+
2. upon express reinstatement by the Licensor.
|
| 328 |
+
|
| 329 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 330 |
+
right the Licensor may have to seek remedies for Your violations
|
| 331 |
+
of this Public License.
|
| 332 |
+
|
| 333 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 334 |
+
Licensed Material under separate terms or conditions or stop
|
| 335 |
+
distributing the Licensed Material at any time; however, doing so
|
| 336 |
+
will not terminate this Public License.
|
| 337 |
+
|
| 338 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 339 |
+
License.
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
Section 7 -- Other Terms and Conditions.
|
| 343 |
+
|
| 344 |
+
a. The Licensor shall not be bound by any additional or different
|
| 345 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 346 |
+
|
| 347 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 348 |
+
Licensed Material not stated herein are separate from and
|
| 349 |
+
independent of the terms and conditions of this Public License.
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
Section 8 -- Interpretation.
|
| 353 |
+
|
| 354 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 355 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 356 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 357 |
+
be made without permission under this Public License.
|
| 358 |
+
|
| 359 |
+
b. To the extent possible, if any provision of this Public License is
|
| 360 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 361 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 362 |
+
cannot be reformed, it shall be severed from this Public License
|
| 363 |
+
without affecting the enforceability of the remaining terms and
|
| 364 |
+
conditions.
|
| 365 |
+
|
| 366 |
+
c. No term or condition of this Public License will be waived and no
|
| 367 |
+
failure to comply consented to unless expressly agreed to by the
|
| 368 |
+
Licensor.
|
| 369 |
+
|
| 370 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 371 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 372 |
+
that apply to the Licensor or You, including from the legal
|
| 373 |
+
processes of any jurisdiction or authority.
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
=======================================================================
|
| 377 |
+
|
| 378 |
+
Creative Commons is not a party to its public
|
| 379 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 380 |
+
its public licenses to material it publishes and in those instances
|
| 381 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 382 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 383 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 384 |
+
material is shared under a Creative Commons public license or as
|
| 385 |
+
otherwise permitted by the Creative Commons policies published at
|
| 386 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 387 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 388 |
+
of Creative Commons without its prior written consent including,
|
| 389 |
+
without limitation, in connection with any unauthorized modifications
|
| 390 |
+
to any of its public licenses or any other arrangements,
|
| 391 |
+
understandings, or agreements concerning use of licensed material. For
|
| 392 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 393 |
+
public licenses.
|
| 394 |
+
|
| 395 |
+
Creative Commons may be contacted at creativecommons.org.
|
app.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# ====== 你的原有导入和模型加载保持不变 ======
|
| 7 |
+
from paths import *
|
| 8 |
+
from vision_tower import VGGT_OriAny_Ref
|
| 9 |
+
from inference import *
|
| 10 |
+
from app_utils import *
|
| 11 |
+
from axis_renderer import BlendRenderer
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
|
| 15 |
+
print(ckpt_path)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 19 |
+
# device = 'cuda:0'
|
| 20 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
|
| 22 |
+
model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
|
| 23 |
+
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
|
| 24 |
+
model.eval()
|
| 25 |
+
model = model.to(device)
|
| 26 |
+
print('Model loaded.')
|
| 27 |
+
|
| 28 |
+
axis_renderer = BlendRenderer(RENDER_FILE)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ====== 工具函数:安全图像处理 ======
|
| 32 |
+
def safe_image_input(image):
|
| 33 |
+
"""确保返回合法的 numpy 数组或 None"""
|
| 34 |
+
if image is None:
|
| 35 |
+
return None
|
| 36 |
+
if isinstance(image, np.ndarray):
|
| 37 |
+
return image
|
| 38 |
+
try:
|
| 39 |
+
return np.array(image)
|
| 40 |
+
except Exception:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ====== 推理函数 ======
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def run_inference(image_ref, image_tgt, do_rm_bkg):
|
| 47 |
+
image_ref = safe_image_input(image_ref)
|
| 48 |
+
image_tgt = safe_image_input(image_tgt)
|
| 49 |
+
|
| 50 |
+
if image_ref is None:
|
| 51 |
+
raise gr.Error("Please upload a reference image before running inference.")
|
| 52 |
+
|
| 53 |
+
# 转为 PIL(用于背景去除和后续叠加)
|
| 54 |
+
pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
|
| 55 |
+
pil_tgt = None
|
| 56 |
+
|
| 57 |
+
if image_tgt is not None:
|
| 58 |
+
pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
|
| 59 |
+
if do_rm_bkg:
|
| 60 |
+
pil_ref = background_preprocess(pil_ref, True)
|
| 61 |
+
pil_tgt = background_preprocess(pil_tgt, True)
|
| 62 |
+
else:
|
| 63 |
+
if do_rm_bkg:
|
| 64 |
+
pil_ref = background_preprocess(pil_ref, True)
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
ans_dict = inf_single_case(model, pil_ref, pil_tgt)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print("Inference error:", e)
|
| 70 |
+
raise gr.Error(f"Inference failed: {str(e)}")
|
| 71 |
+
|
| 72 |
+
def safe_float(val, default=0.0):
|
| 73 |
+
try:
|
| 74 |
+
return float(val)
|
| 75 |
+
except:
|
| 76 |
+
return float(default)
|
| 77 |
+
|
| 78 |
+
az = safe_float(ans_dict.get('ref_az_pred', 0))
|
| 79 |
+
el = safe_float(ans_dict.get('ref_el_pred', 0))
|
| 80 |
+
ro = safe_float(ans_dict.get('ref_ro_pred', 0))
|
| 81 |
+
alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
|
| 82 |
+
|
| 83 |
+
# ===== 渲染参考图的坐标轴 =====
|
| 84 |
+
axis_renderer.render_axis(az, el, ro, alpha, save_path=REF_AXIS_IMAGE)
|
| 85 |
+
axis_ref = Image.open(REF_AXIS_IMAGE).convert("RGBA")
|
| 86 |
+
|
| 87 |
+
# 叠加坐标轴到参考图
|
| 88 |
+
# 确保尺寸一致
|
| 89 |
+
if axis_ref.size != pil_ref.size:
|
| 90 |
+
axis_ref = axis_ref.resize(pil_ref.size, Image.LANCZOS)
|
| 91 |
+
pil_ref_rgba = pil_ref.convert("RGBA")
|
| 92 |
+
overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
|
| 93 |
+
|
| 94 |
+
# ===== 处理目标图(如果有)=====
|
| 95 |
+
if pil_tgt is not None:
|
| 96 |
+
rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
|
| 97 |
+
rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
|
| 98 |
+
rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
|
| 99 |
+
|
| 100 |
+
tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
|
| 101 |
+
print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
|
| 102 |
+
|
| 103 |
+
# target 默认 alpha=1(根据你的说明)
|
| 104 |
+
axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=TGT_AXIS_IMAGE)
|
| 105 |
+
axis_tgt = Image.open(TGT_AXIS_IMAGE).convert("RGBA")
|
| 106 |
+
|
| 107 |
+
if axis_tgt.size != pil_tgt.size:
|
| 108 |
+
axis_tgt = axis_tgt.resize(pil_tgt.size, Image.LANCZOS)
|
| 109 |
+
pil_tgt_rgba = pil_tgt.convert("RGBA")
|
| 110 |
+
overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
|
| 111 |
+
else:
|
| 112 |
+
overlaid_tgt = None
|
| 113 |
+
rel_az = rel_el = rel_ro = 0.0
|
| 114 |
+
|
| 115 |
+
return [
|
| 116 |
+
overlaid_ref, # 渲染+叠加后的参考图
|
| 117 |
+
overlaid_tgt, # 渲染+叠加后的目标图(可能为 None)
|
| 118 |
+
f"{az:.2f}",
|
| 119 |
+
f"{el:.2f}",
|
| 120 |
+
f"{ro:.2f}",
|
| 121 |
+
str(alpha),
|
| 122 |
+
f"{rel_az:.2f}",
|
| 123 |
+
f"{rel_el:.2f}",
|
| 124 |
+
f"{rel_ro:.2f}",
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ====== Gradio Blocks UI ======
|
| 129 |
+
with gr.Blocks(title="Orient-Anything Demo") as demo:
|
| 130 |
+
gr.Markdown("# Orient-Anything Demo")
|
| 131 |
+
gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
|
| 132 |
+
|
| 133 |
+
with gr.Row():
|
| 134 |
+
# 左侧:输入图像(参考图 + 目标图,同一行)
|
| 135 |
+
with gr.Column():
|
| 136 |
+
with gr.Row():
|
| 137 |
+
ref_img = gr.Image(
|
| 138 |
+
label="Reference Image (required)",
|
| 139 |
+
type="numpy",
|
| 140 |
+
height=256,
|
| 141 |
+
width=256,
|
| 142 |
+
value=None,
|
| 143 |
+
interactive=True
|
| 144 |
+
)
|
| 145 |
+
tgt_img = gr.Image(
|
| 146 |
+
label="Target Image (optional)",
|
| 147 |
+
type="numpy",
|
| 148 |
+
height=256,
|
| 149 |
+
width=256,
|
| 150 |
+
value=None,
|
| 151 |
+
interactive=True
|
| 152 |
+
)
|
| 153 |
+
rm_bkg = gr.Checkbox(label="Remove Background", value=True)
|
| 154 |
+
run_btn = gr.Button("Run Inference", variant="primary")
|
| 155 |
+
|
| 156 |
+
# 右侧:结果图像 + 文本输出
|
| 157 |
+
with gr.Column():
|
| 158 |
+
# 结果图像:参考结果 + 目标结果(可选)
|
| 159 |
+
with gr.Row():
|
| 160 |
+
res_ref_img = gr.Image(
|
| 161 |
+
label="Rendered Reference",
|
| 162 |
+
type="pil",
|
| 163 |
+
height=256,
|
| 164 |
+
width=256,
|
| 165 |
+
interactive=False
|
| 166 |
+
)
|
| 167 |
+
res_tgt_img = gr.Image(
|
| 168 |
+
label="Rendered Target (if provided)",
|
| 169 |
+
type="pil",
|
| 170 |
+
height=256,
|
| 171 |
+
width=256,
|
| 172 |
+
interactive=False
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# 文本输出放在图像下方
|
| 176 |
+
with gr.Row():
|
| 177 |
+
with gr.Column():
|
| 178 |
+
gr.Markdown("### Absolute Pose (Reference)")
|
| 179 |
+
az_out = gr.Textbox(label="Azimuth (0~360°)")
|
| 180 |
+
el_out = gr.Textbox(label="Polar (-90~90°)")
|
| 181 |
+
ro_out = gr.Textbox(label="Rotation (-90~90°)")
|
| 182 |
+
alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
|
| 183 |
+
with gr.Column():
|
| 184 |
+
gr.Markdown("### Relative Pose (Target w.r.t Reference)")
|
| 185 |
+
rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
|
| 186 |
+
rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
|
| 187 |
+
rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
|
| 188 |
+
|
| 189 |
+
# 绑定点击事件
|
| 190 |
+
run_btn.click(
|
| 191 |
+
fn=run_inference,
|
| 192 |
+
inputs=[ref_img, tgt_img, rm_bkg],
|
| 193 |
+
outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
|
| 194 |
+
preprocess=True,
|
| 195 |
+
postprocess=True
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# 启动(禁用 API 避免 schema 错误)
|
| 199 |
+
demo.launch(show_api=False)
|
app_utils.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import rembg
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image, ImageOps
|
| 6 |
+
import PIL
|
| 7 |
+
from typing import Any
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
def resize_foreground(
|
| 12 |
+
image: Image,
|
| 13 |
+
ratio: float,
|
| 14 |
+
) -> Image:
|
| 15 |
+
image = np.array(image)
|
| 16 |
+
assert image.shape[-1] == 4
|
| 17 |
+
alpha = np.where(image[..., 3] > 0)
|
| 18 |
+
y1, y2, x1, x2 = (
|
| 19 |
+
alpha[0].min(),
|
| 20 |
+
alpha[0].max(),
|
| 21 |
+
alpha[1].min(),
|
| 22 |
+
alpha[1].max(),
|
| 23 |
+
)
|
| 24 |
+
# crop the foreground
|
| 25 |
+
fg = image[y1:y2, x1:x2]
|
| 26 |
+
# pad to square
|
| 27 |
+
size = max(fg.shape[0], fg.shape[1])
|
| 28 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
| 29 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
| 30 |
+
new_image = np.pad(
|
| 31 |
+
fg,
|
| 32 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 33 |
+
mode="constant",
|
| 34 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# compute padding according to the ratio
|
| 38 |
+
new_size = int(new_image.shape[0] / ratio)
|
| 39 |
+
# pad to size, double side
|
| 40 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
| 41 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
| 42 |
+
new_image = np.pad(
|
| 43 |
+
new_image,
|
| 44 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 45 |
+
mode="constant",
|
| 46 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 47 |
+
)
|
| 48 |
+
new_image = Image.fromarray(new_image)
|
| 49 |
+
return new_image
|
| 50 |
+
|
| 51 |
+
def remove_background(image: Image,
|
| 52 |
+
rembg_session: Any = None,
|
| 53 |
+
force: bool = False,
|
| 54 |
+
**rembg_kwargs,
|
| 55 |
+
) -> Image:
|
| 56 |
+
do_remove = True
|
| 57 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 58 |
+
do_remove = False
|
| 59 |
+
do_remove = do_remove or force
|
| 60 |
+
if do_remove:
|
| 61 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 62 |
+
return image
|
| 63 |
+
|
| 64 |
+
def background_preprocess(input_image, do_remove_background):
|
| 65 |
+
if input_image is None:
|
| 66 |
+
return None
|
| 67 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
| 68 |
+
|
| 69 |
+
if do_remove_background:
|
| 70 |
+
input_image = remove_background(input_image, rembg_session)
|
| 71 |
+
input_image = resize_foreground(input_image, 0.85)
|
| 72 |
+
|
| 73 |
+
return input_image
|
| 74 |
+
|
| 75 |
+
def axis_angle_rotation_batch(axis: torch.Tensor, theta: torch.Tensor, homogeneous: bool = False) -> torch.Tensor:
|
| 76 |
+
"""
|
| 77 |
+
支持batch输入的版本:
|
| 78 |
+
Args:
|
| 79 |
+
axis: (3,) or (N,3)
|
| 80 |
+
theta: scalar or (N,)
|
| 81 |
+
homogeneous: 是否输出 4x4 齐次矩阵
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
(N,3,3) or (N,4,4)
|
| 85 |
+
"""
|
| 86 |
+
axis = torch.as_tensor(axis).float()
|
| 87 |
+
theta = torch.as_tensor(theta).float()
|
| 88 |
+
|
| 89 |
+
if axis.ndim == 1:
|
| 90 |
+
axis = axis.unsqueeze(0) # (1,3)
|
| 91 |
+
if theta.ndim == 0:
|
| 92 |
+
theta = theta.unsqueeze(0) # (1,)
|
| 93 |
+
|
| 94 |
+
N = axis.shape[0]
|
| 95 |
+
|
| 96 |
+
# normalize axis
|
| 97 |
+
axis = axis / torch.norm(axis, dim=1, keepdim=True)
|
| 98 |
+
|
| 99 |
+
x, y, z = axis[:, 0], axis[:, 1], axis[:, 2]
|
| 100 |
+
cos_t = torch.cos(theta)
|
| 101 |
+
sin_t = torch.sin(theta)
|
| 102 |
+
one_minus_cos = 1 - cos_t
|
| 103 |
+
|
| 104 |
+
# 公式展开
|
| 105 |
+
rot = torch.zeros((N, 3, 3), dtype=axis.dtype, device=axis.device)
|
| 106 |
+
rot[:, 0, 0] = cos_t + x*x*one_minus_cos
|
| 107 |
+
rot[:, 0, 1] = x*y*one_minus_cos - z*sin_t
|
| 108 |
+
rot[:, 0, 2] = x*z*one_minus_cos + y*sin_t
|
| 109 |
+
rot[:, 1, 0] = y*x*one_minus_cos + z*sin_t
|
| 110 |
+
rot[:, 1, 1] = cos_t + y*y*one_minus_cos
|
| 111 |
+
rot[:, 1, 2] = y*z*one_minus_cos - x*sin_t
|
| 112 |
+
rot[:, 2, 0] = z*x*one_minus_cos - y*sin_t
|
| 113 |
+
rot[:, 2, 1] = z*y*one_minus_cos + x*sin_t
|
| 114 |
+
rot[:, 2, 2] = cos_t + z*z*one_minus_cos
|
| 115 |
+
|
| 116 |
+
if homogeneous:
|
| 117 |
+
rot_homo = torch.eye(4, dtype=axis.dtype, device=axis.device).unsqueeze(0).repeat(N, 1, 1)
|
| 118 |
+
rot_homo[:, :3, :3] = rot
|
| 119 |
+
return rot_homo
|
| 120 |
+
|
| 121 |
+
return rot
|
| 122 |
+
|
| 123 |
+
def azi_ele_rot_to_Obj_Rmatrix_batch(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""支持batch输入的: (azi, ele, rot) -> R matrix (N,3,3)"""
|
| 125 |
+
# 转成tensor
|
| 126 |
+
azi = torch.as_tensor(azi).float() * torch.pi / 180.
|
| 127 |
+
ele = torch.as_tensor(ele).float() * torch.pi / 180.
|
| 128 |
+
rot = torch.as_tensor(rot).float() * torch.pi / 180.
|
| 129 |
+
|
| 130 |
+
# 保证有batch维度
|
| 131 |
+
if azi.ndim == 0:
|
| 132 |
+
azi = azi.unsqueeze(0)
|
| 133 |
+
if ele.ndim == 0:
|
| 134 |
+
ele = ele.unsqueeze(0)
|
| 135 |
+
if rot.ndim == 0:
|
| 136 |
+
rot = rot.unsqueeze(0)
|
| 137 |
+
|
| 138 |
+
N = azi.shape[0]
|
| 139 |
+
|
| 140 |
+
device = azi.device
|
| 141 |
+
dtype = azi.dtype
|
| 142 |
+
|
| 143 |
+
z0_axis = torch.tensor([0.,0.,1.], device=device, dtype=dtype).expand(N, -1)
|
| 144 |
+
y0_axis = torch.tensor([0.,1.,0.], device=device, dtype=dtype).expand(N, -1)
|
| 145 |
+
x0_axis = torch.tensor([1.,0.,0.], device=device, dtype=dtype).expand(N, -1)
|
| 146 |
+
# print(z0_axis.shape, azi.shape)
|
| 147 |
+
R_azi = axis_angle_rotation_batch(z0_axis, -1 * azi)
|
| 148 |
+
R_ele = axis_angle_rotation_batch(y0_axis, ele)
|
| 149 |
+
R_rot = axis_angle_rotation_batch(x0_axis, rot)
|
| 150 |
+
|
| 151 |
+
R_res = R_rot @ R_ele @ R_azi
|
| 152 |
+
return R_res
|
| 153 |
+
|
| 154 |
+
def Cam_Rmatrix_to_azi_ele_rot_batch(R: torch.Tensor):
|
| 155 |
+
"""支持batch输入的: R matrix -> (azi, ele, rot),角度制 (度)"""
|
| 156 |
+
R = torch.as_tensor(R).float()
|
| 157 |
+
|
| 158 |
+
# 如果是(3,3),补batch维度
|
| 159 |
+
if R.ndim == 2:
|
| 160 |
+
R = R.unsqueeze(0)
|
| 161 |
+
|
| 162 |
+
r0 = R[:, :, 0] # shape (N,3)
|
| 163 |
+
r1 = R[:, :, 1]
|
| 164 |
+
r2 = R[:, :, 2]
|
| 165 |
+
|
| 166 |
+
ele = torch.asin(r0[:, 2]) # r0.z
|
| 167 |
+
cos_ele = torch.cos(ele)
|
| 168 |
+
|
| 169 |
+
# 创建默认azi、rot
|
| 170 |
+
azi = torch.zeros_like(ele)
|
| 171 |
+
rot = torch.zeros_like(ele)
|
| 172 |
+
|
| 173 |
+
# 正常情况
|
| 174 |
+
normal_mask = (cos_ele.abs() >= 1e-6)
|
| 175 |
+
if normal_mask.any():
|
| 176 |
+
azi[normal_mask] = torch.atan2(r0[normal_mask, 1], r0[normal_mask, 0])
|
| 177 |
+
rot[normal_mask] = torch.atan2(-r1[normal_mask, 2], r2[normal_mask, 2])
|
| 178 |
+
|
| 179 |
+
# Gimbal lock特殊情况
|
| 180 |
+
gimbal_mask = ~normal_mask
|
| 181 |
+
if gimbal_mask.any():
|
| 182 |
+
# 这里设azi为0
|
| 183 |
+
azi[gimbal_mask] = 0.0
|
| 184 |
+
rot[gimbal_mask] = torch.atan2(-r1[gimbal_mask, 0], r1[gimbal_mask, 1])
|
| 185 |
+
|
| 186 |
+
# 弧度转角度
|
| 187 |
+
azi = azi * 180. / torch.pi
|
| 188 |
+
ele = ele * 180. / torch.pi
|
| 189 |
+
rot = rot * 180. / torch.pi
|
| 190 |
+
|
| 191 |
+
return azi, ele, rot
|
| 192 |
+
|
| 193 |
+
def Get_target_azi_ele_rot(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor, rel_azi: torch.Tensor, rel_ele: torch.Tensor, rel_rot: torch.Tensor):
|
| 194 |
+
Rmat0 = azi_ele_rot_to_Obj_Rmatrix_batch(azi = azi , ele = ele , rot = rot)
|
| 195 |
+
Rmat_rel = azi_ele_rot_to_Obj_Rmatrix_batch(azi = rel_azi, ele = rel_ele, rot = rel_rot)
|
| 196 |
+
# Rmat_rel = Rmat1 @ Rmat0.permute(0, 2, 1)
|
| 197 |
+
# azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat_rel.permute(0, 2, 1))
|
| 198 |
+
|
| 199 |
+
Rmat1 = Rmat_rel @ Rmat0
|
| 200 |
+
azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat1.permute(0, 2, 1))
|
| 201 |
+
|
| 202 |
+
return azi_out, ele_out, rot_out
|
assets/axis_ref.png
ADDED
|
Git LFS Details
|
assets/axis_render.blend
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76f8fd3b4ce574a6973ed9637a6c8194fcf46edf72f9266786036c21cf7023a1
|
| 3 |
+
size 2136460
|
assets/axis_tgt.png
ADDED
|
Git LFS Details
|
axis_renderer.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bpy
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from paths import *
|
| 5 |
+
|
| 6 |
+
class BlendRenderer:
|
| 7 |
+
def __init__(self, blend_file_path=RENDER_FILE):
|
| 8 |
+
"""
|
| 9 |
+
初始化渲染器,加载指定的 .blend 文件并进行基础设置。
|
| 10 |
+
|
| 11 |
+
:param blend_file_path: 要加载的 .blend 文件的完整路径
|
| 12 |
+
"""
|
| 13 |
+
if not os.path.isfile(blend_file_path):
|
| 14 |
+
raise FileNotFoundError(f"Blend file not found: {blend_file_path}")
|
| 15 |
+
|
| 16 |
+
# 加载 blend 文件
|
| 17 |
+
bpy.ops.wm.open_mainfile(filepath=blend_file_path)
|
| 18 |
+
|
| 19 |
+
# 设置渲染引擎为 Cycles
|
| 20 |
+
bpy.context.scene.render.engine = 'CYCLES'
|
| 21 |
+
|
| 22 |
+
# 使用 CPU 渲染
|
| 23 |
+
bpy.context.scene.cycles.device = 'CPU'
|
| 24 |
+
|
| 25 |
+
# 设置采样数为 4
|
| 26 |
+
bpy.context.scene.cycles.samples = 4
|
| 27 |
+
|
| 28 |
+
# 设置所有反弹次数为 4(包括 diffuse, glossy, transmission, etc.)
|
| 29 |
+
bpy.context.scene.cycles.max_bounces = 4
|
| 30 |
+
|
| 31 |
+
# 设置渲染分辨率
|
| 32 |
+
bpy.context.scene.render.resolution_x = 512
|
| 33 |
+
bpy.context.scene.render.resolution_y = 512
|
| 34 |
+
bpy.context.scene.render.resolution_percentage = 100
|
| 35 |
+
|
| 36 |
+
# 启用透明背景(RGBA)
|
| 37 |
+
bpy.context.scene.render.film_transparent = True
|
| 38 |
+
|
| 39 |
+
# 遍历所有对象,初始化渲染可见性
|
| 40 |
+
for obj in bpy.data.objects:
|
| 41 |
+
if obj.type == 'LIGHT':
|
| 42 |
+
obj.hide_render = False
|
| 43 |
+
elif obj.type == 'CAMERA':
|
| 44 |
+
obj.hide_render = False
|
| 45 |
+
elif obj.type == 'MESH':
|
| 46 |
+
obj.hide_render = True # 默认所有网格不参与渲染
|
| 47 |
+
|
| 48 |
+
# 设置活动摄像机(选第一个)
|
| 49 |
+
cameras = [obj for obj in bpy.data.objects if obj.type == 'CAMERA']
|
| 50 |
+
if cameras:
|
| 51 |
+
bpy.context.scene.camera = cameras[0]
|
| 52 |
+
|
| 53 |
+
print(f"Loaded blend file: {blend_file_path}")
|
| 54 |
+
print("Render settings applied: 512x512, CPU, samples=4, bounces=4, transparent background.")
|
| 55 |
+
|
| 56 |
+
self.alpha_axis_map = {
|
| 57 |
+
0: "单轴平面",
|
| 58 |
+
1: "三轴",
|
| 59 |
+
2: "双向标注",
|
| 60 |
+
4: "四向标注"
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _get_all_children(self, obj):
|
| 65 |
+
"""递归获取对象的所有子对象(包括嵌套子级)"""
|
| 66 |
+
children = []
|
| 67 |
+
for child in obj.children:
|
| 68 |
+
children.append(child)
|
| 69 |
+
children.extend(self._get_all_children(child))
|
| 70 |
+
return children
|
| 71 |
+
|
| 72 |
+
def render_axis(self, azi, ele, rot, alpha, save_path):
|
| 73 |
+
"""
|
| 74 |
+
渲染特定方向的图像。
|
| 75 |
+
|
| 76 |
+
:param azi: 方位角(绕 Z 轴旋转,弧度)
|
| 77 |
+
:param ele: 仰角(绕 Y 轴旋转,弧度)
|
| 78 |
+
:param rot: 自转(绕 X 轴旋转,弧度)
|
| 79 |
+
:param save_path: 渲染结果保存路径(如 '/output/render.png')
|
| 80 |
+
"""
|
| 81 |
+
# 遍历所有对象,初始化渲染可见性
|
| 82 |
+
for obj in bpy.data.objects:
|
| 83 |
+
if obj.type == 'LIGHT':
|
| 84 |
+
obj.hide_render = False
|
| 85 |
+
elif obj.type == 'CAMERA':
|
| 86 |
+
obj.hide_render = False
|
| 87 |
+
elif obj.type == 'MESH':
|
| 88 |
+
obj.hide_render = True # 默认所有网格不参与渲染
|
| 89 |
+
# 根据 alpha 选择目标对象
|
| 90 |
+
target_name = self.alpha_axis_map.get(alpha, "单轴平面")
|
| 91 |
+
target_obj = None
|
| 92 |
+
for obj in bpy.data.objects:
|
| 93 |
+
# if obj.type == 'MESH' and obj.name == target_name:
|
| 94 |
+
if obj.name == target_name:
|
| 95 |
+
target_obj = obj
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
if target_obj is None:
|
| 99 |
+
raise ValueError(f'Object named "{target_name}" not found in the scene.')
|
| 100 |
+
|
| 101 |
+
# 获取该对象及其所有子对象
|
| 102 |
+
all_objects_to_render = [target_obj] + self._get_all_children(target_obj)
|
| 103 |
+
|
| 104 |
+
# 设置它们参与渲染
|
| 105 |
+
for obj in all_objects_to_render:
|
| 106 |
+
if obj.type == 'MESH':
|
| 107 |
+
obj.hide_render = False
|
| 108 |
+
|
| 109 |
+
# 设置旋转(ZYX 顺序:Z=azi, Y=ele, X=rot → Euler XYZ = (rot, ele, azi))
|
| 110 |
+
# 注意:Blender 使用弧度
|
| 111 |
+
target_obj.rotation_mode = 'ZYX' # 确保使用欧拉角 ZYX 模式
|
| 112 |
+
target_obj.rotation_euler = (rot*math.pi/180, ele*math.pi/180, -azi*math.pi/180)
|
| 113 |
+
|
| 114 |
+
# 确保路径目录存在
|
| 115 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 116 |
+
|
| 117 |
+
# 设置输出路径
|
| 118 |
+
bpy.context.scene.render.filepath = save_path
|
| 119 |
+
|
| 120 |
+
# 执行渲染并保存
|
| 121 |
+
bpy.ops.render.render(write_still=True)
|
| 122 |
+
|
| 123 |
+
print(f"Rendered and saved to: {save_path}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
renderer = BlendRenderer(RENDER_FILE)
|
| 128 |
+
# Example usage:
|
| 129 |
+
renderer.render_axis(45, 0, 0, 1, "./test_demo_output/render_1_dir_azi45.png")
|
| 130 |
+
renderer.render_axis(0, 45, 0, 2, "./test_demo_output/render_2_dir_ele45.png")
|
| 131 |
+
renderer.render_axis(0, 0, 45, 4, "./test_demo_output/render_4_dir_rot45.png")
|
| 132 |
+
# renderer.render_1_dir()
|
| 133 |
+
# renderer.render_2_dir()
|
| 134 |
+
# renderer.render_4_dir()
|
| 135 |
+
|
| 136 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from app_utils import *
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torchvision import transforms as TF
|
| 7 |
+
|
| 8 |
+
from scipy.special import i0
|
| 9 |
+
from scipy.optimize import curve_fit
|
| 10 |
+
from scipy.integrate import trapezoid
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
def von_mises_pdf_alpha_numpy(alpha, x, mu, kappa):
|
| 14 |
+
normalization = 2 * np.pi
|
| 15 |
+
pdf = np.exp(kappa * np.cos(alpha * (x - mu))) / normalization
|
| 16 |
+
return pdf
|
| 17 |
+
|
| 18 |
+
def val_fit_alpha(distribute):
|
| 19 |
+
fit_alphas = []
|
| 20 |
+
for y_noise in distribute:
|
| 21 |
+
x = np.linspace(0, 2 * np.pi, 360)
|
| 22 |
+
y_noise /= trapezoid(y_noise, x) + 1e-8
|
| 23 |
+
|
| 24 |
+
initial_guess = [x[np.argmax(y_noise)], 1]
|
| 25 |
+
|
| 26 |
+
# support 1,2,4
|
| 27 |
+
alphas = [1.0, 2.0, 4.0]
|
| 28 |
+
saved_params = []
|
| 29 |
+
saved_r_squared = []
|
| 30 |
+
|
| 31 |
+
for alpha in alphas:
|
| 32 |
+
try:
|
| 33 |
+
von_mises_pdf_alpha_partial = partial(von_mises_pdf_alpha_numpy, alpha)
|
| 34 |
+
params, covariance = curve_fit(von_mises_pdf_alpha_partial, x, y_noise, p0=initial_guess)
|
| 35 |
+
|
| 36 |
+
residuals = y_noise - von_mises_pdf_alpha_partial(x, *params)
|
| 37 |
+
ss_res = np.sum(residuals**2)
|
| 38 |
+
ss_tot = np.sum((y_noise - np.mean(y_noise))**2)
|
| 39 |
+
r_squared = 1 - (ss_res / (ss_tot+1e-8))
|
| 40 |
+
|
| 41 |
+
saved_params.append(params)
|
| 42 |
+
saved_r_squared.append(r_squared)
|
| 43 |
+
if r_squared > 0.8:
|
| 44 |
+
break
|
| 45 |
+
except:
|
| 46 |
+
saved_params.append((0.,0.))
|
| 47 |
+
saved_r_squared.append(0.)
|
| 48 |
+
|
| 49 |
+
max_index = np.argmax(saved_r_squared)
|
| 50 |
+
alpha = alphas[max_index]
|
| 51 |
+
mu_fit, kappa_fit = saved_params[max_index]
|
| 52 |
+
r_squared = saved_r_squared[max_index]
|
| 53 |
+
|
| 54 |
+
if alpha == 1. and kappa_fit>=0.5 and r_squared>=0.5:
|
| 55 |
+
pass
|
| 56 |
+
elif alpha == 2. and kappa_fit>=0.35 and r_squared>=0.35:
|
| 57 |
+
pass
|
| 58 |
+
elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.25:
|
| 59 |
+
pass
|
| 60 |
+
else:
|
| 61 |
+
alpha=0.
|
| 62 |
+
fit_alphas.append(alpha)
|
| 63 |
+
return torch.tensor(fit_alphas)
|
| 64 |
+
|
| 65 |
+
def preprocess_images(image_list, mode="crop"):
|
| 66 |
+
|
| 67 |
+
# Check for empty list
|
| 68 |
+
if len(image_list) == 0:
|
| 69 |
+
raise ValueError("At least 1 image is required")
|
| 70 |
+
|
| 71 |
+
# Validate mode
|
| 72 |
+
if mode not in ["crop", "pad"]:
|
| 73 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
| 74 |
+
|
| 75 |
+
images = []
|
| 76 |
+
shapes = set()
|
| 77 |
+
to_tensor = TF.ToTensor()
|
| 78 |
+
target_size = 518
|
| 79 |
+
|
| 80 |
+
# First process all images and collect their shapes
|
| 81 |
+
# for image_path in image_path_list:
|
| 82 |
+
for img in image_list:
|
| 83 |
+
# If there's an alpha channel, blend onto white background:
|
| 84 |
+
if img.mode == "RGBA":
|
| 85 |
+
# Create white background
|
| 86 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
| 87 |
+
# Alpha composite onto the white background
|
| 88 |
+
img = Image.alpha_composite(background, img)
|
| 89 |
+
|
| 90 |
+
# Now convert to "RGB" (this step assigns white for transparent areas)
|
| 91 |
+
img = img.convert("RGB")
|
| 92 |
+
width, height = img.size
|
| 93 |
+
|
| 94 |
+
if mode == "pad":
|
| 95 |
+
# Make the largest dimension 518px while maintaining aspect ratio
|
| 96 |
+
if width >= height:
|
| 97 |
+
new_width = target_size
|
| 98 |
+
new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
|
| 99 |
+
else:
|
| 100 |
+
new_height = target_size
|
| 101 |
+
new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
|
| 102 |
+
else: # mode == "crop"
|
| 103 |
+
# Original behavior: set width to 518px
|
| 104 |
+
new_width = target_size
|
| 105 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
| 106 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
| 107 |
+
|
| 108 |
+
# Resize with new dimensions (width, height)
|
| 109 |
+
try:
|
| 110 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
| 111 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(e)
|
| 114 |
+
print(width, height)
|
| 115 |
+
print(new_width, new_height)
|
| 116 |
+
assert False
|
| 117 |
+
|
| 118 |
+
# Center crop height if it's larger than 518 (only in crop mode)
|
| 119 |
+
if mode == "crop" and new_height > target_size:
|
| 120 |
+
start_y = (new_height - target_size) // 2
|
| 121 |
+
img = img[:, start_y : start_y + target_size, :]
|
| 122 |
+
|
| 123 |
+
# For pad mode, pad to make a square of target_size x target_size
|
| 124 |
+
if mode == "pad":
|
| 125 |
+
h_padding = target_size - img.shape[1]
|
| 126 |
+
w_padding = target_size - img.shape[2]
|
| 127 |
+
|
| 128 |
+
if h_padding > 0 or w_padding > 0:
|
| 129 |
+
pad_top = h_padding // 2
|
| 130 |
+
pad_bottom = h_padding - pad_top
|
| 131 |
+
pad_left = w_padding // 2
|
| 132 |
+
pad_right = w_padding - pad_left
|
| 133 |
+
|
| 134 |
+
# Pad with white (value=1.0)
|
| 135 |
+
img = torch.nn.functional.pad(
|
| 136 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
shapes.add((img.shape[1], img.shape[2]))
|
| 140 |
+
images.append(img)
|
| 141 |
+
|
| 142 |
+
# Check if we have different shapes
|
| 143 |
+
# In theory our model can also work well with different shapes
|
| 144 |
+
if len(shapes) > 1:
|
| 145 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
| 146 |
+
# Find maximum dimensions
|
| 147 |
+
max_height = max(shape[0] for shape in shapes)
|
| 148 |
+
max_width = max(shape[1] for shape in shapes)
|
| 149 |
+
|
| 150 |
+
# Pad images if necessary
|
| 151 |
+
padded_images = []
|
| 152 |
+
for img in images:
|
| 153 |
+
h_padding = max_height - img.shape[1]
|
| 154 |
+
w_padding = max_width - img.shape[2]
|
| 155 |
+
|
| 156 |
+
if h_padding > 0 or w_padding > 0:
|
| 157 |
+
pad_top = h_padding // 2
|
| 158 |
+
pad_bottom = h_padding - pad_top
|
| 159 |
+
pad_left = w_padding // 2
|
| 160 |
+
pad_right = w_padding - pad_left
|
| 161 |
+
|
| 162 |
+
img = torch.nn.functional.pad(
|
| 163 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 164 |
+
)
|
| 165 |
+
padded_images.append(img)
|
| 166 |
+
images = padded_images
|
| 167 |
+
|
| 168 |
+
images = torch.stack(images) # concatenate images
|
| 169 |
+
|
| 170 |
+
# Ensure correct shape when single image
|
| 171 |
+
if len(image_list) == 1:
|
| 172 |
+
# Verify shape is (1, C, H, W)
|
| 173 |
+
if images.dim() == 3:
|
| 174 |
+
images = images.unsqueeze(0)
|
| 175 |
+
|
| 176 |
+
return images
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def inf_single_batch(model, batch):
|
| 180 |
+
device = model.get_device()
|
| 181 |
+
batch_img_inputs = batch # (B, S, 3, H, W)
|
| 182 |
+
# print(batch_img_inputs.shape)
|
| 183 |
+
B, S, C, H, W = batch_img_inputs.shape
|
| 184 |
+
pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
|
| 185 |
+
|
| 186 |
+
pose_enc = pose_enc.view(B*S, -1)
|
| 187 |
+
angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1)
|
| 188 |
+
angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
|
| 189 |
+
angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
|
| 190 |
+
|
| 191 |
+
# ori_val
|
| 192 |
+
# trained with BCE loss
|
| 193 |
+
distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
|
| 194 |
+
# trained with CE loss
|
| 195 |
+
# distribute = pose_enc[:, 0:360].cpu().float().numpy()
|
| 196 |
+
alpha_pred = val_fit_alpha(distribute = distribute)
|
| 197 |
+
|
| 198 |
+
# ref_val
|
| 199 |
+
if S > 1:
|
| 200 |
+
ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
|
| 201 |
+
ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
|
| 202 |
+
ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
|
| 203 |
+
ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
|
| 204 |
+
rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
|
| 205 |
+
rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
|
| 206 |
+
rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
|
| 207 |
+
else:
|
| 208 |
+
ref_az_pred = angle_az_pred[0]
|
| 209 |
+
ref_el_pred = angle_el_pred[0]
|
| 210 |
+
ref_ro_pred = angle_ro_pred[0]
|
| 211 |
+
ref_alpha_pred = alpha_pred[0]
|
| 212 |
+
rel_az_pred = 0.
|
| 213 |
+
rel_el_pred = 0.
|
| 214 |
+
rel_ro_pred = 0.
|
| 215 |
+
|
| 216 |
+
ans_dict = {
|
| 217 |
+
'ref_az_pred': ref_az_pred,
|
| 218 |
+
'ref_el_pred': ref_el_pred,
|
| 219 |
+
'ref_ro_pred': ref_ro_pred,
|
| 220 |
+
'ref_alpha_pred' : ref_alpha_pred,
|
| 221 |
+
'rel_az_pred' : rel_az_pred,
|
| 222 |
+
'rel_el_pred' : rel_el_pred,
|
| 223 |
+
'rel_ro_pred' : rel_ro_pred,
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return ans_dict
|
| 227 |
+
|
| 228 |
+
# input PIL Image
|
| 229 |
+
@torch.no_grad()
|
| 230 |
+
def inf_single_case(model, image_ref, image_tgt):
|
| 231 |
+
if image_tgt is None:
|
| 232 |
+
image_list = [image_ref]
|
| 233 |
+
else:
|
| 234 |
+
image_list = [image_ref, image_tgt]
|
| 235 |
+
image_tensors = preprocess_images(image_list, mode="pad").to(model.get_device())
|
| 236 |
+
ans_dict = inf_single_batch(model=model, batch=image_tensors.unsqueeze(0))
|
| 237 |
+
print(ans_dict)
|
| 238 |
+
return ans_dict
|
orianyV2_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
paths.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DINO_SMALL = "facebook/dinov2-small"
|
| 2 |
+
DINO_BASE = "facebook/dinov2-base"
|
| 3 |
+
DINO_LARGE = "facebook/dinov2-large"
|
| 4 |
+
DINO_GIANT = "facebook/dinov2-giant"
|
| 5 |
+
|
| 6 |
+
VGGT_1B = "facebook/VGGT-1B"
|
| 7 |
+
|
| 8 |
+
ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
|
| 9 |
+
|
| 10 |
+
REMOTE_CKPT_PATH = "demo_ckpts/acc8mask20lowlr.pt"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
RENDER_FILE = "assets/axis_render.blend"
|
| 14 |
+
REF_AXIS_IMAGE = "assets/axis_ref.png"
|
| 15 |
+
TGT_AXIS_IMAGE = "assets/axis_tgt.png"
|
| 16 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
pydantic==2.10.6
|
| 3 |
+
gradio==5.9.0
|
| 4 |
+
onnxruntime
|
| 5 |
+
rembg
|
| 6 |
+
accelerate==1.8.1
|
| 7 |
+
numpy>=1.24
|
| 8 |
+
einops
|
| 9 |
+
pandas
|
| 10 |
+
pillow
|
| 11 |
+
huggingface_hub>=0.23
|
| 12 |
+
pytorch-lightning
|
| 13 |
+
scipy
|
| 14 |
+
torch
|
| 15 |
+
torchmetrics
|
| 16 |
+
torchvision
|
| 17 |
+
tqdm
|
| 18 |
+
transformers
|
| 19 |
+
scikit-learn
|
| 20 |
+
opencv-python
|
| 21 |
+
timm
|
| 22 |
+
bpy==4.2
|
vggt/heads/camera_head.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from vggt.layers import Mlp
|
| 15 |
+
from vggt.layers.block import Block
|
| 16 |
+
from vggt.heads.head_act import activate_pose
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 22 |
+
|
| 23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int = 2048,
|
| 29 |
+
trunk_depth: int = 4,
|
| 30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 31 |
+
num_heads: int = 16,
|
| 32 |
+
mlp_ratio: int = 4,
|
| 33 |
+
init_values: float = 0.01,
|
| 34 |
+
trans_act: str = "linear",
|
| 35 |
+
quat_act: str = "linear",
|
| 36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 41 |
+
self.target_dim = 9
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 44 |
+
|
| 45 |
+
self.trans_act = trans_act
|
| 46 |
+
self.quat_act = quat_act
|
| 47 |
+
self.fl_act = fl_act
|
| 48 |
+
self.trunk_depth = trunk_depth
|
| 49 |
+
|
| 50 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 51 |
+
self.trunk = nn.Sequential(
|
| 52 |
+
*[
|
| 53 |
+
Block(
|
| 54 |
+
dim=dim_in,
|
| 55 |
+
num_heads=num_heads,
|
| 56 |
+
mlp_ratio=mlp_ratio,
|
| 57 |
+
init_values=init_values,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(trunk_depth)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Normalizations for camera token and trunk output.
|
| 64 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 65 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 66 |
+
|
| 67 |
+
# Learnable empty camera pose token.
|
| 68 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 69 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 70 |
+
|
| 71 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 72 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 73 |
+
|
| 74 |
+
# Adaptive layer normalization without affine parameters.
|
| 75 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 76 |
+
self.pose_branch = Mlp(
|
| 77 |
+
in_features=dim_in,
|
| 78 |
+
hidden_features=dim_in // 2,
|
| 79 |
+
out_features=self.target_dim,
|
| 80 |
+
drop=0,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 84 |
+
"""
|
| 85 |
+
Forward pass to predict camera parameters.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 89 |
+
the last tensor is used for prediction.
|
| 90 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 94 |
+
"""
|
| 95 |
+
# Use tokens from the last block for camera prediction.
|
| 96 |
+
tokens = aggregated_tokens_list[-1]
|
| 97 |
+
|
| 98 |
+
# Extract the camera tokens
|
| 99 |
+
pose_tokens = tokens[:, :, 0]
|
| 100 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 101 |
+
|
| 102 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 103 |
+
return pred_pose_enc_list
|
| 104 |
+
|
| 105 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 106 |
+
"""
|
| 107 |
+
Iteratively refine camera pose predictions.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 111 |
+
num_iterations (int): Number of refinement iterations.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
list: List of activated camera encodings from each iteration.
|
| 115 |
+
"""
|
| 116 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 117 |
+
pred_pose_enc = None
|
| 118 |
+
pred_pose_enc_list = []
|
| 119 |
+
|
| 120 |
+
for _ in range(num_iterations):
|
| 121 |
+
# Use a learned empty pose for the first iteration.
|
| 122 |
+
if pred_pose_enc is None:
|
| 123 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 124 |
+
else:
|
| 125 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 126 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 127 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 128 |
+
|
| 129 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 130 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 131 |
+
|
| 132 |
+
# Adaptive layer normalization and modulation.
|
| 133 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 134 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 135 |
+
|
| 136 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 137 |
+
# Compute the delta update for the pose encoding.
|
| 138 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 139 |
+
|
| 140 |
+
if pred_pose_enc is None:
|
| 141 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 142 |
+
else:
|
| 143 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 144 |
+
|
| 145 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 146 |
+
activated_pose = activate_pose(
|
| 147 |
+
pred_pose_enc,
|
| 148 |
+
trans_act=self.trans_act,
|
| 149 |
+
quat_act=self.quat_act,
|
| 150 |
+
fl_act=self.fl_act,
|
| 151 |
+
)
|
| 152 |
+
pred_pose_enc_list.append(activated_pose)
|
| 153 |
+
|
| 154 |
+
return pred_pose_enc_list
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 158 |
+
"""
|
| 159 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 160 |
+
"""
|
| 161 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 162 |
+
return x * (1 + scale) + shift
|
vggt/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .head_act import activate_head
|
| 18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DPTHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
DPT Head for dense prediction tasks.
|
| 24 |
+
|
| 25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim_in (int): Input dimension (channels).
|
| 31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim_in: int,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 4,
|
| 48 |
+
activation: str = "inv_log",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 53 |
+
pos_embed: bool = True,
|
| 54 |
+
feature_only: bool = False,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
) -> None:
|
| 57 |
+
super(DPTHead, self).__init__()
|
| 58 |
+
self.patch_size = patch_size
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.conf_activation = conf_activation
|
| 61 |
+
self.pos_embed = pos_embed
|
| 62 |
+
self.feature_only = feature_only
|
| 63 |
+
self.down_ratio = down_ratio
|
| 64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
+
|
| 66 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
+
|
| 68 |
+
# Projection layers for each output channel from tokens.
|
| 69 |
+
self.projects = nn.ModuleList(
|
| 70 |
+
[
|
| 71 |
+
nn.Conv2d(
|
| 72 |
+
in_channels=dim_in,
|
| 73 |
+
out_channels=oc,
|
| 74 |
+
kernel_size=1,
|
| 75 |
+
stride=1,
|
| 76 |
+
padding=0,
|
| 77 |
+
)
|
| 78 |
+
for oc in out_channels
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Resize layers for upsampling feature maps.
|
| 83 |
+
self.resize_layers = nn.ModuleList(
|
| 84 |
+
[
|
| 85 |
+
nn.ConvTranspose2d(
|
| 86 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 87 |
+
),
|
| 88 |
+
nn.ConvTranspose2d(
|
| 89 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 90 |
+
),
|
| 91 |
+
nn.Identity(),
|
| 92 |
+
nn.Conv2d(
|
| 93 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 94 |
+
),
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.scratch = _make_scratch(
|
| 99 |
+
out_channels,
|
| 100 |
+
features,
|
| 101 |
+
expand=False,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Attach additional modules to scratch.
|
| 105 |
+
self.scratch.stem_transpose = None
|
| 106 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 107 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 108 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 109 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 110 |
+
|
| 111 |
+
head_features_1 = features
|
| 112 |
+
head_features_2 = 32
|
| 113 |
+
|
| 114 |
+
if feature_only:
|
| 115 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 116 |
+
else:
|
| 117 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 118 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 119 |
+
)
|
| 120 |
+
conv2_in_channels = head_features_1 // 2
|
| 121 |
+
|
| 122 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 123 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 124 |
+
nn.ReLU(inplace=True),
|
| 125 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 131 |
+
images: torch.Tensor,
|
| 132 |
+
patch_start_idx: int,
|
| 133 |
+
frames_chunk_size: int = 8,
|
| 134 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 135 |
+
"""
|
| 136 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 137 |
+
Args:
|
| 138 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 139 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 140 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 141 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 142 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 143 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 147 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 148 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 149 |
+
"""
|
| 150 |
+
B, S, _, H, W = images.shape
|
| 151 |
+
|
| 152 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 153 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 154 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 155 |
+
|
| 156 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 157 |
+
assert frames_chunk_size > 0
|
| 158 |
+
|
| 159 |
+
# Process frames in batches
|
| 160 |
+
all_preds = []
|
| 161 |
+
all_conf = []
|
| 162 |
+
|
| 163 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 164 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 165 |
+
|
| 166 |
+
# Process batch of frames
|
| 167 |
+
if self.feature_only:
|
| 168 |
+
chunk_output = self._forward_impl(
|
| 169 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 170 |
+
)
|
| 171 |
+
all_preds.append(chunk_output)
|
| 172 |
+
else:
|
| 173 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 174 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 175 |
+
)
|
| 176 |
+
all_preds.append(chunk_preds)
|
| 177 |
+
all_conf.append(chunk_conf)
|
| 178 |
+
|
| 179 |
+
# Concatenate results along the sequence dimension
|
| 180 |
+
if self.feature_only:
|
| 181 |
+
return torch.cat(all_preds, dim=1)
|
| 182 |
+
else:
|
| 183 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 184 |
+
|
| 185 |
+
def _forward_impl(
|
| 186 |
+
self,
|
| 187 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 188 |
+
images: torch.Tensor,
|
| 189 |
+
patch_start_idx: int,
|
| 190 |
+
frames_start_idx: int = None,
|
| 191 |
+
frames_end_idx: int = None,
|
| 192 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 193 |
+
"""
|
| 194 |
+
Implementation of the forward pass through the DPT head.
|
| 195 |
+
|
| 196 |
+
This method processes a specific chunk of frames from the sequence.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 200 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 201 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 202 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 203 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 207 |
+
"""
|
| 208 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 209 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 210 |
+
|
| 211 |
+
B, S, _, H, W = images.shape
|
| 212 |
+
|
| 213 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 214 |
+
|
| 215 |
+
out = []
|
| 216 |
+
dpt_idx = 0
|
| 217 |
+
|
| 218 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 219 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 220 |
+
|
| 221 |
+
# Select frames if processing a chunk
|
| 222 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 223 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 224 |
+
|
| 225 |
+
x = x.view(B * S, -1, x.shape[-1])
|
| 226 |
+
|
| 227 |
+
x = self.norm(x)
|
| 228 |
+
|
| 229 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 230 |
+
|
| 231 |
+
x = self.projects[dpt_idx](x)
|
| 232 |
+
if self.pos_embed:
|
| 233 |
+
x = self._apply_pos_embed(x, W, H)
|
| 234 |
+
x = self.resize_layers[dpt_idx](x)
|
| 235 |
+
|
| 236 |
+
out.append(x)
|
| 237 |
+
dpt_idx += 1
|
| 238 |
+
|
| 239 |
+
# Fuse features from multiple layers.
|
| 240 |
+
out = self.scratch_forward(out)
|
| 241 |
+
# Interpolate fused output to match target image resolution.
|
| 242 |
+
out = custom_interpolate(
|
| 243 |
+
out,
|
| 244 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
| 245 |
+
mode="bilinear",
|
| 246 |
+
align_corners=True,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if self.pos_embed:
|
| 250 |
+
out = self._apply_pos_embed(out, W, H)
|
| 251 |
+
|
| 252 |
+
if self.feature_only:
|
| 253 |
+
return out.view(B, S, *out.shape[1:])
|
| 254 |
+
|
| 255 |
+
out = self.scratch.output_conv2(out)
|
| 256 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
| 257 |
+
|
| 258 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 259 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 260 |
+
return preds, conf
|
| 261 |
+
|
| 262 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 263 |
+
"""
|
| 264 |
+
Apply positional embedding to tensor x.
|
| 265 |
+
"""
|
| 266 |
+
patch_w = x.shape[-1]
|
| 267 |
+
patch_h = x.shape[-2]
|
| 268 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 269 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 270 |
+
pos_embed = pos_embed * ratio
|
| 271 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 272 |
+
return x + pos_embed
|
| 273 |
+
|
| 274 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 275 |
+
"""
|
| 276 |
+
Forward pass through the fusion blocks.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Tensor: Fused feature map.
|
| 283 |
+
"""
|
| 284 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 285 |
+
|
| 286 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 287 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 288 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 289 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 290 |
+
|
| 291 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 292 |
+
del layer_4_rn, layer_4
|
| 293 |
+
|
| 294 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 295 |
+
del layer_3_rn, layer_3
|
| 296 |
+
|
| 297 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 298 |
+
del layer_2_rn, layer_2
|
| 299 |
+
|
| 300 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 301 |
+
del layer_1_rn, layer_1
|
| 302 |
+
|
| 303 |
+
out = self.scratch.output_conv1(out)
|
| 304 |
+
return out
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
################################################################################
|
| 308 |
+
# Modules
|
| 309 |
+
################################################################################
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 313 |
+
return FeatureFusionBlock(
|
| 314 |
+
features,
|
| 315 |
+
nn.ReLU(inplace=True),
|
| 316 |
+
deconv=False,
|
| 317 |
+
bn=False,
|
| 318 |
+
expand=False,
|
| 319 |
+
align_corners=True,
|
| 320 |
+
size=size,
|
| 321 |
+
has_residual=has_residual,
|
| 322 |
+
groups=groups,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 327 |
+
scratch = nn.Module()
|
| 328 |
+
out_shape1 = out_shape
|
| 329 |
+
out_shape2 = out_shape
|
| 330 |
+
out_shape3 = out_shape
|
| 331 |
+
if len(in_shape) >= 4:
|
| 332 |
+
out_shape4 = out_shape
|
| 333 |
+
|
| 334 |
+
if expand:
|
| 335 |
+
out_shape1 = out_shape
|
| 336 |
+
out_shape2 = out_shape * 2
|
| 337 |
+
out_shape3 = out_shape * 4
|
| 338 |
+
if len(in_shape) >= 4:
|
| 339 |
+
out_shape4 = out_shape * 8
|
| 340 |
+
|
| 341 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 342 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 343 |
+
)
|
| 344 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 345 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 346 |
+
)
|
| 347 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 348 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 349 |
+
)
|
| 350 |
+
if len(in_shape) >= 4:
|
| 351 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 352 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 353 |
+
)
|
| 354 |
+
return scratch
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class ResidualConvUnit(nn.Module):
|
| 358 |
+
"""Residual convolution module."""
|
| 359 |
+
|
| 360 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 361 |
+
"""Init.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
features (int): number of features
|
| 365 |
+
"""
|
| 366 |
+
super().__init__()
|
| 367 |
+
|
| 368 |
+
self.bn = bn
|
| 369 |
+
self.groups = groups
|
| 370 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 371 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 372 |
+
|
| 373 |
+
self.norm1 = None
|
| 374 |
+
self.norm2 = None
|
| 375 |
+
|
| 376 |
+
self.activation = activation
|
| 377 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 378 |
+
|
| 379 |
+
def forward(self, x):
|
| 380 |
+
"""Forward pass.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
x (tensor): input
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
tensor: output
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
out = self.activation(x)
|
| 390 |
+
out = self.conv1(out)
|
| 391 |
+
if self.norm1 is not None:
|
| 392 |
+
out = self.norm1(out)
|
| 393 |
+
|
| 394 |
+
out = self.activation(out)
|
| 395 |
+
out = self.conv2(out)
|
| 396 |
+
if self.norm2 is not None:
|
| 397 |
+
out = self.norm2(out)
|
| 398 |
+
|
| 399 |
+
return self.skip_add.add(out, x)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class FeatureFusionBlock(nn.Module):
|
| 403 |
+
"""Feature fusion block."""
|
| 404 |
+
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
features,
|
| 408 |
+
activation,
|
| 409 |
+
deconv=False,
|
| 410 |
+
bn=False,
|
| 411 |
+
expand=False,
|
| 412 |
+
align_corners=True,
|
| 413 |
+
size=None,
|
| 414 |
+
has_residual=True,
|
| 415 |
+
groups=1,
|
| 416 |
+
):
|
| 417 |
+
"""Init.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
features (int): number of features
|
| 421 |
+
"""
|
| 422 |
+
super(FeatureFusionBlock, self).__init__()
|
| 423 |
+
|
| 424 |
+
self.deconv = deconv
|
| 425 |
+
self.align_corners = align_corners
|
| 426 |
+
self.groups = groups
|
| 427 |
+
self.expand = expand
|
| 428 |
+
out_features = features
|
| 429 |
+
if self.expand == True:
|
| 430 |
+
out_features = features // 2
|
| 431 |
+
|
| 432 |
+
self.out_conv = nn.Conv2d(
|
| 433 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
if has_residual:
|
| 437 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 438 |
+
|
| 439 |
+
self.has_residual = has_residual
|
| 440 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 441 |
+
|
| 442 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 443 |
+
self.size = size
|
| 444 |
+
|
| 445 |
+
def forward(self, *xs, size=None):
|
| 446 |
+
"""Forward pass.
|
| 447 |
+
|
| 448 |
+
Returns:
|
| 449 |
+
tensor: output
|
| 450 |
+
"""
|
| 451 |
+
output = xs[0]
|
| 452 |
+
|
| 453 |
+
if self.has_residual:
|
| 454 |
+
res = self.resConfUnit1(xs[1])
|
| 455 |
+
output = self.skip_add.add(output, res)
|
| 456 |
+
|
| 457 |
+
output = self.resConfUnit2(output)
|
| 458 |
+
|
| 459 |
+
if (size is None) and (self.size is None):
|
| 460 |
+
modifier = {"scale_factor": 2}
|
| 461 |
+
elif size is None:
|
| 462 |
+
modifier = {"size": self.size}
|
| 463 |
+
else:
|
| 464 |
+
modifier = {"size": size}
|
| 465 |
+
|
| 466 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 467 |
+
output = self.out_conv(output)
|
| 468 |
+
|
| 469 |
+
return output
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def custom_interpolate(
|
| 473 |
+
x: torch.Tensor,
|
| 474 |
+
size: Tuple[int, int] = None,
|
| 475 |
+
scale_factor: float = None,
|
| 476 |
+
mode: str = "bilinear",
|
| 477 |
+
align_corners: bool = True,
|
| 478 |
+
) -> torch.Tensor:
|
| 479 |
+
"""
|
| 480 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 481 |
+
"""
|
| 482 |
+
if size is None:
|
| 483 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 484 |
+
|
| 485 |
+
INT_MAX = 1610612736
|
| 486 |
+
|
| 487 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 488 |
+
|
| 489 |
+
if input_elements > INT_MAX:
|
| 490 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 491 |
+
interpolated_chunks = [
|
| 492 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 493 |
+
]
|
| 494 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 495 |
+
return x.contiguous()
|
| 496 |
+
else:
|
| 497 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
vggt/heads/head_act.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
| 13 |
+
"""
|
| 14 |
+
Activate pose parameters with specified activation functions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 18 |
+
trans_act: Activation type for translation component
|
| 19 |
+
quat_act: Activation type for quaternion component
|
| 20 |
+
fl_act: Activation type for focal length component
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Activated pose parameters tensor
|
| 24 |
+
"""
|
| 25 |
+
T = pred_pose_enc[..., :3]
|
| 26 |
+
quat = pred_pose_enc[..., 3:7]
|
| 27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
| 28 |
+
|
| 29 |
+
T = base_pose_act(T, trans_act)
|
| 30 |
+
quat = base_pose_act(quat, quat_act)
|
| 31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
| 32 |
+
|
| 33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 34 |
+
|
| 35 |
+
return pred_pose_enc
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 39 |
+
"""
|
| 40 |
+
Apply basic activation function to pose parameters.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Activated pose parameters
|
| 48 |
+
"""
|
| 49 |
+
if act_type == "linear":
|
| 50 |
+
return pose_enc
|
| 51 |
+
elif act_type == "inv_log":
|
| 52 |
+
return inverse_log_transform(pose_enc)
|
| 53 |
+
elif act_type == "exp":
|
| 54 |
+
return torch.exp(pose_enc)
|
| 55 |
+
elif act_type == "relu":
|
| 56 |
+
return F.relu(pose_enc)
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 62 |
+
"""
|
| 63 |
+
Process network output to extract 3D points and confidence values.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
out: Network output tensor (B, C, H, W)
|
| 67 |
+
activation: Activation type for 3D points
|
| 68 |
+
conf_activation: Activation type for confidence values
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 72 |
+
"""
|
| 73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 75 |
+
|
| 76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 77 |
+
xyz = fmap[:, :, :, :-1]
|
| 78 |
+
conf = fmap[:, :, :, -1]
|
| 79 |
+
|
| 80 |
+
if activation == "norm_exp":
|
| 81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 82 |
+
xyz_normed = xyz / d
|
| 83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 84 |
+
elif activation == "norm":
|
| 85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 86 |
+
elif activation == "exp":
|
| 87 |
+
pts3d = torch.exp(xyz)
|
| 88 |
+
elif activation == "relu":
|
| 89 |
+
pts3d = F.relu(xyz)
|
| 90 |
+
elif activation == "inv_log":
|
| 91 |
+
pts3d = inverse_log_transform(xyz)
|
| 92 |
+
elif activation == "xy_inv_log":
|
| 93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 94 |
+
z = inverse_log_transform(z)
|
| 95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 96 |
+
elif activation == "sigmoid":
|
| 97 |
+
pts3d = torch.sigmoid(xyz)
|
| 98 |
+
elif activation == "linear":
|
| 99 |
+
pts3d = xyz
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 102 |
+
|
| 103 |
+
if conf_activation == "expp1":
|
| 104 |
+
conf_out = 1 + conf.exp()
|
| 105 |
+
elif conf_activation == "expp0":
|
| 106 |
+
conf_out = conf.exp()
|
| 107 |
+
elif conf_activation == "sigmoid":
|
| 108 |
+
conf_out = torch.sigmoid(conf)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 111 |
+
|
| 112 |
+
return pts3d, conf_out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def inverse_log_transform(y):
|
| 116 |
+
"""
|
| 117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
y: Input tensor
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Transformed tensor
|
| 124 |
+
"""
|
| 125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
vggt/heads/track_head.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from .dpt_head import DPTHead
|
| 9 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TrackHead(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
| 15 |
+
The tracking is performed iteratively, refining predictions over multiple iterations.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim_in,
|
| 21 |
+
patch_size=14,
|
| 22 |
+
features=128,
|
| 23 |
+
iters=4,
|
| 24 |
+
predict_conf=True,
|
| 25 |
+
stride=2,
|
| 26 |
+
corr_levels=7,
|
| 27 |
+
corr_radius=4,
|
| 28 |
+
hidden_size=384,
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the TrackHead module.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dim_in (int): Input dimension of tokens from the backbone.
|
| 35 |
+
patch_size (int): Size of image patches used in the vision transformer.
|
| 36 |
+
features (int): Number of feature channels in the feature extractor output.
|
| 37 |
+
iters (int): Number of refinement iterations for tracking predictions.
|
| 38 |
+
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
| 39 |
+
stride (int): Stride value for the tracker predictor.
|
| 40 |
+
corr_levels (int): Number of correlation pyramid levels
|
| 41 |
+
corr_radius (int): Radius for correlation computation, controlling the search area.
|
| 42 |
+
hidden_size (int): Size of hidden layers in the tracker network.
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.patch_size = patch_size
|
| 47 |
+
|
| 48 |
+
# Feature extractor based on DPT architecture
|
| 49 |
+
# Processes tokens into feature maps for tracking
|
| 50 |
+
self.feature_extractor = DPTHead(
|
| 51 |
+
dim_in=dim_in,
|
| 52 |
+
patch_size=patch_size,
|
| 53 |
+
features=features,
|
| 54 |
+
feature_only=True, # Only output features, no activation
|
| 55 |
+
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
| 56 |
+
pos_embed=False,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Tracker module that predicts point trajectories
|
| 60 |
+
# Takes feature maps and predicts coordinates and visibility
|
| 61 |
+
self.tracker = BaseTrackerPredictor(
|
| 62 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
| 63 |
+
predict_conf=predict_conf,
|
| 64 |
+
stride=stride,
|
| 65 |
+
corr_levels=corr_levels,
|
| 66 |
+
corr_radius=corr_radius,
|
| 67 |
+
hidden_size=hidden_size,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.iters = iters
|
| 71 |
+
|
| 72 |
+
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
|
| 73 |
+
"""
|
| 74 |
+
Forward pass of the TrackHead.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
| 78 |
+
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
| 79 |
+
B = batch size, S = sequence length.
|
| 80 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 81 |
+
query_points (torch.Tensor, optional): Initial query points to track.
|
| 82 |
+
If None, points are initialized by the tracker.
|
| 83 |
+
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
tuple:
|
| 87 |
+
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
| 88 |
+
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
| 89 |
+
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
| 90 |
+
"""
|
| 91 |
+
B, S, _, H, W = images.shape
|
| 92 |
+
|
| 93 |
+
# Extract features from tokens
|
| 94 |
+
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
| 95 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
|
| 96 |
+
|
| 97 |
+
# Use default iterations if not specified
|
| 98 |
+
if iters is None:
|
| 99 |
+
iters = self.iters
|
| 100 |
+
|
| 101 |
+
# Perform tracking using the extracted features
|
| 102 |
+
coord_preds, vis_scores, conf_scores = self.tracker(
|
| 103 |
+
query_points=query_points,
|
| 104 |
+
fmaps=feature_maps,
|
| 105 |
+
iters=iters,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return coord_preds, vis_scores, conf_scores
|
vggt/heads/track_modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
vggt/heads/track_modules/base_track_predictor.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from .blocks import EfficientUpdateFormer, CorrBlock
|
| 13 |
+
from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
|
| 14 |
+
from .modules import Mlp
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseTrackerPredictor(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
stride=1,
|
| 21 |
+
corr_levels=5,
|
| 22 |
+
corr_radius=4,
|
| 23 |
+
latent_dim=128,
|
| 24 |
+
hidden_size=384,
|
| 25 |
+
use_spaceatt=True,
|
| 26 |
+
depth=6,
|
| 27 |
+
max_scale=518,
|
| 28 |
+
predict_conf=True,
|
| 29 |
+
):
|
| 30 |
+
super(BaseTrackerPredictor, self).__init__()
|
| 31 |
+
"""
|
| 32 |
+
The base template to create a track predictor
|
| 33 |
+
|
| 34 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
| 35 |
+
and https://github.com/facebookresearch/vggsfm
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
self.stride = stride
|
| 39 |
+
self.latent_dim = latent_dim
|
| 40 |
+
self.corr_levels = corr_levels
|
| 41 |
+
self.corr_radius = corr_radius
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.max_scale = max_scale
|
| 44 |
+
self.predict_conf = predict_conf
|
| 45 |
+
|
| 46 |
+
self.flows_emb_dim = latent_dim // 2
|
| 47 |
+
|
| 48 |
+
self.corr_mlp = Mlp(
|
| 49 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
| 50 |
+
hidden_features=self.hidden_size,
|
| 51 |
+
out_features=self.latent_dim,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
| 55 |
+
|
| 56 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
| 57 |
+
|
| 58 |
+
space_depth = depth if use_spaceatt else 0
|
| 59 |
+
time_depth = depth
|
| 60 |
+
|
| 61 |
+
self.updateformer = EfficientUpdateFormer(
|
| 62 |
+
space_depth=space_depth,
|
| 63 |
+
time_depth=time_depth,
|
| 64 |
+
input_dim=self.transformer_dim,
|
| 65 |
+
hidden_size=self.hidden_size,
|
| 66 |
+
output_dim=self.latent_dim + 2,
|
| 67 |
+
mlp_ratio=4.0,
|
| 68 |
+
add_space_attn=use_spaceatt,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
| 72 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
| 73 |
+
|
| 74 |
+
# A linear layer to update track feats at each iteration
|
| 75 |
+
self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
|
| 76 |
+
|
| 77 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 78 |
+
|
| 79 |
+
if predict_conf:
|
| 80 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 81 |
+
|
| 82 |
+
def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
|
| 83 |
+
"""
|
| 84 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 85 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 86 |
+
note HH and WW is the size of feature maps instead of original images
|
| 87 |
+
"""
|
| 88 |
+
B, N, D = query_points.shape
|
| 89 |
+
B, S, C, HH, WW = fmaps.shape
|
| 90 |
+
|
| 91 |
+
assert D == 2, "Input points must be 2D coordinates"
|
| 92 |
+
|
| 93 |
+
# apply a layernorm to fmaps here
|
| 94 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
| 95 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
| 96 |
+
|
| 97 |
+
# Scale the input query_points because we may downsample the images
|
| 98 |
+
# by down_ratio or self.stride
|
| 99 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 100 |
+
# its query_points should be query_points/4
|
| 101 |
+
if down_ratio > 1:
|
| 102 |
+
query_points = query_points / float(down_ratio)
|
| 103 |
+
|
| 104 |
+
query_points = query_points / float(self.stride)
|
| 105 |
+
|
| 106 |
+
# Init with coords as the query points
|
| 107 |
+
# It means the search will start from the position of query points at the reference frames
|
| 108 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 109 |
+
|
| 110 |
+
# Sample/extract the features of the query points in the query frame
|
| 111 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 112 |
+
|
| 113 |
+
# init track feats by query feats
|
| 114 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 115 |
+
# back up the init coords
|
| 116 |
+
coords_backup = coords.clone()
|
| 117 |
+
|
| 118 |
+
fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
|
| 119 |
+
|
| 120 |
+
coord_preds = []
|
| 121 |
+
|
| 122 |
+
# Iterative Refinement
|
| 123 |
+
for _ in range(iters):
|
| 124 |
+
# Detach the gradients from the last iteration
|
| 125 |
+
# (in my experience, not very important for performance)
|
| 126 |
+
coords = coords.detach()
|
| 127 |
+
|
| 128 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
| 129 |
+
|
| 130 |
+
corr_dim = fcorrs.shape[3]
|
| 131 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
| 132 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
| 133 |
+
|
| 134 |
+
# Movement of current coords relative to query points
|
| 135 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 136 |
+
|
| 137 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 138 |
+
|
| 139 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 140 |
+
flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
|
| 141 |
+
|
| 142 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
| 143 |
+
|
| 144 |
+
# Concatenate them as the input for the transformers
|
| 145 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 146 |
+
|
| 147 |
+
# 2D positional embed
|
| 148 |
+
# TODO: this can be much simplified
|
| 149 |
+
pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
|
| 150 |
+
sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
|
| 151 |
+
|
| 152 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
x = transformer_input + sampled_pos_emb
|
| 155 |
+
|
| 156 |
+
# Add the query ref token to the track feats
|
| 157 |
+
query_ref_token = torch.cat(
|
| 158 |
+
[self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
|
| 159 |
+
)
|
| 160 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
| 161 |
+
|
| 162 |
+
# B, N, S, C
|
| 163 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 164 |
+
|
| 165 |
+
# Compute the delta coordinates and delta track features
|
| 166 |
+
delta, _ = self.updateformer(x)
|
| 167 |
+
|
| 168 |
+
# BN, S, C
|
| 169 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 170 |
+
delta_coords_ = delta[:, :, :2]
|
| 171 |
+
delta_feats_ = delta[:, :, 2:]
|
| 172 |
+
|
| 173 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 174 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 175 |
+
|
| 176 |
+
# Update the track features
|
| 177 |
+
track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
| 178 |
+
|
| 179 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
|
| 180 |
+
|
| 181 |
+
# B x S x N x 2
|
| 182 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 183 |
+
|
| 184 |
+
# Force coord0 as query
|
| 185 |
+
# because we assume the query points should not be changed
|
| 186 |
+
coords[:, 0] = coords_backup[:, 0]
|
| 187 |
+
|
| 188 |
+
# The predicted tracks are in the original image scale
|
| 189 |
+
if down_ratio > 1:
|
| 190 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
| 191 |
+
else:
|
| 192 |
+
coord_preds.append(coords * self.stride)
|
| 193 |
+
|
| 194 |
+
# B, S, N
|
| 195 |
+
vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
| 196 |
+
if apply_sigmoid:
|
| 197 |
+
vis_e = torch.sigmoid(vis_e)
|
| 198 |
+
|
| 199 |
+
if self.predict_conf:
|
| 200 |
+
conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
|
| 201 |
+
if apply_sigmoid:
|
| 202 |
+
conf_e = torch.sigmoid(conf_e)
|
| 203 |
+
else:
|
| 204 |
+
conf_e = None
|
| 205 |
+
|
| 206 |
+
if return_feat:
|
| 207 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
| 208 |
+
else:
|
| 209 |
+
return coord_preds, vis_e, conf_e
|
vggt/heads/track_modules/blocks.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from .utils import bilinear_sampler
|
| 16 |
+
from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EfficientUpdateFormer(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Transformer model that updates track estimates.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
space_depth=6,
|
| 27 |
+
time_depth=6,
|
| 28 |
+
input_dim=320,
|
| 29 |
+
hidden_size=384,
|
| 30 |
+
num_heads=8,
|
| 31 |
+
output_dim=130,
|
| 32 |
+
mlp_ratio=4.0,
|
| 33 |
+
add_space_attn=True,
|
| 34 |
+
num_virtual_tracks=64,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.out_channels = 2
|
| 39 |
+
self.num_heads = num_heads
|
| 40 |
+
self.hidden_size = hidden_size
|
| 41 |
+
self.add_space_attn = add_space_attn
|
| 42 |
+
|
| 43 |
+
# Add input LayerNorm before linear projection
|
| 44 |
+
self.input_norm = nn.LayerNorm(input_dim)
|
| 45 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 46 |
+
|
| 47 |
+
# Add output LayerNorm before final projection
|
| 48 |
+
self.output_norm = nn.LayerNorm(hidden_size)
|
| 49 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 50 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 51 |
+
|
| 52 |
+
if self.add_space_attn:
|
| 53 |
+
self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
|
| 54 |
+
else:
|
| 55 |
+
self.virual_tracks = None
|
| 56 |
+
|
| 57 |
+
self.time_blocks = nn.ModuleList(
|
| 58 |
+
[
|
| 59 |
+
AttnBlock(
|
| 60 |
+
hidden_size,
|
| 61 |
+
num_heads,
|
| 62 |
+
mlp_ratio=mlp_ratio,
|
| 63 |
+
attn_class=nn.MultiheadAttention,
|
| 64 |
+
)
|
| 65 |
+
for _ in range(time_depth)
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if add_space_attn:
|
| 70 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 71 |
+
[
|
| 72 |
+
AttnBlock(
|
| 73 |
+
hidden_size,
|
| 74 |
+
num_heads,
|
| 75 |
+
mlp_ratio=mlp_ratio,
|
| 76 |
+
attn_class=nn.MultiheadAttention,
|
| 77 |
+
)
|
| 78 |
+
for _ in range(space_depth)
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 82 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 83 |
+
)
|
| 84 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 85 |
+
[CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
|
| 86 |
+
)
|
| 87 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 88 |
+
self.initialize_weights()
|
| 89 |
+
|
| 90 |
+
def initialize_weights(self):
|
| 91 |
+
def _basic_init(module):
|
| 92 |
+
if isinstance(module, nn.Linear):
|
| 93 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 94 |
+
if module.bias is not None:
|
| 95 |
+
nn.init.constant_(module.bias, 0)
|
| 96 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
| 97 |
+
|
| 98 |
+
self.apply(_basic_init)
|
| 99 |
+
|
| 100 |
+
def forward(self, input_tensor, mask=None):
|
| 101 |
+
# Apply input LayerNorm
|
| 102 |
+
input_tensor = self.input_norm(input_tensor)
|
| 103 |
+
tokens = self.input_transform(input_tensor)
|
| 104 |
+
|
| 105 |
+
init_tokens = tokens
|
| 106 |
+
|
| 107 |
+
B, _, T, _ = tokens.shape
|
| 108 |
+
|
| 109 |
+
if self.add_space_attn:
|
| 110 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 111 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 112 |
+
|
| 113 |
+
_, N, _, _ = tokens.shape
|
| 114 |
+
|
| 115 |
+
j = 0
|
| 116 |
+
for i in range(len(self.time_blocks)):
|
| 117 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 118 |
+
|
| 119 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 120 |
+
|
| 121 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 122 |
+
if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
|
| 123 |
+
space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
|
| 124 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 125 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 126 |
+
|
| 127 |
+
virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
|
| 128 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 129 |
+
point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
|
| 130 |
+
|
| 131 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 132 |
+
tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
|
| 133 |
+
j += 1
|
| 134 |
+
|
| 135 |
+
if self.add_space_attn:
|
| 136 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 137 |
+
|
| 138 |
+
tokens = tokens + init_tokens
|
| 139 |
+
|
| 140 |
+
# Apply output LayerNorm before final projection
|
| 141 |
+
tokens = self.output_norm(tokens)
|
| 142 |
+
flow = self.flow_head(tokens)
|
| 143 |
+
|
| 144 |
+
return flow, None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class CorrBlock:
|
| 148 |
+
def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
|
| 149 |
+
"""
|
| 150 |
+
Build a pyramid of feature maps from the input.
|
| 151 |
+
|
| 152 |
+
fmaps: Tensor (B, S, C, H, W)
|
| 153 |
+
num_levels: number of pyramid levels (each downsampled by factor 2)
|
| 154 |
+
radius: search radius for sampling correlation
|
| 155 |
+
multiple_track_feats: if True, split the target features per pyramid level
|
| 156 |
+
padding_mode: passed to grid_sample / bilinear_sampler
|
| 157 |
+
"""
|
| 158 |
+
B, S, C, H, W = fmaps.shape
|
| 159 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 160 |
+
self.num_levels = num_levels
|
| 161 |
+
self.radius = radius
|
| 162 |
+
self.padding_mode = padding_mode
|
| 163 |
+
self.multiple_track_feats = multiple_track_feats
|
| 164 |
+
|
| 165 |
+
# Build pyramid: each level is half the spatial resolution of the previous
|
| 166 |
+
self.fmaps_pyramid = [fmaps] # level 0 is full resolution
|
| 167 |
+
current_fmaps = fmaps
|
| 168 |
+
for i in range(num_levels - 1):
|
| 169 |
+
B, S, C, H, W = current_fmaps.shape
|
| 170 |
+
# Merge batch & sequence dimensions
|
| 171 |
+
current_fmaps = current_fmaps.reshape(B * S, C, H, W)
|
| 172 |
+
# Avg pool down by factor 2
|
| 173 |
+
current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
|
| 174 |
+
_, _, H_new, W_new = current_fmaps.shape
|
| 175 |
+
current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
|
| 176 |
+
self.fmaps_pyramid.append(current_fmaps)
|
| 177 |
+
|
| 178 |
+
# Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
|
| 179 |
+
# This grid is added to the (scaled) coordinate centroids.
|
| 180 |
+
r = self.radius
|
| 181 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
| 182 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
|
| 183 |
+
# delta: for every (dy,dx) displacement (i.e. Δx, Δy)
|
| 184 |
+
self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
|
| 185 |
+
|
| 186 |
+
def corr_sample(self, targets, coords):
|
| 187 |
+
"""
|
| 188 |
+
Instead of storing the entire correlation pyramid, we compute each level's correlation
|
| 189 |
+
volume, sample it immediately, then discard it. This saves GPU memory.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
targets: Tensor (B, S, N, C) — features for the current targets.
|
| 193 |
+
coords: Tensor (B, S, N, 2) — coordinates at full resolution.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
|
| 197 |
+
"""
|
| 198 |
+
B, S, N, C = targets.shape
|
| 199 |
+
|
| 200 |
+
# If you have multiple track features, split them per level.
|
| 201 |
+
if self.multiple_track_feats:
|
| 202 |
+
targets_split = torch.split(targets, C // self.num_levels, dim=-1)
|
| 203 |
+
|
| 204 |
+
out_pyramid = []
|
| 205 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 206 |
+
# Get current spatial resolution H, W for this pyramid level.
|
| 207 |
+
B, S, C, H, W = fmaps.shape
|
| 208 |
+
# Reshape feature maps for correlation computation:
|
| 209 |
+
# fmap2s: (B, S, C, H*W)
|
| 210 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
| 211 |
+
# Choose appropriate target features.
|
| 212 |
+
fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
|
| 213 |
+
|
| 214 |
+
# Compute correlation directly
|
| 215 |
+
corrs = compute_corr_level(fmap1, fmap2s, C)
|
| 216 |
+
corrs = corrs.view(B, S, N, H, W)
|
| 217 |
+
|
| 218 |
+
# Prepare sampling grid:
|
| 219 |
+
# Scale down the coordinates for the current level.
|
| 220 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
|
| 221 |
+
# Make sure our precomputed delta grid is on the same device/dtype.
|
| 222 |
+
delta_lvl = self.delta.to(coords.device).to(coords.dtype)
|
| 223 |
+
# Now the grid for grid_sample is:
|
| 224 |
+
# coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
|
| 225 |
+
coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
|
| 226 |
+
|
| 227 |
+
# Sample from the correlation volume using bilinear interpolation.
|
| 228 |
+
# We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
|
| 229 |
+
corrs_sampled = bilinear_sampler(
|
| 230 |
+
corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
|
| 231 |
+
)
|
| 232 |
+
# The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
|
| 233 |
+
corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
|
| 234 |
+
out_pyramid.append(corrs_sampled)
|
| 235 |
+
|
| 236 |
+
# Concatenate all levels along the last dimension.
|
| 237 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous()
|
| 238 |
+
return out
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def compute_corr_level(fmap1, fmap2s, C):
|
| 242 |
+
# fmap1: (B, S, N, C)
|
| 243 |
+
# fmap2s: (B, S, C, H*W)
|
| 244 |
+
corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
|
| 245 |
+
corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
|
| 246 |
+
return corrs / math.sqrt(C)
|
vggt/heads/track_modules/modules.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import Callable
|
| 13 |
+
import collections
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from itertools import repeat
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# From PyTorch internals
|
| 19 |
+
def _ntuple(n):
|
| 20 |
+
def parse(x):
|
| 21 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 22 |
+
return tuple(x)
|
| 23 |
+
return tuple(repeat(x, n))
|
| 24 |
+
|
| 25 |
+
return parse
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def exists(val):
|
| 29 |
+
return val is not None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def default(val, d):
|
| 33 |
+
return val if exists(val) else d
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
to_2tuple = _ntuple(2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResidualBlock(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
| 45 |
+
super(ResidualBlock, self).__init__()
|
| 46 |
+
|
| 47 |
+
self.conv1 = nn.Conv2d(
|
| 48 |
+
in_planes,
|
| 49 |
+
planes,
|
| 50 |
+
kernel_size=kernel_size,
|
| 51 |
+
padding=1,
|
| 52 |
+
stride=stride,
|
| 53 |
+
padding_mode="zeros",
|
| 54 |
+
)
|
| 55 |
+
self.conv2 = nn.Conv2d(
|
| 56 |
+
planes,
|
| 57 |
+
planes,
|
| 58 |
+
kernel_size=kernel_size,
|
| 59 |
+
padding=1,
|
| 60 |
+
padding_mode="zeros",
|
| 61 |
+
)
|
| 62 |
+
self.relu = nn.ReLU(inplace=True)
|
| 63 |
+
|
| 64 |
+
num_groups = planes // 8
|
| 65 |
+
|
| 66 |
+
if norm_fn == "group":
|
| 67 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 68 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 69 |
+
if not stride == 1:
|
| 70 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 71 |
+
|
| 72 |
+
elif norm_fn == "batch":
|
| 73 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 74 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 75 |
+
if not stride == 1:
|
| 76 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 77 |
+
|
| 78 |
+
elif norm_fn == "instance":
|
| 79 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 80 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 81 |
+
if not stride == 1:
|
| 82 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 83 |
+
|
| 84 |
+
elif norm_fn == "none":
|
| 85 |
+
self.norm1 = nn.Sequential()
|
| 86 |
+
self.norm2 = nn.Sequential()
|
| 87 |
+
if not stride == 1:
|
| 88 |
+
self.norm3 = nn.Sequential()
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
if stride == 1:
|
| 93 |
+
self.downsample = None
|
| 94 |
+
else:
|
| 95 |
+
self.downsample = nn.Sequential(
|
| 96 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
| 97 |
+
self.norm3,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
y = x
|
| 102 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 103 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 104 |
+
|
| 105 |
+
if self.downsample is not None:
|
| 106 |
+
x = self.downsample(x)
|
| 107 |
+
|
| 108 |
+
return self.relu(x + y)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Mlp(nn.Module):
|
| 112 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
in_features,
|
| 117 |
+
hidden_features=None,
|
| 118 |
+
out_features=None,
|
| 119 |
+
act_layer=nn.GELU,
|
| 120 |
+
norm_layer=None,
|
| 121 |
+
bias=True,
|
| 122 |
+
drop=0.0,
|
| 123 |
+
use_conv=False,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
out_features = out_features or in_features
|
| 127 |
+
hidden_features = hidden_features or in_features
|
| 128 |
+
bias = to_2tuple(bias)
|
| 129 |
+
drop_probs = to_2tuple(drop)
|
| 130 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 131 |
+
|
| 132 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 133 |
+
self.act = act_layer()
|
| 134 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 135 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 136 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
x = self.fc1(x)
|
| 140 |
+
x = self.act(x)
|
| 141 |
+
x = self.drop1(x)
|
| 142 |
+
x = self.fc2(x)
|
| 143 |
+
x = self.drop2(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class AttnBlock(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
hidden_size,
|
| 151 |
+
num_heads,
|
| 152 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
| 153 |
+
mlp_ratio=4.0,
|
| 154 |
+
**block_kwargs
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Self attention block
|
| 158 |
+
"""
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
| 162 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
| 163 |
+
|
| 164 |
+
self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
|
| 165 |
+
|
| 166 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 167 |
+
|
| 168 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 169 |
+
|
| 170 |
+
def forward(self, x, mask=None):
|
| 171 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
| 172 |
+
# attn_mask = mask if mask is not None else None
|
| 173 |
+
# Normalize before attention
|
| 174 |
+
x = self.norm1(x)
|
| 175 |
+
|
| 176 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
| 177 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
| 178 |
+
|
| 179 |
+
attn_output, _ = self.attn(x, x, x)
|
| 180 |
+
|
| 181 |
+
# Add & Norm
|
| 182 |
+
x = x + attn_output
|
| 183 |
+
x = x + self.mlp(self.norm2(x))
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class CrossAttnBlock(nn.Module):
|
| 188 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
|
| 189 |
+
"""
|
| 190 |
+
Cross attention block
|
| 191 |
+
"""
|
| 192 |
+
super().__init__()
|
| 193 |
+
|
| 194 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
| 195 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 196 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
| 197 |
+
|
| 198 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 199 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 203 |
+
|
| 204 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, context, mask=None):
|
| 207 |
+
# Normalize inputs
|
| 208 |
+
x = self.norm1(x)
|
| 209 |
+
context = self.norm_context(context)
|
| 210 |
+
|
| 211 |
+
# Apply cross attention
|
| 212 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
| 213 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
| 214 |
+
|
| 215 |
+
# Add & Norm
|
| 216 |
+
x = x + attn_output
|
| 217 |
+
x = x + self.mlp(self.norm2(x))
|
| 218 |
+
return x
|
vggt/heads/track_modules/utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from https://github.com/facebookresearch/vggsfm
|
| 8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from typing import Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
|
| 19 |
+
"""
|
| 20 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
| 21 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
| 22 |
+
Args:
|
| 23 |
+
- embed_dim: The embedding dimension.
|
| 24 |
+
- grid_size: The grid size.
|
| 25 |
+
Returns:
|
| 26 |
+
- pos_embed: The generated 2D positional embedding.
|
| 27 |
+
"""
|
| 28 |
+
if isinstance(grid_size, tuple):
|
| 29 |
+
grid_size_h, grid_size_w = grid_size
|
| 30 |
+
else:
|
| 31 |
+
grid_size_h = grid_size_w = grid_size
|
| 32 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
| 33 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
| 34 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
| 35 |
+
grid = torch.stack(grid, dim=0)
|
| 36 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 37 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 38 |
+
if return_grid:
|
| 39 |
+
return (
|
| 40 |
+
pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
|
| 41 |
+
grid,
|
| 42 |
+
)
|
| 43 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
| 47 |
+
"""
|
| 48 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
- embed_dim: The embedding dimension.
|
| 52 |
+
- grid: The grid to generate the embedding from.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
- emb: The generated 2D positional embedding.
|
| 56 |
+
"""
|
| 57 |
+
assert embed_dim % 2 == 0
|
| 58 |
+
|
| 59 |
+
# use half of dimensions to encode grid_h
|
| 60 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 61 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 62 |
+
|
| 63 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
| 64 |
+
return emb
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
- embed_dim: The embedding dimension.
|
| 73 |
+
- pos: The position to generate the embedding from.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
- emb: The generated 1D positional embedding.
|
| 77 |
+
"""
|
| 78 |
+
assert embed_dim % 2 == 0
|
| 79 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 80 |
+
omega /= embed_dim / 2.0
|
| 81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 82 |
+
|
| 83 |
+
pos = pos.reshape(-1) # (M,)
|
| 84 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 85 |
+
|
| 86 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 87 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 88 |
+
|
| 89 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 90 |
+
return emb[None].float()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
- xy: The coordinates to generate the embedding from.
|
| 99 |
+
- C: The size of the embedding.
|
| 100 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
- pe: The generated 2D positional embedding.
|
| 104 |
+
"""
|
| 105 |
+
B, N, D = xy.shape
|
| 106 |
+
assert D == 2
|
| 107 |
+
|
| 108 |
+
x = xy[:, :, 0:1]
|
| 109 |
+
y = xy[:, :, 1:2]
|
| 110 |
+
div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
|
| 111 |
+
|
| 112 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 113 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 114 |
+
|
| 115 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 116 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 117 |
+
|
| 118 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 119 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 120 |
+
|
| 121 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
| 122 |
+
if cat_coords:
|
| 123 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
| 124 |
+
return pe
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 128 |
+
r"""Sample a tensor using bilinear interpolation
|
| 129 |
+
|
| 130 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 131 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 132 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 133 |
+
convention.
|
| 134 |
+
|
| 135 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 136 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 137 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 138 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 139 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 140 |
+
|
| 141 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 142 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 143 |
+
that in this case the order of the components is slightly different
|
| 144 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 145 |
+
|
| 146 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 147 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 148 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 149 |
+
pixel.
|
| 150 |
+
|
| 151 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 152 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 153 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 154 |
+
pixel.
|
| 155 |
+
|
| 156 |
+
Similar conventions apply to the :math:`y` for the range
|
| 157 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 158 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
input (Tensor): batch of input images.
|
| 162 |
+
coords (Tensor): batch of coordinates.
|
| 163 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 164 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Tensor: sampled points.
|
| 168 |
+
"""
|
| 169 |
+
coords = coords.detach().clone()
|
| 170 |
+
############################################################
|
| 171 |
+
# IMPORTANT:
|
| 172 |
+
coords = coords.to(input.device).to(input.dtype)
|
| 173 |
+
############################################################
|
| 174 |
+
|
| 175 |
+
sizes = input.shape[2:]
|
| 176 |
+
|
| 177 |
+
assert len(sizes) in [2, 3]
|
| 178 |
+
|
| 179 |
+
if len(sizes) == 3:
|
| 180 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 181 |
+
coords = coords[..., [1, 2, 0]]
|
| 182 |
+
|
| 183 |
+
if align_corners:
|
| 184 |
+
scale = torch.tensor(
|
| 185 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
|
| 186 |
+
)
|
| 187 |
+
else:
|
| 188 |
+
scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
|
| 189 |
+
|
| 190 |
+
coords.mul_(scale) # coords = coords * scale
|
| 191 |
+
coords.sub_(1) # coords = coords - 1
|
| 192 |
+
|
| 193 |
+
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def sample_features4d(input, coords):
|
| 197 |
+
r"""Sample spatial features
|
| 198 |
+
|
| 199 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 200 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 201 |
+
|
| 202 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 203 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 204 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 205 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 206 |
+
|
| 207 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 208 |
+
R, C)`.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
input (Tensor): spatial features.
|
| 212 |
+
coords (Tensor): points.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Tensor: sampled features.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
B, _, _, _ = input.shape
|
| 219 |
+
|
| 220 |
+
# B R 2 -> B R 1 2
|
| 221 |
+
coords = coords.unsqueeze(2)
|
| 222 |
+
|
| 223 |
+
# B C R 1
|
| 224 |
+
feats = bilinear_sampler(input, coords)
|
| 225 |
+
|
| 226 |
+
return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
|
vggt/heads/utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 17 |
+
embed_dim: Output channel dimension for embeddings
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 21 |
+
"""
|
| 22 |
+
H, W, grid_dim = pos_grid.shape
|
| 23 |
+
assert grid_dim == 2
|
| 24 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 25 |
+
|
| 26 |
+
# Process x and y coordinates separately
|
| 27 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 28 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 29 |
+
|
| 30 |
+
# Combine and reshape
|
| 31 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 32 |
+
|
| 33 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
- embed_dim: The embedding dimension.
|
| 42 |
+
- pos: The position to generate the embedding from.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
- emb: The generated 1D positional embedding.
|
| 46 |
+
"""
|
| 47 |
+
assert embed_dim % 2 == 0
|
| 48 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
|
| 49 |
+
omega /= embed_dim / 2.0
|
| 50 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 51 |
+
|
| 52 |
+
pos = pos.reshape(-1) # (M,)
|
| 53 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 54 |
+
|
| 55 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 56 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 57 |
+
|
| 58 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 59 |
+
return emb.float()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Inspired by https://github.com/microsoft/moge
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_uv_grid(
|
| 66 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 70 |
+
|
| 71 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 72 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 73 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
width (int): Number of points horizontally.
|
| 77 |
+
height (int): Number of points vertically.
|
| 78 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 79 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 80 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 84 |
+
"""
|
| 85 |
+
# Derive aspect ratio if not explicitly provided
|
| 86 |
+
if aspect_ratio is None:
|
| 87 |
+
aspect_ratio = float(width) / float(height)
|
| 88 |
+
|
| 89 |
+
# Compute normalized spans for X and Y
|
| 90 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 91 |
+
span_x = aspect_ratio / diag_factor
|
| 92 |
+
span_y = 1.0 / diag_factor
|
| 93 |
+
|
| 94 |
+
# Establish the linspace boundaries
|
| 95 |
+
left_x = -span_x * (width - 1) / width
|
| 96 |
+
right_x = span_x * (width - 1) / width
|
| 97 |
+
top_y = -span_y * (height - 1) / height
|
| 98 |
+
bottom_y = span_y * (height - 1) / height
|
| 99 |
+
|
| 100 |
+
# Generate 1D coordinates
|
| 101 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 102 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 103 |
+
|
| 104 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 105 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 106 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 107 |
+
|
| 108 |
+
return uv_grid
|
vggt/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
vggt/layers/attention.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
XFORMERS_AVAILABLE = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int = 8,
|
| 26 |
+
qkv_bias: bool = True,
|
| 27 |
+
proj_bias: bool = True,
|
| 28 |
+
attn_drop: float = 0.0,
|
| 29 |
+
proj_drop: float = 0.0,
|
| 30 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 31 |
+
qk_norm: bool = False,
|
| 32 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 33 |
+
rope=None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
self.head_dim = dim // num_heads
|
| 39 |
+
self.scale = self.head_dim**-0.5
|
| 40 |
+
self.fused_attn = fused_attn
|
| 41 |
+
|
| 42 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
self.rope = rope
|
| 49 |
+
|
| 50 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 51 |
+
B, N, C = x.shape
|
| 52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 55 |
+
|
| 56 |
+
if self.rope is not None:
|
| 57 |
+
q = self.rope(q, pos)
|
| 58 |
+
k = self.rope(k, pos)
|
| 59 |
+
|
| 60 |
+
if self.fused_attn:
|
| 61 |
+
x = F.scaled_dot_product_attention(
|
| 62 |
+
q,
|
| 63 |
+
k,
|
| 64 |
+
v,
|
| 65 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
q = q * self.scale
|
| 69 |
+
attn = q @ k.transpose(-2, -1)
|
| 70 |
+
attn = attn.softmax(dim=-1)
|
| 71 |
+
attn = self.attn_drop(attn)
|
| 72 |
+
x = attn @ v
|
| 73 |
+
|
| 74 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 75 |
+
x = self.proj(x)
|
| 76 |
+
x = self.proj_drop(x)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MemEffAttention(Attention):
|
| 81 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
| 82 |
+
assert pos is None
|
| 83 |
+
if not XFORMERS_AVAILABLE:
|
| 84 |
+
if attn_bias is not None:
|
| 85 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 86 |
+
return super().forward(x)
|
| 87 |
+
|
| 88 |
+
B, N, C = x.shape
|
| 89 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 90 |
+
|
| 91 |
+
q, k, v = unbind(qkv, 2)
|
| 92 |
+
|
| 93 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 94 |
+
x = x.reshape([B, N, C])
|
| 95 |
+
|
| 96 |
+
x = self.proj(x)
|
| 97 |
+
x = self.proj_drop(x)
|
| 98 |
+
return x
|
vggt/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
XFORMERS_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Block(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_ratio: float = 4.0,
|
| 33 |
+
qkv_bias: bool = True,
|
| 34 |
+
proj_bias: bool = True,
|
| 35 |
+
ffn_bias: bool = True,
|
| 36 |
+
drop: float = 0.0,
|
| 37 |
+
attn_drop: float = 0.0,
|
| 38 |
+
init_values=None,
|
| 39 |
+
drop_path: float = 0.0,
|
| 40 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 41 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 42 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 43 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 44 |
+
qk_norm: bool = False,
|
| 45 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 46 |
+
rope=None,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.norm1 = norm_layer(dim)
|
| 51 |
+
|
| 52 |
+
self.attn = attn_class(
|
| 53 |
+
dim,
|
| 54 |
+
num_heads=num_heads,
|
| 55 |
+
qkv_bias=qkv_bias,
|
| 56 |
+
proj_bias=proj_bias,
|
| 57 |
+
attn_drop=attn_drop,
|
| 58 |
+
proj_drop=drop,
|
| 59 |
+
qk_norm=qk_norm,
|
| 60 |
+
fused_attn=fused_attn,
|
| 61 |
+
rope=rope,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 65 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.norm2 = norm_layer(dim)
|
| 68 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 69 |
+
self.mlp = ffn_layer(
|
| 70 |
+
in_features=dim,
|
| 71 |
+
hidden_features=mlp_hidden_dim,
|
| 72 |
+
act_layer=act_layer,
|
| 73 |
+
drop=drop,
|
| 74 |
+
bias=ffn_bias,
|
| 75 |
+
)
|
| 76 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 77 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 78 |
+
|
| 79 |
+
self.sample_drop_ratio = drop_path
|
| 80 |
+
|
| 81 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 82 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 83 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
| 84 |
+
|
| 85 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 86 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 87 |
+
|
| 88 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 89 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 90 |
+
x = drop_add_residual_stochastic_depth(
|
| 91 |
+
x,
|
| 92 |
+
pos=pos,
|
| 93 |
+
residual_func=attn_residual_func,
|
| 94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 95 |
+
)
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=ffn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 102 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 104 |
+
else:
|
| 105 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 106 |
+
x = x + ffn_residual_func(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def drop_add_residual_stochastic_depth(
|
| 111 |
+
x: Tensor,
|
| 112 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 113 |
+
sample_drop_ratio: float = 0.0,
|
| 114 |
+
pos=None,
|
| 115 |
+
) -> Tensor:
|
| 116 |
+
# 1) extract subset using permutation
|
| 117 |
+
b, n, d = x.shape
|
| 118 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 119 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 120 |
+
x_subset = x[brange]
|
| 121 |
+
|
| 122 |
+
# 2) apply residual_func to get residual
|
| 123 |
+
if pos is not None:
|
| 124 |
+
# if necessary, apply rope to the subset
|
| 125 |
+
pos = pos[brange]
|
| 126 |
+
residual = residual_func(x_subset, pos=pos)
|
| 127 |
+
else:
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
vggt/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
vggt/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
vggt/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
vggt/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
vggt/layers/rope.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
| 8 |
+
|
| 9 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
| 10 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
| 11 |
+
|
| 12 |
+
# Inspired by:
|
| 13 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 14 |
+
# https://github.com/naver-ai/rope-vit
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from typing import Dict, Tuple
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PositionGetter:
|
| 25 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 26 |
+
|
| 27 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 28 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 29 |
+
|
| 30 |
+
Attributes:
|
| 31 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 32 |
+
grid dimensions.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
"""Initializes the position generator with an empty cache."""
|
| 37 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 38 |
+
|
| 39 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 40 |
+
"""Generates spatial positions for a batch of patches.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
batch_size: Number of samples in the batch.
|
| 44 |
+
height: Height of the grid in patches.
|
| 45 |
+
width: Width of the grid in patches.
|
| 46 |
+
device: Target device for the position tensor.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 50 |
+
for each position in the grid, repeated for each batch item.
|
| 51 |
+
"""
|
| 52 |
+
if (height, width) not in self.position_cache:
|
| 53 |
+
y_coords = torch.arange(height, device=device)
|
| 54 |
+
x_coords = torch.arange(width, device=device)
|
| 55 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 56 |
+
self.position_cache[height, width] = positions
|
| 57 |
+
|
| 58 |
+
cached_positions = self.position_cache[height, width]
|
| 59 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 63 |
+
"""2D Rotary Position Embedding implementation.
|
| 64 |
+
|
| 65 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 66 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 67 |
+
separately for vertical and horizontal dimensions.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 71 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 72 |
+
|
| 73 |
+
Attributes:
|
| 74 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 75 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 76 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
| 80 |
+
"""Initializes the 2D RoPE module."""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.base_frequency = frequency
|
| 83 |
+
self.scaling_factor = scaling_factor
|
| 84 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 85 |
+
|
| 86 |
+
def _compute_frequency_components(
|
| 87 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""Computes frequency components for rotary embeddings.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
dim: Feature dimension (must be even).
|
| 93 |
+
seq_len: Maximum sequence length.
|
| 94 |
+
device: Target device for computations.
|
| 95 |
+
dtype: Data type for the computed tensors.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 99 |
+
"""
|
| 100 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 101 |
+
if cache_key not in self.frequency_cache:
|
| 102 |
+
# Compute frequency bands
|
| 103 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 104 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 105 |
+
|
| 106 |
+
# Generate position-dependent frequencies
|
| 107 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 108 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 109 |
+
|
| 110 |
+
# Compute and cache frequency components
|
| 111 |
+
angles = angles.to(dtype)
|
| 112 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 113 |
+
cos_components = angles.cos().to(dtype)
|
| 114 |
+
sin_components = angles.sin().to(dtype)
|
| 115 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 116 |
+
|
| 117 |
+
return self.frequency_cache[cache_key]
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor to rotate.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Rotated feature tensor.
|
| 128 |
+
"""
|
| 129 |
+
feature_dim = x.shape[-1]
|
| 130 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 131 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 132 |
+
|
| 133 |
+
def _apply_1d_rope(
|
| 134 |
+
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
| 135 |
+
) -> torch.Tensor:
|
| 136 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
tokens: Input token features.
|
| 140 |
+
positions: Position indices.
|
| 141 |
+
cos_comp: Cosine components for rotation.
|
| 142 |
+
sin_comp: Sine components for rotation.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tokens with applied rotary position embeddings.
|
| 146 |
+
"""
|
| 147 |
+
# Embed positions with frequency components
|
| 148 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 149 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 150 |
+
|
| 151 |
+
# Apply rotation
|
| 152 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 153 |
+
|
| 154 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 159 |
+
The feature dimension (dim) must be divisible by 4.
|
| 160 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 161 |
+
the y and x coordinates for each token.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 165 |
+
|
| 166 |
+
Raises:
|
| 167 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 168 |
+
"""
|
| 169 |
+
# Validate inputs
|
| 170 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 171 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 172 |
+
|
| 173 |
+
# Compute feature dimension for each spatial direction
|
| 174 |
+
feature_dim = tokens.size(-1) // 2
|
| 175 |
+
|
| 176 |
+
# Get frequency components
|
| 177 |
+
max_position = int(positions.max()) + 1
|
| 178 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
| 179 |
+
|
| 180 |
+
# Split features for vertical and horizontal processing
|
| 181 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 182 |
+
|
| 183 |
+
# Apply RoPE separately for each dimension
|
| 184 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
| 185 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
| 186 |
+
|
| 187 |
+
# Combine processed features
|
| 188 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
vggt/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
# try:
|
| 39 |
+
# if XFORMERS_ENABLED:
|
| 40 |
+
# from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
# XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
# else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
# raise ImportError
|
| 47 |
+
# except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
vggt/layers/vision_transformer.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("dinov2")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 25 |
+
if not depth_first and include_root:
|
| 26 |
+
fn(module=module, name=name)
|
| 27 |
+
for child_name, child_module in module.named_children():
|
| 28 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 29 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 30 |
+
if depth_first and include_root:
|
| 31 |
+
fn(module=module, name=name)
|
| 32 |
+
return module
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BlockChunk(nn.ModuleList):
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
for b in self:
|
| 38 |
+
x = b(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DinoVisionTransformer(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
img_size=224,
|
| 46 |
+
patch_size=16,
|
| 47 |
+
in_chans=3,
|
| 48 |
+
embed_dim=768,
|
| 49 |
+
depth=12,
|
| 50 |
+
num_heads=12,
|
| 51 |
+
mlp_ratio=4.0,
|
| 52 |
+
qkv_bias=True,
|
| 53 |
+
ffn_bias=True,
|
| 54 |
+
proj_bias=True,
|
| 55 |
+
drop_path_rate=0.0,
|
| 56 |
+
drop_path_uniform=False,
|
| 57 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 58 |
+
embed_layer=PatchEmbed,
|
| 59 |
+
act_layer=nn.GELU,
|
| 60 |
+
block_fn=Block,
|
| 61 |
+
ffn_layer="mlp",
|
| 62 |
+
block_chunks=1,
|
| 63 |
+
num_register_tokens=0,
|
| 64 |
+
interpolate_antialias=False,
|
| 65 |
+
interpolate_offset=0.1,
|
| 66 |
+
qk_norm=False,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
img_size (int, tuple): input image size
|
| 71 |
+
patch_size (int, tuple): patch size
|
| 72 |
+
in_chans (int): number of input channels
|
| 73 |
+
embed_dim (int): embedding dimension
|
| 74 |
+
depth (int): depth of transformer
|
| 75 |
+
num_heads (int): number of attention heads
|
| 76 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 77 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 78 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 79 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 80 |
+
drop_path_rate (float): stochastic depth rate
|
| 81 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 82 |
+
weight_init (str): weight init scheme
|
| 83 |
+
init_values (float): layer-scale init values
|
| 84 |
+
embed_layer (nn.Module): patch embedding layer
|
| 85 |
+
act_layer (nn.Module): MLP activation layer
|
| 86 |
+
block_fn (nn.Module): transformer block class
|
| 87 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 88 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 89 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 90 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 91 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 95 |
+
|
| 96 |
+
# tricky but makes it work
|
| 97 |
+
self.use_checkpoint = False
|
| 98 |
+
#
|
| 99 |
+
|
| 100 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 101 |
+
self.num_tokens = 1
|
| 102 |
+
self.n_blocks = depth
|
| 103 |
+
self.num_heads = num_heads
|
| 104 |
+
self.patch_size = patch_size
|
| 105 |
+
self.num_register_tokens = num_register_tokens
|
| 106 |
+
self.interpolate_antialias = interpolate_antialias
|
| 107 |
+
self.interpolate_offset = interpolate_offset
|
| 108 |
+
|
| 109 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 110 |
+
num_patches = self.patch_embed.num_patches
|
| 111 |
+
|
| 112 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 113 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 114 |
+
assert num_register_tokens >= 0
|
| 115 |
+
self.register_tokens = (
|
| 116 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if drop_path_uniform is True:
|
| 120 |
+
dpr = [drop_path_rate] * depth
|
| 121 |
+
else:
|
| 122 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 123 |
+
|
| 124 |
+
if ffn_layer == "mlp":
|
| 125 |
+
logger.info("using MLP layer as FFN")
|
| 126 |
+
ffn_layer = Mlp
|
| 127 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 128 |
+
logger.info("using SwiGLU layer as FFN")
|
| 129 |
+
ffn_layer = SwiGLUFFNFused
|
| 130 |
+
elif ffn_layer == "identity":
|
| 131 |
+
logger.info("using Identity layer as FFN")
|
| 132 |
+
|
| 133 |
+
def f(*args, **kwargs):
|
| 134 |
+
return nn.Identity()
|
| 135 |
+
|
| 136 |
+
ffn_layer = f
|
| 137 |
+
else:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
|
| 140 |
+
blocks_list = [
|
| 141 |
+
block_fn(
|
| 142 |
+
dim=embed_dim,
|
| 143 |
+
num_heads=num_heads,
|
| 144 |
+
mlp_ratio=mlp_ratio,
|
| 145 |
+
qkv_bias=qkv_bias,
|
| 146 |
+
proj_bias=proj_bias,
|
| 147 |
+
ffn_bias=ffn_bias,
|
| 148 |
+
drop_path=dpr[i],
|
| 149 |
+
norm_layer=norm_layer,
|
| 150 |
+
act_layer=act_layer,
|
| 151 |
+
ffn_layer=ffn_layer,
|
| 152 |
+
init_values=init_values,
|
| 153 |
+
qk_norm=qk_norm,
|
| 154 |
+
)
|
| 155 |
+
for i in range(depth)
|
| 156 |
+
]
|
| 157 |
+
if block_chunks > 0:
|
| 158 |
+
self.chunked_blocks = True
|
| 159 |
+
chunked_blocks = []
|
| 160 |
+
chunksize = depth // block_chunks
|
| 161 |
+
for i in range(0, depth, chunksize):
|
| 162 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 163 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 164 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 165 |
+
else:
|
| 166 |
+
self.chunked_blocks = False
|
| 167 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 168 |
+
|
| 169 |
+
self.norm = norm_layer(embed_dim)
|
| 170 |
+
self.head = nn.Identity()
|
| 171 |
+
|
| 172 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 173 |
+
|
| 174 |
+
self.init_weights()
|
| 175 |
+
|
| 176 |
+
def init_weights(self):
|
| 177 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 178 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 179 |
+
if self.register_tokens is not None:
|
| 180 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 181 |
+
named_apply(init_weights_vit_timm, self)
|
| 182 |
+
|
| 183 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 184 |
+
previous_dtype = x.dtype
|
| 185 |
+
npatch = x.shape[1] - 1
|
| 186 |
+
N = self.pos_embed.shape[1] - 1
|
| 187 |
+
if npatch == N and w == h:
|
| 188 |
+
return self.pos_embed
|
| 189 |
+
pos_embed = self.pos_embed.float()
|
| 190 |
+
class_pos_embed = pos_embed[:, 0]
|
| 191 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 192 |
+
dim = x.shape[-1]
|
| 193 |
+
w0 = w // self.patch_size
|
| 194 |
+
h0 = h // self.patch_size
|
| 195 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 196 |
+
assert N == M * M
|
| 197 |
+
kwargs = {}
|
| 198 |
+
if self.interpolate_offset:
|
| 199 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 200 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 201 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 202 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 203 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 204 |
+
else:
|
| 205 |
+
# Simply specify an output size instead of a scale factor
|
| 206 |
+
kwargs["size"] = (w0, h0)
|
| 207 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 208 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 209 |
+
mode="bicubic",
|
| 210 |
+
antialias=self.interpolate_antialias,
|
| 211 |
+
**kwargs,
|
| 212 |
+
)
|
| 213 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 214 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 215 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 216 |
+
|
| 217 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 218 |
+
B, nc, w, h = x.shape
|
| 219 |
+
x = self.patch_embed(x)
|
| 220 |
+
if masks is not None:
|
| 221 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 222 |
+
|
| 223 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 224 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 225 |
+
|
| 226 |
+
if self.register_tokens is not None:
|
| 227 |
+
x = torch.cat(
|
| 228 |
+
(
|
| 229 |
+
x[:, :1],
|
| 230 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 231 |
+
x[:, 1:],
|
| 232 |
+
),
|
| 233 |
+
dim=1,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
def forward_features_list(self, x_list, masks_list):
|
| 239 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 240 |
+
|
| 241 |
+
for blk in self.blocks:
|
| 242 |
+
if self.use_checkpoint:
|
| 243 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 244 |
+
else:
|
| 245 |
+
x = blk(x)
|
| 246 |
+
|
| 247 |
+
all_x = x
|
| 248 |
+
output = []
|
| 249 |
+
for x, masks in zip(all_x, masks_list):
|
| 250 |
+
x_norm = self.norm(x)
|
| 251 |
+
output.append(
|
| 252 |
+
{
|
| 253 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 254 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 255 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 256 |
+
"x_prenorm": x,
|
| 257 |
+
"masks": masks,
|
| 258 |
+
}
|
| 259 |
+
)
|
| 260 |
+
return output
|
| 261 |
+
|
| 262 |
+
def forward_features(self, x, masks=None):
|
| 263 |
+
if isinstance(x, list):
|
| 264 |
+
return self.forward_features_list(x, masks)
|
| 265 |
+
|
| 266 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 267 |
+
|
| 268 |
+
for blk in self.blocks:
|
| 269 |
+
if self.use_checkpoint:
|
| 270 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 271 |
+
else:
|
| 272 |
+
x = blk(x)
|
| 273 |
+
|
| 274 |
+
x_norm = self.norm(x)
|
| 275 |
+
return {
|
| 276 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 277 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 278 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 279 |
+
"x_prenorm": x,
|
| 280 |
+
"masks": masks,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 284 |
+
x = self.prepare_tokens_with_masks(x)
|
| 285 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 286 |
+
output, total_block_len = [], len(self.blocks)
|
| 287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 288 |
+
for i, blk in enumerate(self.blocks):
|
| 289 |
+
x = blk(x)
|
| 290 |
+
if i in blocks_to_take:
|
| 291 |
+
output.append(x)
|
| 292 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 296 |
+
x = self.prepare_tokens_with_masks(x)
|
| 297 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 298 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 299 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 300 |
+
for block_chunk in self.blocks:
|
| 301 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 302 |
+
x = blk(x)
|
| 303 |
+
if i in blocks_to_take:
|
| 304 |
+
output.append(x)
|
| 305 |
+
i += 1
|
| 306 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 307 |
+
return output
|
| 308 |
+
|
| 309 |
+
def get_intermediate_layers(
|
| 310 |
+
self,
|
| 311 |
+
x: torch.Tensor,
|
| 312 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 313 |
+
reshape: bool = False,
|
| 314 |
+
return_class_token: bool = False,
|
| 315 |
+
norm=True,
|
| 316 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 317 |
+
if self.chunked_blocks:
|
| 318 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 319 |
+
else:
|
| 320 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 321 |
+
if norm:
|
| 322 |
+
outputs = [self.norm(out) for out in outputs]
|
| 323 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 324 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 325 |
+
if reshape:
|
| 326 |
+
B, _, w, h = x.shape
|
| 327 |
+
outputs = [
|
| 328 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 329 |
+
for out in outputs
|
| 330 |
+
]
|
| 331 |
+
if return_class_token:
|
| 332 |
+
return tuple(zip(outputs, class_tokens))
|
| 333 |
+
return tuple(outputs)
|
| 334 |
+
|
| 335 |
+
def forward(self, *args, is_training=True, **kwargs):
|
| 336 |
+
ret = self.forward_features(*args, **kwargs)
|
| 337 |
+
if is_training:
|
| 338 |
+
return ret
|
| 339 |
+
else:
|
| 340 |
+
return self.head(ret["x_norm_clstoken"])
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 344 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 345 |
+
if isinstance(module, nn.Linear):
|
| 346 |
+
trunc_normal_(module.weight, std=0.02)
|
| 347 |
+
if module.bias is not None:
|
| 348 |
+
nn.init.zeros_(module.bias)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 352 |
+
model = DinoVisionTransformer(
|
| 353 |
+
patch_size=patch_size,
|
| 354 |
+
embed_dim=384,
|
| 355 |
+
depth=12,
|
| 356 |
+
num_heads=6,
|
| 357 |
+
mlp_ratio=4,
|
| 358 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 359 |
+
num_register_tokens=num_register_tokens,
|
| 360 |
+
**kwargs,
|
| 361 |
+
)
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 366 |
+
model = DinoVisionTransformer(
|
| 367 |
+
patch_size=patch_size,
|
| 368 |
+
embed_dim=768,
|
| 369 |
+
depth=12,
|
| 370 |
+
num_heads=12,
|
| 371 |
+
mlp_ratio=4,
|
| 372 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 373 |
+
num_register_tokens=num_register_tokens,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
return model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 380 |
+
model = DinoVisionTransformer(
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
embed_dim=1024,
|
| 383 |
+
depth=24,
|
| 384 |
+
num_heads=16,
|
| 385 |
+
mlp_ratio=4,
|
| 386 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 387 |
+
num_register_tokens=num_register_tokens,
|
| 388 |
+
**kwargs,
|
| 389 |
+
)
|
| 390 |
+
return model
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 394 |
+
"""
|
| 395 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 396 |
+
"""
|
| 397 |
+
model = DinoVisionTransformer(
|
| 398 |
+
patch_size=patch_size,
|
| 399 |
+
embed_dim=1536,
|
| 400 |
+
depth=40,
|
| 401 |
+
num_heads=24,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 404 |
+
num_register_tokens=num_register_tokens,
|
| 405 |
+
**kwargs,
|
| 406 |
+
)
|
| 407 |
+
return model
|
vggt/models/aggregator.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
| 12 |
+
|
| 13 |
+
from vggt.layers import PatchEmbed
|
| 14 |
+
from vggt.layers.block import Block
|
| 15 |
+
from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 16 |
+
from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 21 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Aggregator(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
The Aggregator applies alternating-attention over input frames,
|
| 27 |
+
as described in VGGT: Visual Geometry Grounded Transformer.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size (int): Image size in pixels.
|
| 32 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
| 33 |
+
embed_dim (int): Dimension of the token embeddings.
|
| 34 |
+
depth (int): Number of blocks.
|
| 35 |
+
num_heads (int): Number of attention heads.
|
| 36 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
| 37 |
+
num_register_tokens (int): Number of register tokens.
|
| 38 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
| 39 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
| 40 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
| 41 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
| 42 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
| 43 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
| 44 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
| 45 |
+
qk_norm (bool): Whether to apply QK normalization.
|
| 46 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
| 47 |
+
init_values (float): Init scale for layer scale.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
img_size=518,
|
| 53 |
+
patch_size=14,
|
| 54 |
+
embed_dim=1024,
|
| 55 |
+
depth=24,
|
| 56 |
+
num_heads=16,
|
| 57 |
+
mlp_ratio=4.0,
|
| 58 |
+
num_register_tokens=4,
|
| 59 |
+
block_fn=Block,
|
| 60 |
+
qkv_bias=True,
|
| 61 |
+
proj_bias=True,
|
| 62 |
+
ffn_bias=True,
|
| 63 |
+
patch_embed="dinov2_vitl14_reg",
|
| 64 |
+
aa_order=["frame", "global"],
|
| 65 |
+
aa_block_size=1,
|
| 66 |
+
qk_norm=True,
|
| 67 |
+
rope_freq=100,
|
| 68 |
+
init_values=0.01,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
| 73 |
+
|
| 74 |
+
# Initialize rotary position embedding if frequency > 0
|
| 75 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 76 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 77 |
+
|
| 78 |
+
self.frame_blocks = nn.ModuleList(
|
| 79 |
+
[
|
| 80 |
+
block_fn(
|
| 81 |
+
dim=embed_dim,
|
| 82 |
+
num_heads=num_heads,
|
| 83 |
+
mlp_ratio=mlp_ratio,
|
| 84 |
+
qkv_bias=qkv_bias,
|
| 85 |
+
proj_bias=proj_bias,
|
| 86 |
+
ffn_bias=ffn_bias,
|
| 87 |
+
init_values=init_values,
|
| 88 |
+
qk_norm=qk_norm,
|
| 89 |
+
rope=self.rope,
|
| 90 |
+
)
|
| 91 |
+
for _ in range(depth)
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.global_blocks = nn.ModuleList(
|
| 96 |
+
[
|
| 97 |
+
block_fn(
|
| 98 |
+
dim=embed_dim,
|
| 99 |
+
num_heads=num_heads,
|
| 100 |
+
mlp_ratio=mlp_ratio,
|
| 101 |
+
qkv_bias=qkv_bias,
|
| 102 |
+
proj_bias=proj_bias,
|
| 103 |
+
ffn_bias=ffn_bias,
|
| 104 |
+
init_values=init_values,
|
| 105 |
+
qk_norm=qk_norm,
|
| 106 |
+
rope=self.rope,
|
| 107 |
+
)
|
| 108 |
+
for _ in range(depth)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.depth = depth
|
| 113 |
+
self.aa_order = aa_order
|
| 114 |
+
self.patch_size = patch_size
|
| 115 |
+
self.aa_block_size = aa_block_size
|
| 116 |
+
|
| 117 |
+
# Validate that depth is divisible by aa_block_size
|
| 118 |
+
if self.depth % self.aa_block_size != 0:
|
| 119 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 120 |
+
|
| 121 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 122 |
+
|
| 123 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
| 124 |
+
# The same applies for register tokens
|
| 125 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
| 126 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
| 127 |
+
|
| 128 |
+
# The patch tokens start after the camera and register tokens
|
| 129 |
+
self.patch_start_idx = 1 + num_register_tokens
|
| 130 |
+
|
| 131 |
+
# Initialize parameters with small values
|
| 132 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
| 133 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 134 |
+
|
| 135 |
+
# Register normalization constants as buffers
|
| 136 |
+
for name, value in (
|
| 137 |
+
("_resnet_mean", _RESNET_MEAN),
|
| 138 |
+
("_resnet_std", _RESNET_STD),
|
| 139 |
+
):
|
| 140 |
+
self.register_buffer(
|
| 141 |
+
name,
|
| 142 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
| 143 |
+
persistent=False,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def __build_patch_embed__(
|
| 147 |
+
self,
|
| 148 |
+
patch_embed,
|
| 149 |
+
img_size,
|
| 150 |
+
patch_size,
|
| 151 |
+
num_register_tokens,
|
| 152 |
+
interpolate_antialias=True,
|
| 153 |
+
interpolate_offset=0.0,
|
| 154 |
+
block_chunks=0,
|
| 155 |
+
init_values=1.0,
|
| 156 |
+
embed_dim=1024,
|
| 157 |
+
):
|
| 158 |
+
"""
|
| 159 |
+
Build the patch embed layer. If 'conv', we use a
|
| 160 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
if "conv" in patch_embed:
|
| 164 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
| 165 |
+
else:
|
| 166 |
+
vit_models = {
|
| 167 |
+
"dinov2_vitl14_reg": vit_large,
|
| 168 |
+
"dinov2_vitb14_reg": vit_base,
|
| 169 |
+
"dinov2_vits14_reg": vit_small,
|
| 170 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
self.patch_embed = vit_models[patch_embed](
|
| 174 |
+
img_size=img_size,
|
| 175 |
+
patch_size=patch_size,
|
| 176 |
+
num_register_tokens=num_register_tokens,
|
| 177 |
+
interpolate_antialias=interpolate_antialias,
|
| 178 |
+
interpolate_offset=interpolate_offset,
|
| 179 |
+
block_chunks=block_chunks,
|
| 180 |
+
init_values=init_values,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Disable gradient updates for mask token
|
| 184 |
+
if hasattr(self.patch_embed, "mask_token"):
|
| 185 |
+
self.patch_embed.mask_token.requires_grad_(False)
|
| 186 |
+
|
| 187 |
+
def forward(
|
| 188 |
+
self,
|
| 189 |
+
images: torch.Tensor,
|
| 190 |
+
) -> Tuple[List[torch.Tensor], int]:
|
| 191 |
+
"""
|
| 192 |
+
Args:
|
| 193 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 194 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
(list[torch.Tensor], int):
|
| 198 |
+
The list of outputs from the attention blocks,
|
| 199 |
+
and the patch_start_idx indicating where patch tokens begin.
|
| 200 |
+
"""
|
| 201 |
+
B, S, C_in, H, W = images.shape
|
| 202 |
+
|
| 203 |
+
if C_in != 3:
|
| 204 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
| 205 |
+
|
| 206 |
+
# Normalize images and reshape for patch embed
|
| 207 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 208 |
+
|
| 209 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
| 210 |
+
images = images.view(B * S, C_in, H, W)
|
| 211 |
+
patch_tokens = self.patch_embed(images)
|
| 212 |
+
|
| 213 |
+
if isinstance(patch_tokens, dict):
|
| 214 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 215 |
+
|
| 216 |
+
_, P, C = patch_tokens.shape
|
| 217 |
+
|
| 218 |
+
# Expand camera and register tokens to match batch size and sequence length
|
| 219 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
| 220 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
| 221 |
+
|
| 222 |
+
# Concatenate special tokens with patch tokens
|
| 223 |
+
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
|
| 224 |
+
|
| 225 |
+
pos = None
|
| 226 |
+
if self.rope is not None:
|
| 227 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
| 228 |
+
|
| 229 |
+
if self.patch_start_idx > 0:
|
| 230 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 231 |
+
# so set pos to 0 for the special tokens
|
| 232 |
+
pos = pos + 1
|
| 233 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 234 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
| 235 |
+
|
| 236 |
+
# update P because we added special tokens
|
| 237 |
+
_, P, C = tokens.shape
|
| 238 |
+
|
| 239 |
+
frame_idx = 0
|
| 240 |
+
global_idx = 0
|
| 241 |
+
output_list = []
|
| 242 |
+
|
| 243 |
+
for _ in range(self.aa_block_num):
|
| 244 |
+
for attn_type in self.aa_order:
|
| 245 |
+
if attn_type == "frame":
|
| 246 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
| 247 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
| 248 |
+
)
|
| 249 |
+
elif attn_type == "global":
|
| 250 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
| 251 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 255 |
+
|
| 256 |
+
for i in range(len(frame_intermediates)):
|
| 257 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
| 258 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
| 259 |
+
output_list.append(concat_inter)
|
| 260 |
+
|
| 261 |
+
del concat_inter
|
| 262 |
+
del frame_intermediates
|
| 263 |
+
del global_intermediates
|
| 264 |
+
return output_list, self.patch_start_idx
|
| 265 |
+
|
| 266 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
| 267 |
+
"""
|
| 268 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 269 |
+
"""
|
| 270 |
+
# If needed, reshape tokens or positions:
|
| 271 |
+
if tokens.shape != (B * S, P, C):
|
| 272 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 273 |
+
|
| 274 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 275 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 276 |
+
|
| 277 |
+
intermediates = []
|
| 278 |
+
|
| 279 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 280 |
+
for _ in range(self.aa_block_size):
|
| 281 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
| 282 |
+
frame_idx += 1
|
| 283 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 284 |
+
|
| 285 |
+
return tokens, frame_idx, intermediates
|
| 286 |
+
|
| 287 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
| 288 |
+
"""
|
| 289 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 290 |
+
"""
|
| 291 |
+
if tokens.shape != (B, S * P, C):
|
| 292 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 293 |
+
|
| 294 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 295 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 296 |
+
|
| 297 |
+
intermediates = []
|
| 298 |
+
|
| 299 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 300 |
+
for _ in range(self.aa_block_size):
|
| 301 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
| 302 |
+
global_idx += 1
|
| 303 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 304 |
+
|
| 305 |
+
return tokens, global_idx, intermediates
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
| 309 |
+
"""
|
| 310 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 311 |
+
1) Uses the first position (index=0) for the first frame only
|
| 312 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 313 |
+
3) Expands both to match batch size B
|
| 314 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 315 |
+
followed by (S-1) second-position tokens
|
| 316 |
+
5) Flattens to (B*S, X, C) for processing
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 323 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
| 324 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
| 325 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
| 326 |
+
# Concatenate => shape (B, S, ...)
|
| 327 |
+
combined = torch.cat([query, others], dim=1)
|
| 328 |
+
|
| 329 |
+
# Finally flatten => shape (B*S, ...)
|
| 330 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
| 331 |
+
return combined
|
vggt/models/vggt.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
| 10 |
+
|
| 11 |
+
from vggt.models.aggregator import Aggregator
|
| 12 |
+
from vggt.heads.camera_head import CameraHead
|
| 13 |
+
from vggt.heads.dpt_head import DPTHead
|
| 14 |
+
from vggt.heads.track_head import TrackHead
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class VGGT(nn.Module, PyTorchModelHubMixin):
|
| 18 |
+
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
|
| 22 |
+
self.camera_head = CameraHead(dim_in=2 * embed_dim)
|
| 23 |
+
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
|
| 24 |
+
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
|
| 25 |
+
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
images: torch.Tensor,
|
| 30 |
+
query_points: torch.Tensor = None,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Forward pass of the VGGT model.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
|
| 37 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
| 38 |
+
query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
|
| 39 |
+
Shape: [N, 2] or [B, N, 2], where N is the number of query points.
|
| 40 |
+
Default: None
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
dict: A dictionary containing the following predictions:
|
| 44 |
+
- pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
|
| 45 |
+
- depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
|
| 46 |
+
- depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
|
| 47 |
+
- world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
|
| 48 |
+
- world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
|
| 49 |
+
- images (torch.Tensor): Original input images, preserved for visualization
|
| 50 |
+
|
| 51 |
+
If query_points is provided, also includes:
|
| 52 |
+
- track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
|
| 53 |
+
- vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
|
| 54 |
+
- conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
# If without batch dimension, add it
|
| 58 |
+
if len(images.shape) == 4:
|
| 59 |
+
images = images.unsqueeze(0)
|
| 60 |
+
if query_points is not None and len(query_points.shape) == 2:
|
| 61 |
+
query_points = query_points.unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
|
| 64 |
+
|
| 65 |
+
predictions = {}
|
| 66 |
+
|
| 67 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 68 |
+
if self.camera_head is not None:
|
| 69 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 70 |
+
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
|
| 71 |
+
|
| 72 |
+
if self.depth_head is not None:
|
| 73 |
+
depth, depth_conf = self.depth_head(
|
| 74 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
| 75 |
+
)
|
| 76 |
+
predictions["depth"] = depth
|
| 77 |
+
predictions["depth_conf"] = depth_conf
|
| 78 |
+
|
| 79 |
+
if self.point_head is not None:
|
| 80 |
+
pts3d, pts3d_conf = self.point_head(
|
| 81 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
|
| 82 |
+
)
|
| 83 |
+
predictions["world_points"] = pts3d
|
| 84 |
+
predictions["world_points_conf"] = pts3d_conf
|
| 85 |
+
|
| 86 |
+
if self.track_head is not None and query_points is not None:
|
| 87 |
+
track_list, vis, conf = self.track_head(
|
| 88 |
+
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
|
| 89 |
+
)
|
| 90 |
+
predictions["track"] = track_list[-1] # track of the last iteration
|
| 91 |
+
predictions["vis"] = vis
|
| 92 |
+
predictions["conf"] = conf
|
| 93 |
+
|
| 94 |
+
predictions["images"] = images
|
| 95 |
+
|
| 96 |
+
return predictions
|
vggt/utils/geometry.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def unproject_depth_map_to_point_map(
|
| 13 |
+
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
| 14 |
+
) -> np.ndarray:
|
| 15 |
+
"""
|
| 16 |
+
Unproject a batch of depth maps to 3D world coordinates.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
| 20 |
+
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
| 21 |
+
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
| 25 |
+
"""
|
| 26 |
+
if isinstance(depth_map, torch.Tensor):
|
| 27 |
+
depth_map = depth_map.cpu().numpy()
|
| 28 |
+
if isinstance(extrinsics_cam, torch.Tensor):
|
| 29 |
+
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
| 30 |
+
if isinstance(intrinsics_cam, torch.Tensor):
|
| 31 |
+
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
| 32 |
+
|
| 33 |
+
world_points_list = []
|
| 34 |
+
for frame_idx in range(depth_map.shape[0]):
|
| 35 |
+
cur_world_points, _, _ = depth_to_world_coords_points(
|
| 36 |
+
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
| 37 |
+
)
|
| 38 |
+
world_points_list.append(cur_world_points)
|
| 39 |
+
world_points_array = np.stack(world_points_list, axis=0)
|
| 40 |
+
|
| 41 |
+
return world_points_array
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def depth_to_world_coords_points(
|
| 45 |
+
depth_map: np.ndarray,
|
| 46 |
+
extrinsic: np.ndarray,
|
| 47 |
+
intrinsic: np.ndarray,
|
| 48 |
+
eps=1e-8,
|
| 49 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 50 |
+
"""
|
| 51 |
+
Convert a depth map to world coordinates.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 55 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 56 |
+
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
| 60 |
+
"""
|
| 61 |
+
if depth_map is None:
|
| 62 |
+
return None, None, None
|
| 63 |
+
|
| 64 |
+
# Valid depth mask
|
| 65 |
+
point_mask = depth_map > eps
|
| 66 |
+
|
| 67 |
+
# Convert depth map to camera coordinates
|
| 68 |
+
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
| 69 |
+
|
| 70 |
+
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
| 71 |
+
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
| 72 |
+
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
| 73 |
+
|
| 74 |
+
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
| 75 |
+
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
| 76 |
+
|
| 77 |
+
# Apply the rotation and translation to the camera coordinates
|
| 78 |
+
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
| 79 |
+
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
| 80 |
+
|
| 81 |
+
return world_coords_points, cam_coords_points, point_mask
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 85 |
+
"""
|
| 86 |
+
Convert a depth map to camera coordinates.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
depth_map (np.ndarray): Depth map of shape (H, W).
|
| 90 |
+
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
| 94 |
+
"""
|
| 95 |
+
H, W = depth_map.shape
|
| 96 |
+
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
| 97 |
+
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
| 98 |
+
|
| 99 |
+
# Intrinsic parameters
|
| 100 |
+
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
| 101 |
+
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
| 102 |
+
|
| 103 |
+
# Generate grid of pixel coordinates
|
| 104 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 105 |
+
|
| 106 |
+
# Unproject to camera coordinates
|
| 107 |
+
x_cam = (u - cu) * depth_map / fu
|
| 108 |
+
y_cam = (v - cv) * depth_map / fv
|
| 109 |
+
z_cam = depth_map
|
| 110 |
+
|
| 111 |
+
# Stack to form camera coordinates
|
| 112 |
+
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 113 |
+
|
| 114 |
+
return cam_coords
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
| 118 |
+
"""
|
| 119 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
| 120 |
+
|
| 121 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
| 122 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
| 126 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
| 127 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
| 131 |
+
|
| 132 |
+
Shapes:
|
| 133 |
+
se3: (N, 4, 4)
|
| 134 |
+
R: (N, 3, 3)
|
| 135 |
+
T: (N, 3, 1)
|
| 136 |
+
"""
|
| 137 |
+
# Check if se3 is a numpy array or a torch tensor
|
| 138 |
+
is_numpy = isinstance(se3, np.ndarray)
|
| 139 |
+
|
| 140 |
+
# Validate shapes
|
| 141 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
| 142 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
| 143 |
+
|
| 144 |
+
# Extract R and T if not provided
|
| 145 |
+
if R is None:
|
| 146 |
+
R = se3[:, :3, :3] # (N,3,3)
|
| 147 |
+
if T is None:
|
| 148 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
| 149 |
+
|
| 150 |
+
# Transpose R
|
| 151 |
+
if is_numpy:
|
| 152 |
+
# Compute the transpose of the rotation for NumPy
|
| 153 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
| 154 |
+
# -R^T t for NumPy
|
| 155 |
+
top_right = -np.matmul(R_transposed, T)
|
| 156 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
| 157 |
+
else:
|
| 158 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
| 159 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
| 160 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
| 161 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 162 |
+
|
| 163 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
| 164 |
+
inverted_matrix[:, :3, 3:] = top_right
|
| 165 |
+
|
| 166 |
+
return inverted_matrix
|
vggt/utils/load_fn.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision import transforms as TF
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_and_preprocess_images(image_path_list, mode="crop"):
|
| 13 |
+
"""
|
| 14 |
+
A quick start function to load and preprocess images for model input.
|
| 15 |
+
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image_path_list (list): List of paths to image files
|
| 19 |
+
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
| 20 |
+
- "crop" (default): Sets width to 518px and center crops height if needed.
|
| 21 |
+
- "pad": Preserves all pixels by making the largest dimension 518px
|
| 22 |
+
and padding the smaller dimension to reach a square shape.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
ValueError: If the input list is empty or if mode is invalid
|
| 29 |
+
|
| 30 |
+
Notes:
|
| 31 |
+
- Images with different dimensions will be padded with white (value=1.0)
|
| 32 |
+
- A warning is printed when images have different shapes
|
| 33 |
+
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
| 34 |
+
and height is center-cropped if larger than 518px
|
| 35 |
+
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
| 36 |
+
and the smaller dimension is padded to reach a square shape (518x518)
|
| 37 |
+
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
| 38 |
+
"""
|
| 39 |
+
# Check for empty list
|
| 40 |
+
if len(image_path_list) == 0:
|
| 41 |
+
raise ValueError("At least 1 image is required")
|
| 42 |
+
|
| 43 |
+
# Validate mode
|
| 44 |
+
if mode not in ["crop", "pad"]:
|
| 45 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
| 46 |
+
|
| 47 |
+
images = []
|
| 48 |
+
shapes = set()
|
| 49 |
+
to_tensor = TF.ToTensor()
|
| 50 |
+
target_size = 518
|
| 51 |
+
|
| 52 |
+
# First process all images and collect their shapes
|
| 53 |
+
for image_path in image_path_list:
|
| 54 |
+
|
| 55 |
+
# Open image
|
| 56 |
+
img = Image.open(image_path)
|
| 57 |
+
|
| 58 |
+
# If there's an alpha channel, blend onto white background:
|
| 59 |
+
if img.mode == "RGBA":
|
| 60 |
+
# Create white background
|
| 61 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
| 62 |
+
# Alpha composite onto the white background
|
| 63 |
+
img = Image.alpha_composite(background, img)
|
| 64 |
+
|
| 65 |
+
# Now convert to "RGB" (this step assigns white for transparent areas)
|
| 66 |
+
img = img.convert("RGB")
|
| 67 |
+
|
| 68 |
+
width, height = img.size
|
| 69 |
+
|
| 70 |
+
if mode == "pad":
|
| 71 |
+
# Make the largest dimension 518px while maintaining aspect ratio
|
| 72 |
+
if width >= height:
|
| 73 |
+
new_width = target_size
|
| 74 |
+
new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
|
| 75 |
+
else:
|
| 76 |
+
new_height = target_size
|
| 77 |
+
new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
|
| 78 |
+
else: # mode == "crop"
|
| 79 |
+
# Original behavior: set width to 518px
|
| 80 |
+
new_width = target_size
|
| 81 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
| 82 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
| 83 |
+
|
| 84 |
+
# Resize with new dimensions (width, height)
|
| 85 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
| 86 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
| 87 |
+
|
| 88 |
+
# Center crop height if it's larger than 518 (only in crop mode)
|
| 89 |
+
if mode == "crop" and new_height > target_size:
|
| 90 |
+
start_y = (new_height - target_size) // 2
|
| 91 |
+
img = img[:, start_y : start_y + target_size, :]
|
| 92 |
+
|
| 93 |
+
# For pad mode, pad to make a square of target_size x target_size
|
| 94 |
+
if mode == "pad":
|
| 95 |
+
h_padding = target_size - img.shape[1]
|
| 96 |
+
w_padding = target_size - img.shape[2]
|
| 97 |
+
|
| 98 |
+
if h_padding > 0 or w_padding > 0:
|
| 99 |
+
pad_top = h_padding // 2
|
| 100 |
+
pad_bottom = h_padding - pad_top
|
| 101 |
+
pad_left = w_padding // 2
|
| 102 |
+
pad_right = w_padding - pad_left
|
| 103 |
+
|
| 104 |
+
# Pad with white (value=1.0)
|
| 105 |
+
img = torch.nn.functional.pad(
|
| 106 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
shapes.add((img.shape[1], img.shape[2]))
|
| 110 |
+
images.append(img)
|
| 111 |
+
|
| 112 |
+
# Check if we have different shapes
|
| 113 |
+
# In theory our model can also work well with different shapes
|
| 114 |
+
if len(shapes) > 1:
|
| 115 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
| 116 |
+
# Find maximum dimensions
|
| 117 |
+
max_height = max(shape[0] for shape in shapes)
|
| 118 |
+
max_width = max(shape[1] for shape in shapes)
|
| 119 |
+
|
| 120 |
+
# Pad images if necessary
|
| 121 |
+
padded_images = []
|
| 122 |
+
for img in images:
|
| 123 |
+
h_padding = max_height - img.shape[1]
|
| 124 |
+
w_padding = max_width - img.shape[2]
|
| 125 |
+
|
| 126 |
+
if h_padding > 0 or w_padding > 0:
|
| 127 |
+
pad_top = h_padding // 2
|
| 128 |
+
pad_bottom = h_padding - pad_top
|
| 129 |
+
pad_left = w_padding // 2
|
| 130 |
+
pad_right = w_padding - pad_left
|
| 131 |
+
|
| 132 |
+
img = torch.nn.functional.pad(
|
| 133 |
+
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
| 134 |
+
)
|
| 135 |
+
padded_images.append(img)
|
| 136 |
+
images = padded_images
|
| 137 |
+
|
| 138 |
+
images = torch.stack(images) # concatenate images
|
| 139 |
+
|
| 140 |
+
# Ensure correct shape when single image
|
| 141 |
+
if len(image_path_list) == 1:
|
| 142 |
+
# Verify shape is (1, C, H, W)
|
| 143 |
+
if images.dim() == 3:
|
| 144 |
+
images = images.unsqueeze(0)
|
| 145 |
+
|
| 146 |
+
return images
|
vggt/utils/pose_enc.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from .rotation import quat_to_mat, mat_to_quat
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def extri_intri_to_pose_encoding(
|
| 12 |
+
extrinsics,
|
| 13 |
+
intrinsics,
|
| 14 |
+
image_size_hw=None, # e.g., (256, 512)
|
| 15 |
+
pose_encoding_type="absT_quaR_FoV",
|
| 16 |
+
):
|
| 17 |
+
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
|
| 18 |
+
|
| 19 |
+
This function transforms camera parameters into a unified pose encoding format,
|
| 20 |
+
which can be used for various downstream tasks like pose prediction or representation.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
|
| 24 |
+
where B is batch size and S is sequence length.
|
| 25 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
|
| 26 |
+
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
|
| 27 |
+
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
|
| 28 |
+
Defined in pixels, with format:
|
| 29 |
+
[[fx, 0, cx],
|
| 30 |
+
[0, fy, cy],
|
| 31 |
+
[0, 0, 1]]
|
| 32 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point
|
| 33 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 34 |
+
Required for computing field of view values. For example: (256, 512).
|
| 35 |
+
pose_encoding_type (str): Type of pose encoding to use. Currently only
|
| 36 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
|
| 40 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 41 |
+
- [:3] = absolute translation vector T (3D)
|
| 42 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 43 |
+
- [7:] = field of view (2D)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
# extrinsics: BxSx3x4
|
| 47 |
+
# intrinsics: BxSx3x3
|
| 48 |
+
|
| 49 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 50 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
| 51 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
| 52 |
+
|
| 53 |
+
quat = mat_to_quat(R)
|
| 54 |
+
# Note the order of h and w here
|
| 55 |
+
H, W = image_size_hw
|
| 56 |
+
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
| 57 |
+
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
| 58 |
+
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
| 59 |
+
else:
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
return pose_encoding
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def pose_encoding_to_extri_intri(
|
| 66 |
+
pose_encoding,
|
| 67 |
+
image_size_hw=None, # e.g., (256, 512)
|
| 68 |
+
pose_encoding_type="absT_quaR_FoV",
|
| 69 |
+
build_intrinsics=True,
|
| 70 |
+
):
|
| 71 |
+
"""Convert a pose encoding back to camera extrinsics and intrinsics.
|
| 72 |
+
|
| 73 |
+
This function performs the inverse operation of extri_intri_to_pose_encoding,
|
| 74 |
+
reconstructing the full camera parameters from the compact encoding.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
|
| 78 |
+
where B is batch size and S is sequence length.
|
| 79 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 80 |
+
- [:3] = absolute translation vector T (3D)
|
| 81 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 82 |
+
- [7:] = field of view (2D)
|
| 83 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 84 |
+
Required for reconstructing intrinsics from field of view values.
|
| 85 |
+
For example: (256, 512).
|
| 86 |
+
pose_encoding_type (str): Type of pose encoding used. Currently only
|
| 87 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 88 |
+
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
|
| 89 |
+
If False, only extrinsics are returned and intrinsics will be None.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
tuple: (extrinsics, intrinsics)
|
| 93 |
+
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
|
| 94 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
|
| 95 |
+
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
|
| 96 |
+
a 3x1 translation vector.
|
| 97 |
+
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
|
| 98 |
+
or None if build_intrinsics is False. Defined in pixels, with format:
|
| 99 |
+
[[fx, 0, cx],
|
| 100 |
+
[0, fy, cy],
|
| 101 |
+
[0, 0, 1]]
|
| 102 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point,
|
| 103 |
+
assumed to be at the center of the image (W/2, H/2).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
intrinsics = None
|
| 107 |
+
|
| 108 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 109 |
+
T = pose_encoding[..., :3]
|
| 110 |
+
quat = pose_encoding[..., 3:7]
|
| 111 |
+
fov_h = pose_encoding[..., 7]
|
| 112 |
+
fov_w = pose_encoding[..., 8]
|
| 113 |
+
|
| 114 |
+
R = quat_to_mat(quat)
|
| 115 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
| 116 |
+
|
| 117 |
+
if build_intrinsics:
|
| 118 |
+
H, W = image_size_hw
|
| 119 |
+
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
| 120 |
+
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
| 121 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
| 122 |
+
intrinsics[..., 0, 0] = fx
|
| 123 |
+
intrinsics[..., 1, 1] = fy
|
| 124 |
+
intrinsics[..., 0, 2] = W / 2
|
| 125 |
+
intrinsics[..., 1, 2] = H / 2
|
| 126 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
| 127 |
+
else:
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
return extrinsics, intrinsics
|
vggt/utils/rotation.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 17 |
+
|
| 18 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 19 |
+
Args:
|
| 20 |
+
quaternions: quaternions with real part last,
|
| 21 |
+
as tensor of shape (..., 4).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 25 |
+
"""
|
| 26 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 27 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 28 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 29 |
+
|
| 30 |
+
o = torch.stack(
|
| 31 |
+
(
|
| 32 |
+
1 - two_s * (j * j + k * k),
|
| 33 |
+
two_s * (i * j - k * r),
|
| 34 |
+
two_s * (i * k + j * r),
|
| 35 |
+
two_s * (i * j + k * r),
|
| 36 |
+
1 - two_s * (i * i + k * k),
|
| 37 |
+
two_s * (j * k - i * r),
|
| 38 |
+
two_s * (i * k - j * r),
|
| 39 |
+
two_s * (j * k + i * r),
|
| 40 |
+
1 - two_s * (i * i + j * j),
|
| 41 |
+
),
|
| 42 |
+
-1,
|
| 43 |
+
)
|
| 44 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 56 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 57 |
+
"""
|
| 58 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 59 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 60 |
+
|
| 61 |
+
batch_dim = matrix.shape[:-2]
|
| 62 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 63 |
+
|
| 64 |
+
q_abs = _sqrt_positive_part(
|
| 65 |
+
torch.stack(
|
| 66 |
+
[
|
| 67 |
+
1.0 + m00 + m11 + m22,
|
| 68 |
+
1.0 + m00 - m11 - m22,
|
| 69 |
+
1.0 - m00 + m11 - m22,
|
| 70 |
+
1.0 - m00 - m11 + m22,
|
| 71 |
+
],
|
| 72 |
+
dim=-1,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 77 |
+
quat_by_rijk = torch.stack(
|
| 78 |
+
[
|
| 79 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 80 |
+
# `int`.
|
| 81 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 82 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 83 |
+
# `int`.
|
| 84 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 85 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 86 |
+
# `int`.
|
| 87 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 88 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 89 |
+
# `int`.
|
| 90 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 91 |
+
],
|
| 92 |
+
dim=-2,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 96 |
+
# the candidate won't be picked.
|
| 97 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 98 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 99 |
+
|
| 100 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 101 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 102 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 103 |
+
|
| 104 |
+
# Convert from rijk to ijkr
|
| 105 |
+
out = out[..., [1, 2, 3, 0]]
|
| 106 |
+
|
| 107 |
+
out = standardize_quaternion(out)
|
| 108 |
+
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
"""
|
| 114 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 115 |
+
but with a zero subgradient where x is 0.
|
| 116 |
+
"""
|
| 117 |
+
ret = torch.zeros_like(x)
|
| 118 |
+
positive_mask = x > 0
|
| 119 |
+
if torch.is_grad_enabled():
|
| 120 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 121 |
+
else:
|
| 122 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 123 |
+
return ret
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
"""
|
| 128 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 129 |
+
part is non negative.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
quaternions: Quaternions with real part last,
|
| 133 |
+
as tensor of shape (..., 4).
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 137 |
+
"""
|
| 138 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
vggt/utils/visual_track.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def color_from_xy(x, y, W, H, cmap_name="hsv"):
|
| 14 |
+
"""
|
| 15 |
+
Map (x, y) -> color in (R, G, B).
|
| 16 |
+
1) Normalize x,y to [0,1].
|
| 17 |
+
2) Combine them into a single scalar c in [0,1].
|
| 18 |
+
3) Use matplotlib's colormap to convert c -> (R,G,B).
|
| 19 |
+
|
| 20 |
+
You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
|
| 21 |
+
"""
|
| 22 |
+
import matplotlib.cm
|
| 23 |
+
import matplotlib.colors
|
| 24 |
+
|
| 25 |
+
x_norm = x / max(W - 1, 1)
|
| 26 |
+
y_norm = y / max(H - 1, 1)
|
| 27 |
+
# Simple combination:
|
| 28 |
+
c = (x_norm + y_norm) / 2.0
|
| 29 |
+
|
| 30 |
+
cmap = matplotlib.cm.get_cmap(cmap_name)
|
| 31 |
+
# cmap(c) -> (r,g,b,a) in [0,1]
|
| 32 |
+
rgba = cmap(c)
|
| 33 |
+
r, g, b = rgba[0], rgba[1], rgba[2]
|
| 34 |
+
return (r, g, b) # in [0,1], RGB order
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
|
| 38 |
+
"""
|
| 39 |
+
Given all tracks in one sample (b), compute a (N,3) array of RGB color values
|
| 40 |
+
in [0,255]. The color is determined by the (x,y) position in the first
|
| 41 |
+
visible frame for each track.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
|
| 45 |
+
vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
|
| 46 |
+
image_width, image_height: used for normalizing (x, y).
|
| 47 |
+
cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
|
| 51 |
+
"""
|
| 52 |
+
S, N, _ = tracks_b.shape
|
| 53 |
+
track_colors = np.zeros((N, 3), dtype=np.uint8)
|
| 54 |
+
|
| 55 |
+
if vis_mask_b is None:
|
| 56 |
+
# treat all as visible
|
| 57 |
+
vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
|
| 58 |
+
|
| 59 |
+
for i in range(N):
|
| 60 |
+
# Find first visible frame for track i
|
| 61 |
+
visible_frames = torch.where(vis_mask_b[:, i])[0]
|
| 62 |
+
if len(visible_frames) == 0:
|
| 63 |
+
# track is never visible; just assign black or something
|
| 64 |
+
track_colors[i] = (0, 0, 0)
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
first_s = int(visible_frames[0].item())
|
| 68 |
+
# use that frame's (x,y)
|
| 69 |
+
x, y = tracks_b[first_s, i].tolist()
|
| 70 |
+
|
| 71 |
+
# map (x,y) -> (R,G,B) in [0,1]
|
| 72 |
+
r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
|
| 73 |
+
# scale to [0,255]
|
| 74 |
+
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
| 75 |
+
track_colors[i] = (r, g, b)
|
| 76 |
+
|
| 77 |
+
return track_colors
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def visualize_tracks_on_images(
|
| 81 |
+
images,
|
| 82 |
+
tracks,
|
| 83 |
+
track_vis_mask=None,
|
| 84 |
+
out_dir="track_visuals_concat_by_xy",
|
| 85 |
+
image_format="CHW", # "CHW" or "HWC"
|
| 86 |
+
normalize_mode="[0,1]",
|
| 87 |
+
cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
|
| 88 |
+
frames_per_row=4, # New parameter for grid layout
|
| 89 |
+
save_grid=True, # Flag to control whether to save the grid image
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
Visualizes frames in a grid layout with specified frames per row.
|
| 93 |
+
Each track's color is determined by its (x,y) position
|
| 94 |
+
in the first visible frame (or frame 0 if always visible).
|
| 95 |
+
Finally convert the BGR result to RGB before saving.
|
| 96 |
+
Also saves each individual frame as a separate PNG file.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
|
| 100 |
+
tracks: torch.Tensor (S, N, 2), last dim = (x, y).
|
| 101 |
+
track_vis_mask: torch.Tensor (S, N) or None.
|
| 102 |
+
out_dir: folder to save visualizations.
|
| 103 |
+
image_format: "CHW" or "HWC".
|
| 104 |
+
normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
|
| 105 |
+
cmap_name: a matplotlib colormap name for color_from_xy.
|
| 106 |
+
frames_per_row: number of frames to display in each row of the grid.
|
| 107 |
+
save_grid: whether to save all frames in one grid image.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
None (saves images in out_dir).
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
if len(tracks.shape) == 4:
|
| 114 |
+
tracks = tracks.squeeze(0)
|
| 115 |
+
images = images.squeeze(0)
|
| 116 |
+
if track_vis_mask is not None:
|
| 117 |
+
track_vis_mask = track_vis_mask.squeeze(0)
|
| 118 |
+
|
| 119 |
+
import matplotlib
|
| 120 |
+
|
| 121 |
+
matplotlib.use("Agg") # for non-interactive (optional)
|
| 122 |
+
|
| 123 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 124 |
+
|
| 125 |
+
S = images.shape[0]
|
| 126 |
+
_, N, _ = tracks.shape # (S, N, 2)
|
| 127 |
+
|
| 128 |
+
# Move to CPU
|
| 129 |
+
images = images.cpu().clone()
|
| 130 |
+
tracks = tracks.cpu().clone()
|
| 131 |
+
if track_vis_mask is not None:
|
| 132 |
+
track_vis_mask = track_vis_mask.cpu().clone()
|
| 133 |
+
|
| 134 |
+
# Infer H, W from images shape
|
| 135 |
+
if image_format == "CHW":
|
| 136 |
+
# e.g. images[s].shape = (3, H, W)
|
| 137 |
+
H, W = images.shape[2], images.shape[3]
|
| 138 |
+
else:
|
| 139 |
+
# e.g. images[s].shape = (H, W, 3)
|
| 140 |
+
H, W = images.shape[1], images.shape[2]
|
| 141 |
+
|
| 142 |
+
# Pre-compute the color for each track i based on first visible position
|
| 143 |
+
track_colors_rgb = get_track_colors_by_position(
|
| 144 |
+
tracks, # shape (S, N, 2)
|
| 145 |
+
vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
|
| 146 |
+
image_width=W,
|
| 147 |
+
image_height=H,
|
| 148 |
+
cmap_name=cmap_name,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# We'll accumulate each frame's drawn image in a list
|
| 152 |
+
frame_images = []
|
| 153 |
+
|
| 154 |
+
for s in range(S):
|
| 155 |
+
# shape => either (3, H, W) or (H, W, 3)
|
| 156 |
+
img = images[s]
|
| 157 |
+
|
| 158 |
+
# Convert to (H, W, 3)
|
| 159 |
+
if image_format == "CHW":
|
| 160 |
+
img = img.permute(1, 2, 0) # (H, W, 3)
|
| 161 |
+
# else "HWC", do nothing
|
| 162 |
+
|
| 163 |
+
img = img.numpy().astype(np.float32)
|
| 164 |
+
|
| 165 |
+
# Scale to [0,255] if needed
|
| 166 |
+
if normalize_mode == "[0,1]":
|
| 167 |
+
img = np.clip(img, 0, 1) * 255.0
|
| 168 |
+
elif normalize_mode == "[-1,1]":
|
| 169 |
+
img = (img + 1.0) * 0.5 * 255.0
|
| 170 |
+
img = np.clip(img, 0, 255.0)
|
| 171 |
+
# else no normalization
|
| 172 |
+
|
| 173 |
+
# Convert to uint8
|
| 174 |
+
img = img.astype(np.uint8)
|
| 175 |
+
|
| 176 |
+
# For drawing in OpenCV, convert to BGR
|
| 177 |
+
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 178 |
+
|
| 179 |
+
# Draw each visible track
|
| 180 |
+
cur_tracks = tracks[s] # shape (N, 2)
|
| 181 |
+
if track_vis_mask is not None:
|
| 182 |
+
valid_indices = torch.where(track_vis_mask[s])[0]
|
| 183 |
+
else:
|
| 184 |
+
valid_indices = range(N)
|
| 185 |
+
|
| 186 |
+
cur_tracks_np = cur_tracks.numpy()
|
| 187 |
+
for i in valid_indices:
|
| 188 |
+
x, y = cur_tracks_np[i]
|
| 189 |
+
pt = (int(round(x)), int(round(y)))
|
| 190 |
+
|
| 191 |
+
# track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
|
| 192 |
+
R, G, B = track_colors_rgb[i]
|
| 193 |
+
color_bgr = (int(B), int(G), int(R))
|
| 194 |
+
cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
|
| 195 |
+
|
| 196 |
+
# Convert back to RGB for consistent final saving:
|
| 197 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 198 |
+
|
| 199 |
+
# Save individual frame
|
| 200 |
+
frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
|
| 201 |
+
# Convert to BGR for OpenCV imwrite
|
| 202 |
+
frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 203 |
+
cv2.imwrite(frame_path, frame_bgr)
|
| 204 |
+
|
| 205 |
+
frame_images.append(img_rgb)
|
| 206 |
+
|
| 207 |
+
# Only create and save the grid image if save_grid is True
|
| 208 |
+
if save_grid:
|
| 209 |
+
# Calculate grid dimensions
|
| 210 |
+
num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
|
| 211 |
+
|
| 212 |
+
# Create a grid of images
|
| 213 |
+
grid_img = None
|
| 214 |
+
for row in range(num_rows):
|
| 215 |
+
start_idx = row * frames_per_row
|
| 216 |
+
end_idx = min(start_idx + frames_per_row, S)
|
| 217 |
+
|
| 218 |
+
# Concatenate this row horizontally
|
| 219 |
+
row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
|
| 220 |
+
|
| 221 |
+
# If this row has fewer than frames_per_row images, pad with black
|
| 222 |
+
if end_idx - start_idx < frames_per_row:
|
| 223 |
+
padding_width = (frames_per_row - (end_idx - start_idx)) * W
|
| 224 |
+
padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
|
| 225 |
+
row_img = np.concatenate([row_img, padding], axis=1)
|
| 226 |
+
|
| 227 |
+
# Add this row to the grid
|
| 228 |
+
if grid_img is None:
|
| 229 |
+
grid_img = row_img
|
| 230 |
+
else:
|
| 231 |
+
grid_img = np.concatenate([grid_img, row_img], axis=0)
|
| 232 |
+
|
| 233 |
+
out_path = os.path.join(out_dir, "tracks_grid.png")
|
| 234 |
+
# Convert back to BGR for OpenCV imwrite
|
| 235 |
+
grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
|
| 236 |
+
cv2.imwrite(out_path, grid_img_bgr)
|
| 237 |
+
print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
|
| 238 |
+
|
| 239 |
+
print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
|
vision_tower.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import sys
|
| 2 |
+
# sys.path.append("..")
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.init as init
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from paths import *
|
| 10 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from vggt.models.vggt import VGGT
|
| 15 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
| 16 |
+
from vggt.layers import Mlp
|
| 17 |
+
from vggt.layers.block import Block
|
| 18 |
+
from vggt.heads.head_act import activate_pose
|
| 19 |
+
|
| 20 |
+
class OriAny_CameraHead(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dim_in: int = 2048,
|
| 28 |
+
trunk_depth: int = 4,
|
| 29 |
+
pose_encoding_type: str = "OriAny",
|
| 30 |
+
num_heads: int = 16,
|
| 31 |
+
mlp_ratio: int = 4,
|
| 32 |
+
init_values: float = 0.01,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
if pose_encoding_type == "OriAny":
|
| 37 |
+
self.target_dim = 360+180+360+2
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 40 |
+
|
| 41 |
+
self.trunk_depth = trunk_depth
|
| 42 |
+
|
| 43 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 44 |
+
self.trunk = nn.Sequential(
|
| 45 |
+
*[
|
| 46 |
+
Block(
|
| 47 |
+
dim=dim_in,
|
| 48 |
+
num_heads=num_heads,
|
| 49 |
+
mlp_ratio=mlp_ratio,
|
| 50 |
+
init_values=init_values,
|
| 51 |
+
)
|
| 52 |
+
for _ in range(trunk_depth)
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Normalizations for camera token and trunk output.
|
| 57 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 58 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 59 |
+
|
| 60 |
+
# Learnable empty camera pose token.
|
| 61 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 62 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 63 |
+
|
| 64 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 65 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 66 |
+
|
| 67 |
+
# Adaptive layer normalization without affine parameters.
|
| 68 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 69 |
+
self.pose_branch = Mlp(
|
| 70 |
+
in_features=dim_in,
|
| 71 |
+
hidden_features=dim_in // 2,
|
| 72 |
+
out_features=self.target_dim,
|
| 73 |
+
drop=0,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 77 |
+
"""
|
| 78 |
+
Forward pass to predict camera parameters.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 82 |
+
the last tensor is used for prediction.
|
| 83 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 87 |
+
"""
|
| 88 |
+
# Use tokens from the last block for camera prediction.
|
| 89 |
+
tokens = aggregated_tokens_list[-1]
|
| 90 |
+
|
| 91 |
+
# Extract the camera tokens
|
| 92 |
+
pose_tokens = tokens[:, :, 0]
|
| 93 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 94 |
+
|
| 95 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 96 |
+
return pred_pose_enc_list
|
| 97 |
+
|
| 98 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 99 |
+
"""
|
| 100 |
+
Iteratively refine camera pose predictions.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 104 |
+
num_iterations (int): Number of refinement iterations.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
list: List of activated camera encodings from each iteration.
|
| 108 |
+
"""
|
| 109 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 110 |
+
pred_pose_enc = None
|
| 111 |
+
pred_pose_enc_list = []
|
| 112 |
+
|
| 113 |
+
for _ in range(num_iterations):
|
| 114 |
+
# Use a learned empty pose for the first iteration.
|
| 115 |
+
if pred_pose_enc is None:
|
| 116 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 117 |
+
else:
|
| 118 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 119 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 120 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 121 |
+
|
| 122 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 123 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 124 |
+
|
| 125 |
+
# Adaptive layer normalization and modulation.
|
| 126 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 127 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 128 |
+
|
| 129 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 130 |
+
# Compute the delta update for the pose encoding.
|
| 131 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 132 |
+
|
| 133 |
+
if pred_pose_enc is None:
|
| 134 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 135 |
+
else:
|
| 136 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 137 |
+
|
| 138 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 139 |
+
# activated_pose = activate_pose(
|
| 140 |
+
# pred_pose_enc,
|
| 141 |
+
# trans_act=self.trans_act,
|
| 142 |
+
# quat_act=self.quat_act,
|
| 143 |
+
# fl_act=self.fl_act,
|
| 144 |
+
# )
|
| 145 |
+
# pred_pose_enc_list.append(activated_pose)
|
| 146 |
+
pred_pose_enc_list.append(pred_pose_enc)
|
| 147 |
+
|
| 148 |
+
return pred_pose_enc_list
|
| 149 |
+
|
| 150 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
"""
|
| 152 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 153 |
+
"""
|
| 154 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 155 |
+
return x * (1 + scale) + shift
|
| 156 |
+
|
| 157 |
+
def load_patch_embed_weights(model, checkpoint_path):
|
| 158 |
+
# 1. 加载 checkpoint
|
| 159 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 160 |
+
|
| 161 |
+
# 2. 获取 state_dict
|
| 162 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 163 |
+
|
| 164 |
+
# 3. 过滤只包含 aggregator.patch_embed 的参数
|
| 165 |
+
patch_embed_state = {
|
| 166 |
+
k.replace("aggregator.patch_embed.", ""): v
|
| 167 |
+
for k, v in state_dict.items()
|
| 168 |
+
if k.startswith("aggregator.patch_embed.")
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
# 4. 加载到目标模块
|
| 172 |
+
missing_keys, unexpected_keys = model.aggregator.patch_embed.load_state_dict(
|
| 173 |
+
patch_embed_state, strict=False
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
print("Loaded patch_embed weights.")
|
| 177 |
+
print("Missing keys:", missing_keys)
|
| 178 |
+
print("Unexpected keys:", unexpected_keys)
|
| 179 |
+
|
| 180 |
+
class VGGT_OriAny_Ref(nn.Module):
|
| 181 |
+
def __init__(self,
|
| 182 |
+
dtype,
|
| 183 |
+
out_dim,
|
| 184 |
+
nopretrain
|
| 185 |
+
) -> None:
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.vggt = VGGT()
|
| 188 |
+
|
| 189 |
+
self.dtype = dtype
|
| 190 |
+
self.ref_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
|
| 191 |
+
self.ref_sampler.apply(init_weights)
|
| 192 |
+
self.tgt_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
|
| 193 |
+
self.tgt_sampler.apply(init_weights)
|
| 194 |
+
|
| 195 |
+
def forward(self, img_inputs):
|
| 196 |
+
device = self.get_device()
|
| 197 |
+
|
| 198 |
+
with torch.amp.autocast(device_type='cuda', dtype=self.dtype):
|
| 199 |
+
if img_inputs.shape == 4:
|
| 200 |
+
img_inputs = img_inputs[None]
|
| 201 |
+
aggregated_tokens_list, ps_idx = self.vggt.aggregator(img_inputs)
|
| 202 |
+
|
| 203 |
+
# Predict Cameras
|
| 204 |
+
# pose_enc = self.oriany_camera_head(aggregated_tokens_list)[-1]
|
| 205 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
| 206 |
+
# extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
|
| 207 |
+
|
| 208 |
+
# Use tokens from the last block for camera prediction.
|
| 209 |
+
tokens = aggregated_tokens_list[-1]
|
| 210 |
+
# Extract the camera tokens
|
| 211 |
+
pose_tokens = tokens[:, :, 0]
|
| 212 |
+
# tokens = aggregated_tokens_list[-1]
|
| 213 |
+
|
| 214 |
+
B, S, C = pose_tokens.shape
|
| 215 |
+
if S>1:
|
| 216 |
+
# 分离每个 batch 的第一个 token 和其余 token
|
| 217 |
+
ref_tokens = pose_tokens[:, 0, :] # shape: (B, C)
|
| 218 |
+
tgt_tokens = pose_tokens[:, 1:, :] # shape: (B, S-1, C)
|
| 219 |
+
|
| 220 |
+
# 下采样
|
| 221 |
+
ref_feat = self.ref_sampler(ref_tokens) # shape: (B, C'),假设输出 channel 为 C'
|
| 222 |
+
tgt_feat = self.tgt_sampler(tgt_tokens.reshape(B * (S - 1), C)) # shape: (B*(S-1), C')
|
| 223 |
+
|
| 224 |
+
# 合并结果
|
| 225 |
+
pose_enc = torch.cat([
|
| 226 |
+
ref_feat.unsqueeze(1), # (B, 1, C')
|
| 227 |
+
tgt_feat.view(B, S - 1, -1) # (B, S-1, C')
|
| 228 |
+
], dim=1) # 最终 shape: (B*S, C')
|
| 229 |
+
else:
|
| 230 |
+
pose_enc = self.ref_sampler(pose_tokens.view(B*S,C))
|
| 231 |
+
return pose_enc
|
| 232 |
+
|
| 233 |
+
def get_device(self):
|
| 234 |
+
return next(self.parameters()).device
|
| 235 |
+
def init_weights(m):
|
| 236 |
+
if isinstance(m, nn.Linear):
|
| 237 |
+
init.xavier_uniform_(m.weight)
|
| 238 |
+
if m.bias is not None:
|
| 239 |
+
init.constant_(m.bias, 0)
|
| 240 |
+
|
| 241 |
+
def get_activation(activation):
|
| 242 |
+
if activation.lower() == 'gelu':
|
| 243 |
+
return nn.GELU()
|
| 244 |
+
elif activation.lower() == 'rrelu':
|
| 245 |
+
return nn.RReLU(inplace=True)
|
| 246 |
+
elif activation.lower() == 'selu':
|
| 247 |
+
return nn.SELU(inplace=True)
|
| 248 |
+
elif activation.lower() == 'silu':
|
| 249 |
+
return nn.SiLU(inplace=True)
|
| 250 |
+
elif activation.lower() == 'hardswish':
|
| 251 |
+
return nn.Hardswish(inplace=True)
|
| 252 |
+
elif activation.lower() == 'leakyrelu':
|
| 253 |
+
return nn.LeakyReLU(inplace=True)
|
| 254 |
+
elif activation.lower() == 'sigmoid':
|
| 255 |
+
return nn.Sigmoid()
|
| 256 |
+
elif activation.lower() == 'tanh':
|
| 257 |
+
return nn.Tanh()
|
| 258 |
+
else:
|
| 259 |
+
return nn.ReLU(inplace=True)
|
| 260 |
+
|
| 261 |
+
class MLP_dim(nn.Module):
|
| 262 |
+
def __init__(
|
| 263 |
+
self, in_dim=512, out_dim=1024, bias=True, activation='relu'):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.act = get_activation(activation)
|
| 266 |
+
self.net1 = nn.Sequential(
|
| 267 |
+
nn.Linear(in_dim, int(out_dim), bias=bias),
|
| 268 |
+
nn.BatchNorm1d(int(out_dim)),
|
| 269 |
+
self.act
|
| 270 |
+
)
|
| 271 |
+
self.net2 = nn.Sequential(
|
| 272 |
+
nn.Linear(int(out_dim), out_dim, bias=bias),
|
| 273 |
+
nn.BatchNorm1d(out_dim)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def forward(self, x):
|
| 277 |
+
return self.net2(self.net1(x))
|
| 278 |
+
|
| 279 |
+
|