zhangziang commited on
Commit
f783161
·
1 Parent(s): 706e128

initial commit track binary

Browse files
.gitattributes CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.blend filter=lfs diff=lfs merge=lfs -text
4
  *.7z filter=lfs diff=lfs merge=lfs -text
5
  *.arrow filter=lfs diff=lfs merge=lfs -text
6
  *.bin filter=lfs diff=lfs merge=lfs -text
 
36
  *.zip filter=lfs diff=lfs merge=lfs -text
37
  *.zst filter=lfs diff=lfs merge=lfs -text
38
  *tfevents* filter=lfs diff=lfs merge=lfs -text
39
+ assets/axis_ref.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/axis_tgt.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/axis_render.blend filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_demo/
2
+ test_demo_output/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[codz]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ #uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+ #poetry.toml
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
117
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
118
+ #pdm.lock
119
+ #pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # pixi
124
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
125
+ #pixi.lock
126
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
127
+ # in the .venv directory. It is recommended not to include this directory in version control.
128
+ .pixi
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # SageMath parsed files
138
+ *.sage.py
139
+
140
+ # Environments
141
+ .env
142
+ .envrc
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
180
+
181
+ # Abstra
182
+ # Abstra is an AI-powered process automation framework.
183
+ # Ignore directories containing user credentials, local state, and settings.
184
+ # Learn more at https://abstra.io/docs
185
+ .abstra/
186
+
187
+ # Visual Studio Code
188
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
189
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
190
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
191
+ # you could uncomment the following to ignore the entire vscode folder
192
+ # .vscode/
193
+
194
+ # Ruff stuff:
195
+ .ruff_cache/
196
+
197
+ # PyPI configuration file
198
+ .pypirc
199
+
200
+ # Cursor
201
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
202
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
203
+ # refer to https://docs.cursor.com/context/ignore-files
204
+ .cursorignore
205
+ .cursorindexingignore
206
+
207
+ # Marimo
208
+ marimo/_static/
209
+ marimo/_lsp/
210
+ __marimo__/
LICENSE ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution 4.0 International Public License
58
+
59
+ By exercising the Licensed Rights (defined below), You accept and agree
60
+ to be bound by the terms and conditions of this Creative Commons
61
+ Attribution 4.0 International Public License ("Public License"). To the
62
+ extent this Public License may be interpreted as a contract, You are
63
+ granted the Licensed Rights in consideration of Your acceptance of
64
+ these terms and conditions, and the Licensor grants You such rights in
65
+ consideration of benefits the Licensor receives from making the
66
+ Licensed Material available under these terms and conditions.
67
+
68
+
69
+ Section 1 -- Definitions.
70
+
71
+ a. Adapted Material means material subject to Copyright and Similar
72
+ Rights that is derived from or based upon the Licensed Material
73
+ and in which the Licensed Material is translated, altered,
74
+ arranged, transformed, or otherwise modified in a manner requiring
75
+ permission under the Copyright and Similar Rights held by the
76
+ Licensor. For purposes of this Public License, where the Licensed
77
+ Material is a musical work, performance, or sound recording,
78
+ Adapted Material is always produced where the Licensed Material is
79
+ synched in timed relation with a moving image.
80
+
81
+ b. Adapter's License means the license You apply to Your Copyright
82
+ and Similar Rights in Your contributions to Adapted Material in
83
+ accordance with the terms and conditions of this Public License.
84
+
85
+ c. Copyright and Similar Rights means copyright and/or similar rights
86
+ closely related to copyright including, without limitation,
87
+ performance, broadcast, sound recording, and Sui Generis Database
88
+ Rights, without regard to how the rights are labeled or
89
+ categorized. For purposes of this Public License, the rights
90
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
91
+ Rights.
92
+
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. Share means to provide material to the public by any means or
116
+ process that requires permission under the Licensed Rights, such
117
+ as reproduction, public display, public performance, distribution,
118
+ dissemination, communication, or importation, and to make material
119
+ available to the public including in ways that members of the
120
+ public may access the material from a place and at a time
121
+ individually chosen by them.
122
+
123
+ j. Sui Generis Database Rights means rights other than copyright
124
+ resulting from Directive 96/9/EC of the European Parliament and of
125
+ the Council of 11 March 1996 on the legal protection of databases,
126
+ as amended and/or succeeded, as well as other essentially
127
+ equivalent rights anywhere in the world.
128
+
129
+ k. You means the individual or entity exercising the Licensed Rights
130
+ under this Public License. Your has a corresponding meaning.
131
+
132
+
133
+ Section 2 -- Scope.
134
+
135
+ a. License grant.
136
+
137
+ 1. Subject to the terms and conditions of this Public License,
138
+ the Licensor hereby grants You a worldwide, royalty-free,
139
+ non-sublicensable, non-exclusive, irrevocable license to
140
+ exercise the Licensed Rights in the Licensed Material to:
141
+
142
+ a. reproduce and Share the Licensed Material, in whole or
143
+ in part; and
144
+
145
+ b. produce, reproduce, and Share Adapted Material.
146
+
147
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
148
+ Exceptions and Limitations apply to Your use, this Public
149
+ License does not apply, and You do not need to comply with
150
+ its terms and conditions.
151
+
152
+ 3. Term. The term of this Public License is specified in Section
153
+ 6(a).
154
+
155
+ 4. Media and formats; technical modifications allowed. The
156
+ Licensor authorizes You to exercise the Licensed Rights in
157
+ all media and formats whether now known or hereafter created,
158
+ and to make technical modifications necessary to do so. The
159
+ Licensor waives and/or agrees not to assert any right or
160
+ authority to forbid You from making technical modifications
161
+ necessary to exercise the Licensed Rights, including
162
+ technical modifications necessary to circumvent Effective
163
+ Technological Measures. For purposes of this Public License,
164
+ simply making modifications authorized by this Section 2(a)
165
+ (4) never produces Adapted Material.
166
+
167
+ 5. Downstream recipients.
168
+
169
+ a. Offer from the Licensor -- Licensed Material. Every
170
+ recipient of the Licensed Material automatically
171
+ receives an offer from the Licensor to exercise the
172
+ Licensed Rights under the terms and conditions of this
173
+ Public License.
174
+
175
+ b. No downstream restrictions. You may not offer or impose
176
+ any additional or different terms or conditions on, or
177
+ apply any Effective Technological Measures to, the
178
+ Licensed Material if doing so restricts exercise of the
179
+ Licensed Rights by any recipient of the Licensed
180
+ Material.
181
+
182
+ 6. No endorsement. Nothing in this Public License constitutes or
183
+ may be construed as permission to assert or imply that You
184
+ are, or that Your use of the Licensed Material is, connected
185
+ with, or sponsored, endorsed, or granted official status by,
186
+ the Licensor or others designated to receive attribution as
187
+ provided in Section 3(a)(1)(A)(i).
188
+
189
+ b. Other rights.
190
+
191
+ 1. Moral rights, such as the right of integrity, are not
192
+ licensed under this Public License, nor are publicity,
193
+ privacy, and/or other similar personality rights; however, to
194
+ the extent possible, the Licensor waives and/or agrees not to
195
+ assert any such rights held by the Licensor to the limited
196
+ extent necessary to allow You to exercise the Licensed
197
+ Rights, but not otherwise.
198
+
199
+ 2. Patent and trademark rights are not licensed under this
200
+ Public License.
201
+
202
+ 3. To the extent possible, the Licensor waives any right to
203
+ collect royalties from You for the exercise of the Licensed
204
+ Rights, whether directly or through a collecting society
205
+ under any voluntary or waivable statutory or compulsory
206
+ licensing scheme. In all other cases the Licensor expressly
207
+ reserves any right to collect such royalties.
208
+
209
+
210
+ Section 3 -- License Conditions.
211
+
212
+ Your exercise of the Licensed Rights is expressly made subject to the
213
+ following conditions.
214
+
215
+ a. Attribution.
216
+
217
+ 1. If You Share the Licensed Material (including in modified
218
+ form), You must:
219
+
220
+ a. retain the following if it is supplied by the Licensor
221
+ with the Licensed Material:
222
+
223
+ i. identification of the creator(s) of the Licensed
224
+ Material and any others designated to receive
225
+ attribution, in any reasonable manner requested by
226
+ the Licensor (including by pseudonym if
227
+ designated);
228
+
229
+ ii. a copyright notice;
230
+
231
+ iii. a notice that refers to this Public License;
232
+
233
+ iv. a notice that refers to the disclaimer of
234
+ warranties;
235
+
236
+ v. a URI or hyperlink to the Licensed Material to the
237
+ extent reasonably practicable;
238
+
239
+ b. indicate if You modified the Licensed Material and
240
+ retain an indication of any previous modifications; and
241
+
242
+ c. indicate the Licensed Material is licensed under this
243
+ Public License, and include the text of, or the URI or
244
+ hyperlink to, this Public License.
245
+
246
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
247
+ reasonable manner based on the medium, means, and context in
248
+ which You Share the Licensed Material. For example, it may be
249
+ reasonable to satisfy the conditions by providing a URI or
250
+ hyperlink to a resource that includes the required
251
+ information.
252
+
253
+ 3. If requested by the Licensor, You must remove any of the
254
+ information required by Section 3(a)(1)(A) to the extent
255
+ reasonably practicable.
256
+
257
+ 4. If You Share Adapted Material You produce, the Adapter's
258
+ License You apply must not prevent recipients of the Adapted
259
+ Material from complying with this Public License.
260
+
261
+
262
+ Section 4 -- Sui Generis Database Rights.
263
+
264
+ Where the Licensed Rights include Sui Generis Database Rights that
265
+ apply to Your use of the Licensed Material:
266
+
267
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
268
+ to extract, reuse, reproduce, and Share all or a substantial
269
+ portion of the contents of the database;
270
+
271
+ b. if You include all or a substantial portion of the database
272
+ contents in a database in which You have Sui Generis Database
273
+ Rights, then the database in which You have Sui Generis Database
274
+ Rights (but not its individual contents) is Adapted Material; and
275
+
276
+ c. You must comply with the conditions in Section 3(a) if You Share
277
+ all or a substantial portion of the contents of the database.
278
+
279
+ For the avoidance of doubt, this Section 4 supplements and does not
280
+ replace Your obligations under this Public License where the Licensed
281
+ Rights include other Copyright and Similar Rights.
282
+
283
+
284
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
285
+
286
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
287
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
288
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
289
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
290
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
291
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
292
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
293
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
294
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
295
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
296
+
297
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
298
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
299
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
300
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
301
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
302
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
303
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
304
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
305
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
306
+
307
+ c. The disclaimer of warranties and limitation of liability provided
308
+ above shall be interpreted in a manner that, to the extent
309
+ possible, most closely approximates an absolute disclaimer and
310
+ waiver of all liability.
311
+
312
+
313
+ Section 6 -- Term and Termination.
314
+
315
+ a. This Public License applies for the term of the Copyright and
316
+ Similar Rights licensed here. However, if You fail to comply with
317
+ this Public License, then Your rights under this Public License
318
+ terminate automatically.
319
+
320
+ b. Where Your right to use the Licensed Material has terminated under
321
+ Section 6(a), it reinstates:
322
+
323
+ 1. automatically as of the date the violation is cured, provided
324
+ it is cured within 30 days of Your discovery of the
325
+ violation; or
326
+
327
+ 2. upon express reinstatement by the Licensor.
328
+
329
+ For the avoidance of doubt, this Section 6(b) does not affect any
330
+ right the Licensor may have to seek remedies for Your violations
331
+ of this Public License.
332
+
333
+ c. For the avoidance of doubt, the Licensor may also offer the
334
+ Licensed Material under separate terms or conditions or stop
335
+ distributing the Licensed Material at any time; however, doing so
336
+ will not terminate this Public License.
337
+
338
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
339
+ License.
340
+
341
+
342
+ Section 7 -- Other Terms and Conditions.
343
+
344
+ a. The Licensor shall not be bound by any additional or different
345
+ terms or conditions communicated by You unless expressly agreed.
346
+
347
+ b. Any arrangements, understandings, or agreements regarding the
348
+ Licensed Material not stated herein are separate from and
349
+ independent of the terms and conditions of this Public License.
350
+
351
+
352
+ Section 8 -- Interpretation.
353
+
354
+ a. For the avoidance of doubt, this Public License does not, and
355
+ shall not be interpreted to, reduce, limit, restrict, or impose
356
+ conditions on any use of the Licensed Material that could lawfully
357
+ be made without permission under this Public License.
358
+
359
+ b. To the extent possible, if any provision of this Public License is
360
+ deemed unenforceable, it shall be automatically reformed to the
361
+ minimum extent necessary to make it enforceable. If the provision
362
+ cannot be reformed, it shall be severed from this Public License
363
+ without affecting the enforceability of the remaining terms and
364
+ conditions.
365
+
366
+ c. No term or condition of this Public License will be waived and no
367
+ failure to comply consented to unless expressly agreed to by the
368
+ Licensor.
369
+
370
+ d. Nothing in this Public License constitutes or may be interpreted
371
+ as a limitation upon, or waiver of, any privileges and immunities
372
+ that apply to the Licensor or You, including from the legal
373
+ processes of any jurisdiction or authority.
374
+
375
+
376
+ =======================================================================
377
+
378
+ Creative Commons is not a party to its public
379
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
380
+ its public licenses to material it publishes and in those instances
381
+ will be considered the “Licensor.” The text of the Creative Commons
382
+ public licenses is dedicated to the public domain under the CC0 Public
383
+ Domain Dedication. Except for the limited purpose of indicating that
384
+ material is shared under a Creative Commons public license or as
385
+ otherwise permitted by the Creative Commons policies published at
386
+ creativecommons.org/policies, Creative Commons does not authorize the
387
+ use of the trademark "Creative Commons" or any other trademark or logo
388
+ of Creative Commons without its prior written consent including,
389
+ without limitation, in connection with any unauthorized modifications
390
+ to any of its public licenses or any other arrangements,
391
+ understandings, or agreements concerning use of licensed material. For
392
+ the avoidance of doubt, this paragraph does not form part of the
393
+ public licenses.
394
+
395
+ Creative Commons may be contacted at creativecommons.org.
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # ====== 你的原有导入和模型加载保持不变 ======
7
+ from paths import *
8
+ from vision_tower import VGGT_OriAny_Ref
9
+ from inference import *
10
+ from app_utils import *
11
+ from axis_renderer import BlendRenderer
12
+
13
+ from huggingface_hub import hf_hub_download
14
+ ckpt_path = hf_hub_download(repo_id=ORIANY_V2, filename=REMOTE_CKPT_PATH, repo_type="model", cache_dir='./', resume_download=True)
15
+ print(ckpt_path)
16
+
17
+
18
+ mark_dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
19
+ # device = 'cuda:0'
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ model = VGGT_OriAny_Ref(out_dim=900, dtype=mark_dtype, nopretrain=True)
23
+ model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
24
+ model.eval()
25
+ model = model.to(device)
26
+ print('Model loaded.')
27
+
28
+ axis_renderer = BlendRenderer(RENDER_FILE)
29
+
30
+
31
+ # ====== 工具函数:安全图像处理 ======
32
+ def safe_image_input(image):
33
+ """确保返回合法的 numpy 数组或 None"""
34
+ if image is None:
35
+ return None
36
+ if isinstance(image, np.ndarray):
37
+ return image
38
+ try:
39
+ return np.array(image)
40
+ except Exception:
41
+ return None
42
+
43
+
44
+ # ====== 推理函数 ======
45
+ @torch.no_grad()
46
+ def run_inference(image_ref, image_tgt, do_rm_bkg):
47
+ image_ref = safe_image_input(image_ref)
48
+ image_tgt = safe_image_input(image_tgt)
49
+
50
+ if image_ref is None:
51
+ raise gr.Error("Please upload a reference image before running inference.")
52
+
53
+ # 转为 PIL(用于背景去除和后续叠加)
54
+ pil_ref = Image.fromarray(image_ref.astype(np.uint8)).convert("RGB")
55
+ pil_tgt = None
56
+
57
+ if image_tgt is not None:
58
+ pil_tgt = Image.fromarray(image_tgt.astype(np.uint8)).convert("RGB")
59
+ if do_rm_bkg:
60
+ pil_ref = background_preprocess(pil_ref, True)
61
+ pil_tgt = background_preprocess(pil_tgt, True)
62
+ else:
63
+ if do_rm_bkg:
64
+ pil_ref = background_preprocess(pil_ref, True)
65
+
66
+ try:
67
+ ans_dict = inf_single_case(model, pil_ref, pil_tgt)
68
+ except Exception as e:
69
+ print("Inference error:", e)
70
+ raise gr.Error(f"Inference failed: {str(e)}")
71
+
72
+ def safe_float(val, default=0.0):
73
+ try:
74
+ return float(val)
75
+ except:
76
+ return float(default)
77
+
78
+ az = safe_float(ans_dict.get('ref_az_pred', 0))
79
+ el = safe_float(ans_dict.get('ref_el_pred', 0))
80
+ ro = safe_float(ans_dict.get('ref_ro_pred', 0))
81
+ alpha = int(ans_dict.get('ref_alpha_pred', 1)) # 注意:target 默认 alpha=1,但 ref 可能不是
82
+
83
+ # ===== 渲染参考图的坐标轴 =====
84
+ axis_renderer.render_axis(az, el, ro, alpha, save_path=REF_AXIS_IMAGE)
85
+ axis_ref = Image.open(REF_AXIS_IMAGE).convert("RGBA")
86
+
87
+ # 叠加坐标轴到参考图
88
+ # 确保尺寸一致
89
+ if axis_ref.size != pil_ref.size:
90
+ axis_ref = axis_ref.resize(pil_ref.size, Image.LANCZOS)
91
+ pil_ref_rgba = pil_ref.convert("RGBA")
92
+ overlaid_ref = Image.alpha_composite(pil_ref_rgba, axis_ref).convert("RGB")
93
+
94
+ # ===== 处理目标图(如果有)=====
95
+ if pil_tgt is not None:
96
+ rel_az = safe_float(ans_dict.get('rel_az_pred', 0))
97
+ rel_el = safe_float(ans_dict.get('rel_el_pred', 0))
98
+ rel_ro = safe_float(ans_dict.get('rel_ro_pred', 0))
99
+
100
+ tgt_azi, tgt_ele, tgt_rot = Get_target_azi_ele_rot(az, el, ro, rel_az, rel_el, rel_ro)
101
+ print("Target: Azi",tgt_azi,"Ele",tgt_ele,"Rot",tgt_rot)
102
+
103
+ # target 默认 alpha=1(根据你的说明)
104
+ axis_renderer.render_axis(tgt_azi, tgt_ele, tgt_rot, alpha=1, save_path=TGT_AXIS_IMAGE)
105
+ axis_tgt = Image.open(TGT_AXIS_IMAGE).convert("RGBA")
106
+
107
+ if axis_tgt.size != pil_tgt.size:
108
+ axis_tgt = axis_tgt.resize(pil_tgt.size, Image.LANCZOS)
109
+ pil_tgt_rgba = pil_tgt.convert("RGBA")
110
+ overlaid_tgt = Image.alpha_composite(pil_tgt_rgba, axis_tgt).convert("RGB")
111
+ else:
112
+ overlaid_tgt = None
113
+ rel_az = rel_el = rel_ro = 0.0
114
+
115
+ return [
116
+ overlaid_ref, # 渲染+叠加后的参考图
117
+ overlaid_tgt, # 渲染+叠加后的目标图(可能为 None)
118
+ f"{az:.2f}",
119
+ f"{el:.2f}",
120
+ f"{ro:.2f}",
121
+ str(alpha),
122
+ f"{rel_az:.2f}",
123
+ f"{rel_el:.2f}",
124
+ f"{rel_ro:.2f}",
125
+ ]
126
+
127
+
128
+ # ====== Gradio Blocks UI ======
129
+ with gr.Blocks(title="Orient-Anything Demo") as demo:
130
+ gr.Markdown("# Orient-Anything Demo")
131
+ gr.Markdown("Upload a **reference image** (required). Optionally upload a **target image** for relative pose.")
132
+
133
+ with gr.Row():
134
+ # 左侧:输入图像(参考图 + 目标图,同一行)
135
+ with gr.Column():
136
+ with gr.Row():
137
+ ref_img = gr.Image(
138
+ label="Reference Image (required)",
139
+ type="numpy",
140
+ height=256,
141
+ width=256,
142
+ value=None,
143
+ interactive=True
144
+ )
145
+ tgt_img = gr.Image(
146
+ label="Target Image (optional)",
147
+ type="numpy",
148
+ height=256,
149
+ width=256,
150
+ value=None,
151
+ interactive=True
152
+ )
153
+ rm_bkg = gr.Checkbox(label="Remove Background", value=True)
154
+ run_btn = gr.Button("Run Inference", variant="primary")
155
+
156
+ # 右侧:结果图像 + 文本输出
157
+ with gr.Column():
158
+ # 结果图像:参考结果 + 目标结果(可选)
159
+ with gr.Row():
160
+ res_ref_img = gr.Image(
161
+ label="Rendered Reference",
162
+ type="pil",
163
+ height=256,
164
+ width=256,
165
+ interactive=False
166
+ )
167
+ res_tgt_img = gr.Image(
168
+ label="Rendered Target (if provided)",
169
+ type="pil",
170
+ height=256,
171
+ width=256,
172
+ interactive=False
173
+ )
174
+
175
+ # 文本输出放在图像下方
176
+ with gr.Row():
177
+ with gr.Column():
178
+ gr.Markdown("### Absolute Pose (Reference)")
179
+ az_out = gr.Textbox(label="Azimuth (0~360°)")
180
+ el_out = gr.Textbox(label="Polar (-90~90°)")
181
+ ro_out = gr.Textbox(label="Rotation (-90~90°)")
182
+ alpha_out = gr.Textbox(label="Number of Directions (0/1/2/4)")
183
+ with gr.Column():
184
+ gr.Markdown("### Relative Pose (Target w.r.t Reference)")
185
+ rel_az_out = gr.Textbox(label="Relative Azimuth (0~360°)")
186
+ rel_el_out = gr.Textbox(label="Relative Polar (-90~90°)")
187
+ rel_ro_out = gr.Textbox(label="Relative Rotation (-90~90°)")
188
+
189
+ # 绑定点击事件
190
+ run_btn.click(
191
+ fn=run_inference,
192
+ inputs=[ref_img, tgt_img, rm_bkg],
193
+ outputs=[res_ref_img, res_tgt_img, az_out, el_out, ro_out, alpha_out, rel_az_out, rel_el_out, rel_ro_out],
194
+ preprocess=True,
195
+ postprocess=True
196
+ )
197
+
198
+ # 启动(禁用 API 避免 schema 错误)
199
+ demo.launch(show_api=False)
app_utils.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rembg
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image, ImageOps
6
+ import PIL
7
+ from typing import Any
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+
11
+ def resize_foreground(
12
+ image: Image,
13
+ ratio: float,
14
+ ) -> Image:
15
+ image = np.array(image)
16
+ assert image.shape[-1] == 4
17
+ alpha = np.where(image[..., 3] > 0)
18
+ y1, y2, x1, x2 = (
19
+ alpha[0].min(),
20
+ alpha[0].max(),
21
+ alpha[1].min(),
22
+ alpha[1].max(),
23
+ )
24
+ # crop the foreground
25
+ fg = image[y1:y2, x1:x2]
26
+ # pad to square
27
+ size = max(fg.shape[0], fg.shape[1])
28
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
29
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
30
+ new_image = np.pad(
31
+ fg,
32
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
33
+ mode="constant",
34
+ constant_values=((0, 0), (0, 0), (0, 0)),
35
+ )
36
+
37
+ # compute padding according to the ratio
38
+ new_size = int(new_image.shape[0] / ratio)
39
+ # pad to size, double side
40
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
41
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
42
+ new_image = np.pad(
43
+ new_image,
44
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
45
+ mode="constant",
46
+ constant_values=((0, 0), (0, 0), (0, 0)),
47
+ )
48
+ new_image = Image.fromarray(new_image)
49
+ return new_image
50
+
51
+ def remove_background(image: Image,
52
+ rembg_session: Any = None,
53
+ force: bool = False,
54
+ **rembg_kwargs,
55
+ ) -> Image:
56
+ do_remove = True
57
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
58
+ do_remove = False
59
+ do_remove = do_remove or force
60
+ if do_remove:
61
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
62
+ return image
63
+
64
+ def background_preprocess(input_image, do_remove_background):
65
+ if input_image is None:
66
+ return None
67
+ rembg_session = rembg.new_session() if do_remove_background else None
68
+
69
+ if do_remove_background:
70
+ input_image = remove_background(input_image, rembg_session)
71
+ input_image = resize_foreground(input_image, 0.85)
72
+
73
+ return input_image
74
+
75
+ def axis_angle_rotation_batch(axis: torch.Tensor, theta: torch.Tensor, homogeneous: bool = False) -> torch.Tensor:
76
+ """
77
+ 支持batch输入的版本:
78
+ Args:
79
+ axis: (3,) or (N,3)
80
+ theta: scalar or (N,)
81
+ homogeneous: 是否输出 4x4 齐次矩阵
82
+
83
+ Returns:
84
+ (N,3,3) or (N,4,4)
85
+ """
86
+ axis = torch.as_tensor(axis).float()
87
+ theta = torch.as_tensor(theta).float()
88
+
89
+ if axis.ndim == 1:
90
+ axis = axis.unsqueeze(0) # (1,3)
91
+ if theta.ndim == 0:
92
+ theta = theta.unsqueeze(0) # (1,)
93
+
94
+ N = axis.shape[0]
95
+
96
+ # normalize axis
97
+ axis = axis / torch.norm(axis, dim=1, keepdim=True)
98
+
99
+ x, y, z = axis[:, 0], axis[:, 1], axis[:, 2]
100
+ cos_t = torch.cos(theta)
101
+ sin_t = torch.sin(theta)
102
+ one_minus_cos = 1 - cos_t
103
+
104
+ # 公式展开
105
+ rot = torch.zeros((N, 3, 3), dtype=axis.dtype, device=axis.device)
106
+ rot[:, 0, 0] = cos_t + x*x*one_minus_cos
107
+ rot[:, 0, 1] = x*y*one_minus_cos - z*sin_t
108
+ rot[:, 0, 2] = x*z*one_minus_cos + y*sin_t
109
+ rot[:, 1, 0] = y*x*one_minus_cos + z*sin_t
110
+ rot[:, 1, 1] = cos_t + y*y*one_minus_cos
111
+ rot[:, 1, 2] = y*z*one_minus_cos - x*sin_t
112
+ rot[:, 2, 0] = z*x*one_minus_cos - y*sin_t
113
+ rot[:, 2, 1] = z*y*one_minus_cos + x*sin_t
114
+ rot[:, 2, 2] = cos_t + z*z*one_minus_cos
115
+
116
+ if homogeneous:
117
+ rot_homo = torch.eye(4, dtype=axis.dtype, device=axis.device).unsqueeze(0).repeat(N, 1, 1)
118
+ rot_homo[:, :3, :3] = rot
119
+ return rot_homo
120
+
121
+ return rot
122
+
123
+ def azi_ele_rot_to_Obj_Rmatrix_batch(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor) -> torch.Tensor:
124
+ """支持batch输入的: (azi, ele, rot) -> R matrix (N,3,3)"""
125
+ # 转成tensor
126
+ azi = torch.as_tensor(azi).float() * torch.pi / 180.
127
+ ele = torch.as_tensor(ele).float() * torch.pi / 180.
128
+ rot = torch.as_tensor(rot).float() * torch.pi / 180.
129
+
130
+ # 保证有batch维度
131
+ if azi.ndim == 0:
132
+ azi = azi.unsqueeze(0)
133
+ if ele.ndim == 0:
134
+ ele = ele.unsqueeze(0)
135
+ if rot.ndim == 0:
136
+ rot = rot.unsqueeze(0)
137
+
138
+ N = azi.shape[0]
139
+
140
+ device = azi.device
141
+ dtype = azi.dtype
142
+
143
+ z0_axis = torch.tensor([0.,0.,1.], device=device, dtype=dtype).expand(N, -1)
144
+ y0_axis = torch.tensor([0.,1.,0.], device=device, dtype=dtype).expand(N, -1)
145
+ x0_axis = torch.tensor([1.,0.,0.], device=device, dtype=dtype).expand(N, -1)
146
+ # print(z0_axis.shape, azi.shape)
147
+ R_azi = axis_angle_rotation_batch(z0_axis, -1 * azi)
148
+ R_ele = axis_angle_rotation_batch(y0_axis, ele)
149
+ R_rot = axis_angle_rotation_batch(x0_axis, rot)
150
+
151
+ R_res = R_rot @ R_ele @ R_azi
152
+ return R_res
153
+
154
+ def Cam_Rmatrix_to_azi_ele_rot_batch(R: torch.Tensor):
155
+ """支持batch输入的: R matrix -> (azi, ele, rot),角度制 (度)"""
156
+ R = torch.as_tensor(R).float()
157
+
158
+ # 如果是(3,3),补batch维度
159
+ if R.ndim == 2:
160
+ R = R.unsqueeze(0)
161
+
162
+ r0 = R[:, :, 0] # shape (N,3)
163
+ r1 = R[:, :, 1]
164
+ r2 = R[:, :, 2]
165
+
166
+ ele = torch.asin(r0[:, 2]) # r0.z
167
+ cos_ele = torch.cos(ele)
168
+
169
+ # 创建默认azi、rot
170
+ azi = torch.zeros_like(ele)
171
+ rot = torch.zeros_like(ele)
172
+
173
+ # 正常情况
174
+ normal_mask = (cos_ele.abs() >= 1e-6)
175
+ if normal_mask.any():
176
+ azi[normal_mask] = torch.atan2(r0[normal_mask, 1], r0[normal_mask, 0])
177
+ rot[normal_mask] = torch.atan2(-r1[normal_mask, 2], r2[normal_mask, 2])
178
+
179
+ # Gimbal lock特殊情况
180
+ gimbal_mask = ~normal_mask
181
+ if gimbal_mask.any():
182
+ # 这里设azi为0
183
+ azi[gimbal_mask] = 0.0
184
+ rot[gimbal_mask] = torch.atan2(-r1[gimbal_mask, 0], r1[gimbal_mask, 1])
185
+
186
+ # 弧度转角度
187
+ azi = azi * 180. / torch.pi
188
+ ele = ele * 180. / torch.pi
189
+ rot = rot * 180. / torch.pi
190
+
191
+ return azi, ele, rot
192
+
193
+ def Get_target_azi_ele_rot(azi: torch.Tensor, ele: torch.Tensor, rot: torch.Tensor, rel_azi: torch.Tensor, rel_ele: torch.Tensor, rel_rot: torch.Tensor):
194
+ Rmat0 = azi_ele_rot_to_Obj_Rmatrix_batch(azi = azi , ele = ele , rot = rot)
195
+ Rmat_rel = azi_ele_rot_to_Obj_Rmatrix_batch(azi = rel_azi, ele = rel_ele, rot = rel_rot)
196
+ # Rmat_rel = Rmat1 @ Rmat0.permute(0, 2, 1)
197
+ # azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat_rel.permute(0, 2, 1))
198
+
199
+ Rmat1 = Rmat_rel @ Rmat0
200
+ azi_out, ele_out, rot_out = Cam_Rmatrix_to_azi_ele_rot_batch(Rmat1.permute(0, 2, 1))
201
+
202
+ return azi_out, ele_out, rot_out
assets/axis_ref.png ADDED

Git LFS Details

  • SHA256: 4ac0eb370f3d33fb8d6fc5c4e309b35f38a879d4a51e34ab490a35d39d09b1fa
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
assets/axis_render.blend ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76f8fd3b4ce574a6973ed9637a6c8194fcf46edf72f9266786036c21cf7023a1
3
+ size 2136460
assets/axis_tgt.png ADDED

Git LFS Details

  • SHA256: 39fb1b1e9ef7e16ff25c4d1d9df8fd53ec14f9e7006356db0609ab9c2ee9c048
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
axis_renderer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import math
3
+ import os
4
+ from paths import *
5
+
6
+ class BlendRenderer:
7
+ def __init__(self, blend_file_path=RENDER_FILE):
8
+ """
9
+ 初始化渲染器,加载指定的 .blend 文件并进行基础设置。
10
+
11
+ :param blend_file_path: 要加载的 .blend 文件的完整路径
12
+ """
13
+ if not os.path.isfile(blend_file_path):
14
+ raise FileNotFoundError(f"Blend file not found: {blend_file_path}")
15
+
16
+ # 加载 blend 文件
17
+ bpy.ops.wm.open_mainfile(filepath=blend_file_path)
18
+
19
+ # 设置渲染引擎为 Cycles
20
+ bpy.context.scene.render.engine = 'CYCLES'
21
+
22
+ # 使用 CPU 渲染
23
+ bpy.context.scene.cycles.device = 'CPU'
24
+
25
+ # 设置采样数为 4
26
+ bpy.context.scene.cycles.samples = 4
27
+
28
+ # 设置所有反弹次数为 4(包括 diffuse, glossy, transmission, etc.)
29
+ bpy.context.scene.cycles.max_bounces = 4
30
+
31
+ # 设置渲染分辨率
32
+ bpy.context.scene.render.resolution_x = 512
33
+ bpy.context.scene.render.resolution_y = 512
34
+ bpy.context.scene.render.resolution_percentage = 100
35
+
36
+ # 启用透明背景(RGBA)
37
+ bpy.context.scene.render.film_transparent = True
38
+
39
+ # 遍历所有对象,初始化渲染可见性
40
+ for obj in bpy.data.objects:
41
+ if obj.type == 'LIGHT':
42
+ obj.hide_render = False
43
+ elif obj.type == 'CAMERA':
44
+ obj.hide_render = False
45
+ elif obj.type == 'MESH':
46
+ obj.hide_render = True # 默认所有网格不参与渲染
47
+
48
+ # 设置活动摄像机(选第一个)
49
+ cameras = [obj for obj in bpy.data.objects if obj.type == 'CAMERA']
50
+ if cameras:
51
+ bpy.context.scene.camera = cameras[0]
52
+
53
+ print(f"Loaded blend file: {blend_file_path}")
54
+ print("Render settings applied: 512x512, CPU, samples=4, bounces=4, transparent background.")
55
+
56
+ self.alpha_axis_map = {
57
+ 0: "单轴平面",
58
+ 1: "三轴",
59
+ 2: "双向标注",
60
+ 4: "四向标注"
61
+ }
62
+
63
+
64
+ def _get_all_children(self, obj):
65
+ """递归获取对象的所有子对象(包括嵌套子级)"""
66
+ children = []
67
+ for child in obj.children:
68
+ children.append(child)
69
+ children.extend(self._get_all_children(child))
70
+ return children
71
+
72
+ def render_axis(self, azi, ele, rot, alpha, save_path):
73
+ """
74
+ 渲染特定方向的图像。
75
+
76
+ :param azi: 方位角(绕 Z 轴旋转,弧度)
77
+ :param ele: 仰角(绕 Y 轴旋转,弧度)
78
+ :param rot: 自转(绕 X 轴旋转,弧度)
79
+ :param save_path: 渲染结果保存路径(如 '/output/render.png')
80
+ """
81
+ # 遍历所有对象,初始化渲染可见性
82
+ for obj in bpy.data.objects:
83
+ if obj.type == 'LIGHT':
84
+ obj.hide_render = False
85
+ elif obj.type == 'CAMERA':
86
+ obj.hide_render = False
87
+ elif obj.type == 'MESH':
88
+ obj.hide_render = True # 默认所有网格不参与渲染
89
+ # 根据 alpha 选择目标对象
90
+ target_name = self.alpha_axis_map.get(alpha, "单轴平面")
91
+ target_obj = None
92
+ for obj in bpy.data.objects:
93
+ # if obj.type == 'MESH' and obj.name == target_name:
94
+ if obj.name == target_name:
95
+ target_obj = obj
96
+ break
97
+
98
+ if target_obj is None:
99
+ raise ValueError(f'Object named "{target_name}" not found in the scene.')
100
+
101
+ # 获取该对象及其所有子对象
102
+ all_objects_to_render = [target_obj] + self._get_all_children(target_obj)
103
+
104
+ # 设置它们参与渲染
105
+ for obj in all_objects_to_render:
106
+ if obj.type == 'MESH':
107
+ obj.hide_render = False
108
+
109
+ # 设置旋转(ZYX 顺序:Z=azi, Y=ele, X=rot → Euler XYZ = (rot, ele, azi))
110
+ # 注意:Blender 使用弧度
111
+ target_obj.rotation_mode = 'ZYX' # 确保使用欧拉角 ZYX 模式
112
+ target_obj.rotation_euler = (rot*math.pi/180, ele*math.pi/180, -azi*math.pi/180)
113
+
114
+ # 确保路径目录存在
115
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
116
+
117
+ # 设置输出路径
118
+ bpy.context.scene.render.filepath = save_path
119
+
120
+ # 执行渲染并保存
121
+ bpy.ops.render.render(write_still=True)
122
+
123
+ print(f"Rendered and saved to: {save_path}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ renderer = BlendRenderer(RENDER_FILE)
128
+ # Example usage:
129
+ renderer.render_axis(45, 0, 0, 1, "./test_demo_output/render_1_dir_azi45.png")
130
+ renderer.render_axis(0, 45, 0, 2, "./test_demo_output/render_2_dir_ele45.png")
131
+ renderer.render_axis(0, 0, 45, 4, "./test_demo_output/render_4_dir_rot45.png")
132
+ # renderer.render_1_dir()
133
+ # renderer.render_2_dir()
134
+ # renderer.render_4_dir()
135
+
136
+
inference.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from app_utils import *
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from torchvision import transforms as TF
7
+
8
+ from scipy.special import i0
9
+ from scipy.optimize import curve_fit
10
+ from scipy.integrate import trapezoid
11
+ from functools import partial
12
+
13
+ def von_mises_pdf_alpha_numpy(alpha, x, mu, kappa):
14
+ normalization = 2 * np.pi
15
+ pdf = np.exp(kappa * np.cos(alpha * (x - mu))) / normalization
16
+ return pdf
17
+
18
+ def val_fit_alpha(distribute):
19
+ fit_alphas = []
20
+ for y_noise in distribute:
21
+ x = np.linspace(0, 2 * np.pi, 360)
22
+ y_noise /= trapezoid(y_noise, x) + 1e-8
23
+
24
+ initial_guess = [x[np.argmax(y_noise)], 1]
25
+
26
+ # support 1,2,4
27
+ alphas = [1.0, 2.0, 4.0]
28
+ saved_params = []
29
+ saved_r_squared = []
30
+
31
+ for alpha in alphas:
32
+ try:
33
+ von_mises_pdf_alpha_partial = partial(von_mises_pdf_alpha_numpy, alpha)
34
+ params, covariance = curve_fit(von_mises_pdf_alpha_partial, x, y_noise, p0=initial_guess)
35
+
36
+ residuals = y_noise - von_mises_pdf_alpha_partial(x, *params)
37
+ ss_res = np.sum(residuals**2)
38
+ ss_tot = np.sum((y_noise - np.mean(y_noise))**2)
39
+ r_squared = 1 - (ss_res / (ss_tot+1e-8))
40
+
41
+ saved_params.append(params)
42
+ saved_r_squared.append(r_squared)
43
+ if r_squared > 0.8:
44
+ break
45
+ except:
46
+ saved_params.append((0.,0.))
47
+ saved_r_squared.append(0.)
48
+
49
+ max_index = np.argmax(saved_r_squared)
50
+ alpha = alphas[max_index]
51
+ mu_fit, kappa_fit = saved_params[max_index]
52
+ r_squared = saved_r_squared[max_index]
53
+
54
+ if alpha == 1. and kappa_fit>=0.5 and r_squared>=0.5:
55
+ pass
56
+ elif alpha == 2. and kappa_fit>=0.35 and r_squared>=0.35:
57
+ pass
58
+ elif alpha == 4. and kappa_fit>=0.25 and r_squared>=0.25:
59
+ pass
60
+ else:
61
+ alpha=0.
62
+ fit_alphas.append(alpha)
63
+ return torch.tensor(fit_alphas)
64
+
65
+ def preprocess_images(image_list, mode="crop"):
66
+
67
+ # Check for empty list
68
+ if len(image_list) == 0:
69
+ raise ValueError("At least 1 image is required")
70
+
71
+ # Validate mode
72
+ if mode not in ["crop", "pad"]:
73
+ raise ValueError("Mode must be either 'crop' or 'pad'")
74
+
75
+ images = []
76
+ shapes = set()
77
+ to_tensor = TF.ToTensor()
78
+ target_size = 518
79
+
80
+ # First process all images and collect their shapes
81
+ # for image_path in image_path_list:
82
+ for img in image_list:
83
+ # If there's an alpha channel, blend onto white background:
84
+ if img.mode == "RGBA":
85
+ # Create white background
86
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
87
+ # Alpha composite onto the white background
88
+ img = Image.alpha_composite(background, img)
89
+
90
+ # Now convert to "RGB" (this step assigns white for transparent areas)
91
+ img = img.convert("RGB")
92
+ width, height = img.size
93
+
94
+ if mode == "pad":
95
+ # Make the largest dimension 518px while maintaining aspect ratio
96
+ if width >= height:
97
+ new_width = target_size
98
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
99
+ else:
100
+ new_height = target_size
101
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
102
+ else: # mode == "crop"
103
+ # Original behavior: set width to 518px
104
+ new_width = target_size
105
+ # Calculate height maintaining aspect ratio, divisible by 14
106
+ new_height = round(height * (new_width / width) / 14) * 14
107
+
108
+ # Resize with new dimensions (width, height)
109
+ try:
110
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
111
+ img = to_tensor(img) # Convert to tensor (0, 1)
112
+ except Exception as e:
113
+ print(e)
114
+ print(width, height)
115
+ print(new_width, new_height)
116
+ assert False
117
+
118
+ # Center crop height if it's larger than 518 (only in crop mode)
119
+ if mode == "crop" and new_height > target_size:
120
+ start_y = (new_height - target_size) // 2
121
+ img = img[:, start_y : start_y + target_size, :]
122
+
123
+ # For pad mode, pad to make a square of target_size x target_size
124
+ if mode == "pad":
125
+ h_padding = target_size - img.shape[1]
126
+ w_padding = target_size - img.shape[2]
127
+
128
+ if h_padding > 0 or w_padding > 0:
129
+ pad_top = h_padding // 2
130
+ pad_bottom = h_padding - pad_top
131
+ pad_left = w_padding // 2
132
+ pad_right = w_padding - pad_left
133
+
134
+ # Pad with white (value=1.0)
135
+ img = torch.nn.functional.pad(
136
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
137
+ )
138
+
139
+ shapes.add((img.shape[1], img.shape[2]))
140
+ images.append(img)
141
+
142
+ # Check if we have different shapes
143
+ # In theory our model can also work well with different shapes
144
+ if len(shapes) > 1:
145
+ print(f"Warning: Found images with different shapes: {shapes}")
146
+ # Find maximum dimensions
147
+ max_height = max(shape[0] for shape in shapes)
148
+ max_width = max(shape[1] for shape in shapes)
149
+
150
+ # Pad images if necessary
151
+ padded_images = []
152
+ for img in images:
153
+ h_padding = max_height - img.shape[1]
154
+ w_padding = max_width - img.shape[2]
155
+
156
+ if h_padding > 0 or w_padding > 0:
157
+ pad_top = h_padding // 2
158
+ pad_bottom = h_padding - pad_top
159
+ pad_left = w_padding // 2
160
+ pad_right = w_padding - pad_left
161
+
162
+ img = torch.nn.functional.pad(
163
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
164
+ )
165
+ padded_images.append(img)
166
+ images = padded_images
167
+
168
+ images = torch.stack(images) # concatenate images
169
+
170
+ # Ensure correct shape when single image
171
+ if len(image_list) == 1:
172
+ # Verify shape is (1, C, H, W)
173
+ if images.dim() == 3:
174
+ images = images.unsqueeze(0)
175
+
176
+ return images
177
+
178
+ @torch.no_grad()
179
+ def inf_single_batch(model, batch):
180
+ device = model.get_device()
181
+ batch_img_inputs = batch # (B, S, 3, H, W)
182
+ # print(batch_img_inputs.shape)
183
+ B, S, C, H, W = batch_img_inputs.shape
184
+ pose_enc = model(batch_img_inputs) # (B, S, D) S = 1
185
+
186
+ pose_enc = pose_enc.view(B*S, -1)
187
+ angle_az_pred = torch.argmax(pose_enc[:, 0:360] , dim=-1)
188
+ angle_el_pred = torch.argmax(pose_enc[:, 360:360+180] , dim=-1) - 90
189
+ angle_ro_pred = torch.argmax(pose_enc[:, 360+180:360+180+360] , dim=-1) - 180
190
+
191
+ # ori_val
192
+ # trained with BCE loss
193
+ distribute = F.sigmoid(pose_enc[:, 0:360]).cpu().float().numpy()
194
+ # trained with CE loss
195
+ # distribute = pose_enc[:, 0:360].cpu().float().numpy()
196
+ alpha_pred = val_fit_alpha(distribute = distribute)
197
+
198
+ # ref_val
199
+ if S > 1:
200
+ ref_az_pred = angle_az_pred.reshape(B,S)[:,0]
201
+ ref_el_pred = angle_el_pred.reshape(B,S)[:,0]
202
+ ref_ro_pred = angle_ro_pred.reshape(B,S)[:,0]
203
+ ref_alpha_pred = alpha_pred.reshape(B,S)[:,0]
204
+ rel_az_pred = angle_az_pred.reshape(B,S)[:,1]
205
+ rel_el_pred = angle_el_pred.reshape(B,S)[:,1]
206
+ rel_ro_pred = angle_ro_pred.reshape(B,S)[:,1]
207
+ else:
208
+ ref_az_pred = angle_az_pred[0]
209
+ ref_el_pred = angle_el_pred[0]
210
+ ref_ro_pred = angle_ro_pred[0]
211
+ ref_alpha_pred = alpha_pred[0]
212
+ rel_az_pred = 0.
213
+ rel_el_pred = 0.
214
+ rel_ro_pred = 0.
215
+
216
+ ans_dict = {
217
+ 'ref_az_pred': ref_az_pred,
218
+ 'ref_el_pred': ref_el_pred,
219
+ 'ref_ro_pred': ref_ro_pred,
220
+ 'ref_alpha_pred' : ref_alpha_pred,
221
+ 'rel_az_pred' : rel_az_pred,
222
+ 'rel_el_pred' : rel_el_pred,
223
+ 'rel_ro_pred' : rel_ro_pred,
224
+ }
225
+
226
+ return ans_dict
227
+
228
+ # input PIL Image
229
+ @torch.no_grad()
230
+ def inf_single_case(model, image_ref, image_tgt):
231
+ if image_tgt is None:
232
+ image_list = [image_ref]
233
+ else:
234
+ image_list = [image_ref, image_tgt]
235
+ image_tensors = preprocess_images(image_list, mode="pad").to(model.get_device())
236
+ ans_dict = inf_single_batch(model=model, batch=image_tensors.unsqueeze(0))
237
+ print(ans_dict)
238
+ return ans_dict
orianyV2_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
paths.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DINO_SMALL = "facebook/dinov2-small"
2
+ DINO_BASE = "facebook/dinov2-base"
3
+ DINO_LARGE = "facebook/dinov2-large"
4
+ DINO_GIANT = "facebook/dinov2-giant"
5
+
6
+ VGGT_1B = "facebook/VGGT-1B"
7
+
8
+ ORIANY_V2 = "Viglong/OriAnyV2_ckpt"
9
+
10
+ REMOTE_CKPT_PATH = "demo_ckpts/acc8mask20lowlr.pt"
11
+
12
+
13
+ RENDER_FILE = "assets/axis_render.blend"
14
+ REF_AXIS_IMAGE = "assets/axis_ref.png"
15
+ TGT_AXIS_IMAGE = "assets/axis_tgt.png"
16
+
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ pydantic==2.10.6
3
+ gradio==5.9.0
4
+ onnxruntime
5
+ rembg
6
+ accelerate==1.8.1
7
+ numpy>=1.24
8
+ einops
9
+ pandas
10
+ pillow
11
+ huggingface_hub>=0.23
12
+ pytorch-lightning
13
+ scipy
14
+ torch
15
+ torchmetrics
16
+ torchvision
17
+ tqdm
18
+ transformers
19
+ scikit-learn
20
+ opencv-python
21
+ timm
22
+ bpy==4.2
vggt/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.view(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.view(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.view(B, S, *preds.shape[1:])
259
+ conf = conf.view(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
vggt/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
vggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
vggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(
60
+ hidden_size,
61
+ num_heads,
62
+ mlp_ratio=mlp_ratio,
63
+ attn_class=nn.MultiheadAttention,
64
+ )
65
+ for _ in range(time_depth)
66
+ ]
67
+ )
68
+
69
+ if add_space_attn:
70
+ self.space_virtual_blocks = nn.ModuleList(
71
+ [
72
+ AttnBlock(
73
+ hidden_size,
74
+ num_heads,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_class=nn.MultiheadAttention,
77
+ )
78
+ for _ in range(space_depth)
79
+ ]
80
+ )
81
+ self.space_point2virtual_blocks = nn.ModuleList(
82
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
83
+ )
84
+ self.space_virtual2point_blocks = nn.ModuleList(
85
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
86
+ )
87
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
88
+ self.initialize_weights()
89
+
90
+ def initialize_weights(self):
91
+ def _basic_init(module):
92
+ if isinstance(module, nn.Linear):
93
+ torch.nn.init.xavier_uniform_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
97
+
98
+ self.apply(_basic_init)
99
+
100
+ def forward(self, input_tensor, mask=None):
101
+ # Apply input LayerNorm
102
+ input_tensor = self.input_norm(input_tensor)
103
+ tokens = self.input_transform(input_tensor)
104
+
105
+ init_tokens = tokens
106
+
107
+ B, _, T, _ = tokens.shape
108
+
109
+ if self.add_space_attn:
110
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
111
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
112
+
113
+ _, N, _, _ = tokens.shape
114
+
115
+ j = 0
116
+ for i in range(len(self.time_blocks)):
117
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
118
+
119
+ time_tokens = self.time_blocks[i](time_tokens)
120
+
121
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
122
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
123
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
124
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
125
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
126
+
127
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
128
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
129
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
130
+
131
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
132
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
133
+ j += 1
134
+
135
+ if self.add_space_attn:
136
+ tokens = tokens[:, : N - self.num_virtual_tracks]
137
+
138
+ tokens = tokens + init_tokens
139
+
140
+ # Apply output LayerNorm before final projection
141
+ tokens = self.output_norm(tokens)
142
+ flow = self.flow_head(tokens)
143
+
144
+ return flow, None
145
+
146
+
147
+ class CorrBlock:
148
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
149
+ """
150
+ Build a pyramid of feature maps from the input.
151
+
152
+ fmaps: Tensor (B, S, C, H, W)
153
+ num_levels: number of pyramid levels (each downsampled by factor 2)
154
+ radius: search radius for sampling correlation
155
+ multiple_track_feats: if True, split the target features per pyramid level
156
+ padding_mode: passed to grid_sample / bilinear_sampler
157
+ """
158
+ B, S, C, H, W = fmaps.shape
159
+ self.S, self.C, self.H, self.W = S, C, H, W
160
+ self.num_levels = num_levels
161
+ self.radius = radius
162
+ self.padding_mode = padding_mode
163
+ self.multiple_track_feats = multiple_track_feats
164
+
165
+ # Build pyramid: each level is half the spatial resolution of the previous
166
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
167
+ current_fmaps = fmaps
168
+ for i in range(num_levels - 1):
169
+ B, S, C, H, W = current_fmaps.shape
170
+ # Merge batch & sequence dimensions
171
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
172
+ # Avg pool down by factor 2
173
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
174
+ _, _, H_new, W_new = current_fmaps.shape
175
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
176
+ self.fmaps_pyramid.append(current_fmaps)
177
+
178
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
179
+ # This grid is added to the (scaled) coordinate centroids.
180
+ r = self.radius
181
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
182
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
183
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
184
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
185
+
186
+ def corr_sample(self, targets, coords):
187
+ """
188
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
189
+ volume, sample it immediately, then discard it. This saves GPU memory.
190
+
191
+ Args:
192
+ targets: Tensor (B, S, N, C) — features for the current targets.
193
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
194
+
195
+ Returns:
196
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
197
+ """
198
+ B, S, N, C = targets.shape
199
+
200
+ # If you have multiple track features, split them per level.
201
+ if self.multiple_track_feats:
202
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
203
+
204
+ out_pyramid = []
205
+ for i, fmaps in enumerate(self.fmaps_pyramid):
206
+ # Get current spatial resolution H, W for this pyramid level.
207
+ B, S, C, H, W = fmaps.shape
208
+ # Reshape feature maps for correlation computation:
209
+ # fmap2s: (B, S, C, H*W)
210
+ fmap2s = fmaps.view(B, S, C, H * W)
211
+ # Choose appropriate target features.
212
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
213
+
214
+ # Compute correlation directly
215
+ corrs = compute_corr_level(fmap1, fmap2s, C)
216
+ corrs = corrs.view(B, S, N, H, W)
217
+
218
+ # Prepare sampling grid:
219
+ # Scale down the coordinates for the current level.
220
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
221
+ # Make sure our precomputed delta grid is on the same device/dtype.
222
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
223
+ # Now the grid for grid_sample is:
224
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
225
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
226
+
227
+ # Sample from the correlation volume using bilinear interpolation.
228
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
229
+ corrs_sampled = bilinear_sampler(
230
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
231
+ )
232
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
233
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
234
+ out_pyramid.append(corrs_sampled)
235
+
236
+ # Concatenate all levels along the last dimension.
237
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
238
+ return out
239
+
240
+
241
+ def compute_corr_level(fmap1, fmap2s, C):
242
+ # fmap1: (B, S, N, C)
243
+ # fmap2s: (B, S, C, H*W)
244
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
245
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
246
+ return corrs / math.sqrt(C)
vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
vggt/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
vggt/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
vggt/layers/attention.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(
62
+ q,
63
+ k,
64
+ v,
65
+ dropout_p=self.attn_drop.p if self.training else 0.0,
66
+ )
67
+ else:
68
+ q = q * self.scale
69
+ attn = q @ k.transpose(-2, -1)
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+ x = attn @ v
73
+
74
+ x = x.transpose(1, 2).reshape(B, N, C)
75
+ x = self.proj(x)
76
+ x = self.proj_drop(x)
77
+ return x
78
+
79
+
80
+ class MemEffAttention(Attention):
81
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
82
+ assert pos is None
83
+ if not XFORMERS_AVAILABLE:
84
+ if attn_bias is not None:
85
+ raise AssertionError("xFormers is required for using nested tensors")
86
+ return super().forward(x)
87
+
88
+ B, N, C = x.shape
89
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
90
+
91
+ q, k, v = unbind(qkv, 2)
92
+
93
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
94
+ x = x.reshape([B, N, C])
95
+
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x
vggt/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim,
71
+ hidden_features=mlp_hidden_dim,
72
+ act_layer=act_layer,
73
+ drop=drop,
74
+ bias=ffn_bias,
75
+ )
76
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
78
+
79
+ self.sample_drop_ratio = drop_path
80
+
81
+ def forward(self, x: Tensor, pos=None) -> Tensor:
82
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
83
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
84
+
85
+ def ffn_residual_func(x: Tensor) -> Tensor:
86
+ return self.ls2(self.mlp(self.norm2(x)))
87
+
88
+ if self.training and self.sample_drop_ratio > 0.1:
89
+ # the overhead is compensated only for a drop path rate larger than 0.1
90
+ x = drop_add_residual_stochastic_depth(
91
+ x,
92
+ pos=pos,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x, pos=pos)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ pos=None,
115
+ ) -> Tensor:
116
+ # 1) extract subset using permutation
117
+ b, n, d = x.shape
118
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
119
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
120
+ x_subset = x[brange]
121
+
122
+ # 2) apply residual_func to get residual
123
+ if pos is not None:
124
+ # if necessary, apply rope to the subset
125
+ pos = pos[brange]
126
+ residual = residual_func(x_subset, pos=pos)
127
+ else:
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
vggt/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
vggt/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
vggt/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
vggt/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
vggt/layers/rope.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+
24
+ class PositionGetter:
25
+ """Generates and caches 2D spatial positions for patches in a grid.
26
+
27
+ This class efficiently manages the generation of spatial coordinates for patches
28
+ in a 2D grid, caching results to avoid redundant computations.
29
+
30
+ Attributes:
31
+ position_cache: Dictionary storing precomputed position tensors for different
32
+ grid dimensions.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the position generator with an empty cache."""
37
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
38
+
39
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
40
+ """Generates spatial positions for a batch of patches.
41
+
42
+ Args:
43
+ batch_size: Number of samples in the batch.
44
+ height: Height of the grid in patches.
45
+ width: Width of the grid in patches.
46
+ device: Target device for the position tensor.
47
+
48
+ Returns:
49
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
50
+ for each position in the grid, repeated for each batch item.
51
+ """
52
+ if (height, width) not in self.position_cache:
53
+ y_coords = torch.arange(height, device=device)
54
+ x_coords = torch.arange(width, device=device)
55
+ positions = torch.cartesian_prod(y_coords, x_coords)
56
+ self.position_cache[height, width] = positions
57
+
58
+ cached_positions = self.position_cache[height, width]
59
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
60
+
61
+
62
+ class RotaryPositionEmbedding2D(nn.Module):
63
+ """2D Rotary Position Embedding implementation.
64
+
65
+ This module applies rotary position embeddings to input tokens based on their
66
+ 2D spatial positions. It handles the position-dependent rotation of features
67
+ separately for vertical and horizontal dimensions.
68
+
69
+ Args:
70
+ frequency: Base frequency for the position embeddings. Default: 100.0
71
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
72
+
73
+ Attributes:
74
+ base_frequency: Base frequency for computing position embeddings.
75
+ scaling_factor: Factor to scale the computed frequencies.
76
+ frequency_cache: Cache for storing precomputed frequency components.
77
+ """
78
+
79
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
80
+ """Initializes the 2D RoPE module."""
81
+ super().__init__()
82
+ self.base_frequency = frequency
83
+ self.scaling_factor = scaling_factor
84
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
85
+
86
+ def _compute_frequency_components(
87
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Computes frequency components for rotary embeddings.
90
+
91
+ Args:
92
+ dim: Feature dimension (must be even).
93
+ seq_len: Maximum sequence length.
94
+ device: Target device for computations.
95
+ dtype: Data type for the computed tensors.
96
+
97
+ Returns:
98
+ Tuple of (cosine, sine) tensors for frequency components.
99
+ """
100
+ cache_key = (dim, seq_len, device, dtype)
101
+ if cache_key not in self.frequency_cache:
102
+ # Compute frequency bands
103
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
104
+ inv_freq = 1.0 / (self.base_frequency**exponents)
105
+
106
+ # Generate position-dependent frequencies
107
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
108
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
109
+
110
+ # Compute and cache frequency components
111
+ angles = angles.to(dtype)
112
+ angles = torch.cat((angles, angles), dim=-1)
113
+ cos_components = angles.cos().to(dtype)
114
+ sin_components = angles.sin().to(dtype)
115
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
116
+
117
+ return self.frequency_cache[cache_key]
118
+
119
+ @staticmethod
120
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
121
+ """Performs feature rotation by splitting and recombining feature dimensions.
122
+
123
+ Args:
124
+ x: Input tensor to rotate.
125
+
126
+ Returns:
127
+ Rotated feature tensor.
128
+ """
129
+ feature_dim = x.shape[-1]
130
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def _apply_1d_rope(
134
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """Applies 1D rotary position embeddings along one dimension.
137
+
138
+ Args:
139
+ tokens: Input token features.
140
+ positions: Position indices.
141
+ cos_comp: Cosine components for rotation.
142
+ sin_comp: Sine components for rotation.
143
+
144
+ Returns:
145
+ Tokens with applied rotary position embeddings.
146
+ """
147
+ # Embed positions with frequency components
148
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
149
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
150
+
151
+ # Apply rotation
152
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
153
+
154
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
155
+ """Applies 2D rotary position embeddings to input tokens.
156
+
157
+ Args:
158
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
159
+ The feature dimension (dim) must be divisible by 4.
160
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
161
+ the y and x coordinates for each token.
162
+
163
+ Returns:
164
+ Tensor of same shape as input with applied 2D rotary position embeddings.
165
+
166
+ Raises:
167
+ AssertionError: If input dimensions are invalid or positions are malformed.
168
+ """
169
+ # Validate inputs
170
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
171
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
172
+
173
+ # Compute feature dimension for each spatial direction
174
+ feature_dim = tokens.size(-1) // 2
175
+
176
+ # Get frequency components
177
+ max_position = int(positions.max()) + 1
178
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
179
+
180
+ # Split features for vertical and horizontal processing
181
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
182
+
183
+ # Apply RoPE separately for each dimension
184
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
185
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
186
+
187
+ # Combine processed features
188
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
vggt/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
vggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25
+ if not depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ for child_name, child_module in module.named_children():
28
+ child_name = ".".join((name, child_name)) if name else child_name
29
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30
+ if depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ return module
33
+
34
+
35
+ class BlockChunk(nn.ModuleList):
36
+ def forward(self, x):
37
+ for b in self:
38
+ x = b(x)
39
+ return x
40
+
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=Block,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ num_register_tokens=0,
64
+ interpolate_antialias=False,
65
+ interpolate_offset=0.1,
66
+ qk_norm=False,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ # tricky but makes it work
97
+ self.use_checkpoint = False
98
+ #
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ qk_norm=qk_norm,
154
+ )
155
+ for i in range(depth)
156
+ ]
157
+ if block_chunks > 0:
158
+ self.chunked_blocks = True
159
+ chunked_blocks = []
160
+ chunksize = depth // block_chunks
161
+ for i in range(0, depth, chunksize):
162
+ # this is to keep the block index consistent if we chunk the block list
163
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
164
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
165
+ else:
166
+ self.chunked_blocks = False
167
+ self.blocks = nn.ModuleList(blocks_list)
168
+
169
+ self.norm = norm_layer(embed_dim)
170
+ self.head = nn.Identity()
171
+
172
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
173
+
174
+ self.init_weights()
175
+
176
+ def init_weights(self):
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ nn.init.normal_(self.cls_token, std=1e-6)
179
+ if self.register_tokens is not None:
180
+ nn.init.normal_(self.register_tokens, std=1e-6)
181
+ named_apply(init_weights_vit_timm, self)
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
196
+ assert N == M * M
197
+ kwargs = {}
198
+ if self.interpolate_offset:
199
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
200
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
201
+ sx = float(w0 + self.interpolate_offset) / M
202
+ sy = float(h0 + self.interpolate_offset) / M
203
+ kwargs["scale_factor"] = (sx, sy)
204
+ else:
205
+ # Simply specify an output size instead of a scale factor
206
+ kwargs["size"] = (w0, h0)
207
+ patch_pos_embed = nn.functional.interpolate(
208
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
209
+ mode="bicubic",
210
+ antialias=self.interpolate_antialias,
211
+ **kwargs,
212
+ )
213
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
214
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
215
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
216
+
217
+ def prepare_tokens_with_masks(self, x, masks=None):
218
+ B, nc, w, h = x.shape
219
+ x = self.patch_embed(x)
220
+ if masks is not None:
221
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
222
+
223
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
224
+ x = x + self.interpolate_pos_encoding(x, w, h)
225
+
226
+ if self.register_tokens is not None:
227
+ x = torch.cat(
228
+ (
229
+ x[:, :1],
230
+ self.register_tokens.expand(x.shape[0], -1, -1),
231
+ x[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ return x
237
+
238
+ def forward_features_list(self, x_list, masks_list):
239
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
240
+
241
+ for blk in self.blocks:
242
+ if self.use_checkpoint:
243
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
244
+ else:
245
+ x = blk(x)
246
+
247
+ all_x = x
248
+ output = []
249
+ for x, masks in zip(all_x, masks_list):
250
+ x_norm = self.norm(x)
251
+ output.append(
252
+ {
253
+ "x_norm_clstoken": x_norm[:, 0],
254
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
255
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
256
+ "x_prenorm": x,
257
+ "masks": masks,
258
+ }
259
+ )
260
+ return output
261
+
262
+ def forward_features(self, x, masks=None):
263
+ if isinstance(x, list):
264
+ return self.forward_features_list(x, masks)
265
+
266
+ x = self.prepare_tokens_with_masks(x, masks)
267
+
268
+ for blk in self.blocks:
269
+ if self.use_checkpoint:
270
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
271
+ else:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=True, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
vggt/models/aggregator.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from vggt.layers import PatchEmbed
14
+ from vggt.layers.block import Block
15
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
21
+ _RESNET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class Aggregator(nn.Module):
25
+ """
26
+ The Aggregator applies alternating-attention over input frames,
27
+ as described in VGGT: Visual Geometry Grounded Transformer.
28
+
29
+
30
+ Args:
31
+ img_size (int): Image size in pixels.
32
+ patch_size (int): Size of each patch for PatchEmbed.
33
+ embed_dim (int): Dimension of the token embeddings.
34
+ depth (int): Number of blocks.
35
+ num_heads (int): Number of attention heads.
36
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
37
+ num_register_tokens (int): Number of register tokens.
38
+ block_fn (nn.Module): The block type used for attention (Block by default).
39
+ qkv_bias (bool): Whether to include bias in QKV projections.
40
+ proj_bias (bool): Whether to include bias in the output projection.
41
+ ffn_bias (bool): Whether to include bias in MLP layers.
42
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
43
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
44
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
45
+ qk_norm (bool): Whether to apply QK normalization.
46
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
47
+ init_values (float): Init scale for layer scale.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size=518,
53
+ patch_size=14,
54
+ embed_dim=1024,
55
+ depth=24,
56
+ num_heads=16,
57
+ mlp_ratio=4.0,
58
+ num_register_tokens=4,
59
+ block_fn=Block,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ patch_embed="dinov2_vitl14_reg",
64
+ aa_order=["frame", "global"],
65
+ aa_block_size=1,
66
+ qk_norm=True,
67
+ rope_freq=100,
68
+ init_values=0.01,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+ # Note: We have two camera tokens, one for the first frame and one for the rest
124
+ # The same applies for register tokens
125
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
126
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
127
+
128
+ # The patch tokens start after the camera and register tokens
129
+ self.patch_start_idx = 1 + num_register_tokens
130
+
131
+ # Initialize parameters with small values
132
+ nn.init.normal_(self.camera_token, std=1e-6)
133
+ nn.init.normal_(self.register_token, std=1e-6)
134
+
135
+ # Register normalization constants as buffers
136
+ for name, value in (
137
+ ("_resnet_mean", _RESNET_MEAN),
138
+ ("_resnet_std", _RESNET_STD),
139
+ ):
140
+ self.register_buffer(
141
+ name,
142
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
143
+ persistent=False,
144
+ )
145
+
146
+ def __build_patch_embed__(
147
+ self,
148
+ patch_embed,
149
+ img_size,
150
+ patch_size,
151
+ num_register_tokens,
152
+ interpolate_antialias=True,
153
+ interpolate_offset=0.0,
154
+ block_chunks=0,
155
+ init_values=1.0,
156
+ embed_dim=1024,
157
+ ):
158
+ """
159
+ Build the patch embed layer. If 'conv', we use a
160
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
161
+ """
162
+
163
+ if "conv" in patch_embed:
164
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
165
+ else:
166
+ vit_models = {
167
+ "dinov2_vitl14_reg": vit_large,
168
+ "dinov2_vitb14_reg": vit_base,
169
+ "dinov2_vits14_reg": vit_small,
170
+ "dinov2_vitg2_reg": vit_giant2,
171
+ }
172
+
173
+ self.patch_embed = vit_models[patch_embed](
174
+ img_size=img_size,
175
+ patch_size=patch_size,
176
+ num_register_tokens=num_register_tokens,
177
+ interpolate_antialias=interpolate_antialias,
178
+ interpolate_offset=interpolate_offset,
179
+ block_chunks=block_chunks,
180
+ init_values=init_values,
181
+ )
182
+
183
+ # Disable gradient updates for mask token
184
+ if hasattr(self.patch_embed, "mask_token"):
185
+ self.patch_embed.mask_token.requires_grad_(False)
186
+
187
+ def forward(
188
+ self,
189
+ images: torch.Tensor,
190
+ ) -> Tuple[List[torch.Tensor], int]:
191
+ """
192
+ Args:
193
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
194
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
195
+
196
+ Returns:
197
+ (list[torch.Tensor], int):
198
+ The list of outputs from the attention blocks,
199
+ and the patch_start_idx indicating where patch tokens begin.
200
+ """
201
+ B, S, C_in, H, W = images.shape
202
+
203
+ if C_in != 3:
204
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
205
+
206
+ # Normalize images and reshape for patch embed
207
+ images = (images - self._resnet_mean) / self._resnet_std
208
+
209
+ # Reshape to [B*S, C, H, W] for patch embedding
210
+ images = images.view(B * S, C_in, H, W)
211
+ patch_tokens = self.patch_embed(images)
212
+
213
+ if isinstance(patch_tokens, dict):
214
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
215
+
216
+ _, P, C = patch_tokens.shape
217
+
218
+ # Expand camera and register tokens to match batch size and sequence length
219
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
220
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
221
+
222
+ # Concatenate special tokens with patch tokens
223
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
224
+
225
+ pos = None
226
+ if self.rope is not None:
227
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
228
+
229
+ if self.patch_start_idx > 0:
230
+ # do not use position embedding for special tokens (camera and register tokens)
231
+ # so set pos to 0 for the special tokens
232
+ pos = pos + 1
233
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
234
+ pos = torch.cat([pos_special, pos], dim=1)
235
+
236
+ # update P because we added special tokens
237
+ _, P, C = tokens.shape
238
+
239
+ frame_idx = 0
240
+ global_idx = 0
241
+ output_list = []
242
+
243
+ for _ in range(self.aa_block_num):
244
+ for attn_type in self.aa_order:
245
+ if attn_type == "frame":
246
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
247
+ tokens, B, S, P, C, frame_idx, pos=pos
248
+ )
249
+ elif attn_type == "global":
250
+ tokens, global_idx, global_intermediates = self._process_global_attention(
251
+ tokens, B, S, P, C, global_idx, pos=pos
252
+ )
253
+ else:
254
+ raise ValueError(f"Unknown attention type: {attn_type}")
255
+
256
+ for i in range(len(frame_intermediates)):
257
+ # concat frame and global intermediates, [B x S x P x 2C]
258
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
259
+ output_list.append(concat_inter)
260
+
261
+ del concat_inter
262
+ del frame_intermediates
263
+ del global_intermediates
264
+ return output_list, self.patch_start_idx
265
+
266
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
267
+ """
268
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
269
+ """
270
+ # If needed, reshape tokens or positions:
271
+ if tokens.shape != (B * S, P, C):
272
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
273
+
274
+ if pos is not None and pos.shape != (B * S, P, 2):
275
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
276
+
277
+ intermediates = []
278
+
279
+ # by default, self.aa_block_size=1, which processes one block at a time
280
+ for _ in range(self.aa_block_size):
281
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
282
+ frame_idx += 1
283
+ intermediates.append(tokens.view(B, S, P, C))
284
+
285
+ return tokens, frame_idx, intermediates
286
+
287
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
288
+ """
289
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
290
+ """
291
+ if tokens.shape != (B, S * P, C):
292
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
293
+
294
+ if pos is not None and pos.shape != (B, S * P, 2):
295
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
296
+
297
+ intermediates = []
298
+
299
+ # by default, self.aa_block_size=1, which processes one block at a time
300
+ for _ in range(self.aa_block_size):
301
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
302
+ global_idx += 1
303
+ intermediates.append(tokens.view(B, S, P, C))
304
+
305
+ return tokens, global_idx, intermediates
306
+
307
+
308
+ def slice_expand_and_flatten(token_tensor, B, S):
309
+ """
310
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
311
+ 1) Uses the first position (index=0) for the first frame only
312
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
313
+ 3) Expands both to match batch size B
314
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
315
+ followed by (S-1) second-position tokens
316
+ 5) Flattens to (B*S, X, C) for processing
317
+
318
+ Returns:
319
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
320
+ """
321
+
322
+ # Slice out the "query" tokens => shape (1, 1, ...)
323
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
324
+ # Slice out the "other" tokens => shape (1, S-1, ...)
325
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
326
+ # Concatenate => shape (B, S, ...)
327
+ combined = torch.cat([query, others], dim=1)
328
+
329
+ # Finally flatten => shape (B*S, ...)
330
+ combined = combined.view(B * S, *combined.shape[2:])
331
+ return combined
vggt/models/vggt.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
+
11
+ from vggt.models.aggregator import Aggregator
12
+ from vggt.heads.camera_head import CameraHead
13
+ from vggt.heads.dpt_head import DPTHead
14
+ from vggt.heads.track_head import TrackHead
15
+
16
+
17
+ class VGGT(nn.Module, PyTorchModelHubMixin):
18
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
19
+ super().__init__()
20
+
21
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
22
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
23
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
24
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
25
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
26
+
27
+ def forward(
28
+ self,
29
+ images: torch.Tensor,
30
+ query_points: torch.Tensor = None,
31
+ ):
32
+ """
33
+ Forward pass of the VGGT model.
34
+
35
+ Args:
36
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
37
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
38
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
39
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
40
+ Default: None
41
+
42
+ Returns:
43
+ dict: A dictionary containing the following predictions:
44
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
45
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
46
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
47
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
48
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
49
+ - images (torch.Tensor): Original input images, preserved for visualization
50
+
51
+ If query_points is provided, also includes:
52
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
53
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
54
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
55
+ """
56
+
57
+ # If without batch dimension, add it
58
+ if len(images.shape) == 4:
59
+ images = images.unsqueeze(0)
60
+ if query_points is not None and len(query_points.shape) == 2:
61
+ query_points = query_points.unsqueeze(0)
62
+
63
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
64
+
65
+ predictions = {}
66
+
67
+ with torch.cuda.amp.autocast(enabled=False):
68
+ if self.camera_head is not None:
69
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
70
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
71
+
72
+ if self.depth_head is not None:
73
+ depth, depth_conf = self.depth_head(
74
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
75
+ )
76
+ predictions["depth"] = depth
77
+ predictions["depth_conf"] = depth_conf
78
+
79
+ if self.point_head is not None:
80
+ pts3d, pts3d_conf = self.point_head(
81
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
82
+ )
83
+ predictions["world_points"] = pts3d
84
+ predictions["world_points_conf"] = pts3d_conf
85
+
86
+ if self.track_head is not None and query_points is not None:
87
+ track_list, vis, conf = self.track_head(
88
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
89
+ )
90
+ predictions["track"] = track_list[-1] # track of the last iteration
91
+ predictions["vis"] = vis
92
+ predictions["conf"] = conf
93
+
94
+ predictions["images"] = images
95
+
96
+ return predictions
vggt/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
vggt/utils/load_fn.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 518
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+
55
+ # Open image
56
+ img = Image.open(image_path)
57
+
58
+ # If there's an alpha channel, blend onto white background:
59
+ if img.mode == "RGBA":
60
+ # Create white background
61
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
62
+ # Alpha composite onto the white background
63
+ img = Image.alpha_composite(background, img)
64
+
65
+ # Now convert to "RGB" (this step assigns white for transparent areas)
66
+ img = img.convert("RGB")
67
+
68
+ width, height = img.size
69
+
70
+ if mode == "pad":
71
+ # Make the largest dimension 518px while maintaining aspect ratio
72
+ if width >= height:
73
+ new_width = target_size
74
+ new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14
75
+ else:
76
+ new_height = target_size
77
+ new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14
78
+ else: # mode == "crop"
79
+ # Original behavior: set width to 518px
80
+ new_width = target_size
81
+ # Calculate height maintaining aspect ratio, divisible by 14
82
+ new_height = round(height * (new_width / width) / 14) * 14
83
+
84
+ # Resize with new dimensions (width, height)
85
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
86
+ img = to_tensor(img) # Convert to tensor (0, 1)
87
+
88
+ # Center crop height if it's larger than 518 (only in crop mode)
89
+ if mode == "crop" and new_height > target_size:
90
+ start_y = (new_height - target_size) // 2
91
+ img = img[:, start_y : start_y + target_size, :]
92
+
93
+ # For pad mode, pad to make a square of target_size x target_size
94
+ if mode == "pad":
95
+ h_padding = target_size - img.shape[1]
96
+ w_padding = target_size - img.shape[2]
97
+
98
+ if h_padding > 0 or w_padding > 0:
99
+ pad_top = h_padding // 2
100
+ pad_bottom = h_padding - pad_top
101
+ pad_left = w_padding // 2
102
+ pad_right = w_padding - pad_left
103
+
104
+ # Pad with white (value=1.0)
105
+ img = torch.nn.functional.pad(
106
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
107
+ )
108
+
109
+ shapes.add((img.shape[1], img.shape[2]))
110
+ images.append(img)
111
+
112
+ # Check if we have different shapes
113
+ # In theory our model can also work well with different shapes
114
+ if len(shapes) > 1:
115
+ print(f"Warning: Found images with different shapes: {shapes}")
116
+ # Find maximum dimensions
117
+ max_height = max(shape[0] for shape in shapes)
118
+ max_width = max(shape[1] for shape in shapes)
119
+
120
+ # Pad images if necessary
121
+ padded_images = []
122
+ for img in images:
123
+ h_padding = max_height - img.shape[1]
124
+ w_padding = max_width - img.shape[2]
125
+
126
+ if h_padding > 0 or w_padding > 0:
127
+ pad_top = h_padding // 2
128
+ pad_bottom = h_padding - pad_top
129
+ pad_left = w_padding // 2
130
+ pad_right = w_padding - pad_left
131
+
132
+ img = torch.nn.functional.pad(
133
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
134
+ )
135
+ padded_images.append(img)
136
+ images = padded_images
137
+
138
+ images = torch.stack(images) # concatenate images
139
+
140
+ # Ensure correct shape when single image
141
+ if len(image_path_list) == 1:
142
+ # Verify shape is (1, C, H, W)
143
+ if images.dim() == 3:
144
+ images = images.unsqueeze(0)
145
+
146
+ return images
vggt/utils/pose_enc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ ):
17
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
18
+
19
+ This function transforms camera parameters into a unified pose encoding format,
20
+ which can be used for various downstream tasks like pose prediction or representation.
21
+
22
+ Args:
23
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
24
+ where B is batch size and S is sequence length.
25
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
26
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
27
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
28
+ Defined in pixels, with format:
29
+ [[fx, 0, cx],
30
+ [0, fy, cy],
31
+ [0, 0, 1]]
32
+ where fx, fy are focal lengths and (cx, cy) is the principal point
33
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
34
+ Required for computing field of view values. For example: (256, 512).
35
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
36
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
37
+
38
+ Returns:
39
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
40
+ For "absT_quaR_FoV" type, the 9 dimensions are:
41
+ - [:3] = absolute translation vector T (3D)
42
+ - [3:7] = rotation as quaternion quat (4D)
43
+ - [7:] = field of view (2D)
44
+ """
45
+
46
+ # extrinsics: BxSx3x4
47
+ # intrinsics: BxSx3x3
48
+
49
+ if pose_encoding_type == "absT_quaR_FoV":
50
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
51
+ T = extrinsics[:, :, :3, 3] # BxSx3
52
+
53
+ quat = mat_to_quat(R)
54
+ # Note the order of h and w here
55
+ H, W = image_size_hw
56
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
57
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
58
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return pose_encoding
63
+
64
+
65
+ def pose_encoding_to_extri_intri(
66
+ pose_encoding,
67
+ image_size_hw=None, # e.g., (256, 512)
68
+ pose_encoding_type="absT_quaR_FoV",
69
+ build_intrinsics=True,
70
+ ):
71
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
72
+
73
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
74
+ reconstructing the full camera parameters from the compact encoding.
75
+
76
+ Args:
77
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
78
+ where B is batch size and S is sequence length.
79
+ For "absT_quaR_FoV" type, the 9 dimensions are:
80
+ - [:3] = absolute translation vector T (3D)
81
+ - [3:7] = rotation as quaternion quat (4D)
82
+ - [7:] = field of view (2D)
83
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
84
+ Required for reconstructing intrinsics from field of view values.
85
+ For example: (256, 512).
86
+ pose_encoding_type (str): Type of pose encoding used. Currently only
87
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
88
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
89
+ If False, only extrinsics are returned and intrinsics will be None.
90
+
91
+ Returns:
92
+ tuple: (extrinsics, intrinsics)
93
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
94
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
95
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
96
+ a 3x1 translation vector.
97
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
98
+ or None if build_intrinsics is False. Defined in pixels, with format:
99
+ [[fx, 0, cx],
100
+ [0, fy, cy],
101
+ [0, 0, 1]]
102
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
103
+ assumed to be at the center of the image (W/2, H/2).
104
+ """
105
+
106
+ intrinsics = None
107
+
108
+ if pose_encoding_type == "absT_quaR_FoV":
109
+ T = pose_encoding[..., :3]
110
+ quat = pose_encoding[..., 3:7]
111
+ fov_h = pose_encoding[..., 7]
112
+ fov_w = pose_encoding[..., 8]
113
+
114
+ R = quat_to_mat(quat)
115
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
116
+
117
+ if build_intrinsics:
118
+ H, W = image_size_hw
119
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
120
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
121
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
122
+ intrinsics[..., 0, 0] = fx
123
+ intrinsics[..., 1, 1] = fy
124
+ intrinsics[..., 0, 2] = W / 2
125
+ intrinsics[..., 1, 2] = H / 2
126
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ return extrinsics, intrinsics
vggt/utils/rotation.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [
67
+ 1.0 + m00 + m11 + m22,
68
+ 1.0 + m00 - m11 - m22,
69
+ 1.0 - m00 + m11 - m22,
70
+ 1.0 - m00 - m11 + m22,
71
+ ],
72
+ dim=-1,
73
+ )
74
+ )
75
+
76
+ # we produce the desired quaternion multiplied by each of r, i, j, k
77
+ quat_by_rijk = torch.stack(
78
+ [
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
85
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
86
+ # `int`.
87
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
88
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
89
+ # `int`.
90
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
91
+ ],
92
+ dim=-2,
93
+ )
94
+
95
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
96
+ # the candidate won't be picked.
97
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
98
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
99
+
100
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
101
+ # forall i; we pick the best-conditioned one (with the largest denominator)
102
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+
112
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Returns torch.sqrt(torch.max(0, x))
115
+ but with a zero subgradient where x is 0.
116
+ """
117
+ ret = torch.zeros_like(x)
118
+ positive_mask = x > 0
119
+ if torch.is_grad_enabled():
120
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
121
+ else:
122
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
123
+ return ret
124
+
125
+
126
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Convert a unit quaternion to a standard form: one in which the real
129
+ part is non negative.
130
+
131
+ Args:
132
+ quaternions: Quaternions with real part last,
133
+ as tensor of shape (..., 4).
134
+
135
+ Returns:
136
+ Standardized quaternions as tensor of shape (..., 4).
137
+ """
138
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
vggt/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
vision_tower.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys
2
+ # sys.path.append("..")
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.init as init
7
+ import torch.nn.functional as F
8
+
9
+ from paths import *
10
+ from typing import Dict, List, Optional, Set, Tuple, Union
11
+ import os
12
+
13
+ from contextlib import nullcontext
14
+ from vggt.models.vggt import VGGT
15
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
16
+ from vggt.layers import Mlp
17
+ from vggt.layers.block import Block
18
+ from vggt.heads.head_act import activate_pose
19
+
20
+ class OriAny_CameraHead(nn.Module):
21
+ """
22
+ CameraHead predicts camera parameters from token representations using iterative refinement.
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+ def __init__(
26
+ self,
27
+ dim_in: int = 2048,
28
+ trunk_depth: int = 4,
29
+ pose_encoding_type: str = "OriAny",
30
+ num_heads: int = 16,
31
+ mlp_ratio: int = 4,
32
+ init_values: float = 0.01,
33
+ ):
34
+ super().__init__()
35
+
36
+ if pose_encoding_type == "OriAny":
37
+ self.target_dim = 360+180+360+2
38
+ else:
39
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
40
+
41
+ self.trunk_depth = trunk_depth
42
+
43
+ # Build the trunk using a sequence of transformer blocks.
44
+ self.trunk = nn.Sequential(
45
+ *[
46
+ Block(
47
+ dim=dim_in,
48
+ num_heads=num_heads,
49
+ mlp_ratio=mlp_ratio,
50
+ init_values=init_values,
51
+ )
52
+ for _ in range(trunk_depth)
53
+ ]
54
+ )
55
+
56
+ # Normalizations for camera token and trunk output.
57
+ self.token_norm = nn.LayerNorm(dim_in)
58
+ self.trunk_norm = nn.LayerNorm(dim_in)
59
+
60
+ # Learnable empty camera pose token.
61
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
62
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
63
+
64
+ # Module for producing modulation parameters: shift, scale, and a gate.
65
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
66
+
67
+ # Adaptive layer normalization without affine parameters.
68
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
69
+ self.pose_branch = Mlp(
70
+ in_features=dim_in,
71
+ hidden_features=dim_in // 2,
72
+ out_features=self.target_dim,
73
+ drop=0,
74
+ )
75
+
76
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
77
+ """
78
+ Forward pass to predict camera parameters.
79
+
80
+ Args:
81
+ aggregated_tokens_list (list): List of token tensors from the network;
82
+ the last tensor is used for prediction.
83
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
84
+
85
+ Returns:
86
+ list: A list of predicted camera encodings (post-activation) from each iteration.
87
+ """
88
+ # Use tokens from the last block for camera prediction.
89
+ tokens = aggregated_tokens_list[-1]
90
+
91
+ # Extract the camera tokens
92
+ pose_tokens = tokens[:, :, 0]
93
+ pose_tokens = self.token_norm(pose_tokens)
94
+
95
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
96
+ return pred_pose_enc_list
97
+
98
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
99
+ """
100
+ Iteratively refine camera pose predictions.
101
+
102
+ Args:
103
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
104
+ num_iterations (int): Number of refinement iterations.
105
+
106
+ Returns:
107
+ list: List of activated camera encodings from each iteration.
108
+ """
109
+ B, S, C = pose_tokens.shape # S is expected to be 1.
110
+ pred_pose_enc = None
111
+ pred_pose_enc_list = []
112
+
113
+ for _ in range(num_iterations):
114
+ # Use a learned empty pose for the first iteration.
115
+ if pred_pose_enc is None:
116
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
117
+ else:
118
+ # Detach the previous prediction to avoid backprop through time.
119
+ pred_pose_enc = pred_pose_enc.detach()
120
+ module_input = self.embed_pose(pred_pose_enc)
121
+
122
+ # Generate modulation parameters and split them into shift, scale, and gate components.
123
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
124
+
125
+ # Adaptive layer normalization and modulation.
126
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
127
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
128
+
129
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
130
+ # Compute the delta update for the pose encoding.
131
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
132
+
133
+ if pred_pose_enc is None:
134
+ pred_pose_enc = pred_pose_enc_delta
135
+ else:
136
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
137
+
138
+ # Apply final activation functions for translation, quaternion, and field-of-view.
139
+ # activated_pose = activate_pose(
140
+ # pred_pose_enc,
141
+ # trans_act=self.trans_act,
142
+ # quat_act=self.quat_act,
143
+ # fl_act=self.fl_act,
144
+ # )
145
+ # pred_pose_enc_list.append(activated_pose)
146
+ pred_pose_enc_list.append(pred_pose_enc)
147
+
148
+ return pred_pose_enc_list
149
+
150
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
151
+ """
152
+ Modulate the input tensor using scaling and shifting parameters.
153
+ """
154
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
155
+ return x * (1 + scale) + shift
156
+
157
+ def load_patch_embed_weights(model, checkpoint_path):
158
+ # 1. 加载 checkpoint
159
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
160
+
161
+ # 2. 获取 state_dict
162
+ state_dict = checkpoint.get("state_dict", checkpoint)
163
+
164
+ # 3. 过滤只包含 aggregator.patch_embed 的参数
165
+ patch_embed_state = {
166
+ k.replace("aggregator.patch_embed.", ""): v
167
+ for k, v in state_dict.items()
168
+ if k.startswith("aggregator.patch_embed.")
169
+ }
170
+
171
+ # 4. 加载到目标模块
172
+ missing_keys, unexpected_keys = model.aggregator.patch_embed.load_state_dict(
173
+ patch_embed_state, strict=False
174
+ )
175
+
176
+ print("Loaded patch_embed weights.")
177
+ print("Missing keys:", missing_keys)
178
+ print("Unexpected keys:", unexpected_keys)
179
+
180
+ class VGGT_OriAny_Ref(nn.Module):
181
+ def __init__(self,
182
+ dtype,
183
+ out_dim,
184
+ nopretrain
185
+ ) -> None:
186
+ super().__init__()
187
+ self.vggt = VGGT()
188
+
189
+ self.dtype = dtype
190
+ self.ref_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
191
+ self.ref_sampler.apply(init_weights)
192
+ self.tgt_sampler = MLP_dim(in_dim=2048, out_dim=out_dim)
193
+ self.tgt_sampler.apply(init_weights)
194
+
195
+ def forward(self, img_inputs):
196
+ device = self.get_device()
197
+
198
+ with torch.amp.autocast(device_type='cuda', dtype=self.dtype):
199
+ if img_inputs.shape == 4:
200
+ img_inputs = img_inputs[None]
201
+ aggregated_tokens_list, ps_idx = self.vggt.aggregator(img_inputs)
202
+
203
+ # Predict Cameras
204
+ # pose_enc = self.oriany_camera_head(aggregated_tokens_list)[-1]
205
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
206
+ # extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
207
+
208
+ # Use tokens from the last block for camera prediction.
209
+ tokens = aggregated_tokens_list[-1]
210
+ # Extract the camera tokens
211
+ pose_tokens = tokens[:, :, 0]
212
+ # tokens = aggregated_tokens_list[-1]
213
+
214
+ B, S, C = pose_tokens.shape
215
+ if S>1:
216
+ # 分离每个 batch 的第一个 token 和其余 token
217
+ ref_tokens = pose_tokens[:, 0, :] # shape: (B, C)
218
+ tgt_tokens = pose_tokens[:, 1:, :] # shape: (B, S-1, C)
219
+
220
+ # 下采样
221
+ ref_feat = self.ref_sampler(ref_tokens) # shape: (B, C'),假设输出 channel 为 C'
222
+ tgt_feat = self.tgt_sampler(tgt_tokens.reshape(B * (S - 1), C)) # shape: (B*(S-1), C')
223
+
224
+ # 合并结果
225
+ pose_enc = torch.cat([
226
+ ref_feat.unsqueeze(1), # (B, 1, C')
227
+ tgt_feat.view(B, S - 1, -1) # (B, S-1, C')
228
+ ], dim=1) # 最终 shape: (B*S, C')
229
+ else:
230
+ pose_enc = self.ref_sampler(pose_tokens.view(B*S,C))
231
+ return pose_enc
232
+
233
+ def get_device(self):
234
+ return next(self.parameters()).device
235
+ def init_weights(m):
236
+ if isinstance(m, nn.Linear):
237
+ init.xavier_uniform_(m.weight)
238
+ if m.bias is not None:
239
+ init.constant_(m.bias, 0)
240
+
241
+ def get_activation(activation):
242
+ if activation.lower() == 'gelu':
243
+ return nn.GELU()
244
+ elif activation.lower() == 'rrelu':
245
+ return nn.RReLU(inplace=True)
246
+ elif activation.lower() == 'selu':
247
+ return nn.SELU(inplace=True)
248
+ elif activation.lower() == 'silu':
249
+ return nn.SiLU(inplace=True)
250
+ elif activation.lower() == 'hardswish':
251
+ return nn.Hardswish(inplace=True)
252
+ elif activation.lower() == 'leakyrelu':
253
+ return nn.LeakyReLU(inplace=True)
254
+ elif activation.lower() == 'sigmoid':
255
+ return nn.Sigmoid()
256
+ elif activation.lower() == 'tanh':
257
+ return nn.Tanh()
258
+ else:
259
+ return nn.ReLU(inplace=True)
260
+
261
+ class MLP_dim(nn.Module):
262
+ def __init__(
263
+ self, in_dim=512, out_dim=1024, bias=True, activation='relu'):
264
+ super().__init__()
265
+ self.act = get_activation(activation)
266
+ self.net1 = nn.Sequential(
267
+ nn.Linear(in_dim, int(out_dim), bias=bias),
268
+ nn.BatchNorm1d(int(out_dim)),
269
+ self.act
270
+ )
271
+ self.net2 = nn.Sequential(
272
+ nn.Linear(int(out_dim), out_dim, bias=bias),
273
+ nn.BatchNorm1d(out_dim)
274
+ )
275
+
276
+ def forward(self, x):
277
+ return self.net2(self.net1(x))
278
+
279
+