Kalpit
commited on
Commit
·
d39b279
0
Parent(s):
feat: Add model files with LFS
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +2 -0
- LICENSE +201 -0
- Preprocess/compress.py +38 -0
- Preprocess/folder2csv.py +72 -0
- Preprocess/video2frame.py +57 -0
- README.md +88 -0
- __pycache__/dataloader2.cpython-39.pyc +0 -0
- __pycache__/util.cpython-39.pyc +0 -0
- commands.md +3 -0
- create_csv.py +44 -0
- create_submission.py +35 -0
- dataloader.py +281 -0
- dataloader2.py +246 -0
- eval.py +48 -0
- eval2.py +110 -0
- extract_frames.py +74 -0
- models/DeMamba.py +176 -0
- models/F3Net.py +434 -0
- models/FTCN.py +143 -0
- models/MINTIME +269 -0
- models/NPR.py +284 -0
- models/STIL.py +641 -0
- models/TALL.py +935 -0
- models/VideoMAE.py +67 -0
- models/XCLIP.py +33 -0
- models/__init__.py +5 -0
- models/__pycache__/DeMamba.cpython-39.pyc +0 -0
- models/__pycache__/F3Net.cpython-39.pyc +0 -0
- models/__pycache__/NPR.cpython-39.pyc +0 -0
- models/__pycache__/STIL.cpython-39.pyc +0 -0
- models/__pycache__/XCLIP.cpython-39.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/mamba_base.cpython-39.pyc +0 -0
- models/__pycache__/pscan.cpython-39.pyc +0 -0
- models/clip/__init__.py +1 -0
- models/clip/__pycache__/__init__.cpython-39.pyc +0 -0
- models/clip/__pycache__/clip.cpython-39.pyc +0 -0
- models/clip/__pycache__/model.cpython-39.pyc +0 -0
- models/clip/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
- models/clip/clip.py +233 -0
- models/clip/model.py +432 -0
- models/clip/simple_tokenizer.py +132 -0
- models/mamba_base.py +352 -0
- models/pscan.py +232 -0
- models/time_transformer +256 -0
- requirements.txt +59 -0
- results.csv +6 -0
- script.py +22 -0
- submission.csv +2 -0
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
results/kling_9k_9k/best_acc.pth
|
| 2 |
+
models/clip/bpe_simple_vocab_16e6.txt.gz
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Preprocess/compress.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
from glob import glob
|
| 3 |
+
|
| 4 |
+
def convert_and_compress_video(input_video_path, output_video_path, crf=23, preset='medium'):
|
| 5 |
+
"""
|
| 6 |
+
使用ffmpeg转换并压缩视频
|
| 7 |
+
:param input_video_path: 输入视频文件路径
|
| 8 |
+
:param output_video_path: 输出压缩后视频文件路径
|
| 9 |
+
:param crf: Constant Rate Factor,值越小质量越好,但文件也越大,默认为23是一个平衡点
|
| 10 |
+
:param preset: 预设值,影响压缩速度和文件大小,如'ultrafast', 'fast', 'medium', 'slow', 'veryslow'等,默认'medium'
|
| 11 |
+
"""
|
| 12 |
+
# 检查文件是否为GIF
|
| 13 |
+
if input_video_path.lower().endswith('.gif'):
|
| 14 |
+
# 构建ffmpeg命令
|
| 15 |
+
# 首先将GIF转换为MP4
|
| 16 |
+
convert_cmd = f'ffmpeg -i "{input_video_path}" -vf "palettegen" -y palette.png'
|
| 17 |
+
subprocess.run(convert_cmd, shell=True, check=True)
|
| 18 |
+
convert_cmd = f'ffmpeg -i "{input_video_path}" -i palette.png -lavfi "paletteuse" -c:v libx264 -preset {preset} -crf {crf} -c:a copy "{output_video_path}"'
|
| 19 |
+
else:
|
| 20 |
+
# 直接压缩视频
|
| 21 |
+
convert_cmd = f'ffmpeg -i "{input_video_path}" -c:v libx264 -preset {preset} -crf {crf} -c:a copy "{output_video_path}"'
|
| 22 |
+
|
| 23 |
+
# 执行ffmpeg命令
|
| 24 |
+
try:
|
| 25 |
+
subprocess.run(convert_cmd, shell=True, check=True)
|
| 26 |
+
print(f"视频压缩完成,输出文件:{output_video_path}")
|
| 27 |
+
except subprocess.CalledProcessError as e:
|
| 28 |
+
print(f"视频压缩失败:{e}")
|
| 29 |
+
|
| 30 |
+
video_paths = '/home3/Transformer'
|
| 31 |
+
all_dirs = glob(video_paths+'/*')
|
| 32 |
+
output_dirs = '/home3/robust/compress/Transformer_'
|
| 33 |
+
|
| 34 |
+
for path in all_dirs:
|
| 35 |
+
name = path.split('/')[-1]
|
| 36 |
+
out_path = output_dirs + name.replace('.gif', '.mp4') if path.lower().endswith('.gif') else output_dirs + name
|
| 37 |
+
print(name)
|
| 38 |
+
convert_and_compress_video(path, out_path, crf=28)
|
Preprocess/folder2csv.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from pandas import Series, DataFrame
|
| 5 |
+
from glob import glob
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
def count_images_in_folder(folder_path):
|
| 9 |
+
image_count = 0
|
| 10 |
+
image_names = []
|
| 11 |
+
for file_name in os.listdir(folder_path):
|
| 12 |
+
if file_name.endswith('.png') or file_name.endswith('.jpg') or file_name.endswith('.jpeg'):
|
| 13 |
+
image_count += 1
|
| 14 |
+
image_names.append(int(file_name.split('.')[0]))
|
| 15 |
+
image_names.sort()
|
| 16 |
+
return image_count, image_names
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
folder_path = './SD_frames'
|
| 20 |
+
all_dirs = []
|
| 21 |
+
|
| 22 |
+
for root, dirs, files in os.walk(folder_path):
|
| 23 |
+
for dir in dirs:
|
| 24 |
+
all_dirs.append(os.path.join(root, dir))
|
| 25 |
+
|
| 26 |
+
label = list()
|
| 27 |
+
save_path = list()
|
| 28 |
+
frame_counts = list()
|
| 29 |
+
frame_seq_counts = list()
|
| 30 |
+
content_paths = list()
|
| 31 |
+
chinese_labels = list()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
for video_path in all_dirs:
|
| 35 |
+
frame_paths = glob(video_path + '/*')
|
| 36 |
+
temp_frame_count, temp_frame_seqs = count_images_in_folder(video_path)
|
| 37 |
+
if temp_frame_count == 0:
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
for frame in frame_paths:
|
| 41 |
+
content_path = frame.split('/')[1:-1]
|
| 42 |
+
content_path = '/'.join(content_path)
|
| 43 |
+
# input your own path
|
| 44 |
+
content_path = '/home/AIGC_Video_Det/SD/' + content_path
|
| 45 |
+
|
| 46 |
+
frame_path = frame.split('/')[1:]
|
| 47 |
+
frame_path = '/'.join(frame_path)
|
| 48 |
+
frame_path = '/home/AIGC_Video_Det/SD/' + frame_path
|
| 49 |
+
|
| 50 |
+
print(content_path, frame_path)
|
| 51 |
+
label.append(str(1))
|
| 52 |
+
frame_counts.append(int(temp_frame_count))
|
| 53 |
+
frame_seq_counts.append(temp_frame_seqs)
|
| 54 |
+
save_path.append(frame_path)
|
| 55 |
+
content_paths.append(content_path)
|
| 56 |
+
chinese_labels.append('AIGC视频')
|
| 57 |
+
# chinese_labels.append('真实视频')
|
| 58 |
+
break
|
| 59 |
+
|
| 60 |
+
dic={
|
| 61 |
+
'content_path': Series(data=content_paths),
|
| 62 |
+
'image_path': Series(data=save_path),
|
| 63 |
+
'type_id': Series(data=chinese_labels),
|
| 64 |
+
'label': Series(data=label),
|
| 65 |
+
'frame_len': Series(data=frame_counts),
|
| 66 |
+
'frame_seq': Series(data=frame_seq_counts)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
print(dic)
|
| 70 |
+
pd.DataFrame(dic).to_csv('SD.csv', encoding='utf-8', index=False)
|
| 71 |
+
|
| 72 |
+
|
Preprocess/video2frame.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
from moviepy.editor import VideoFileClip
|
| 4 |
+
import multiprocessing
|
| 5 |
+
import cv2, math
|
| 6 |
+
|
| 7 |
+
def get_video_length(file_path):
|
| 8 |
+
video = VideoFileClip(file_path)
|
| 9 |
+
return video.duration
|
| 10 |
+
|
| 11 |
+
def process_video(video_path):
|
| 12 |
+
video_name = video_path.split('/')[-1]
|
| 13 |
+
video_name = video_name.split('.')[:-1]
|
| 14 |
+
video_name = '.'.join(video_name)
|
| 15 |
+
|
| 16 |
+
path = video_path.split('/')[1:-1]
|
| 17 |
+
path = '/'.join(path)
|
| 18 |
+
|
| 19 |
+
image_path = './SD_frames/'+path+'/'+ video_name+'/'
|
| 20 |
+
print(video_name)
|
| 21 |
+
if os.path.exists(image_path):
|
| 22 |
+
print("路径存在")
|
| 23 |
+
else:
|
| 24 |
+
print(video_name, "路径不存在")
|
| 25 |
+
try:
|
| 26 |
+
try:
|
| 27 |
+
video_length = get_video_length(video_path)
|
| 28 |
+
print(video_name, f"视频长度为:{video_length} 秒")
|
| 29 |
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
| 30 |
+
|
| 31 |
+
if video_length >= 4 :
|
| 32 |
+
inter_val = 2
|
| 33 |
+
os.system(f"cd {image_path} | ffmpeg -loglevel quiet -i {video_path} -r {inter_val} {image_path}%d.jpg")
|
| 34 |
+
else:
|
| 35 |
+
inter_val = math.ceil(8 / video_length)
|
| 36 |
+
os.system(f"cd {image_path} | ffmpeg -loglevel quiet -i {video_path} -r {inter_val} {image_path}%d.jpg")
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
print("发生异常:", str(e))
|
| 40 |
+
except:
|
| 41 |
+
print("Skip")
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
print("Getting frames!!")
|
| 45 |
+
video_paths = './SD'
|
| 46 |
+
all_dirs = []
|
| 47 |
+
all_dirs = glob(video_paths+'/*')
|
| 48 |
+
|
| 49 |
+
print(all_dirs)
|
| 50 |
+
|
| 51 |
+
pool = multiprocessing.Pool(processes=8)
|
| 52 |
+
pool.map(process_video, all_dirs)
|
| 53 |
+
pool.close()
|
| 54 |
+
pool.join()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This is the official code of paper 'DeMamba: AI-Generated Video Detection on Million-Scale GenVideo Benchmark'.
|
| 2 |
+
|
| 3 |
+
## :boom: News!!!
|
| 4 |
+
- (25.09.24) We have released a lightweight version of GenVideo, named [GenVideo-100K](https://modelscope.cn/datasets/cccnju/GenVideo-100K), with 10,000 samples for each forgery category.
|
| 5 |
+
- (24.12.11) We are pleased to announce that our AI-generated content (AIGC) video detection model, developed based on GenVideo, has passed the evaluation by the China Academy of Information and Communications Technology (CAICT) and achieved the highest rating, making us the first organization in China to be officially registered and approved.[[certificate](https://github.com/chenhaoxing/DeMamba/blob/main/figs/xty.jpg)][[report](https://mp.weixin.qq.com/s/OoW7EI1QoSrQ3FIftfbudg)]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## :dart: Todo
|
| 9 |
+
|
| 10 |
+
- [x] Release source code.
|
| 11 |
+
- [x] Release dataset.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## :file_folder: Dataset download
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
### Data preparation process
|
| 18 |
+
- Download the original videos.
|
| 19 |
+
|
| 20 |
+
- Generated videos: all generated videos can download at [ModelScope](https://modelscope.cn/collections/Gen-Video-7de46cd6846f4e).
|
| 21 |
+
|
| 22 |
+
- Real videos: The data from the MSRVTT dataset is contained within the GenVideo-Val.zip file. We also provided the selected Youku videos in previous link . For Kinetics-400, you will need to download it yourself at [https://github.com/cvdfoundation/kinetics-dataset](https://github.com/cvdfoundation/kinetics-dataset).
|
| 23 |
+
|
| 24 |
+
- Preprocess the video and get the data list csv file.
|
| 25 |
+
|
| 26 |
+
Statistics of real and generated videos in the GenVideo dataset:
|
| 27 |
+
| **Video Source** | **Type** | **Task** | **Time** | **Resolution** | **FPS** | **Length** | **Training Set** | **Testing Set** |
|
| 28 |
+
|-----------------------------------------------------|----------|----------|----------|----------------|---------|------------|------------------|----------------|
|
| 29 |
+
| Kinetics-400 | Real | - | 17.05 | 224-340 | - | 5-10s | 260,232 | - |
|
| 30 |
+
| Youku-mPLUG | Real | - | 23.07 | - | - | 10-120s | 953,279 | - |
|
| 31 |
+
| MSR-VTT | Real | - | 16.05 | - | - | 10-30s | - | 10,000 |
|
| 32 |
+
| ZeroScope | Fake | T2V | 23.07 | 1024×576 | 8 | 3s | 133,169 | - |
|
| 33 |
+
| I2VGen-XL | Fake | I2V | 23.12 | 1280×720 | 8 | 2s | 61,975 | - |
|
| 34 |
+
| SVD | Fake | I2V | 23.12 | 1024×576 | 8 | 4s | 149,026 | - |
|
| 35 |
+
| VideoCrafte | Fake | T2V | 24.01 | 1024×576 | 8 | 2s | 39,485 | - |
|
| 36 |
+
| Pika | Fake | T2V&I2V | 24.02 | 1088×640 | 24 | 3s | 98,377 | |
|
| 37 |
+
| DynamiCrafter | Fake | I2V | 24.03 | 1024×576 | 8 | 3s | 46,205 | - |
|
| 38 |
+
| SD | Fake | T2V&I2V | 23-24 | 512-1024 | 8 | 2-6s | 200,720 | - |
|
| 39 |
+
| SEINE | Fake | I2V | 24.04 | 1024×576 | 8 | 2-4s | 24,737 | - |
|
| 40 |
+
| Latte | Fake | T2V | 24.03 | 512×512 | 8 | 2s | 149,979 | - |
|
| 41 |
+
| OpenSora | Fake | T2V | 24.03 | 512×512 | 8 | 2s | 177,410 | - |
|
| 42 |
+
| ModelScope | Fake | T2V | 23.03 | 256×256 | 8 | 4s | - | 700 |
|
| 43 |
+
| MorphStudio | Fake | T2V | 23.08 | 1280×720 | 8 | 2s | - | 700 |
|
| 44 |
+
| MoonValley | Fake | T2V | 24.01 | 1024×576 | 16 | 3s | - | 626 |
|
| 45 |
+
| HotShot | Fake | T2V | 23.10 | 672×384 | 8 | 1s | - | 700 |
|
| 46 |
+
| Show_1 | Fake | T2V | 23.10 | 576×320 | 8 | 4s | - | 700 |
|
| 47 |
+
| Gen2 | Fake | I2V&T2V | 23.09 | 896×512 | 24 | 4s | - | 1,380 |
|
| 48 |
+
| Crafter | Fake | T2V | 23.04 | 256×256 | 8 | 4s | - | 1,400 |
|
| 49 |
+
| Lavie | Fake | T2V | 23.09 | 1280×2048 | 8 | 2s | - | 1,400 |
|
| 50 |
+
| Sora | Fake | T2V | 24.02 | - | - | -60s | - | 56 |
|
| 51 |
+
| WildScrape | Fake | T2V&I2V | 24 | 512-1024 | 8-16 | 2-6s | - | 926 |
|
| 52 |
+
| **Total Count** | - | - | - | - | - | - | 2,294,594 | 19,588 |
|
| 53 |
+
|
| 54 |
+
## :snake: Detail Mamba (DeMamba)
|
| 55 |
+
|
| 56 |
+

|
| 57 |
+
<p align="center"><em>In memory of Kobe Bryant (generated by GPT-4o)</em></p>
|
| 58 |
+
|
| 59 |
+
> "Determination wins games, but Detail wins championships." — *Kobe Bryant, in his Show Detail, 2018*
|
| 60 |
+
|
| 61 |
+

|
| 62 |
+
<p align="center"><em>The overall framework of our Detail Mamba (DeMamba)</em></p>
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## :space_invader: Citing GenVideo&DeMamba
|
| 66 |
+
If you use GenVideo or DeMamba in your research or use the codebase here, please use the following BibTeX entry.
|
| 67 |
+
|
| 68 |
+
```BibTeX
|
| 69 |
+
@article{DeMamba,
|
| 70 |
+
title={DeMamba: AI-Generated Video Detection on Million-Scale GenVideo Benchmark},
|
| 71 |
+
author={Haoxing Chen and Yan Hong and Zizheng Huang and Zhuoer Xu and Zhangxuan Gu and Yaohui Li and Jun Lan and Huijia Zhu and Jianfu Zhang and Weiqiang Wang and Huaxiong Li},
|
| 72 |
+
journal={arXiv preprint arXiv:2405.19707},
|
| 73 |
+
year={2024}
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Star History
|
| 78 |
+
|
| 79 |
+
[](https://star-history.com/#chenhaoxing/DeMamba&Date)
|
| 80 |
+
|
| 81 |
+
## Acknowledgement
|
| 82 |
+
Many thanks to the nice work of [STIL](https://github.com/wizyoung/STIL-DeepFake-Video-Detection), [CLIP](https://github.com/openai/CLIP), [XCLIP](https://github.com/microsoft/VideoX/tree/master/X-CLIP), [NPR](https://github.com/chuangchuangtan/NPR-DeepfakeDetection/tree/main) and [VideoMAE](https://github.com/MCG-NJU/VideoMAE-Action-Detection).
|
| 83 |
+
|
| 84 |
+
## :email: Contact
|
| 85 |
+
If you have any questions, feel free to contact us: hx.chen@hotmail.com.
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
__pycache__/dataloader2.cpython-39.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
__pycache__/util.cpython-39.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
commands.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python train.py --config /home/kalpit/workspace/aigc/repos/DeMamba/configs/my_config.yaml
|
| 2 |
+
|
| 3 |
+
python train.py --config /home/kalpit/workspace/aigc/repos/DeMamba/configs/vid_grand_v1_train.yaml
|
create_csv.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
|
| 4 |
+
# Source directories
|
| 5 |
+
real_dir = "/home/kalpit/workspace/aigc/data/ShareVeo3/test/0_real"
|
| 6 |
+
fake_dir = "/home/kalpit/workspace/aigc/data/ShareVeo3/test/1_fake"
|
| 7 |
+
|
| 8 |
+
# Output CSV path
|
| 9 |
+
output_csv = "/home/kalpit/workspace/aigc/repos/DeMamba/veo_test.csv"
|
| 10 |
+
|
| 11 |
+
# Function to get all image paths from a directory
|
| 12 |
+
def get_image_paths(directory, label):
|
| 13 |
+
image_paths = []
|
| 14 |
+
for root, _, files in os.walk(directory):
|
| 15 |
+
for file in files:
|
| 16 |
+
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 17 |
+
full_path = os.path.join(root, file)
|
| 18 |
+
image_paths.append({
|
| 19 |
+
"content_path": full_path,
|
| 20 |
+
"frame_seq": [full_path], # list containing the single frame path
|
| 21 |
+
"label": label
|
| 22 |
+
})
|
| 23 |
+
return image_paths
|
| 24 |
+
|
| 25 |
+
# Collect all images
|
| 26 |
+
data = []
|
| 27 |
+
data.extend(get_image_paths(real_dir, 0))
|
| 28 |
+
data.extend(get_image_paths(fake_dir, 1))
|
| 29 |
+
|
| 30 |
+
# Write to CSV
|
| 31 |
+
with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
|
| 32 |
+
fieldnames = ["content_path", "frame_seq", "label"]
|
| 33 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
| 34 |
+
|
| 35 |
+
writer.writeheader()
|
| 36 |
+
for row in data:
|
| 37 |
+
# Convert frame_seq list to string as shown in your example
|
| 38 |
+
writer.writerow({
|
| 39 |
+
"content_path": row["content_path"],
|
| 40 |
+
"frame_seq": str(row["frame_seq"]),
|
| 41 |
+
"label": row["label"]
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
print(f"CSV saved at: {output_csv}")
|
create_submission.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
# Input and output files
|
| 5 |
+
input_csv = "results.csv"
|
| 6 |
+
output_csv = "submission.csv"
|
| 7 |
+
|
| 8 |
+
# Load frame-level results
|
| 9 |
+
df = pd.read_csv(input_csv)
|
| 10 |
+
|
| 11 |
+
# Extract base video ID (remove _f###.jpg)
|
| 12 |
+
def extract_video_id(filename):
|
| 13 |
+
return re.sub(r"_f\d+\.\w+$", "", filename)
|
| 14 |
+
|
| 15 |
+
df["video_id"] = df["file_name"].apply(extract_video_id)
|
| 16 |
+
|
| 17 |
+
# Aggregate probabilities per video (mean of frame scores)
|
| 18 |
+
video_scores = (
|
| 19 |
+
df.groupby("video_id")["predicted_prob"]
|
| 20 |
+
.mean()
|
| 21 |
+
.reset_index()
|
| 22 |
+
.rename(columns={"video_id": "id", "predicted_prob": "score"})
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Assign label: generated if >0.5 else real
|
| 26 |
+
video_scores["pred"] = video_scores["score"].apply(
|
| 27 |
+
lambda x: "generated" if x > 0.5 else "real"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Reorder columns
|
| 31 |
+
video_scores = video_scores[["id", "pred", "score"]]
|
| 32 |
+
|
| 33 |
+
# Save submission file
|
| 34 |
+
video_scores.to_csv(output_csv, index=False)
|
| 35 |
+
print(f"Saved submission file to {output_csv}")
|
dataloader.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.utils.data as data
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import albumentations
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import math
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def crop_center_by_percentage(image, percentage):
|
| 15 |
+
height, width = image.shape[:2]
|
| 16 |
+
|
| 17 |
+
if width > height:
|
| 18 |
+
left_pixels = int(width * percentage)
|
| 19 |
+
right_pixels = int(width * percentage)
|
| 20 |
+
start_x = left_pixels
|
| 21 |
+
end_x = width - right_pixels
|
| 22 |
+
cropped_image = image[:, start_x:end_x]
|
| 23 |
+
else:
|
| 24 |
+
up_pixels = int(height * percentage)
|
| 25 |
+
down_pixels = int(height * percentage)
|
| 26 |
+
start_y = up_pixels
|
| 27 |
+
end_y = height - down_pixels
|
| 28 |
+
cropped_image = image[start_y:end_y, :]
|
| 29 |
+
|
| 30 |
+
return cropped_image
|
| 31 |
+
|
| 32 |
+
class Ours_Dataset_train(Dataset):
|
| 33 |
+
def __init__(self, index_list=None, df=None):
|
| 34 |
+
self.index_list = index_list
|
| 35 |
+
self.df = df
|
| 36 |
+
self.positive_indices = df[df['label'] == 1].index.tolist()
|
| 37 |
+
self.negative_indices = df[df['label'] == 0].index.tolist()
|
| 38 |
+
self.balanced_indices = []
|
| 39 |
+
self.resample()
|
| 40 |
+
|
| 41 |
+
def resample(self):
|
| 42 |
+
# Ensure each epoch uses a balanced dataset
|
| 43 |
+
min_samples = min(len(self.positive_indices), len(self.negative_indices))
|
| 44 |
+
self.balanced_indices.clear()
|
| 45 |
+
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
|
| 46 |
+
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
|
| 47 |
+
random.shuffle(self.balanced_indices) # Shuffle to mix positive and negative samples
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
real_idx = self.balanced_indices[idx]
|
| 51 |
+
row = self.df.iloc[real_idx]
|
| 52 |
+
video_id = row['content_path']
|
| 53 |
+
label = row['label']
|
| 54 |
+
frame_list = eval(row['frame_seq'])
|
| 55 |
+
label_onehot = [0]*2
|
| 56 |
+
select_frame_nums = 8
|
| 57 |
+
|
| 58 |
+
aug_list = [
|
| 59 |
+
albumentations.Resize(224, 224)
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
if random.random() < 0.5:
|
| 63 |
+
aug_list.append(albumentations.HorizontalFlip(p=1.0))
|
| 64 |
+
if random.random() < 0.5:
|
| 65 |
+
quality_score = random.randint(50, 100)
|
| 66 |
+
aug_list.append(albumentations.ImageCompression(quality_lower=quality_score, quality_upper=quality_score))
|
| 67 |
+
if random.random() < 0.3:
|
| 68 |
+
aug_list.append(albumentations.GaussNoise(p=1.0))
|
| 69 |
+
if random.random() < 0.3:
|
| 70 |
+
aug_list.append(albumentations.GaussianBlur(blur_limit=(3, 5), p=1.0))
|
| 71 |
+
if random.random() < 0.001:
|
| 72 |
+
aug_list.append(albumentations.ToGray(p=1.0))
|
| 73 |
+
|
| 74 |
+
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
|
| 75 |
+
trans = albumentations.Compose(aug_list)
|
| 76 |
+
|
| 77 |
+
if len(frame_list) >= select_frame_nums:
|
| 78 |
+
start_frame = random.randint(0, len(frame_list)-select_frame_nums)
|
| 79 |
+
select_frames = frame_list[start_frame:start_frame+select_frame_nums]
|
| 80 |
+
frames = []
|
| 81 |
+
for x in frame_list[start_frame:start_frame+select_frame_nums]:
|
| 82 |
+
while True:
|
| 83 |
+
try:
|
| 84 |
+
temp_image_path = video_id+'/'+str(x)+'.jpg'
|
| 85 |
+
image = download_oss_file('GenVideo/'+ temp_image_path)
|
| 86 |
+
if video_id.startswith("real/youku"):
|
| 87 |
+
image = crop_center_by_percentage(image, 0.15)
|
| 88 |
+
break
|
| 89 |
+
except Exception as e:
|
| 90 |
+
if x+1 < len(frame_list):
|
| 91 |
+
x = x + 1
|
| 92 |
+
elif x - 1 >=0 :
|
| 93 |
+
x = x - 1
|
| 94 |
+
augmented = trans(image=image)
|
| 95 |
+
image = augmented["image"]
|
| 96 |
+
frames.append(image.transpose(2,0,1)[np.newaxis,:])
|
| 97 |
+
else:
|
| 98 |
+
pad_num = select_frame_nums-len(frame_list)
|
| 99 |
+
frames = []
|
| 100 |
+
for x in frame_list:
|
| 101 |
+
temp_image_path = video_id+'/'+str(x)+'.jpg'
|
| 102 |
+
image = download_oss_file('GenVideo/'+temp_image_path)
|
| 103 |
+
if video_id.startswith("real/youku"):
|
| 104 |
+
image = crop_center_by_percentage(image, 0.15)
|
| 105 |
+
augmented = trans(image=image)
|
| 106 |
+
image = augmented["image"]
|
| 107 |
+
frames.append(image.transpose(2,0,1)[np.newaxis,:])
|
| 108 |
+
for i in range(pad_num):
|
| 109 |
+
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
|
| 110 |
+
|
| 111 |
+
label_onehot[int(label)] = 1
|
| 112 |
+
frames = np.concatenate(frames, 0)
|
| 113 |
+
frames = torch.tensor(frames[np.newaxis,:])
|
| 114 |
+
label_onehot = torch.FloatTensor(label_onehot)
|
| 115 |
+
binary_label = torch.FloatTensor([int(label)])
|
| 116 |
+
|
| 117 |
+
return self.index_list[idx], frames, label_onehot, binary_label
|
| 118 |
+
|
| 119 |
+
def __len__(self):
|
| 120 |
+
return len(self.balanced_indices)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Ours_Dataset_val(data.Dataset):
|
| 124 |
+
def __init__(self, cfg, index_list=None, df=None):
|
| 125 |
+
self.index_list = index_list
|
| 126 |
+
self.cfg = cfg
|
| 127 |
+
self.df = df
|
| 128 |
+
self.frame_dir = df['image_path'].tolist()
|
| 129 |
+
|
| 130 |
+
def __getitem__(self, idx):
|
| 131 |
+
aug_list = [
|
| 132 |
+
albumentations.Resize(224, 224),
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
if self.cfg['task'] == 'JPEG_Compress_Attack':
|
| 136 |
+
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35,p=1.0))
|
| 137 |
+
if self.cfg['task'] == 'FLIP_Attack':
|
| 138 |
+
if random.random() < 0.5:
|
| 139 |
+
aug_list.append(albumentations.HorizontalFlip(p=1.0))
|
| 140 |
+
else:
|
| 141 |
+
aug_list.append(albumentations.VerticalFlip(p=1.0))
|
| 142 |
+
if self.cfg['task'] == 'CROP_Attack':
|
| 143 |
+
random_crop_x = random.randint(0, 16)
|
| 144 |
+
random_crop_y = random.randint(0, 16)
|
| 145 |
+
crop_width = random.randint(160, 208)
|
| 146 |
+
crop_height = random.randint(160, 208)
|
| 147 |
+
aug_list.append(albumentations.Crop(x_min=random_crop_x, y_min=random_crop_y, x_max=random_crop_x+crop_width, y_max=random_crop_y+crop_height))
|
| 148 |
+
aug_list.append(albumentations.Resize(224, 224))
|
| 149 |
+
|
| 150 |
+
if self.cfg['task'] == 'Color_Attack':
|
| 151 |
+
index = random.choice([i for i in range(4)])
|
| 152 |
+
dicts = {0:[0.5,0,0,0],1:[0,0.5,0,0],2:[0,0,0.5,0],3:[0,0,0,0.5]}
|
| 153 |
+
brightness,contrast,saturation,hue = dicts[index]
|
| 154 |
+
aug_list.append(albumentations.ColorJitter(
|
| 155 |
+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue))
|
| 156 |
+
|
| 157 |
+
if self.cfg['task'] == 'Gaussian_Attack':
|
| 158 |
+
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
|
| 159 |
+
|
| 160 |
+
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
|
| 161 |
+
trans = albumentations.Compose(aug_list)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
df_v = self.df.loc[self.index_list[idx]]
|
| 165 |
+
video_id = df_v['content_path']
|
| 166 |
+
activity_id = df_v['activity_id']
|
| 167 |
+
label = df_v['label']
|
| 168 |
+
label_onehot = [0]*2
|
| 169 |
+
frame_list = eval(df_v['frame_seq'])
|
| 170 |
+
|
| 171 |
+
select_frame_nums = 8
|
| 172 |
+
|
| 173 |
+
if len(frame_list) >= select_frame_nums:
|
| 174 |
+
start_frame = random.randint(0, len(frame_list)-select_frame_nums)
|
| 175 |
+
select_frames = frame_list[start_frame:start_frame+select_frame_nums]
|
| 176 |
+
frames = []
|
| 177 |
+
for x in frame_list[start_frame:start_frame+select_frame_nums]:
|
| 178 |
+
while True:
|
| 179 |
+
try:
|
| 180 |
+
temp_image_path = video_id+'/'+str(x)+'.jpg'
|
| 181 |
+
image = download_oss_file('GenVideo/'+ temp_image_path)
|
| 182 |
+
image = crop_center_by_percentage(image, 0.1)
|
| 183 |
+
break
|
| 184 |
+
except Exception as e:
|
| 185 |
+
if x+1 < len(frame_list):
|
| 186 |
+
x = x + 1
|
| 187 |
+
elif x - 1 >=0 :
|
| 188 |
+
x = x - 1
|
| 189 |
+
augmented = trans(image=image)
|
| 190 |
+
image = augmented["image"]
|
| 191 |
+
frames.append(image.transpose(2,0,1)[np.newaxis,:])
|
| 192 |
+
else:
|
| 193 |
+
pad_num = select_frame_nums-len(frame_list)
|
| 194 |
+
frames = []
|
| 195 |
+
for x in frame_list:
|
| 196 |
+
temp_image_path = video_id+'/'+str(x)+'.jpg'
|
| 197 |
+
image = download_oss_file('GenVideo/'+temp_image_path)
|
| 198 |
+
image = crop_center_by_percentage(image, 0.1)
|
| 199 |
+
augmented = trans(image=image)
|
| 200 |
+
image = augmented["image"]
|
| 201 |
+
frames.append(image.transpose(2,0,1)[np.newaxis,:])
|
| 202 |
+
for i in range(pad_num):
|
| 203 |
+
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
|
| 204 |
+
|
| 205 |
+
label_onehot[int(label)] = 1
|
| 206 |
+
frames = np.concatenate(frames, 0)
|
| 207 |
+
frames = torch.tensor(frames[np.newaxis,:])
|
| 208 |
+
label_onehot = torch.FloatTensor(label_onehot)
|
| 209 |
+
binary_label = torch.FloatTensor([int(label)])
|
| 210 |
+
return self.index_list[idx], frames, label_onehot, binary_label, video_id
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.index_list)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def generate_dataset_loader(cfg):
|
| 219 |
+
df_train = pd.read_csv('GenVideo/datasets/train.csv')
|
| 220 |
+
|
| 221 |
+
if cfg['task'] == 'normal':
|
| 222 |
+
df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
|
| 223 |
+
elif cfg['task'] == 'robust_compress':
|
| 224 |
+
df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
|
| 225 |
+
elif cfg['task'] == 'Image_Water_Attack':
|
| 226 |
+
df_val = pd.read_csv('GenVideo/datasets/imgwater.csv')
|
| 227 |
+
elif cfg['task'] == 'Text_Water_Attack':
|
| 228 |
+
df_val = pd.read_csv('GenVideo/datasets/textwater.csv')
|
| 229 |
+
elif cfg['task'] == 'one2many':
|
| 230 |
+
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
|
| 231 |
+
if cfg['train_sub_set'] == 'pika':
|
| 232 |
+
prefixes = ["fake/pika", "real"]
|
| 233 |
+
video_condition = df_train['content_path'].str.startswith(prefixes[0])
|
| 234 |
+
for prefix in prefixes[1:]:
|
| 235 |
+
video_condition |= df_train['content_path'].str.startswith(prefix)
|
| 236 |
+
df_train = df_train[video_condition]
|
| 237 |
+
elif cfg['train_sub_set'] == 'SEINE':
|
| 238 |
+
prefixes = ["fake/SEINE", "real"]
|
| 239 |
+
video_condition = df_train['content_path'].str.startswith(prefixes[0])
|
| 240 |
+
for prefix in prefixes[1:]:
|
| 241 |
+
video_condition |= df_train['content_path'].str.startswith(prefix)
|
| 242 |
+
df_train = df_train[video_condition]
|
| 243 |
+
elif cfg['train_sub_set'] == 'OpenSora':
|
| 244 |
+
prefixes = ["fake/OpenSora", "real"]
|
| 245 |
+
video_condition = df_train['content_path'].str.startswith(prefixes[0])
|
| 246 |
+
for prefix in prefixes[1:]:
|
| 247 |
+
video_condition |= df_train['content_path'].str.startswith(prefix)
|
| 248 |
+
df_train = df_train[video_condition]
|
| 249 |
+
elif cfg['train_sub_set'] == 'Latte':
|
| 250 |
+
prefixes = ["fake/Latte", "real"]
|
| 251 |
+
video_condition = df_train['content_path'].str.startswith(prefixes[0])
|
| 252 |
+
for prefix in prefixes[1:]:
|
| 253 |
+
video_condition |= df_train['content_path'].str.startswith(prefix)
|
| 254 |
+
df_train = df_train[video_condition]
|
| 255 |
+
else:
|
| 256 |
+
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
|
| 257 |
+
|
| 258 |
+
df_train.reset_index(drop=True, inplace=True)
|
| 259 |
+
df_val.reset_index(drop=True, inplace=True)
|
| 260 |
+
|
| 261 |
+
index_val = df_val.index.tolist()
|
| 262 |
+
index_val = index_val[:]
|
| 263 |
+
|
| 264 |
+
val_dataset = Ours_Dataset_val(cfg, index_val, df_val)
|
| 265 |
+
val_loader = torch.utils.data.DataLoader(
|
| 266 |
+
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
index_train = df_train.index.tolist()
|
| 270 |
+
index_train = index_train[:]
|
| 271 |
+
train_dataset = Ours_Dataset_train(index_train, df_train)
|
| 272 |
+
train_loader = torch.utils.data.DataLoader(
|
| 273 |
+
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print("******* Training Video IDs", str(len(index_train))," Training Batch size ", str(cfg['train_batch_size'])," *******")
|
| 277 |
+
print("******* Testing Video IDs", str(len(index_val)), " Testing Batch size ", str(cfg['val_batch_size'])," *******")
|
| 278 |
+
|
| 279 |
+
return train_loader, val_loader
|
| 280 |
+
|
| 281 |
+
|
dataloader2.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# new_dataloader.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import albumentations
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
|
| 12 |
+
# =========== TRANSFORMATION HELPERS ===========
|
| 13 |
+
|
| 14 |
+
def get_train_transforms():
|
| 15 |
+
"""Defines the probabilistic augmentations for training."""
|
| 16 |
+
return albumentations.Compose([
|
| 17 |
+
albumentations.Resize(224, 224),
|
| 18 |
+
albumentations.HorizontalFlip(p=0.5),
|
| 19 |
+
albumentations.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
|
| 20 |
+
albumentations.GaussNoise(p=0.3),
|
| 21 |
+
albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3),
|
| 22 |
+
albumentations.ToGray(p=0.01),
|
| 23 |
+
albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
def get_val_transforms(cfg):
|
| 27 |
+
"""Defines augmentations for validation, handling different attack tasks from the config."""
|
| 28 |
+
aug_list = [albumentations.Resize(224, 224)]
|
| 29 |
+
|
| 30 |
+
task = cfg.get('task', 'normal') # Use .get for safety
|
| 31 |
+
|
| 32 |
+
if task == 'JPEG_Compress_Attack':
|
| 33 |
+
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35, p=1.0))
|
| 34 |
+
elif task == 'FLIP_Attack':
|
| 35 |
+
aug_list.append(albumentations.HorizontalFlip(p=0.5)) # Original had random choice, 50% HFlip is common
|
| 36 |
+
elif task == 'CROP_Attack':
|
| 37 |
+
aug_list.append(albumentations.RandomCrop(height=192, width=192, p=1.0))
|
| 38 |
+
aug_list.append(albumentations.Resize(224, 224))
|
| 39 |
+
elif task == 'Color_Attack':
|
| 40 |
+
aug_list.append(albumentations.ColorJitter(p=1.0))
|
| 41 |
+
elif task == 'Gaussian_Attack':
|
| 42 |
+
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
|
| 43 |
+
|
| 44 |
+
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
|
| 45 |
+
return albumentations.Compose(aug_list)
|
| 46 |
+
|
| 47 |
+
# =========== TRAINING DATASET ===========
|
| 48 |
+
|
| 49 |
+
class VideoDataset(Dataset):
|
| 50 |
+
"""
|
| 51 |
+
A PyTorch Dataset for loading video frame sequences based on a DataFrame.
|
| 52 |
+
Handles class balancing for each epoch.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
|
| 55 |
+
self.df = df
|
| 56 |
+
self.index_list = index_list
|
| 57 |
+
self.base_data_path = base_data_path
|
| 58 |
+
self.transform = transform
|
| 59 |
+
self.select_frame_nums = select_frame_nums
|
| 60 |
+
|
| 61 |
+
self.positive_indices = self.df[self.df['label'] == 1].index.tolist()
|
| 62 |
+
self.negative_indices = self.df[self.df['label'] == 0].index.tolist()
|
| 63 |
+
|
| 64 |
+
self.balanced_indices = []
|
| 65 |
+
self.resample()
|
| 66 |
+
|
| 67 |
+
def resample(self):
|
| 68 |
+
min_samples = min(len(self.positive_indices), len(self.negative_indices))
|
| 69 |
+
self.balanced_indices.clear()
|
| 70 |
+
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
|
| 71 |
+
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
|
| 72 |
+
random.shuffle(self.balanced_indices)
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.balanced_indices)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
real_idx = self.balanced_indices[idx]
|
| 79 |
+
row = self.df.iloc[real_idx]
|
| 80 |
+
|
| 81 |
+
video_id = row['content_path']
|
| 82 |
+
label = int(row['label'])
|
| 83 |
+
frame_list = eval(row['frame_seq'])
|
| 84 |
+
|
| 85 |
+
frames = []
|
| 86 |
+
|
| 87 |
+
if len(frame_list) >= self.select_frame_nums:
|
| 88 |
+
start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
|
| 89 |
+
selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
|
| 90 |
+
else:
|
| 91 |
+
selected_frames = frame_list
|
| 92 |
+
|
| 93 |
+
for frame_path in selected_frames:
|
| 94 |
+
try:
|
| 95 |
+
image = cv2.imread(frame_path)
|
| 96 |
+
if image is None:
|
| 97 |
+
raise ValueError("Failed to load")
|
| 98 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 99 |
+
except Exception:
|
| 100 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
| 101 |
+
|
| 102 |
+
if self.transform:
|
| 103 |
+
image = self.transform(image=image)['image']
|
| 104 |
+
|
| 105 |
+
frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
pad_num = self.select_frame_nums - len(frames)
|
| 109 |
+
if pad_num > 0:
|
| 110 |
+
for _ in range(pad_num):
|
| 111 |
+
frames.append(np.zeros((1, 3, 224, 224)))
|
| 112 |
+
|
| 113 |
+
frames_tensor = np.concatenate(frames, axis=0)
|
| 114 |
+
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
|
| 115 |
+
|
| 116 |
+
label_onehot = torch.zeros(2)
|
| 117 |
+
label_onehot[label] = 1.0
|
| 118 |
+
binary_label = torch.FloatTensor([label])
|
| 119 |
+
|
| 120 |
+
original_index = self.index_list[idx]
|
| 121 |
+
return original_index, frames_tensor, label_onehot, binary_label
|
| 122 |
+
|
| 123 |
+
# =========== VALIDATION DATASET ===========
|
| 124 |
+
|
| 125 |
+
class VideoDatasetVal(Dataset):
|
| 126 |
+
"""A compatible validation dataset loader."""
|
| 127 |
+
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
|
| 128 |
+
self.df = df
|
| 129 |
+
self.index_list = index_list
|
| 130 |
+
self.base_data_path = base_data_path
|
| 131 |
+
self.transform = transform
|
| 132 |
+
self.select_frame_nums = select_frame_nums
|
| 133 |
+
|
| 134 |
+
def __len__(self):
|
| 135 |
+
return len(self.index_list)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx):
|
| 138 |
+
# Validation does not use balanced sampling, it uses the provided index directly
|
| 139 |
+
real_idx = self.index_list[idx]
|
| 140 |
+
row = self.df.iloc[real_idx]
|
| 141 |
+
|
| 142 |
+
video_id = row['content_path']
|
| 143 |
+
label = int(row['label'])
|
| 144 |
+
frame_list = eval(row['frame_seq'])
|
| 145 |
+
|
| 146 |
+
# This part is identical to the training dataset's __getitem__
|
| 147 |
+
frames = []
|
| 148 |
+
if len(frame_list) >= self.select_frame_nums:
|
| 149 |
+
start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
|
| 150 |
+
selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
|
| 151 |
+
else:
|
| 152 |
+
selected_frames = frame_list
|
| 153 |
+
|
| 154 |
+
for frame_path in selected_frames:
|
| 155 |
+
try:
|
| 156 |
+
image = cv2.imread(frame_path)
|
| 157 |
+
if image is None:
|
| 158 |
+
raise ValueError("Failed to load")
|
| 159 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 160 |
+
except Exception:
|
| 161 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
| 162 |
+
|
| 163 |
+
if self.transform:
|
| 164 |
+
image = self.transform(image=image)['image']
|
| 165 |
+
|
| 166 |
+
frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
pad_num = self.select_frame_nums - len(frames)
|
| 170 |
+
if pad_num > 0:
|
| 171 |
+
for _ in range(pad_num):
|
| 172 |
+
frames.append(np.zeros((1, 3, 224, 224)))
|
| 173 |
+
|
| 174 |
+
frames_tensor = np.concatenate(frames, axis=0)
|
| 175 |
+
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
|
| 176 |
+
|
| 177 |
+
label_onehot = torch.zeros(2)
|
| 178 |
+
label_onehot[label] = 1.0
|
| 179 |
+
binary_label = torch.FloatTensor([label])
|
| 180 |
+
|
| 181 |
+
# The original validation loader returned video_id at the end
|
| 182 |
+
return self.index_list[idx], frames_tensor, label_onehot, binary_label, video_id
|
| 183 |
+
|
| 184 |
+
# =========== DATALOADER GENERATOR FUNCTION ===========
|
| 185 |
+
|
| 186 |
+
def generate_dataset_loader(cfg):
|
| 187 |
+
"""
|
| 188 |
+
The main function to create train and validation dataloaders using the new classes.
|
| 189 |
+
"""
|
| 190 |
+
df_train = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_train.csv')
|
| 191 |
+
|
| 192 |
+
# This logic for selecting different validation sets is preserved
|
| 193 |
+
task = cfg.get('task', 'normal')
|
| 194 |
+
if task == 'normal':
|
| 195 |
+
df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
|
| 196 |
+
elif task == 'robust_compress':
|
| 197 |
+
df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
|
| 198 |
+
# ... (add other elif conditions from your original script if needed) ...
|
| 199 |
+
else:
|
| 200 |
+
df_val = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_test.csv')
|
| 201 |
+
|
| 202 |
+
# This logic for subsetting the training data is also preserved
|
| 203 |
+
if cfg.get('train_sub_set'):
|
| 204 |
+
prefixes = [f"fake/{cfg['train_sub_set']}", "real"]
|
| 205 |
+
condition = df_train['content_path'].str.startswith(tuple(prefixes))
|
| 206 |
+
df_train = df_train[condition]
|
| 207 |
+
|
| 208 |
+
df_train.reset_index(drop=True, inplace=True)
|
| 209 |
+
df_val.reset_index(drop=True, inplace=True)
|
| 210 |
+
|
| 211 |
+
index_train = df_train.index.tolist()
|
| 212 |
+
index_val = df_val.index.tolist()
|
| 213 |
+
|
| 214 |
+
# --- Use the new VideoDataset classes ---
|
| 215 |
+
base_data_path = 'GenVideo'
|
| 216 |
+
|
| 217 |
+
train_dataset = VideoDataset(
|
| 218 |
+
df=df_train,
|
| 219 |
+
index_list=index_train,
|
| 220 |
+
base_data_path=base_data_path,
|
| 221 |
+
transform=get_train_transforms(),
|
| 222 |
+
select_frame_nums=8
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
val_dataset = VideoDatasetVal(
|
| 226 |
+
df=df_val,
|
| 227 |
+
index_list=index_val,
|
| 228 |
+
base_data_path=base_data_path,
|
| 229 |
+
transform=get_val_transforms(cfg),
|
| 230 |
+
select_frame_nums=8
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
train_loader = DataLoader(
|
| 234 |
+
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True,
|
| 235 |
+
num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
val_loader = DataLoader(
|
| 239 |
+
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False,
|
| 240 |
+
num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
print(f"******* Training Videos {len(index_train)}, Batch size {cfg['train_batch_size']} *******")
|
| 244 |
+
print(f"******* Testing Videos {len(index_val)}, Batch size {cfg['val_batch_size']} *******")
|
| 245 |
+
|
| 246 |
+
return train_loader, val_loader
|
eval.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import models
|
| 2 |
+
import time
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, average_precision_score, roc_auc_score
|
| 9 |
+
|
| 10 |
+
def eval_model(cfg, model, val_loader, loss_ce, val_batch_size):
|
| 11 |
+
model.eval()
|
| 12 |
+
outpred_list = []
|
| 13 |
+
gt_label_list = []
|
| 14 |
+
video_list = []
|
| 15 |
+
valLoss = 0
|
| 16 |
+
lossTrainNorm = 0
|
| 17 |
+
print("******** Start Testing. ********")
|
| 18 |
+
|
| 19 |
+
with torch.no_grad(): # No need to track gradients during validation
|
| 20 |
+
for i, (_, input, target, binary_label, video_id) in enumerate(tqdm(val_loader, desc="Validation", total=len(val_loader))):
|
| 21 |
+
if i == 0:
|
| 22 |
+
ss_time = time.time()
|
| 23 |
+
|
| 24 |
+
input = input[:,0]
|
| 25 |
+
varInput = torch.autograd.Variable(input.float().cuda())
|
| 26 |
+
varTarget = torch.autograd.Variable(target.contiguous().cuda())
|
| 27 |
+
var_Binary_Target = torch.autograd.Variable(binary_label.contiguous().cuda())
|
| 28 |
+
|
| 29 |
+
logit = model(varInput)
|
| 30 |
+
lossvalue = loss_ce(logit, var_Binary_Target)
|
| 31 |
+
|
| 32 |
+
valLoss += lossvalue.item()
|
| 33 |
+
lossTrainNorm += 1
|
| 34 |
+
outpred_list.append(logit[:,0].sigmoid().cpu().detach().numpy())
|
| 35 |
+
gt_label_list.append(varTarget.cpu().detach().numpy())
|
| 36 |
+
video_list.append(video_id)
|
| 37 |
+
|
| 38 |
+
valLoss = valLoss / lossTrainNorm
|
| 39 |
+
|
| 40 |
+
outpred = np.concatenate(outpred_list, 0)
|
| 41 |
+
gt_label = np.concatenate(gt_label_list, 0)
|
| 42 |
+
video_list = np.concatenate(video_list, 0)
|
| 43 |
+
pred_labels = [1 if item > 0.5 else 0 for item in outpred]
|
| 44 |
+
true_labels = np.argmax(gt_label, axis=1)
|
| 45 |
+
|
| 46 |
+
pred_accuracy = accuracy_score(true_labels, pred_labels)
|
| 47 |
+
|
| 48 |
+
return pred_accuracy, video_list, pred_labels, true_labels, outpred
|
eval2.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
|
| 10 |
+
|
| 11 |
+
import models # assuming models.py is in your PYTHONPATH or same dir
|
| 12 |
+
|
| 13 |
+
# -------- Build model as per your code ----------
|
| 14 |
+
def build_model(model_name):
|
| 15 |
+
if model_name == 'F3Net':
|
| 16 |
+
model = models.Det_F3_Net()
|
| 17 |
+
elif model_name == 'NPR':
|
| 18 |
+
model = models.resnet50_npr()
|
| 19 |
+
elif model_name == 'STIL':
|
| 20 |
+
model = models.Det_STIL()
|
| 21 |
+
elif model_name == 'XCLIP_DeMamba':
|
| 22 |
+
model = models.XCLIP_DeMamba()
|
| 23 |
+
elif model_name == 'CLIP_DeMamba':
|
| 24 |
+
model = models.CLIP_DeMamba()
|
| 25 |
+
elif model_name == 'XCLIP':
|
| 26 |
+
model = models.XCLIP()
|
| 27 |
+
elif model_name == 'CLIP':
|
| 28 |
+
model = models.CLIP_Base()
|
| 29 |
+
elif model_name == 'ViT_B_MINTIME':
|
| 30 |
+
model = models.ViT_B_MINTIME()
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError(f"Unknown model: {model_name}")
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --------- Evaluation loop -------------
|
| 37 |
+
def eval_on_frames(model, frames_dir, device):
|
| 38 |
+
model.eval()
|
| 39 |
+
transform = transforms.Compose([
|
| 40 |
+
transforms.Resize((224, 224)),
|
| 41 |
+
transforms.ToTensor(),
|
| 42 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 43 |
+
std=[0.229, 0.224, 0.225])
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
frame_paths = []
|
| 47 |
+
for root, _, files in os.walk(frames_dir):
|
| 48 |
+
for f in files:
|
| 49 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 50 |
+
frame_paths.append(os.path.join(root, f))
|
| 51 |
+
frame_paths.sort()
|
| 52 |
+
|
| 53 |
+
results = []
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
for fp in tqdm(frame_paths, desc="Evaluating frames"):
|
| 57 |
+
img = Image.open(fp).convert("RGB")
|
| 58 |
+
|
| 59 |
+
# Transform and add batch dimension
|
| 60 |
+
x = transform(img).unsqueeze(0).to(device) # [1, C, H, W]
|
| 61 |
+
|
| 62 |
+
# Add temporal dimension expected by DeMamba/XCLIP (T=8)
|
| 63 |
+
x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) # [1, 8, C, H, W]
|
| 64 |
+
|
| 65 |
+
# Forward pass
|
| 66 |
+
logit = model(x)
|
| 67 |
+
prob = torch.sigmoid(logit[:, 0]).item()
|
| 68 |
+
pred_label = int(prob > 0.5)
|
| 69 |
+
|
| 70 |
+
results.append({
|
| 71 |
+
"file_name": os.path.basename(fp),
|
| 72 |
+
"predicted_prob": prob,
|
| 73 |
+
"predicted_label": pred_label
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
return pd.DataFrame(results)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
# ----- config -----
|
| 81 |
+
model_name = "XCLIP_DeMamba" # change if needed
|
| 82 |
+
model_path = "/home/kalpit/workspace/aigc/repos/SAFE_challenge/seetrails_aigvdet_v2.0.0/results/kling_9k_9k/best_acc.pth"
|
| 83 |
+
frames_dir = "./frames"
|
| 84 |
+
output_csv = "results.csv"
|
| 85 |
+
|
| 86 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
print(f"Using device: {device}")
|
| 88 |
+
|
| 89 |
+
# ---- load model ----
|
| 90 |
+
model = build_model(model_name).to(device)
|
| 91 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 92 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 93 |
+
print(f"Loaded model weights from {model_path}")
|
| 94 |
+
|
| 95 |
+
# ---- evaluate ----
|
| 96 |
+
df_results = eval_on_frames(model, frames_dir, device)
|
| 97 |
+
df_results.to_csv(output_csv, index=False)
|
| 98 |
+
print(f"Saved framewise results to {output_csv}")
|
| 99 |
+
|
| 100 |
+
# ---- optional: basic metrics if you have GT ----
|
| 101 |
+
if "label" in df_results.columns:
|
| 102 |
+
y_true = df_results["label"].values
|
| 103 |
+
y_pred = df_results["predicted_label"].values
|
| 104 |
+
y_prob = df_results["predicted_prob"].values
|
| 105 |
+
|
| 106 |
+
acc = accuracy_score(y_true, y_pred)
|
| 107 |
+
auc = roc_auc_score(y_true, y_prob)
|
| 108 |
+
ap = average_precision_score(y_true, y_prob)
|
| 109 |
+
|
| 110 |
+
print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")
|
extract_frames.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tqdm.auto as tqdm
|
| 6 |
+
import io
|
| 7 |
+
import torch
|
| 8 |
+
import av
|
| 9 |
+
from torchvision.utils import save_image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def preprocess_and_save_frames(
|
| 13 |
+
file_like: io.BytesIO,
|
| 14 |
+
video_id: str,
|
| 15 |
+
output_root: str = "./frames",
|
| 16 |
+
crop_size: int = -1,
|
| 17 |
+
every: int = 10,
|
| 18 |
+
max_memory: int = 50 * 1024 * 1024,
|
| 19 |
+
device: str = "cpu"
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Loads a video and saves frames as images in output_root/<video_id>/.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Ensure the base frames directory exists
|
| 26 |
+
os.makedirs(output_root, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
# Create subfolder for this specific video
|
| 29 |
+
video_dir = os.path.join(output_root, video_id)
|
| 30 |
+
os.makedirs(video_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
center_crop_transform = None
|
| 33 |
+
if crop_size > 0:
|
| 34 |
+
from torchvision import transforms
|
| 35 |
+
center_crop_transform = transforms.CenterCrop(crop_size)
|
| 36 |
+
|
| 37 |
+
file_like.seek(0)
|
| 38 |
+
container = av.open(file_like)
|
| 39 |
+
current_memory = 0
|
| 40 |
+
|
| 41 |
+
for i, frame in enumerate(container.decode(video=0)):
|
| 42 |
+
if i % every == 0:
|
| 43 |
+
frame_array = frame.to_ndarray(format="rgb24")
|
| 44 |
+
frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() / 255.0
|
| 45 |
+
|
| 46 |
+
if center_crop_transform is not None:
|
| 47 |
+
frame_tensor = center_crop_transform(frame_tensor)
|
| 48 |
+
|
| 49 |
+
frame_path = os.path.join(video_dir, f"frame_{i:05d}.png")
|
| 50 |
+
save_image(frame_tensor, frame_path)
|
| 51 |
+
|
| 52 |
+
frame_bytes = frame_tensor.numel() * 4
|
| 53 |
+
current_memory += frame_bytes
|
| 54 |
+
if current_memory >= max_memory:
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -------- Main section --------
|
| 59 |
+
if __name__ == "__main__":
|
| 60 |
+
DATASET_PATH = "/tmp/data"
|
| 61 |
+
dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True)
|
| 62 |
+
|
| 63 |
+
# Make sure ./frames exists before processing
|
| 64 |
+
os.makedirs("./frames", exist_ok=True)
|
| 65 |
+
|
| 66 |
+
for el in tqdm.tqdm(dataset_remote, desc="Extracting frames"):
|
| 67 |
+
try:
|
| 68 |
+
video_id = str(el["id"])
|
| 69 |
+
file_like = io.BytesIO(el["video"]["bytes"])
|
| 70 |
+
preprocess_and_save_frames(file_like, video_id)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"⚠️ Failed {el['id']}: {e}")
|
| 73 |
+
|
| 74 |
+
print("✅ All frames saved under ./frames/<video_id>/")
|
models/DeMamba.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import XCLIPVisionModel
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from .mamba_base import MambaConfig, ResidualBlock
|
| 10 |
+
import torch.nn.init as init
|
| 11 |
+
from .clip import clip
|
| 12 |
+
import math
|
| 13 |
+
from transformers import XCLIPVisionConfig, XCLIPVisionModel
|
| 14 |
+
|
| 15 |
+
def create_reorder_index(N, device):
|
| 16 |
+
new_order = []
|
| 17 |
+
for col in range(N):
|
| 18 |
+
if col % 2 == 0:
|
| 19 |
+
new_order.extend(range(col, N*N, N))
|
| 20 |
+
else:
|
| 21 |
+
new_order.extend(range(col + N*(N-1), col-1, -N))
|
| 22 |
+
return torch.tensor(new_order, device=device)
|
| 23 |
+
|
| 24 |
+
def reorder_data(data, N):
|
| 25 |
+
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
|
| 26 |
+
device = data.device
|
| 27 |
+
new_order = create_reorder_index(N, device)
|
| 28 |
+
B, t, _, _ = data.shape
|
| 29 |
+
index = new_order.repeat(B, t, 1).unsqueeze(-1)
|
| 30 |
+
reordered_data = torch.gather(data, 2, index.expand_as(data))
|
| 31 |
+
return reordered_data
|
| 32 |
+
|
| 33 |
+
class XCLIP_DeMamba(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, channel_size=768, class_num=1
|
| 36 |
+
):
|
| 37 |
+
super(XCLIP_DeMamba, self).__init__()
|
| 38 |
+
# self.encoder = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
|
| 39 |
+
# my code for training from scratch
|
| 40 |
+
config = XCLIPVisionConfig()
|
| 41 |
+
self.encoder = XCLIPVisionModel(config)
|
| 42 |
+
|
| 43 |
+
blocks = []
|
| 44 |
+
channel = 768
|
| 45 |
+
self.fusing_ratios = 1
|
| 46 |
+
self.patch_nums = (14//self.fusing_ratios)**2
|
| 47 |
+
self.mamba_configs = MambaConfig(d_model=channel)
|
| 48 |
+
self.mamba = ResidualBlock(config = self.mamba_configs)
|
| 49 |
+
# self.fc1 = nn.Linear((self.patch_nums+1)*channel, class_num)
|
| 50 |
+
self.fc1 = nn.Linear(38400, class_num) # my code
|
| 51 |
+
# self.fc_norm = nn.LayerNorm(self.patch_nums*channel)
|
| 52 |
+
self.fc_norm = None # my code
|
| 53 |
+
self.fc_norm2 = nn.LayerNorm(768)
|
| 54 |
+
self.initialize_weights(self.fc1)
|
| 55 |
+
self.dropout = nn.Dropout(p=0.0)
|
| 56 |
+
|
| 57 |
+
def initialize_weights(self, module):
|
| 58 |
+
for m in module.modules():
|
| 59 |
+
if isinstance(m, nn.Linear):
|
| 60 |
+
init.xavier_uniform_(m.weight)
|
| 61 |
+
if m.bias is not None:
|
| 62 |
+
init.constant_(m.bias, 0)
|
| 63 |
+
elif isinstance(m, nn.Conv2d):
|
| 64 |
+
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 65 |
+
if m.bias is not None:
|
| 66 |
+
init.constant_(m.bias, 0)
|
| 67 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 68 |
+
init.constant_(m.weight, 1)
|
| 69 |
+
init.constant_(m.bias, 0)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
b, t, _, h, w = x.shape
|
| 73 |
+
images = x.view(b * t, 3, h, w)
|
| 74 |
+
outputs = self.encoder(images, output_hidden_states=True)
|
| 75 |
+
sequence_output = outputs['last_hidden_state'][:,1:,:]
|
| 76 |
+
_, _, c = sequence_output.shape
|
| 77 |
+
|
| 78 |
+
global_feat = outputs['pooler_output'].reshape(b, t, -1)
|
| 79 |
+
global_feat = global_feat.mean(1)
|
| 80 |
+
global_feat = self.fc_norm2(global_feat)
|
| 81 |
+
|
| 82 |
+
sequence_output = sequence_output.view(b, t, -1, c)
|
| 83 |
+
_, _, f_w, _ = sequence_output.shape
|
| 84 |
+
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
|
| 85 |
+
|
| 86 |
+
s = f_h//self.fusing_ratios
|
| 87 |
+
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
|
| 88 |
+
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
|
| 89 |
+
b_l = b*s*s
|
| 90 |
+
|
| 91 |
+
x = reorder_data(x, self.fusing_ratios)
|
| 92 |
+
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
|
| 93 |
+
res = self.mamba(x)
|
| 94 |
+
|
| 95 |
+
video_level_features = res.mean(1)
|
| 96 |
+
video_level_features = video_level_features.view(b, -1)
|
| 97 |
+
|
| 98 |
+
# my code
|
| 99 |
+
if self.fc_norm is None:
|
| 100 |
+
self.fc_norm = nn.LayerNorm(video_level_features.size(-1)).to(video_level_features.device)
|
| 101 |
+
|
| 102 |
+
video_level_features = self.fc_norm(video_level_features)
|
| 103 |
+
video_level_features = torch.cat((global_feat, video_level_features), dim=1)
|
| 104 |
+
|
| 105 |
+
pred = self.fc1(video_level_features)
|
| 106 |
+
pred = self.dropout(pred)
|
| 107 |
+
|
| 108 |
+
return pred
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class CLIP_DeMamba(nn.Module):
|
| 113 |
+
def __init__(
|
| 114 |
+
self, channel_size=512, class_num=1
|
| 115 |
+
):
|
| 116 |
+
super(CLIP_DeMamba, self).__init__()
|
| 117 |
+
self.clip_model, preprocess = clip.load('ViT-B-14')
|
| 118 |
+
self.clip_model = self.clip_model.float()
|
| 119 |
+
blocks = []
|
| 120 |
+
channel = 512
|
| 121 |
+
self.fusing_ratios = 2
|
| 122 |
+
self.patch_nums = (14//self.fusing_ratios)**2
|
| 123 |
+
self.mamba_configs = MambaConfig(d_model=channel)
|
| 124 |
+
self.mamba = ResidualBlock(config = self.mamba_configs)
|
| 125 |
+
|
| 126 |
+
self.fc1 = nn.Linear(channel*(self.patch_nums+1), class_num)
|
| 127 |
+
|
| 128 |
+
self.bn1 = nn.BatchNorm1d(channel)
|
| 129 |
+
self.initialize_weights(self.fc1)
|
| 130 |
+
|
| 131 |
+
def initialize_weights(self, module):
|
| 132 |
+
for m in module.modules():
|
| 133 |
+
if isinstance(m, nn.Linear):
|
| 134 |
+
init.xavier_uniform_(m.weight)
|
| 135 |
+
if m.bias is not None:
|
| 136 |
+
init.constant_(m.bias, 0)
|
| 137 |
+
elif isinstance(m, nn.Conv2d):
|
| 138 |
+
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 139 |
+
if m.bias is not None:
|
| 140 |
+
init.constant_(m.bias, 0)
|
| 141 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 142 |
+
init.constant_(m.weight, 1)
|
| 143 |
+
init.constant_(m.bias, 0)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
b, t, _, h, w = x.shape
|
| 147 |
+
images = x.view(b * t, 3, h, w)
|
| 148 |
+
sequence_output = self.clip_model.encode_image(images)
|
| 149 |
+
_, _, c = sequence_output.shape
|
| 150 |
+
sequence_output = sequence_output.view(b, t, -1, c)
|
| 151 |
+
|
| 152 |
+
global_feat = sequence_output.reshape(b, -1, c)
|
| 153 |
+
global_feat = global_feat.mean(1)
|
| 154 |
+
|
| 155 |
+
_, _, f_w, _ = sequence_output.shape
|
| 156 |
+
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
|
| 157 |
+
|
| 158 |
+
s = f_h//self.fusing_ratios
|
| 159 |
+
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
|
| 160 |
+
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
|
| 161 |
+
b_l = b*s*s
|
| 162 |
+
|
| 163 |
+
x = reorder_data(x, self.fusing_ratios)
|
| 164 |
+
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
|
| 165 |
+
res = self.mamba(x)
|
| 166 |
+
video_level_features = res.mean(1)
|
| 167 |
+
video_level_features = video_level_features.view(b, -1)
|
| 168 |
+
|
| 169 |
+
video_level_features = torch.cat((global_feat, video_level_features), dim=1)
|
| 170 |
+
x = self.fc1(video_level_features)
|
| 171 |
+
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
if __name__ == '__main__':
|
| 175 |
+
model = CLIP_DeMamba()
|
| 176 |
+
print(model)
|
models/F3Net.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
F3Net: Fusion, Feedback and Focus for Salient Object Detection @ AAAI'2020
|
| 3 |
+
Copyright (c) University of Chinese Academy of Sciences and its affiliates.
|
| 4 |
+
Modified by Jun Wei from https://github.com/weijun88/F3Net
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.model_zoo as model_zoo
|
| 14 |
+
|
| 15 |
+
pretrained_settings = {
|
| 16 |
+
'xception': {
|
| 17 |
+
'imagenet': {
|
| 18 |
+
'dir': '/ossfs/workspace/aigc_video/weights/xception-b5690688.pth',
|
| 19 |
+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
|
| 20 |
+
'input_space': 'RGB',
|
| 21 |
+
'input_size': [3, 299, 299],
|
| 22 |
+
'input_range': [0, 1],
|
| 23 |
+
'mean': [0.5, 0.5, 0.5],
|
| 24 |
+
'std': [0.5, 0.5, 0.5],
|
| 25 |
+
'num_classes': 1000,
|
| 26 |
+
'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
class F3Net(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Implementation is mainly referenced from https://github.com/yyk-wew/F3Net
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self,
|
| 36 |
+
num_classes: int=2,
|
| 37 |
+
img_width: int=299,
|
| 38 |
+
img_height: int=299,
|
| 39 |
+
LFS_window_size: int=10,
|
| 40 |
+
LFS_M: int=6) -> None:
|
| 41 |
+
super(F3Net, self).__init__()
|
| 42 |
+
assert img_width == img_height
|
| 43 |
+
self.img_size = img_width
|
| 44 |
+
self.num_classes = num_classes
|
| 45 |
+
self._LFS_window_size = LFS_window_size
|
| 46 |
+
self._LFS_M = LFS_M
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
self.fad_head = FAD_Head(self.img_size)
|
| 50 |
+
self.lfs_head = LFS_Head(self.img_size, self._LFS_window_size, self._LFS_M)
|
| 51 |
+
|
| 52 |
+
self.fad_excep = self._init_xcep_fad()
|
| 53 |
+
self.lfs_excep = self._init_xcep_lfs()
|
| 54 |
+
|
| 55 |
+
self.mix_block7 = MixBlock(c_in=728, width=19, height=19)
|
| 56 |
+
self.mix_block12 = MixBlock(c_in=1024, width=10, height=10)
|
| 57 |
+
self.excep_forwards = ['conv1', 'bn1', 'relu', 'conv2', 'bn2', 'relu',
|
| 58 |
+
'block1', 'block2', 'block3', 'block4', 'block5', 'block6',
|
| 59 |
+
'block7', 'block8', 'block9', 'block10' , 'block11', 'block12',
|
| 60 |
+
'conv3', 'bn3', 'relu', 'conv4', 'bn4']
|
| 61 |
+
|
| 62 |
+
# classifier
|
| 63 |
+
self.relu = nn.ReLU(inplace=True)
|
| 64 |
+
self.fc = nn.Linear(4096, num_classes)
|
| 65 |
+
self.dp = nn.Dropout(p=0.2)
|
| 66 |
+
|
| 67 |
+
def _init_xcep_fad(self):
|
| 68 |
+
fad_excep = return_pytorch04_xception(True)
|
| 69 |
+
conv1_data = fad_excep.conv1.weight.data
|
| 70 |
+
# let new conv1 use old param to balance the network
|
| 71 |
+
fad_excep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False)
|
| 72 |
+
for i in range(4):
|
| 73 |
+
fad_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / 4.0
|
| 74 |
+
return fad_excep
|
| 75 |
+
|
| 76 |
+
def _init_xcep_lfs(self):
|
| 77 |
+
lfs_excep = return_pytorch04_xception(True)
|
| 78 |
+
conv1_data = lfs_excep.conv1.weight.data
|
| 79 |
+
# let new conv1 use old param to balance the network
|
| 80 |
+
lfs_excep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False)
|
| 81 |
+
for i in range(int(self._LFS_M / 3)):
|
| 82 |
+
lfs_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / float(self._LFS_M / 3.0)
|
| 83 |
+
return lfs_excep
|
| 84 |
+
|
| 85 |
+
def _features(self, x_fad, x_fls):
|
| 86 |
+
for forward_func in self.excep_forwards:
|
| 87 |
+
x_fad = getattr(self.fad_excep, forward_func)(x_fad)
|
| 88 |
+
x_fls = getattr(self.lfs_excep, forward_func)(x_fls)
|
| 89 |
+
if forward_func == 'block7':
|
| 90 |
+
x_fad, x_fls = self.mix_block7(x_fad, x_fls)
|
| 91 |
+
if forward_func == 'block12':
|
| 92 |
+
x_fad, x_fls = self.mix_block12(x_fad, x_fls)
|
| 93 |
+
return x_fad, x_fls
|
| 94 |
+
|
| 95 |
+
def _norm_feature(self, x):
|
| 96 |
+
x = self.relu(x)
|
| 97 |
+
x = F.adaptive_avg_pool2d(x, (1,1))
|
| 98 |
+
x = x.view(x.size(0), -1)
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
fad_input = self.fad_head(x)
|
| 103 |
+
lfs_input = self.lfs_head(x)
|
| 104 |
+
x_fad, x_fls = self._features(fad_input, lfs_input)
|
| 105 |
+
x_fad = self._norm_feature(x_fad)
|
| 106 |
+
x_fls = self._norm_feature(x_fls)
|
| 107 |
+
x_cat = torch.cat((x_fad, x_fls), dim=1)
|
| 108 |
+
x_drop = self.dp(x_cat)
|
| 109 |
+
logit = self.fc(x_drop)
|
| 110 |
+
return logit
|
| 111 |
+
|
| 112 |
+
class SeparableConv2d(nn.Module):
|
| 113 |
+
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
|
| 114 |
+
super(SeparableConv2d,self).__init__()
|
| 115 |
+
|
| 116 |
+
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
|
| 117 |
+
self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
|
| 118 |
+
|
| 119 |
+
def forward(self,x):
|
| 120 |
+
x = self.conv1(x)
|
| 121 |
+
x = self.pointwise(x)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
class Block(nn.Module):
|
| 125 |
+
def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
|
| 126 |
+
super(Block, self).__init__()
|
| 127 |
+
|
| 128 |
+
if out_filters != in_filters or strides!=1:
|
| 129 |
+
self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
|
| 130 |
+
self.skipbn = nn.BatchNorm2d(out_filters)
|
| 131 |
+
else:
|
| 132 |
+
self.skip=None
|
| 133 |
+
|
| 134 |
+
self.relu = nn.ReLU(inplace=True)
|
| 135 |
+
rep=[]
|
| 136 |
+
|
| 137 |
+
filters=in_filters
|
| 138 |
+
if grow_first:
|
| 139 |
+
rep.append(self.relu)
|
| 140 |
+
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
|
| 141 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 142 |
+
filters = out_filters
|
| 143 |
+
|
| 144 |
+
for i in range(reps-1):
|
| 145 |
+
rep.append(self.relu)
|
| 146 |
+
rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
|
| 147 |
+
rep.append(nn.BatchNorm2d(filters))
|
| 148 |
+
|
| 149 |
+
if not grow_first:
|
| 150 |
+
rep.append(self.relu)
|
| 151 |
+
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
|
| 152 |
+
rep.append(nn.BatchNorm2d(out_filters))
|
| 153 |
+
|
| 154 |
+
if not start_with_relu:
|
| 155 |
+
rep = rep[1:]
|
| 156 |
+
else:
|
| 157 |
+
rep[0] = nn.ReLU(inplace=False)
|
| 158 |
+
|
| 159 |
+
if strides != 1:
|
| 160 |
+
rep.append(nn.MaxPool2d(3,strides,1))
|
| 161 |
+
self.rep = nn.Sequential(*rep)
|
| 162 |
+
|
| 163 |
+
def forward(self,inp):
|
| 164 |
+
x = self.rep(inp)
|
| 165 |
+
|
| 166 |
+
if self.skip is not None:
|
| 167 |
+
skip = self.skip(inp)
|
| 168 |
+
skip = self.skipbn(skip)
|
| 169 |
+
else:
|
| 170 |
+
skip = inp
|
| 171 |
+
|
| 172 |
+
x+=skip
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
class Xception(nn.Module):
|
| 176 |
+
"""
|
| 177 |
+
Xception optimized for the ImageNet dataset, as specified in
|
| 178 |
+
https://arxiv.org/pdf/1610.02357.pdf
|
| 179 |
+
"""
|
| 180 |
+
def __init__(self, num_classes=1000):
|
| 181 |
+
""" Constructor
|
| 182 |
+
Args:
|
| 183 |
+
num_classes: number of classes
|
| 184 |
+
"""
|
| 185 |
+
super(Xception, self).__init__()
|
| 186 |
+
self.num_classes = num_classes
|
| 187 |
+
|
| 188 |
+
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
|
| 189 |
+
self.bn1 = nn.BatchNorm2d(32)
|
| 190 |
+
self.relu = nn.ReLU(inplace=True)
|
| 191 |
+
|
| 192 |
+
self.conv2 = nn.Conv2d(32,64,3,bias=False)
|
| 193 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 194 |
+
#do relu here
|
| 195 |
+
|
| 196 |
+
self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
|
| 197 |
+
self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
|
| 198 |
+
self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)
|
| 199 |
+
|
| 200 |
+
self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 201 |
+
self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 202 |
+
self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 203 |
+
self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 204 |
+
|
| 205 |
+
self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 206 |
+
self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 207 |
+
self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 208 |
+
self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)
|
| 209 |
+
|
| 210 |
+
self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
|
| 211 |
+
|
| 212 |
+
self.conv3 = SeparableConv2d(1024,1536,3,1,1)
|
| 213 |
+
self.bn3 = nn.BatchNorm2d(1536)
|
| 214 |
+
|
| 215 |
+
#do relu here
|
| 216 |
+
self.conv4 = SeparableConv2d(1536,2048,3,1,1)
|
| 217 |
+
self.bn4 = nn.BatchNorm2d(2048)
|
| 218 |
+
|
| 219 |
+
def xception(num_classes=1000, pretrained='imagenet'):
|
| 220 |
+
model = Xception(num_classes=num_classes)
|
| 221 |
+
if pretrained:
|
| 222 |
+
settings = pretrained_settings['xception'][pretrained]
|
| 223 |
+
assert num_classes == settings['num_classes'], \
|
| 224 |
+
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
| 225 |
+
|
| 226 |
+
model = Xception(num_classes=num_classes)
|
| 227 |
+
model.load_state_dict(settings['dir'])
|
| 228 |
+
|
| 229 |
+
model.input_space = settings['input_space']
|
| 230 |
+
model.input_size = settings['input_size']
|
| 231 |
+
model.input_range = settings['input_range']
|
| 232 |
+
model.mean = settings['mean']
|
| 233 |
+
model.std = settings['std']
|
| 234 |
+
|
| 235 |
+
return model
|
| 236 |
+
|
| 237 |
+
def return_pytorch04_xception(pretrained=True):
|
| 238 |
+
model = xception(pretrained=False)
|
| 239 |
+
if pretrained:
|
| 240 |
+
state_dict = torch.load(
|
| 241 |
+
'/ossfs/workspace/GenVideo/weights/xception-b5690688.pth')
|
| 242 |
+
for name, weights in state_dict.items():
|
| 243 |
+
if 'pointwise' in name:
|
| 244 |
+
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
|
| 245 |
+
model.load_state_dict(state_dict, strict=False)
|
| 246 |
+
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
class Filter(nn.Module):
|
| 250 |
+
def __init__(self, size,
|
| 251 |
+
band_start,
|
| 252 |
+
band_end,
|
| 253 |
+
use_learnable=True,
|
| 254 |
+
norm=False):
|
| 255 |
+
super(Filter, self).__init__()
|
| 256 |
+
self.use_learnable = use_learnable
|
| 257 |
+
|
| 258 |
+
self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False)
|
| 259 |
+
if self.use_learnable:
|
| 260 |
+
self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True)
|
| 261 |
+
self.learnable.data.normal_(0., 0.1)
|
| 262 |
+
# Todo
|
| 263 |
+
# self.learnable = nn.Parameter(torch.rand((size, size)) * 0.2 - 0.1, requires_grad=True)
|
| 264 |
+
|
| 265 |
+
self.norm = norm
|
| 266 |
+
if norm:
|
| 267 |
+
self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
if self.use_learnable:
|
| 272 |
+
filt = self.base + norm_sigma(self.learnable)
|
| 273 |
+
else:
|
| 274 |
+
filt = self.base
|
| 275 |
+
|
| 276 |
+
if self.norm:
|
| 277 |
+
y = x * filt / self.ft_num
|
| 278 |
+
else:
|
| 279 |
+
y = x * filt
|
| 280 |
+
return y
|
| 281 |
+
|
| 282 |
+
class FAD_Head(nn.Module):
|
| 283 |
+
def __init__(self, size):
|
| 284 |
+
super(FAD_Head, self).__init__()
|
| 285 |
+
# init DCT matrix
|
| 286 |
+
self._DCT_all = nn.Parameter(torch.tensor(DCT_mat(size)).float(), requires_grad=False)
|
| 287 |
+
self._DCT_all_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(size)).float(), 0, 1), requires_grad=False)
|
| 288 |
+
|
| 289 |
+
# define base filters and learnable
|
| 290 |
+
# 0 - 1/16 || 1/16 - 1/8 || 1/8 - 1
|
| 291 |
+
low_filter = Filter(size, 0, size // 16)
|
| 292 |
+
middle_filter = Filter(size, size // 16, size // 8)
|
| 293 |
+
high_filter = Filter(size, size // 8, size)
|
| 294 |
+
all_filter = Filter(size, 0, size * 2)
|
| 295 |
+
|
| 296 |
+
self.filters = nn.ModuleList([low_filter, middle_filter, high_filter, all_filter])
|
| 297 |
+
|
| 298 |
+
def forward(self, x):
|
| 299 |
+
# DCT
|
| 300 |
+
x_freq = self._DCT_all @ x @ self._DCT_all_T # [N, 3, 299, 299]
|
| 301 |
+
|
| 302 |
+
# 4 kernel
|
| 303 |
+
y_list = []
|
| 304 |
+
for i in range(4):
|
| 305 |
+
x_pass = self.filters[i](x_freq) # [N, 3, 299, 299]
|
| 306 |
+
y = self._DCT_all_T @ x_pass @ self._DCT_all # [N, 3, 299, 299]
|
| 307 |
+
y_list.append(y)
|
| 308 |
+
out = torch.cat(y_list, dim=1) # [N, 12, 299, 299]
|
| 309 |
+
return out
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class LFS_Head(nn.Module):
|
| 313 |
+
def __init__(self, size, window_size, M):
|
| 314 |
+
super(LFS_Head, self).__init__()
|
| 315 |
+
|
| 316 |
+
self.window_size = window_size
|
| 317 |
+
self._M = M
|
| 318 |
+
|
| 319 |
+
# init DCT matrix
|
| 320 |
+
self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False)
|
| 321 |
+
self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False)
|
| 322 |
+
|
| 323 |
+
self.unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=2, padding=4)
|
| 324 |
+
|
| 325 |
+
# init filters
|
| 326 |
+
self.filters = nn.ModuleList([Filter(window_size, window_size * 2. / M * i, window_size * 2. / M * (i+1), norm=True) for i in range(M)])
|
| 327 |
+
|
| 328 |
+
def forward(self, x):
|
| 329 |
+
# turn RGB into Gray
|
| 330 |
+
x_gray = 0.299*x[:,0,:,:] + 0.587*x[:,1,:,:] + 0.114*x[:,2,:,:]
|
| 331 |
+
x = x_gray.unsqueeze(1)
|
| 332 |
+
|
| 333 |
+
# rescale to 0 - 255
|
| 334 |
+
x = (x + 1.) * 122.5
|
| 335 |
+
|
| 336 |
+
# calculate size
|
| 337 |
+
N, C, W, H = x.size()
|
| 338 |
+
S = self.window_size
|
| 339 |
+
size_after = int((W - S + 8)/2) + 1
|
| 340 |
+
assert size_after == 149
|
| 341 |
+
|
| 342 |
+
# sliding window unfold and DCT
|
| 343 |
+
x_unfold = self.unfold(x) # [N, C * S * S, L] L:block num
|
| 344 |
+
L = x_unfold.size()[2]
|
| 345 |
+
x_unfold = x_unfold.transpose(1, 2).reshape(N, L, C, S, S) # [N, L, C, S, S]
|
| 346 |
+
x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T
|
| 347 |
+
|
| 348 |
+
# M kernels filtering
|
| 349 |
+
y_list = []
|
| 350 |
+
for i in range(self._M):
|
| 351 |
+
# y = self.filters[i](x_dct) # [N, L, C, S, S]
|
| 352 |
+
# y = torch.abs(y)
|
| 353 |
+
# y = torch.sum(y, dim=[2,3,4]) # [N, L]
|
| 354 |
+
# y = torch.log10(y + 1e-15)
|
| 355 |
+
y = torch.abs(x_dct)
|
| 356 |
+
y = torch.log10(y + 1e-15)
|
| 357 |
+
y = self.filters[i](y)
|
| 358 |
+
y = torch.sum(y, dim=[2,3,4])
|
| 359 |
+
y = y.reshape(N, size_after, size_after).unsqueeze(dim=1) # [N, 1, 149, 149]
|
| 360 |
+
y_list.append(y)
|
| 361 |
+
out = torch.cat(y_list, dim=1) # [N, M, 149, 149]
|
| 362 |
+
return out
|
| 363 |
+
|
| 364 |
+
class MixBlock(nn.Module):
|
| 365 |
+
|
| 366 |
+
def __init__(self, c_in, width, height):
|
| 367 |
+
super(MixBlock, self).__init__()
|
| 368 |
+
self.FAD_query = nn.Conv2d(c_in, c_in, (1,1))
|
| 369 |
+
self.LFS_query = nn.Conv2d(c_in, c_in, (1,1))
|
| 370 |
+
|
| 371 |
+
self.FAD_key = nn.Conv2d(c_in, c_in, (1,1))
|
| 372 |
+
self.LFS_key = nn.Conv2d(c_in, c_in, (1,1))
|
| 373 |
+
|
| 374 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 375 |
+
self.relu = nn.ReLU()
|
| 376 |
+
|
| 377 |
+
self.FAD_gamma = nn.Parameter(torch.zeros(1))
|
| 378 |
+
self.LFS_gamma = nn.Parameter(torch.zeros(1))
|
| 379 |
+
|
| 380 |
+
self.FAD_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
|
| 381 |
+
self.FAD_bn = nn.BatchNorm2d(c_in)
|
| 382 |
+
self.LFS_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
|
| 383 |
+
self.LFS_bn = nn.BatchNorm2d(c_in)
|
| 384 |
+
|
| 385 |
+
def forward(self, x_FAD, x_LFS):
|
| 386 |
+
B, C, W, H = x_FAD.size()
|
| 387 |
+
assert W == H
|
| 388 |
+
|
| 389 |
+
q_FAD = self.FAD_query(x_FAD).view(-1, W, H) # [BC, W, H]
|
| 390 |
+
q_LFS = self.LFS_query(x_LFS).view(-1, W, H)
|
| 391 |
+
M_query = torch.cat([q_FAD, q_LFS], dim=2) # [BC, W, 2H]
|
| 392 |
+
|
| 393 |
+
k_FAD = self.FAD_key(x_FAD).view(-1, W, H).transpose(1, 2) # [BC, H, W]
|
| 394 |
+
k_LFS = self.LFS_key(x_LFS).view(-1, W, H).transpose(1, 2)
|
| 395 |
+
M_key = torch.cat([k_FAD, k_LFS], dim=1) # [BC, 2H, W]
|
| 396 |
+
|
| 397 |
+
energy = torch.bmm(M_query, M_key) #[BC, W, W]
|
| 398 |
+
attention = self.softmax(energy).view(B, C, W, W)
|
| 399 |
+
|
| 400 |
+
att_LFS = x_LFS * attention * (torch.sigmoid(self.LFS_gamma) * 2.0 - 1.0)
|
| 401 |
+
y_FAD = x_FAD + self.FAD_bn(self.FAD_conv(att_LFS))
|
| 402 |
+
|
| 403 |
+
att_FAD = x_FAD * attention * (torch.sigmoid(self.FAD_gamma) * 2.0 - 1.0)
|
| 404 |
+
y_LFS = x_LFS + self.LFS_bn(self.LFS_conv(att_FAD))
|
| 405 |
+
return y_FAD, y_LFS
|
| 406 |
+
|
| 407 |
+
def DCT_mat(size):
|
| 408 |
+
m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)]
|
| 409 |
+
return m
|
| 410 |
+
|
| 411 |
+
def generate_filter(start, end, size):
|
| 412 |
+
return [[0. if i + j > end or i + j <= start else 1. for j in range(size)] for i in range(size)]
|
| 413 |
+
|
| 414 |
+
def norm_sigma(x):
|
| 415 |
+
return 2. * torch.sigmoid(x) - 1.
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class Det_F3_Net(nn.Module):
|
| 419 |
+
def __init__(self):
|
| 420 |
+
super(Det_F3_Net, self).__init__()
|
| 421 |
+
self.f3net = F3Net(num_classes=1)
|
| 422 |
+
|
| 423 |
+
def forward(self, x):
|
| 424 |
+
b, t, _, h, w = x.shape
|
| 425 |
+
images = x.view(b * t, 3, h, w)
|
| 426 |
+
sequence_output = self.f3net(images)
|
| 427 |
+
sequence_output = sequence_output.view(b, t, -1)
|
| 428 |
+
sequence_output = sequence_output.mean(1)
|
| 429 |
+
|
| 430 |
+
return sequence_output
|
| 431 |
+
|
| 432 |
+
if __name__ == '__main__':
|
| 433 |
+
model = F3Net()
|
| 434 |
+
print(model)
|
models/FTCN.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021
|
| 3 |
+
Copyright (c) Xiamen University and its affiliates.
|
| 4 |
+
Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from .time_transformer import TimeTransformer
|
| 10 |
+
from .clip import clip
|
| 11 |
+
|
| 12 |
+
class RandomPatchPool(nn.Module):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
# batch,channel,16,7x7
|
| 18 |
+
b, c, t, h, w = x.shape
|
| 19 |
+
x = x.reshape(b, c, t, h * w)
|
| 20 |
+
if self.training and my_cfg.model.transformer.random_select:
|
| 21 |
+
while True:
|
| 22 |
+
idx = random.randint(0, h * w - 1)
|
| 23 |
+
i = idx // h
|
| 24 |
+
j = idx % h
|
| 25 |
+
if j == 0 or i == h - 1 or j == h - 1:
|
| 26 |
+
continue
|
| 27 |
+
else:
|
| 28 |
+
break
|
| 29 |
+
else:
|
| 30 |
+
idx = h * w // 2
|
| 31 |
+
x = x[..., idx]
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def valid_idx(idx, h):
|
| 36 |
+
i = idx // h
|
| 37 |
+
j = idx % h
|
| 38 |
+
if j == 0 or i == h - 1 or j == h - 1:
|
| 39 |
+
return False
|
| 40 |
+
else:
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RandomAvgPool(nn.Module):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
# batch,channel,16,7x7
|
| 50 |
+
b, c, t, h, w = x.shape
|
| 51 |
+
x = x.reshape(b, c, t, h * w)
|
| 52 |
+
candidates = list(range(h * w))
|
| 53 |
+
candidates = [idx for idx in candidates if valid_idx(idx, h)]
|
| 54 |
+
max_k = len(candidates)
|
| 55 |
+
if self.training and my_cfg.model.transformer.random_select:
|
| 56 |
+
k = my_cfg.model.transformer.k
|
| 57 |
+
else:
|
| 58 |
+
k = max_k
|
| 59 |
+
candidates = random.sample(candidates, k)
|
| 60 |
+
x = x[..., candidates].mean(-1)
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
class TransformerHead(nn.Module):
|
| 64 |
+
def __init__(self, spatial_size=7, time_size=8, in_channels=2048):
|
| 65 |
+
super().__init__()
|
| 66 |
+
# if my_cfg.model.inco.no_time_pool:
|
| 67 |
+
# time_size = time_size * 2
|
| 68 |
+
patch_type = 'time'
|
| 69 |
+
if patch_type == "time":
|
| 70 |
+
self.pool = nn.AvgPool3d((1, spatial_size, spatial_size))
|
| 71 |
+
self.num_patches = time_size
|
| 72 |
+
elif patch_type == "spatial":
|
| 73 |
+
self.pool = nn.AvgPool3d((time_size, 1, 1))
|
| 74 |
+
self.num_patches = spatial_size ** 2
|
| 75 |
+
elif patch_type == "random":
|
| 76 |
+
self.pool = RandomPatchPool()
|
| 77 |
+
self.num_patches = time_size
|
| 78 |
+
elif patch_type == "random_avg":
|
| 79 |
+
self.pool = RandomAvgPool()
|
| 80 |
+
self.num_patches = time_size
|
| 81 |
+
elif patch_type == "all":
|
| 82 |
+
self.pool = nn.Identity()
|
| 83 |
+
self.num_patches = time_size * spatial_size * spatial_size
|
| 84 |
+
else:
|
| 85 |
+
raise NotImplementedError(patch_type)
|
| 86 |
+
|
| 87 |
+
self.dim = -1
|
| 88 |
+
if self.dim == -1:
|
| 89 |
+
self.dim = in_channels
|
| 90 |
+
|
| 91 |
+
self.in_channels = in_channels
|
| 92 |
+
|
| 93 |
+
if self.dim != self.in_channels:
|
| 94 |
+
self.fc = nn.Linear(self.in_channels, self.dim)
|
| 95 |
+
|
| 96 |
+
default_params = dict(
|
| 97 |
+
dim=self.dim, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.time_T = TimeTransformer(
|
| 101 |
+
num_patches=self.num_patches, num_classes=1, **default_params
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x = self.pool(x)
|
| 107 |
+
x = x.reshape(-1, self.in_channels, self.num_patches)
|
| 108 |
+
x = x.permute(0, 2, 1)
|
| 109 |
+
if self.dim != self.in_channels:
|
| 110 |
+
x = self.fc(x.reshape(-1, self.in_channels))
|
| 111 |
+
x = x.reshape(-1, self.num_patches, self.dim)
|
| 112 |
+
x = self.time_T(x)
|
| 113 |
+
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ViT_B_FTCN(nn.Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self, channel_size=512, class_num=1
|
| 120 |
+
):
|
| 121 |
+
super(ViT_B_FTCN, self).__init__()
|
| 122 |
+
self.clip_model, preprocess = clip.load('ViT-B-16')
|
| 123 |
+
self.clip_model = self.clip_model.float()
|
| 124 |
+
|
| 125 |
+
self.head = TransformerHead(spatial_size=14, time_size=8, in_channels=512)
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
b, t, _, h, w = x.shape
|
| 129 |
+
images = x.view(b * t, 3, h, w)
|
| 130 |
+
sequence_output = self.clip_model.encode_image(images)
|
| 131 |
+
_, _, c = sequence_output.shape
|
| 132 |
+
sequence_output = sequence_output.view(b, t, 14, 14, c)
|
| 133 |
+
sequence_output = sequence_output.permute(0, 4, 1, 2, 3)
|
| 134 |
+
|
| 135 |
+
res = self.head(sequence_output)
|
| 136 |
+
|
| 137 |
+
return res
|
| 138 |
+
if __name__ == '__main__':
|
| 139 |
+
model = ViT_B_FTCN()
|
| 140 |
+
model = model.cuda()
|
| 141 |
+
dummy_input = torch.randn(4,8,3,224,224)
|
| 142 |
+
dummy_input = dummy_input.cuda()
|
| 143 |
+
model(dummy_input)
|
models/MINTIME
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MINTIME: Multi-Identity size-iNvariant TIMEsformer for Video Deepfake Detection@TIFS'2024
|
| 3 |
+
Copyright (c) ISTI-CNR and its affiliates.
|
| 4 |
+
Modified by Davide Alessandro Coccomini from https://github.com/davide-coccomini/MINTIME-Multi-Identity-size-iNvariant-TIMEsformer-for-Video-Deepfake-Detection
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, einsum
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from statistics import mean
|
| 12 |
+
from torch.nn.init import trunc_normal_
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
from random import random
|
| 16 |
+
from .clip import clip
|
| 17 |
+
from einops.layers.torch import Rearrange
|
| 18 |
+
|
| 19 |
+
# helpers
|
| 20 |
+
def exists(val):
|
| 21 |
+
return val is not None
|
| 22 |
+
|
| 23 |
+
# classes
|
| 24 |
+
class PreNorm(nn.Module):
|
| 25 |
+
def __init__(self, dim, fn):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.fn = fn
|
| 28 |
+
self.norm = nn.LayerNorm(dim)
|
| 29 |
+
|
| 30 |
+
def forward(self, x, *args, **kwargs):
|
| 31 |
+
x = self.norm(x)
|
| 32 |
+
return self.fn(x, *args, **kwargs)
|
| 33 |
+
|
| 34 |
+
# time token shift
|
| 35 |
+
def shift(t, amt):
|
| 36 |
+
if amt == 0:
|
| 37 |
+
return t
|
| 38 |
+
return F.pad(t, (0, 0, 0, 0, amt, -amt))
|
| 39 |
+
|
| 40 |
+
class PreTokenShift(nn.Module):
|
| 41 |
+
def __init__(self, frames, fn):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.frames = frames
|
| 44 |
+
self.fn = fn
|
| 45 |
+
|
| 46 |
+
def forward(self, x, *args, **kwargs):
|
| 47 |
+
f, dim = self.frames, x.shape[-1]
|
| 48 |
+
cls_x, x = x[:, :1], x[:, 1:]
|
| 49 |
+
x = rearrange(x, 'b (f n) d -> b f n d', f = f)
|
| 50 |
+
|
| 51 |
+
# shift along time frame before and after
|
| 52 |
+
dim_chunk = (dim // 3)
|
| 53 |
+
chunks = x.split(dim_chunk, dim = -1)
|
| 54 |
+
chunks_to_shift, rest = chunks[:3], chunks[3:]
|
| 55 |
+
shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
|
| 56 |
+
x = torch.cat((*shifted_chunks, *rest), dim = -1)
|
| 57 |
+
|
| 58 |
+
x = rearrange(x, 'b f n d -> b (f n) d')
|
| 59 |
+
x = torch.cat((cls_x, x), dim = 1)
|
| 60 |
+
return self.fn(x, *args, **kwargs)
|
| 61 |
+
|
| 62 |
+
# feedforward
|
| 63 |
+
class GEGLU(nn.Module):
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
x, gates = x.chunk(2, dim = -1)
|
| 66 |
+
return x * F.gelu(gates)
|
| 67 |
+
|
| 68 |
+
class FeedForward(nn.Module):
|
| 69 |
+
def __init__(self, dim, mult = 4, dropout = 0.):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.net = nn.Sequential(
|
| 72 |
+
nn.Linear(dim, dim * mult * 2),
|
| 73 |
+
GEGLU(),
|
| 74 |
+
nn.Dropout(dropout),
|
| 75 |
+
nn.Linear(dim * mult, dim)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
return self.net(x)
|
| 80 |
+
|
| 81 |
+
# attention
|
| 82 |
+
def attn(q, k, v):
|
| 83 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k)
|
| 84 |
+
attn = sim.softmax(dim = -1)
|
| 85 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
| 86 |
+
return out, attn
|
| 87 |
+
|
| 88 |
+
class Attention(nn.Module):
|
| 89 |
+
def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.heads = heads
|
| 92 |
+
self.scale = dim_head ** -0.5
|
| 93 |
+
inner_dim = dim_head * heads
|
| 94 |
+
|
| 95 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
| 96 |
+
self.to_out = nn.Sequential(
|
| 97 |
+
nn.Linear(inner_dim, dim),
|
| 98 |
+
nn.Dropout(dropout)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, einops_from, einops_to, **einops_dims):
|
| 102 |
+
h = self.heads
|
| 103 |
+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
| 104 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
|
| 105 |
+
|
| 106 |
+
q = q * self.scale
|
| 107 |
+
|
| 108 |
+
# splice out classification token at index 1
|
| 109 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))
|
| 110 |
+
|
| 111 |
+
# let classification token attend to key / values of all patches across time and space
|
| 112 |
+
cls_out, cls_attentions = attn(cls_q, k, v)
|
| 113 |
+
|
| 114 |
+
# rearrange across time or space
|
| 115 |
+
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
|
| 116 |
+
|
| 117 |
+
# expand cls token keys and values across time or space and concat
|
| 118 |
+
r = q_.shape[0] // cls_k.shape[0]
|
| 119 |
+
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))
|
| 120 |
+
|
| 121 |
+
k_ = torch.cat((cls_k, k_), dim = 1)
|
| 122 |
+
v_ = torch.cat((cls_v, v_), dim = 1)
|
| 123 |
+
|
| 124 |
+
# attention
|
| 125 |
+
out, attentions = attn(q_, k_, v_)
|
| 126 |
+
|
| 127 |
+
# merge back time or space
|
| 128 |
+
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
|
| 129 |
+
|
| 130 |
+
# concat back the cls token
|
| 131 |
+
out = torch.cat((cls_out, out), dim = 1)
|
| 132 |
+
|
| 133 |
+
# merge back the heads
|
| 134 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
| 135 |
+
|
| 136 |
+
# combine heads out
|
| 137 |
+
return self.to_out(out), cls_attentions
|
| 138 |
+
|
| 139 |
+
class SizeInvariantTimeSformer(nn.Module):
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
*,
|
| 143 |
+
require_attention = False
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.dim = 512
|
| 147 |
+
self.num_frames = 8
|
| 148 |
+
self.max_identities = 1
|
| 149 |
+
self.image_size = 224
|
| 150 |
+
self.num_classes = 1
|
| 151 |
+
self.patch_size = 1
|
| 152 |
+
self.num_patches = 196
|
| 153 |
+
self.channels = 512
|
| 154 |
+
self.depth = 9
|
| 155 |
+
self.heads = 8
|
| 156 |
+
self.dim_head = 64
|
| 157 |
+
self.attn_dropout = 0.
|
| 158 |
+
self.ff_dropout = 0.
|
| 159 |
+
self.shift_tokens = False
|
| 160 |
+
self.enable_size_emb = True
|
| 161 |
+
self.enable_pos_emb = True
|
| 162 |
+
self.require_attention = require_attention
|
| 163 |
+
|
| 164 |
+
num_positions = self.num_frames * self.channels
|
| 165 |
+
self.to_patch_embedding = nn.Linear(self.channels, self.dim)
|
| 166 |
+
self.cls_token = nn.Parameter(torch.randn(1, self.dim))
|
| 167 |
+
self.pos_emb = nn.Embedding(num_positions + 1, self.dim)
|
| 168 |
+
|
| 169 |
+
if self.enable_size_emb:
|
| 170 |
+
self.size_emb = nn.Embedding(num_positions + 1, self.dim)
|
| 171 |
+
|
| 172 |
+
self.layers = nn.ModuleList([])
|
| 173 |
+
for _ in range(self.depth):
|
| 174 |
+
ff = FeedForward(self.dim, dropout = self.ff_dropout)
|
| 175 |
+
time_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
|
| 176 |
+
spatial_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
|
| 177 |
+
if self.shift_tokens:
|
| 178 |
+
time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(self.num_frames, t), (time_attn, spatial_attn, ff))
|
| 179 |
+
|
| 180 |
+
time_attn, spatial_attn, ff = map(lambda t: PreNorm(self.dim, t), (time_attn, spatial_attn, ff))
|
| 181 |
+
self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))
|
| 182 |
+
|
| 183 |
+
self.to_out = nn.Sequential(
|
| 184 |
+
nn.LayerNorm(self.dim),
|
| 185 |
+
nn.Linear(self.dim, self.num_classes)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Initialization
|
| 189 |
+
trunc_normal_(self.pos_emb.weight, std=.02)
|
| 190 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 191 |
+
if self.enable_size_emb:
|
| 192 |
+
trunc_normal_(self.size_emb.weight, std=.02)
|
| 193 |
+
self.apply(self._init_weights)
|
| 194 |
+
|
| 195 |
+
def _init_weights(self, m):
|
| 196 |
+
if isinstance(m, nn.Linear):
|
| 197 |
+
trunc_normal_(m.weight, std=.02)
|
| 198 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 199 |
+
nn.init.constant_(m.bias, 0)
|
| 200 |
+
elif isinstance(m, nn.LayerNorm):
|
| 201 |
+
nn.init.constant_(m.bias, 0)
|
| 202 |
+
nn.init.constant_(m.weight, 1.0)
|
| 203 |
+
|
| 204 |
+
@torch.jit.ignore
|
| 205 |
+
def no_weight_decay(self):
|
| 206 |
+
return {'pos_emb', 'cls_token'}
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
b, f, c, h, w = x.shape
|
| 210 |
+
n = h * w
|
| 211 |
+
device = x.device
|
| 212 |
+
|
| 213 |
+
x = rearrange(x, 'b f c h w -> b (f h w) c') # B x F*P*P x C
|
| 214 |
+
tokens = self.to_patch_embedding(x) # B x 8*7*7 x dim
|
| 215 |
+
|
| 216 |
+
# Add cls token
|
| 217 |
+
cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
|
| 218 |
+
x = torch.cat((cls_token, tokens), dim = 1)
|
| 219 |
+
|
| 220 |
+
# Positional embedding
|
| 221 |
+
x += self.pos_emb(torch.arange(x.shape[1], device=device))
|
| 222 |
+
|
| 223 |
+
# Time and space attention
|
| 224 |
+
for (time_attn, spatial_attn, ff) in self.layers:
|
| 225 |
+
y, _ = time_attn(x, 'b (f n) d', '(b n) f d', n = n)
|
| 226 |
+
x = x + y
|
| 227 |
+
y, _ = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f)
|
| 228 |
+
x = x + y
|
| 229 |
+
x = ff(x) + x
|
| 230 |
+
|
| 231 |
+
cls_token = x[:, 0]
|
| 232 |
+
|
| 233 |
+
if self.require_attention:
|
| 234 |
+
return self.to_out(cls_token)
|
| 235 |
+
else:
|
| 236 |
+
return self.to_out(cls_token)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ViT_B_MINTIME(nn.Module):
|
| 240 |
+
def __init__(
|
| 241 |
+
self, channel_size=512, class_num=1
|
| 242 |
+
):
|
| 243 |
+
super(ViT_B_MINTIME, self).__init__()
|
| 244 |
+
self.clip_model, preprocess = clip.load('ViT-B-16')
|
| 245 |
+
self.clip_model = self.clip_model.float()
|
| 246 |
+
self.head = SizeInvariantTimeSformer()
|
| 247 |
+
|
| 248 |
+
def forward(self, x):
|
| 249 |
+
b, t, _, h, w = x.shape
|
| 250 |
+
images = x.view(b * t, 3, h, w)
|
| 251 |
+
sequence_output = self.clip_model.encode_image(images)
|
| 252 |
+
_, _, c = sequence_output.shape
|
| 253 |
+
sequence_output = sequence_output.view(b, t, 14, 14, c)
|
| 254 |
+
sequence_output = sequence_output.permute(0, 1, 4, 2, 3)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
res = self.head(sequence_output)
|
| 258 |
+
|
| 259 |
+
return res
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
if __name__ == '__main__':
|
| 263 |
+
|
| 264 |
+
model = ViT_B_MINTIME()
|
| 265 |
+
model = model.cuda()
|
| 266 |
+
dummy_input = torch.randn(4,8,3,224,224)
|
| 267 |
+
dummy_input = dummy_input.cuda()
|
| 268 |
+
print(model(dummy_input))
|
| 269 |
+
|
models/NPR.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NPR: Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection@CVPR'2024
|
| 3 |
+
Copyright (c) Beijing Jiaotong University and its affiliates.
|
| 4 |
+
Modified by Chuangchuang Tan from https://github.com/chuangchuangtan/NPR-DeepfakeDetection
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.model_zoo as model_zoo
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from typing import Any, cast, Dict, List, Optional, Union
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
| 15 |
+
'resnet152']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
model_urls = {
|
| 19 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 20 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 21 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 22 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 23 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 28 |
+
"""3x3 convolution with padding"""
|
| 29 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 30 |
+
padding=1, bias=False)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 34 |
+
"""1x1 convolution"""
|
| 35 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BasicBlock(nn.Module):
|
| 39 |
+
expansion = 1
|
| 40 |
+
|
| 41 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 42 |
+
super(BasicBlock, self).__init__()
|
| 43 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 44 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 45 |
+
self.relu = nn.ReLU(inplace=True)
|
| 46 |
+
self.conv2 = conv3x3(planes, planes)
|
| 47 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 48 |
+
self.downsample = downsample
|
| 49 |
+
self.stride = stride
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
identity = x
|
| 53 |
+
|
| 54 |
+
out = self.conv1(x)
|
| 55 |
+
out = self.bn1(out)
|
| 56 |
+
out = self.relu(out)
|
| 57 |
+
|
| 58 |
+
out = self.conv2(out)
|
| 59 |
+
out = self.bn2(out)
|
| 60 |
+
|
| 61 |
+
if self.downsample is not None:
|
| 62 |
+
identity = self.downsample(x)
|
| 63 |
+
|
| 64 |
+
out += identity
|
| 65 |
+
out = self.relu(out)
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Bottleneck(nn.Module):
|
| 71 |
+
expansion = 4
|
| 72 |
+
|
| 73 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 74 |
+
super(Bottleneck, self).__init__()
|
| 75 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 76 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 77 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 78 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 79 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 80 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 81 |
+
self.relu = nn.ReLU(inplace=True)
|
| 82 |
+
self.downsample = downsample
|
| 83 |
+
self.stride = stride
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
identity = x
|
| 87 |
+
|
| 88 |
+
out = self.conv1(x)
|
| 89 |
+
out = self.bn1(out)
|
| 90 |
+
out = self.relu(out)
|
| 91 |
+
|
| 92 |
+
out = self.conv2(out)
|
| 93 |
+
out = self.bn2(out)
|
| 94 |
+
out = self.relu(out)
|
| 95 |
+
|
| 96 |
+
out = self.conv3(out)
|
| 97 |
+
out = self.bn3(out)
|
| 98 |
+
|
| 99 |
+
if self.downsample is not None:
|
| 100 |
+
identity = self.downsample(x)
|
| 101 |
+
|
| 102 |
+
out += identity
|
| 103 |
+
out = self.relu(out)
|
| 104 |
+
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ResNet(nn.Module):
|
| 109 |
+
|
| 110 |
+
def __init__(self, block, layers, num_classes=1, zero_init_residual=False):
|
| 111 |
+
super(ResNet, self).__init__()
|
| 112 |
+
|
| 113 |
+
self.unfoldSize = 2
|
| 114 |
+
self.unfoldIndex = 0
|
| 115 |
+
assert self.unfoldSize > 1
|
| 116 |
+
assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
|
| 117 |
+
self.inplanes = 64
|
| 118 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 119 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 120 |
+
self.relu = nn.ReLU(inplace=True)
|
| 121 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 122 |
+
self.layer1 = self._make_layer(block, 64 , layers[0])
|
| 123 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 124 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 125 |
+
# self.fc1 = nn.Linear(512 * block.expansion, 1)
|
| 126 |
+
self.fc1 = nn.Linear(512, num_classes)
|
| 127 |
+
|
| 128 |
+
for m in self.modules():
|
| 129 |
+
if isinstance(m, nn.Conv2d):
|
| 130 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 131 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 132 |
+
nn.init.constant_(m.weight, 1)
|
| 133 |
+
nn.init.constant_(m.bias, 0)
|
| 134 |
+
|
| 135 |
+
# Zero-initialize the last BN in each residual branch,
|
| 136 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 137 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 138 |
+
if zero_init_residual:
|
| 139 |
+
for m in self.modules():
|
| 140 |
+
if isinstance(m, Bottleneck):
|
| 141 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 142 |
+
elif isinstance(m, BasicBlock):
|
| 143 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 144 |
+
|
| 145 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 146 |
+
downsample = None
|
| 147 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 148 |
+
downsample = nn.Sequential(
|
| 149 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 150 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
layers = []
|
| 154 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 155 |
+
self.inplanes = planes * block.expansion
|
| 156 |
+
for _ in range(1, blocks):
|
| 157 |
+
layers.append(block(self.inplanes, planes))
|
| 158 |
+
|
| 159 |
+
return nn.Sequential(*layers)
|
| 160 |
+
def interpolate(self, img, factor):
|
| 161 |
+
return F.interpolate(F.interpolate(img, scale_factor=factor, mode='nearest', recompute_scale_factor=True), scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
# n,c,w,h = x.shape
|
| 164 |
+
# if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
|
| 165 |
+
# if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
|
| 166 |
+
# factor = 0.5
|
| 167 |
+
# x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
|
| 168 |
+
# x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
|
| 169 |
+
# NPR = x - x_re
|
| 170 |
+
# n,c,w,h = x.shape
|
| 171 |
+
# if w%2 == 1 : x = x[:,:,:-1,:]
|
| 172 |
+
# if h%2 == 1 : x = x[:,:,:,:-1]
|
| 173 |
+
b, t, _, h, w = x.shape
|
| 174 |
+
x = x.view(b * t, 3, h, w)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
NPR = x - self.interpolate(x, 0.5)
|
| 178 |
+
|
| 179 |
+
x = self.conv1(NPR*2.0/3.0)
|
| 180 |
+
x = self.bn1(x)
|
| 181 |
+
x = self.relu(x)
|
| 182 |
+
x = self.maxpool(x)
|
| 183 |
+
|
| 184 |
+
x = self.layer1(x)
|
| 185 |
+
x = self.layer2(x)
|
| 186 |
+
|
| 187 |
+
x = self.avgpool(x)
|
| 188 |
+
x = x.view(x.size(0), -1)
|
| 189 |
+
x = self.fc1(x)
|
| 190 |
+
x = x.view(b, t, -1)
|
| 191 |
+
x = x.mean(1)
|
| 192 |
+
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
def infer(self, x):
|
| 196 |
+
# n,c,w,h = x.shape
|
| 197 |
+
# if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
|
| 198 |
+
# if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
|
| 199 |
+
# factor = 0.5
|
| 200 |
+
# x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
|
| 201 |
+
# x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
|
| 202 |
+
# NPR = x - x_re
|
| 203 |
+
# n,c,w,h = x.shape
|
| 204 |
+
# if w%2 == 1 : x = x[:,:,:-1,:]
|
| 205 |
+
# if h%2 == 1 : x = x[:,:,:,:-1]
|
| 206 |
+
b, t, _, h, w = x.shape
|
| 207 |
+
x = x.view(b * t, 3, h, w)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
NPR = x - self.interpolate(x, 0.5)
|
| 211 |
+
|
| 212 |
+
x = self.conv1(NPR*2.0/3.0)
|
| 213 |
+
x = self.bn1(x)
|
| 214 |
+
x = self.relu(x)
|
| 215 |
+
x = self.maxpool(x)
|
| 216 |
+
|
| 217 |
+
x = self.layer1(x)
|
| 218 |
+
x = self.layer2(x)
|
| 219 |
+
|
| 220 |
+
x = self.avgpool(x)
|
| 221 |
+
x = x.view(x.size(0), -1)
|
| 222 |
+
x = self.fc1(x)
|
| 223 |
+
x = x.view(b, t, -1)
|
| 224 |
+
x = x.mean(1)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def resnet18_npr(pretrained=False, **kwargs):
|
| 229 |
+
"""Constructs a ResNet-18 model.
|
| 230 |
+
Args:
|
| 231 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 232 |
+
"""
|
| 233 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| 234 |
+
if pretrained:
|
| 235 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
| 236 |
+
return model
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def resnet34_npr(pretrained=False, **kwargs):
|
| 240 |
+
"""Constructs a ResNet-34 model.
|
| 241 |
+
Args:
|
| 242 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 243 |
+
"""
|
| 244 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
| 245 |
+
if pretrained:
|
| 246 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
| 247 |
+
return model
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def resnet50_npr(pretrained=False, **kwargs):
|
| 251 |
+
"""Constructs a ResNet-50 model.
|
| 252 |
+
Args:
|
| 253 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 254 |
+
"""
|
| 255 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 256 |
+
if pretrained:
|
| 257 |
+
model_state = torch.load('/ossfs/workspace/aigc_video/weights/resnet50-19c8e357.pth')
|
| 258 |
+
model.load_state_dict(model_state, strict=False)
|
| 259 |
+
return model
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def resnet101_npr(pretrained=False, **kwargs):
|
| 263 |
+
"""Constructs a ResNet-101 model.
|
| 264 |
+
Args:
|
| 265 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 266 |
+
"""
|
| 267 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| 268 |
+
if pretrained:
|
| 269 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def resnet152_npr(pretrained=False, **kwargs):
|
| 274 |
+
"""Constructs a ResNet-152 model.
|
| 275 |
+
Args:
|
| 276 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 277 |
+
"""
|
| 278 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
| 279 |
+
if pretrained:
|
| 280 |
+
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
| 281 |
+
return model
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
models/STIL.py
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
STIL: Spatiotemporal inconsistency learning for deepfake video detection @ ACM MM'2021
|
| 3 |
+
Copyright (c) Tencent Youtu Lab and its affiliates.
|
| 4 |
+
Modified by Zhiyuan Yan from https://github.com/Tencent/TFace?tab=readme-ov-file
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import datetime
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sklearn import metrics
|
| 12 |
+
from typing import Union
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
import torch.utils.model_zoo as model_zoo
|
| 20 |
+
from torch.nn import DataParallel
|
| 21 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 22 |
+
|
| 23 |
+
class ISM_Module(nn.Module):
|
| 24 |
+
def __init__(self, k_size=3):
|
| 25 |
+
"""The Information Supplement Module (ISM).
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
k_size (int, optional): Conv1d kernel_size . Defaults to 3.
|
| 29 |
+
"""
|
| 30 |
+
super(ISM_Module, self).__init__()
|
| 31 |
+
|
| 32 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 33 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False)
|
| 34 |
+
self.sigmoid = nn.Sigmoid()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
"""
|
| 39 |
+
Args:
|
| 40 |
+
x (torch.tensor): Input tensor of shape (nt, c, h, w)
|
| 41 |
+
"""
|
| 42 |
+
y = self.avg_pool(x)
|
| 43 |
+
y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1)
|
| 44 |
+
y = self.sigmoid(y)
|
| 45 |
+
return x * y.expand_as(x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TIM_Module(nn.Module):
|
| 49 |
+
def __init__(self, in_channels, reduction=16, n_segment=8, return_attn=False):
|
| 50 |
+
"""The Temporal Inconsistency Module (TIM).
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
in_channels (int): Input channel number.
|
| 54 |
+
reduction (int, optional): Channel compression ratio r in the split operation.. Defaults to 16.
|
| 55 |
+
n_segment (int, optional): Number of input frames.. Defaults to 8.
|
| 56 |
+
return_attn (bool, optional): Whether to return the attention part. Defaults to False.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
super(TIM_Module, self).__init__()
|
| 60 |
+
self.in_channels = in_channels
|
| 61 |
+
self.reduction = reduction
|
| 62 |
+
self.n_segment = n_segment
|
| 63 |
+
self.return_attn = return_attn
|
| 64 |
+
|
| 65 |
+
self.reduced_channels = self.in_channels // self.reduction
|
| 66 |
+
|
| 67 |
+
# first conv to shrink input channels
|
| 68 |
+
self.conv1 = nn.Conv2d(self.in_channels, self.reduced_channels, kernel_size=1, padding=0, bias=False)
|
| 69 |
+
self.bn1 = nn.BatchNorm2d(self.reduced_channels)
|
| 70 |
+
|
| 71 |
+
self.conv_ht = nn.Conv2d(self.reduced_channels, self.reduced_channels,
|
| 72 |
+
kernel_size=(3, 1), padding=(1, 0), groups=self.reduced_channels, bias=False)
|
| 73 |
+
self.conv_tw = nn.Conv2d(self.reduced_channels, self.reduced_channels,
|
| 74 |
+
kernel_size=(1, 3), padding=(0, 1), groups=self.reduced_channels, bias=False)
|
| 75 |
+
|
| 76 |
+
self.avg_pool_ht = nn.AvgPool2d((2, 1), (2, 1))
|
| 77 |
+
self.avg_pool_tw = nn.AvgPool2d((1, 2), (1, 2))
|
| 78 |
+
|
| 79 |
+
# HTIE in two directions
|
| 80 |
+
self.htie_conv1 = nn.Sequential(
|
| 81 |
+
nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False),
|
| 82 |
+
nn.BatchNorm2d(self.reduced_channels),
|
| 83 |
+
)
|
| 84 |
+
self.vtie_conv1 = nn.Sequential(
|
| 85 |
+
nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False),
|
| 86 |
+
nn.BatchNorm2d(self.reduced_channels),
|
| 87 |
+
)
|
| 88 |
+
self.htie_conv2 = nn.Sequential(
|
| 89 |
+
nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False),
|
| 90 |
+
nn.BatchNorm2d(self.reduced_channels),
|
| 91 |
+
)
|
| 92 |
+
self.vtie_conv2 = nn.Sequential(
|
| 93 |
+
nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False),
|
| 94 |
+
nn.BatchNorm2d(self.reduced_channels),
|
| 95 |
+
)
|
| 96 |
+
self.ht_up_conv = nn.Sequential(
|
| 97 |
+
nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False),
|
| 98 |
+
nn.BatchNorm2d(self.in_channels)
|
| 99 |
+
)
|
| 100 |
+
self.tw_up_conv = nn.Sequential(
|
| 101 |
+
nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False),
|
| 102 |
+
nn.BatchNorm2d(self.in_channels)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.sigmoid = nn.Sigmoid()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def feat_ht(self, feat):
|
| 109 |
+
"""The H-T branch in the TIM module.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction)
|
| 113 |
+
|
| 114 |
+
"""
|
| 115 |
+
n, t, c, h, w = feat.size()
|
| 116 |
+
# [n, t, c, h, w] -> [n, w, c, h, t] -> [nw, c, h, t]
|
| 117 |
+
feat_h = feat.permute(0, 4, 2, 3, 1).contiguous().view(-1, c, h, t)
|
| 118 |
+
|
| 119 |
+
# [nw, c, h, t-1]
|
| 120 |
+
feat_h_fwd, _ = feat_h.split([self.n_segment-1, 1], dim=3)
|
| 121 |
+
feat_h_conv = self.conv_ht(feat_h)
|
| 122 |
+
_, feat_h_conv_fwd = feat_h_conv.split([1, self.n_segment-1], dim=3)
|
| 123 |
+
|
| 124 |
+
diff_feat_fwd = feat_h_conv_fwd - feat_h_fwd
|
| 125 |
+
diff_feat_fwd = F.pad(diff_feat_fwd, [0, 1], value=0) # [nw, c, h, t]
|
| 126 |
+
|
| 127 |
+
# HTIE, down_up branch
|
| 128 |
+
diff_feat_fwd1 = self.avg_pool_ht(diff_feat_fwd) # [nw, c, h//2, t]
|
| 129 |
+
diff_feat_fwd1 = self.htie_conv1(diff_feat_fwd1) # [nw, c, h//2, t]
|
| 130 |
+
diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nw, c, h, t]
|
| 131 |
+
# HTIE, direct conv branch
|
| 132 |
+
diff_feat_fwd2 = self.htie_conv2(diff_feat_fwd) # [nw, c, h, t]
|
| 133 |
+
|
| 134 |
+
# [nw, C, h, t]
|
| 135 |
+
feat_ht_out = self.ht_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2)
|
| 136 |
+
feat_ht_out = self.sigmoid(feat_ht_out) - 0.5
|
| 137 |
+
# [nw, C, h, t] -> [n, w, C, h, t] -> [n, t, C, h, w]
|
| 138 |
+
feat_ht_out = feat_ht_out.view(n, w, self.in_channels, h, t).permute(0, 4, 2, 3, 1).contiguous()
|
| 139 |
+
# [n, t, C, h, w] -> [nt, C, h, w]
|
| 140 |
+
feat_ht_out = feat_ht_out.view(-1, self.in_channels, h, w)
|
| 141 |
+
|
| 142 |
+
return feat_ht_out
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def feat_tw(self, feat):
|
| 146 |
+
"""The T-W branch in the TIM module.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction)
|
| 150 |
+
"""
|
| 151 |
+
n, t, c, h, w = feat.size()
|
| 152 |
+
# [n, t, c, h, w] -> [n, h, c, t, w] -> [nh, c, t, w]
|
| 153 |
+
feat_w = feat.permute(0, 3, 2, 1, 4).contiguous().view(-1, c, t, w)
|
| 154 |
+
|
| 155 |
+
# [nh, c, t-1, w]
|
| 156 |
+
feat_w_fwd, _ = feat_w.split([self.n_segment-1, 1], dim=2)
|
| 157 |
+
feat_w_conv = self.conv_tw(feat_w)
|
| 158 |
+
_, feat_w_conv_fwd = feat_w_conv.split([1, self.n_segment-1], dim=2)
|
| 159 |
+
|
| 160 |
+
diff_feat_fwd = feat_w_conv_fwd - feat_w_fwd
|
| 161 |
+
diff_feat_fwd = F.pad(diff_feat_fwd, [0, 0, 0, 1], value=0) # [nh, c, t, w]
|
| 162 |
+
|
| 163 |
+
# VTIE, down_up branch
|
| 164 |
+
diff_feat_fwd1 = self.avg_pool_tw(diff_feat_fwd) # [nh, c, t, w//2]
|
| 165 |
+
diff_feat_fwd1 = self.vtie_conv1(diff_feat_fwd1) # [nh, c, t, w//2]
|
| 166 |
+
diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nh, c, t, w]
|
| 167 |
+
# VTIE, direct conv branch
|
| 168 |
+
diff_feat_fwd2 = self.vtie_conv2(diff_feat_fwd) # [nh, c, t, w]
|
| 169 |
+
|
| 170 |
+
# [nh, C, t, w]
|
| 171 |
+
feat_tw_out = self.tw_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2)
|
| 172 |
+
feat_tw_out = self.sigmoid(feat_tw_out) - 0.5
|
| 173 |
+
# [nh, C, t, w] -> [n, h, C, t, w] -> [n, t, C, h, W]
|
| 174 |
+
feat_tw_out = feat_tw_out.view(n, h, self.in_channels, t, w).permute(0, 3, 2, 1, 4).contiguous()
|
| 175 |
+
# [n, t, C, h, w] -> [nt, C, h, w]
|
| 176 |
+
feat_tw_out = feat_tw_out.view(-1, self.in_channels, h, w)
|
| 177 |
+
|
| 178 |
+
return feat_tw_out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.tensor): Input with shape [nt, c, h, w]
|
| 185 |
+
"""
|
| 186 |
+
# [nt, c, h, w] -> [nt, c//r, h, w]
|
| 187 |
+
bottleneck = self.conv1(x)
|
| 188 |
+
bottleneck = self.bn1(bottleneck)
|
| 189 |
+
# [nt, c//r, h, w] -> [n, t, c//r, h, w]
|
| 190 |
+
bottleneck = bottleneck.view((-1, self.n_segment) + bottleneck.size()[1:])
|
| 191 |
+
|
| 192 |
+
F_h = self.feat_ht(bottleneck) # [nt, c, h, w]
|
| 193 |
+
F_w = self.feat_tw(bottleneck) # [nt, c, h, w]
|
| 194 |
+
|
| 195 |
+
att = 0.5 * (F_h + F_w)
|
| 196 |
+
|
| 197 |
+
if self.return_attn:
|
| 198 |
+
return att
|
| 199 |
+
|
| 200 |
+
y2 = x + x * att
|
| 201 |
+
|
| 202 |
+
return y2
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class ShiftModule(nn.Module):
|
| 206 |
+
def __init__(self, input_channels, n_segment=8, n_div=8, mode='shift'):
|
| 207 |
+
"""A depth-wise conv on the segment level.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
input_channels (int): Input channel number.
|
| 211 |
+
n_segment (int, optional): Number of input frames.. Defaults to 8.
|
| 212 |
+
n_div (int, optional): How many channels to group as a fold.. Defaults to 8.
|
| 213 |
+
mode (str, optional): One of "shift", "fixed", "norm". Defaults to 'shift'.
|
| 214 |
+
"""
|
| 215 |
+
super(ShiftModule, self).__init__()
|
| 216 |
+
self.input_channels = input_channels
|
| 217 |
+
self.n_segment = n_segment
|
| 218 |
+
self.fold_div = n_div
|
| 219 |
+
self.fold = self.input_channels // self.fold_div
|
| 220 |
+
self.conv = nn.Conv1d(self.fold_div*self.fold, self.fold_div*self.fold,
|
| 221 |
+
kernel_size=3, padding=1, groups=self.fold_div*self.fold,
|
| 222 |
+
bias=False)
|
| 223 |
+
|
| 224 |
+
if mode == 'shift':
|
| 225 |
+
self.conv.weight.requires_grad = True
|
| 226 |
+
self.conv.weight.data.zero_()
|
| 227 |
+
# shift left
|
| 228 |
+
self.conv.weight.data[:self.fold, 0, 2] = 1
|
| 229 |
+
# shift right
|
| 230 |
+
self.conv.weight.data[self.fold: 2 * self.fold, 0, 0] = 1
|
| 231 |
+
if 2*self.fold < self.input_channels:
|
| 232 |
+
self.conv.weight.data[2 * self.fold:, 0, 1] = 1 # fixed
|
| 233 |
+
elif mode == 'fixed':
|
| 234 |
+
self.conv.weight.requires_grad = True
|
| 235 |
+
self.conv.weight.data.zero_()
|
| 236 |
+
self.conv.weight.data[:, 0, 1] = 1 # fixed
|
| 237 |
+
elif mode == 'norm':
|
| 238 |
+
self.conv.weight.requires_grad = True
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
"""
|
| 243 |
+
Args:
|
| 244 |
+
x (torch.tensor): Input with shape [nt, c, h, w]
|
| 245 |
+
"""
|
| 246 |
+
nt, c, h, w = x.size()
|
| 247 |
+
n_batch = nt // self.n_segment
|
| 248 |
+
x = x.view(n_batch, self.n_segment, c, h, w)
|
| 249 |
+
# (n, h, w, c, t)
|
| 250 |
+
x = x.permute(0, 3, 4, 2, 1)
|
| 251 |
+
x = x.contiguous().view(n_batch*h*w, c, self.n_segment)
|
| 252 |
+
# (n*h*w, c, t)
|
| 253 |
+
x = self.conv(x)
|
| 254 |
+
x = x.view(n_batch, h, w, c, self.n_segment)
|
| 255 |
+
# (n, t, c, h, w)
|
| 256 |
+
x = x.permute(0, 4, 3, 1, 2)
|
| 257 |
+
x = x.contiguous().view(nt, c, h, w)
|
| 258 |
+
return x
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class SCConv(nn.Module):
|
| 262 |
+
"""
|
| 263 |
+
The spatial conv in SIM. Used in SCBottleneck
|
| 264 |
+
"""
|
| 265 |
+
def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer):
|
| 266 |
+
super(SCConv, self).__init__()
|
| 267 |
+
self.f_w = nn.Sequential(
|
| 268 |
+
nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
|
| 269 |
+
nn.Conv2d(inplanes, planes, kernel_size=(1,3), stride=1,
|
| 270 |
+
padding=(0,padding), dilation=(1,dilation),
|
| 271 |
+
groups=groups, bias=False),
|
| 272 |
+
norm_layer(planes), nn.ReLU(inplace=True))
|
| 273 |
+
self.f_h = nn.Sequential(
|
| 274 |
+
# nn.AvgPool2d(kernel_size=(pooling_r,1), stride=(pooling_r,1)),
|
| 275 |
+
nn.Conv2d(inplanes, planes, kernel_size=(3,1), stride=1,
|
| 276 |
+
padding=(padding,0), dilation=(dilation,1),
|
| 277 |
+
groups=groups, bias=False),
|
| 278 |
+
norm_layer(planes),
|
| 279 |
+
)
|
| 280 |
+
self.k3 = nn.Sequential(
|
| 281 |
+
nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
|
| 282 |
+
padding=padding, dilation=dilation,
|
| 283 |
+
groups=groups, bias=False),
|
| 284 |
+
norm_layer(planes),
|
| 285 |
+
)
|
| 286 |
+
self.k4 = nn.Sequential(
|
| 287 |
+
nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
| 288 |
+
padding=padding, dilation=dilation,
|
| 289 |
+
groups=groups, bias=False),
|
| 290 |
+
norm_layer(planes),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
identity = x
|
| 296 |
+
|
| 297 |
+
# sigmoid(identity + k2)
|
| 298 |
+
out = torch.sigmoid(
|
| 299 |
+
torch.add(
|
| 300 |
+
identity,
|
| 301 |
+
F.interpolate(self.f_h(self.f_w(x)), identity.size()[2:])
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2)
|
| 305 |
+
s2t_info = out
|
| 306 |
+
out = self.k4(out) # k4
|
| 307 |
+
|
| 308 |
+
return out, s2t_info
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class SCBottleneck(nn.Module):
|
| 312 |
+
"""
|
| 313 |
+
SCNet SCBottleneck. Variant for ResNet Bottlenect.
|
| 314 |
+
"""
|
| 315 |
+
expansion = 4
|
| 316 |
+
pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv.
|
| 317 |
+
|
| 318 |
+
def __init__(self, num_segments, inplanes, planes, stride=1, downsample=None,
|
| 319 |
+
cardinality=1, bottleneck_width=32,
|
| 320 |
+
avd=False, dilation=1, is_first=False,
|
| 321 |
+
norm_layer=None):
|
| 322 |
+
super(SCBottleneck, self).__init__()
|
| 323 |
+
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
|
| 324 |
+
self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
| 325 |
+
self.bn1_a = norm_layer(group_width)
|
| 326 |
+
self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
| 327 |
+
self.bn1_b = norm_layer(group_width)
|
| 328 |
+
self.avd = avd and (stride > 1 or is_first)
|
| 329 |
+
self.tim = TIM_Module(group_width, n_segment=num_segments)
|
| 330 |
+
self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift')
|
| 331 |
+
self.inplanes = inplanes
|
| 332 |
+
self.planes = planes
|
| 333 |
+
self.ism = ISM_Module()
|
| 334 |
+
self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift')
|
| 335 |
+
|
| 336 |
+
if self.avd:
|
| 337 |
+
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
|
| 338 |
+
stride = 1
|
| 339 |
+
|
| 340 |
+
self.k1 = nn.Sequential(
|
| 341 |
+
nn.Conv2d(
|
| 342 |
+
group_width, group_width, kernel_size=3, stride=stride,
|
| 343 |
+
padding=dilation, dilation=dilation,
|
| 344 |
+
groups=cardinality, bias=False),
|
| 345 |
+
norm_layer(group_width),
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
self.scconv = SCConv(
|
| 349 |
+
group_width, group_width, stride=stride,
|
| 350 |
+
padding=dilation, dilation=dilation,
|
| 351 |
+
groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer)
|
| 352 |
+
|
| 353 |
+
self.conv3 = nn.Conv2d(
|
| 354 |
+
group_width * 2, planes * 4, kernel_size=1, bias=False)
|
| 355 |
+
self.bn3 = norm_layer(planes*4)
|
| 356 |
+
|
| 357 |
+
self.relu = nn.ReLU(inplace=True)
|
| 358 |
+
self.downsample = downsample
|
| 359 |
+
self.dilation = dilation
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def forward(self, x):
|
| 363 |
+
"""Forward func which splits the input into two branchs a and b.
|
| 364 |
+
a: trace features
|
| 365 |
+
b: spatial features
|
| 366 |
+
"""
|
| 367 |
+
residual = x
|
| 368 |
+
|
| 369 |
+
out_a = self.relu(self.bn1_a(self.conv1_a(x)))
|
| 370 |
+
out_b = self.relu(self.bn1_b(self.conv1_b(x)))
|
| 371 |
+
|
| 372 |
+
# spatial representations
|
| 373 |
+
out_b, s2t_info = self.scconv(out_b)
|
| 374 |
+
out_b = self.relu(out_b)
|
| 375 |
+
|
| 376 |
+
# trace features
|
| 377 |
+
out_a = self.tim(out_a)
|
| 378 |
+
out_a = self.shift(out_a + self.ism(s2t_info))
|
| 379 |
+
out_a = self.relu(self.k1(out_a))
|
| 380 |
+
|
| 381 |
+
if self.avd:
|
| 382 |
+
out_a = self.avd_layer(out_a)
|
| 383 |
+
out_b = self.avd_layer(out_b)
|
| 384 |
+
|
| 385 |
+
out = self.conv3(torch.cat([out_a, out_b], dim=1))
|
| 386 |
+
out = self.bn3(out)
|
| 387 |
+
|
| 388 |
+
if self.downsample is not None:
|
| 389 |
+
residual = self.downsample(x)
|
| 390 |
+
|
| 391 |
+
out += residual
|
| 392 |
+
out = self.relu(out)
|
| 393 |
+
|
| 394 |
+
return out
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class SCNet(nn.Module):
|
| 398 |
+
def __init__(self, num_segments, block, layers, groups=1, bottleneck_width=32,
|
| 399 |
+
num_classes=1000, dilated=False, dilation=1,
|
| 400 |
+
deep_stem=False, stem_width=64, avg_down=False,
|
| 401 |
+
avd=False, norm_layer=nn.BatchNorm2d):
|
| 402 |
+
"""SCNet, a variant based on ResNet.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
num_segments (int):
|
| 406 |
+
Number of input frames.
|
| 407 |
+
block (class):
|
| 408 |
+
Class for the residual block.
|
| 409 |
+
layers (list):
|
| 410 |
+
Number of layers in each block.
|
| 411 |
+
num_classes (int, optional):
|
| 412 |
+
Number of classification class.. Defaults to 1000.
|
| 413 |
+
dilated (bool, optional):
|
| 414 |
+
Whether to apply dilation conv. Defaults to False.
|
| 415 |
+
dilation (int, optional):
|
| 416 |
+
The dilation parameter in dilation conv. Defaults to 1.
|
| 417 |
+
deep_stem (bool, optional):
|
| 418 |
+
Whether to replace 7x7 conv in input stem with 3 3x3 conv. Defaults to False.
|
| 419 |
+
stem_width (int, optional):
|
| 420 |
+
Stem width in conv1 stem. Defaults to 64.
|
| 421 |
+
avg_down (bool, optional):
|
| 422 |
+
Whether to use AvgPool instead of stride conv when downsampling in the bottleneck. Defaults to False.
|
| 423 |
+
avd (bool, optional):
|
| 424 |
+
The avd parameter for the block Defaults to False.
|
| 425 |
+
norm_layer (class, optional):
|
| 426 |
+
Normalization layer. Defaults to nn.BatchNorm2d.
|
| 427 |
+
"""
|
| 428 |
+
self.cardinality = groups
|
| 429 |
+
self.bottleneck_width = bottleneck_width
|
| 430 |
+
# ResNet-D params
|
| 431 |
+
self.inplanes = stem_width*2 if deep_stem else 64
|
| 432 |
+
self.avg_down = avg_down
|
| 433 |
+
self.avd = avd
|
| 434 |
+
self.num_segments = num_segments
|
| 435 |
+
|
| 436 |
+
super(SCNet, self).__init__()
|
| 437 |
+
conv_layer = nn.Conv2d
|
| 438 |
+
if deep_stem:
|
| 439 |
+
self.conv1 = nn.Sequential(
|
| 440 |
+
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
|
| 441 |
+
norm_layer(stem_width),
|
| 442 |
+
nn.ReLU(inplace=True),
|
| 443 |
+
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
|
| 444 |
+
norm_layer(stem_width),
|
| 445 |
+
nn.ReLU(inplace=True),
|
| 446 |
+
conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False),
|
| 447 |
+
)
|
| 448 |
+
else:
|
| 449 |
+
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
|
| 450 |
+
bias=False)
|
| 451 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 452 |
+
self.relu = nn.ReLU(inplace=True)
|
| 453 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 454 |
+
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
|
| 455 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
|
| 456 |
+
if dilated or dilation == 4:
|
| 457 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
| 458 |
+
dilation=2, norm_layer=norm_layer)
|
| 459 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
| 460 |
+
dilation=4, norm_layer=norm_layer)
|
| 461 |
+
elif dilation==2:
|
| 462 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 463 |
+
dilation=1, norm_layer=norm_layer)
|
| 464 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
| 465 |
+
dilation=2, norm_layer=norm_layer)
|
| 466 |
+
else:
|
| 467 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 468 |
+
norm_layer=norm_layer)
|
| 469 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 470 |
+
norm_layer=norm_layer)
|
| 471 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 472 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 473 |
+
|
| 474 |
+
for m in self.modules():
|
| 475 |
+
if isinstance(m, nn.Conv2d):
|
| 476 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 477 |
+
elif isinstance(m, norm_layer):
|
| 478 |
+
nn.init.constant_(m.weight, 1)
|
| 479 |
+
nn.init.constant_(m.bias, 0)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
|
| 483 |
+
is_first=True):
|
| 484 |
+
"""
|
| 485 |
+
Core function to build layers.
|
| 486 |
+
"""
|
| 487 |
+
downsample = None
|
| 488 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 489 |
+
down_layers = []
|
| 490 |
+
if self.avg_down:
|
| 491 |
+
if dilation == 1:
|
| 492 |
+
down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
|
| 493 |
+
ceil_mode=True, count_include_pad=False))
|
| 494 |
+
else:
|
| 495 |
+
down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
|
| 496 |
+
ceil_mode=True, count_include_pad=False))
|
| 497 |
+
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 498 |
+
kernel_size=1, stride=1, bias=False))
|
| 499 |
+
else:
|
| 500 |
+
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 501 |
+
kernel_size=1, stride=stride, bias=False))
|
| 502 |
+
down_layers.append(norm_layer(planes * block.expansion))
|
| 503 |
+
downsample = nn.Sequential(*down_layers)
|
| 504 |
+
|
| 505 |
+
layers = []
|
| 506 |
+
if dilation == 1 or dilation == 2:
|
| 507 |
+
layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample,
|
| 508 |
+
cardinality=self.cardinality,
|
| 509 |
+
bottleneck_width=self.bottleneck_width,
|
| 510 |
+
avd=self.avd, dilation=1, is_first=is_first,
|
| 511 |
+
norm_layer=norm_layer))
|
| 512 |
+
elif dilation == 4:
|
| 513 |
+
layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample,
|
| 514 |
+
cardinality=self.cardinality,
|
| 515 |
+
bottleneck_width=self.bottleneck_width,
|
| 516 |
+
avd=self.avd, dilation=2, is_first=is_first,
|
| 517 |
+
norm_layer=norm_layer))
|
| 518 |
+
else:
|
| 519 |
+
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
| 520 |
+
|
| 521 |
+
self.inplanes = planes * block.expansion
|
| 522 |
+
for i in range(1, blocks):
|
| 523 |
+
layers.append(block(self.num_segments, self.inplanes, planes,
|
| 524 |
+
cardinality=self.cardinality,
|
| 525 |
+
bottleneck_width=self.bottleneck_width,
|
| 526 |
+
avd=self.avd, dilation=dilation,
|
| 527 |
+
norm_layer=norm_layer))
|
| 528 |
+
|
| 529 |
+
return nn.Sequential(*layers)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def features(self, input):
|
| 533 |
+
x = self.conv1(input)
|
| 534 |
+
x = self.bn1(x)
|
| 535 |
+
x = self.relu(x)
|
| 536 |
+
x = self.maxpool(x)
|
| 537 |
+
|
| 538 |
+
x = self.layer1(x)
|
| 539 |
+
x = self.layer2(x)
|
| 540 |
+
x = self.layer3(x)
|
| 541 |
+
x = self.layer4(x)
|
| 542 |
+
return x
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def logits(self, features):
|
| 546 |
+
x = self.avgpool(features)
|
| 547 |
+
x = x.view(x.size(0), -1)
|
| 548 |
+
x = self.fc(x)
|
| 549 |
+
return x
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def forward(self, input):
|
| 553 |
+
x = self.features(input)
|
| 554 |
+
x = self.logits(x)
|
| 555 |
+
return x
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def scnet50_v1d(num_segments, pretrained=False, **kwargs):
|
| 559 |
+
"""
|
| 560 |
+
SCNet backbone, which is based on ResNet-50
|
| 561 |
+
Args:
|
| 562 |
+
num_segments (int):
|
| 563 |
+
Number of input frames.
|
| 564 |
+
pretrained (bool, optional):
|
| 565 |
+
Whether to load pretrained weights.
|
| 566 |
+
"""
|
| 567 |
+
model = SCNet(num_segments, SCBottleneck, [3, 4, 6, 3],
|
| 568 |
+
deep_stem=True, stem_width=32, avg_down=True,
|
| 569 |
+
avd=True, **kwargs)
|
| 570 |
+
if pretrained:
|
| 571 |
+
model_state = torch.load('/ossfs/workspace/GenVideo/pretrained_weights/scnet50_v1d-4109d1e1.pth')
|
| 572 |
+
model.load_state_dict(model_state, strict=False)
|
| 573 |
+
|
| 574 |
+
return model
|
| 575 |
+
|
| 576 |
+
class STIL_Model(nn.Module):
|
| 577 |
+
def __init__(self,
|
| 578 |
+
num_class=1,
|
| 579 |
+
num_segment=8,
|
| 580 |
+
add_softmax=False,
|
| 581 |
+
**kwargs):
|
| 582 |
+
""" Model Builder for STIL model.
|
| 583 |
+
STIL: Spatiotemporal Inconsistency Learning for DeepFake Video Detection (https://arxiv.org/abs/2109.01860)
|
| 584 |
+
|
| 585 |
+
Args:
|
| 586 |
+
num_class (int, optional): Number of classes. Defaults to 2.
|
| 587 |
+
num_segment (int, optional): Number of segments (frames) fed to the model. Defaults to 8.
|
| 588 |
+
add_softmax (bool, optional): Whether to add softmax layer at the end. Defaults to False.
|
| 589 |
+
"""
|
| 590 |
+
super().__init__()
|
| 591 |
+
|
| 592 |
+
self.num_class = num_class
|
| 593 |
+
self.num_segment = num_segment
|
| 594 |
+
self.build_model()
|
| 595 |
+
|
| 596 |
+
def build_model(self):
|
| 597 |
+
self.base_model = scnet50_v1d(self.num_segment, pretrained=True)
|
| 598 |
+
|
| 599 |
+
fc_feature_dim = self.base_model.fc.in_features
|
| 600 |
+
self.base_model.fc = nn.Linear(fc_feature_dim, self.num_class)
|
| 601 |
+
|
| 602 |
+
def forward(self, x):
|
| 603 |
+
"""Forward pass of the model.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
x (torch.tensor): input tensor of shape (n, t*c, h, w). n is the batch_size, t is num_segment
|
| 607 |
+
"""
|
| 608 |
+
# img channel default to 3
|
| 609 |
+
img_channel = 3
|
| 610 |
+
|
| 611 |
+
# x: [n, tc, h, w] -> [nt, c, h, w]
|
| 612 |
+
# out: [nt, num_class]
|
| 613 |
+
out = self.base_model(
|
| 614 |
+
x.view((-1, img_channel) + x.size()[2:])
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
out = out.view(-1, self.num_segment, self.num_class) # [n, t, num_class]
|
| 618 |
+
out = out.mean(1, keepdim=False) # [n, num_class]
|
| 619 |
+
return out
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def set_segment(self, num_segment):
|
| 623 |
+
"""Change num_segment of the model.
|
| 624 |
+
Useful when the train and test want to feed different number of frames.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
num_segment (int): New number of segments.
|
| 628 |
+
"""
|
| 629 |
+
self.num_segment = num_segment
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class Det_STIL(nn.Module):
|
| 633 |
+
def __init__(self):
|
| 634 |
+
super(Det_STIL, self).__init__()
|
| 635 |
+
self.model = STIL_Model()
|
| 636 |
+
|
| 637 |
+
def forward(self, x):
|
| 638 |
+
b, t, _, h, w = x.shape
|
| 639 |
+
images = x.view(b, t*3, h, w)
|
| 640 |
+
x = self.model(images)
|
| 641 |
+
return x
|
models/TALL.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Swin Transformer
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Ze Liu
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.utils.checkpoint as checkpoint
|
| 11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 13 |
+
from timm.models.registry import register_model
|
| 14 |
+
from torch.hub import load_state_dict_from_url
|
| 15 |
+
import logging
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
_logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _cfg(url='', **kwargs):
|
| 23 |
+
return {
|
| 24 |
+
'url': url,
|
| 25 |
+
'num_classes': 1, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 26 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
| 27 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 28 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 29 |
+
**kwargs
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
default_cfgs = {
|
| 34 |
+
# patch models (my experiments)
|
| 35 |
+
'swin_base_in1k_patch4_224': _cfg(
|
| 36 |
+
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',
|
| 37 |
+
),
|
| 38 |
+
'swin_base_patch4_window7_224_22k': _cfg(
|
| 39 |
+
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
| 40 |
+
)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
class Mlp(nn.Module):
|
| 44 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 45 |
+
super().__init__()
|
| 46 |
+
out_features = out_features or in_features
|
| 47 |
+
hidden_features = hidden_features or in_features
|
| 48 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 49 |
+
self.act = act_layer()
|
| 50 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 51 |
+
self.drop = nn.Dropout(drop)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.fc1(x)
|
| 55 |
+
x = self.act(x)
|
| 56 |
+
x = self.drop(x)
|
| 57 |
+
x = self.fc2(x)
|
| 58 |
+
x = self.drop(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def window_partition(x, window_size):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
x: (B, H, W, C)
|
| 66 |
+
window_size (int): window size
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 70 |
+
"""
|
| 71 |
+
B, H, W, C = x.shape
|
| 72 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 73 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 74 |
+
return windows
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def window_reverse(windows, window_size, H, W):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 81 |
+
window_size (int): Window size
|
| 82 |
+
H (int): Height of image
|
| 83 |
+
W (int): Width of image
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
x: (B, H, W, C)
|
| 87 |
+
"""
|
| 88 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 89 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 90 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
# adjust image size for the pyramid structure (i.e. must be integer of 32)
|
| 94 |
+
def create_new_image_size(img_size, thumbnail_dim, window_size):
|
| 95 |
+
h, w = img_size * thumbnail_dim[0], img_size * thumbnail_dim[1]
|
| 96 |
+
new_h, new_w = h, w
|
| 97 |
+
dim = 32 * window_size
|
| 98 |
+
if h % (32 * window_size) != 0:
|
| 99 |
+
new_h = (h // dim + 1) * dim
|
| 100 |
+
if w % (32 * window_size) != 0:
|
| 101 |
+
new_w = (w // dim + 1) * dim
|
| 102 |
+
|
| 103 |
+
return (new_h//thumbnail_dim[0], new_w//thumbnail_dim[1])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class WindowAttention(nn.Module):
|
| 107 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 108 |
+
It supports both of shifted and non-shifted window.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
dim (int): Number of input channels.
|
| 112 |
+
window_size (tuple[int]): The height and width of the window.
|
| 113 |
+
num_heads (int): Number of attention heads.
|
| 114 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 115 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 116 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 117 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 121 |
+
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.dim = dim
|
| 124 |
+
self.window_size = window_size # Wh, Ww
|
| 125 |
+
self.num_heads = num_heads
|
| 126 |
+
head_dim = dim // num_heads
|
| 127 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 128 |
+
|
| 129 |
+
# define a parameter table of relative position bias
|
| 130 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 131 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 132 |
+
# get pair-wise relative position index for each token inside the window
|
| 133 |
+
coords_h = torch.arange(self.window_size[0])
|
| 134 |
+
coords_w = torch.arange(self.window_size[1])
|
| 135 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 136 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 137 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 138 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 139 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 140 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 141 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 142 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 143 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 144 |
+
|
| 145 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 146 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 147 |
+
self.proj = nn.Linear(dim, dim)
|
| 148 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 149 |
+
|
| 150 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 151 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 152 |
+
|
| 153 |
+
def forward(self, x, mask=None):
|
| 154 |
+
"""
|
| 155 |
+
Args:
|
| 156 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 157 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 158 |
+
"""
|
| 159 |
+
B_, N, C = x.shape
|
| 160 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 161 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 162 |
+
|
| 163 |
+
q = q * self.scale
|
| 164 |
+
attn = (q @ k.transpose(-2, -1))
|
| 165 |
+
|
| 166 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 167 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 168 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 169 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 170 |
+
|
| 171 |
+
if mask is not None:
|
| 172 |
+
nW = mask.shape[0]
|
| 173 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 174 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 175 |
+
attn = self.softmax(attn)
|
| 176 |
+
else:
|
| 177 |
+
attn = self.softmax(attn)
|
| 178 |
+
|
| 179 |
+
attn = self.attn_drop(attn)
|
| 180 |
+
|
| 181 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 182 |
+
x = self.proj(x)
|
| 183 |
+
x = self.proj_drop(x)
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
def extra_repr(self) -> str:
|
| 187 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
| 188 |
+
|
| 189 |
+
def flops(self, N):
|
| 190 |
+
# calculate flops for 1 window with token length of N
|
| 191 |
+
flops = 0
|
| 192 |
+
# qkv = self.qkv(x)
|
| 193 |
+
flops += N * self.dim * 3 * self.dim
|
| 194 |
+
# attn = (q @ k.transpose(-2, -1))
|
| 195 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
| 196 |
+
# x = (attn @ v)
|
| 197 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
| 198 |
+
# x = self.proj(x)
|
| 199 |
+
flops += N * self.dim * self.dim
|
| 200 |
+
return flops
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SwinTransformerBlock(nn.Module):
|
| 204 |
+
r""" Swin Transformer Block.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
dim (int): Number of input channels.
|
| 208 |
+
input_resolution (tuple[int]): Input resulotion.
|
| 209 |
+
num_heads (int): Number of attention heads.
|
| 210 |
+
window_size (int): Window size.
|
| 211 |
+
shift_size (int): Shift size for SW-MSA.
|
| 212 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 213 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 214 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 215 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 216 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 217 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 218 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 219 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
| 223 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 224 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False):
|
| 225 |
+
super().__init__()
|
| 226 |
+
self.dim = dim
|
| 227 |
+
self.input_resolution = input_resolution
|
| 228 |
+
self.num_heads = num_heads
|
| 229 |
+
self.window_size = window_size
|
| 230 |
+
self.shift_size = shift_size
|
| 231 |
+
self.mlp_ratio = mlp_ratio
|
| 232 |
+
self.use_checkpoint = use_checkpoint
|
| 233 |
+
|
| 234 |
+
if min(self.input_resolution) <= self.window_size:
|
| 235 |
+
# if window size is larger than input resolution, we don't partition windows
|
| 236 |
+
self.shift_size = 0
|
| 237 |
+
self.window_size = min(self.input_resolution)
|
| 238 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 239 |
+
|
| 240 |
+
self.norm1 = norm_layer(dim)
|
| 241 |
+
self.attn = WindowAttention(
|
| 242 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 243 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 244 |
+
|
| 245 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 246 |
+
self.norm2 = norm_layer(dim)
|
| 247 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 248 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 249 |
+
|
| 250 |
+
if self.shift_size > 0:
|
| 251 |
+
# calculate attention mask for SW-MSA
|
| 252 |
+
H, W = self.input_resolution
|
| 253 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
| 254 |
+
h_slices = (slice(0, -self.window_size),
|
| 255 |
+
slice(-self.window_size, -self.shift_size),
|
| 256 |
+
slice(-self.shift_size, None))
|
| 257 |
+
w_slices = (slice(0, -self.window_size),
|
| 258 |
+
slice(-self.window_size, -self.shift_size),
|
| 259 |
+
slice(-self.shift_size, None))
|
| 260 |
+
cnt = 0
|
| 261 |
+
for h in h_slices:
|
| 262 |
+
for w in w_slices:
|
| 263 |
+
img_mask[:, h, w, :] = cnt
|
| 264 |
+
cnt += 1
|
| 265 |
+
|
| 266 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 267 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 268 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 269 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 270 |
+
else:
|
| 271 |
+
attn_mask = None
|
| 272 |
+
|
| 273 |
+
self.register_buffer("attn_mask", attn_mask)
|
| 274 |
+
|
| 275 |
+
def forward_attn(self, x):
|
| 276 |
+
H, W = self.input_resolution
|
| 277 |
+
B, L, C = x.shape
|
| 278 |
+
assert L == H * W, "input feature has wrong size"
|
| 279 |
+
|
| 280 |
+
x = self.norm1(x)
|
| 281 |
+
x = x.view(B, H, W, C)
|
| 282 |
+
|
| 283 |
+
# cyclic shift
|
| 284 |
+
if self.shift_size > 0:
|
| 285 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 286 |
+
else:
|
| 287 |
+
shifted_x = x
|
| 288 |
+
|
| 289 |
+
# partition windows
|
| 290 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 291 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 292 |
+
|
| 293 |
+
# W-MSA/SW-MSA
|
| 294 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
| 295 |
+
|
| 296 |
+
# merge windows
|
| 297 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 298 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
| 299 |
+
|
| 300 |
+
# reverse cyclic shift
|
| 301 |
+
if self.shift_size > 0:
|
| 302 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 303 |
+
else:
|
| 304 |
+
x = shifted_x
|
| 305 |
+
x = x.view(B, H * W, C)
|
| 306 |
+
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
def forward_mlp(self, x):
|
| 310 |
+
return self.drop_path(self.mlp(self.norm2(x)))
|
| 311 |
+
|
| 312 |
+
def forward(self, x):
|
| 313 |
+
shortcut = x
|
| 314 |
+
if self.use_checkpoint:
|
| 315 |
+
x = checkpoint.checkpoint(self.forward_attn, x)
|
| 316 |
+
else:
|
| 317 |
+
x = self.forward_attn(x)
|
| 318 |
+
x = shortcut + self.drop_path(x)
|
| 319 |
+
|
| 320 |
+
if self.use_checkpoint:
|
| 321 |
+
x = x + checkpoint.checkpoint(self.forward_mlp, x)
|
| 322 |
+
else:
|
| 323 |
+
x = x + self.forward_mlp(x)
|
| 324 |
+
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
def extra_repr(self) -> str:
|
| 328 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
| 329 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
| 330 |
+
|
| 331 |
+
def flops(self):
|
| 332 |
+
flops = 0
|
| 333 |
+
H, W = self.input_resolution
|
| 334 |
+
# norm1
|
| 335 |
+
flops += self.dim * H * W
|
| 336 |
+
# W-MSA/SW-MSA
|
| 337 |
+
nW = H * W / self.window_size / self.window_size
|
| 338 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
| 339 |
+
# mlp
|
| 340 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
| 341 |
+
# norm2
|
| 342 |
+
flops += self.dim * H * W
|
| 343 |
+
return flops
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class PatchMerging(nn.Module):
|
| 347 |
+
r""" Patch Merging Layer.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
| 351 |
+
dim (int): Number of input channels.
|
| 352 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 353 |
+
"""
|
| 354 |
+
|
| 355 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.input_resolution = input_resolution
|
| 358 |
+
self.dim = dim
|
| 359 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 360 |
+
self.norm = norm_layer(4 * dim)
|
| 361 |
+
|
| 362 |
+
def forward(self, x):
|
| 363 |
+
"""
|
| 364 |
+
x: B, H*W, C
|
| 365 |
+
"""
|
| 366 |
+
H, W = self.input_resolution
|
| 367 |
+
B, L, C = x.shape
|
| 368 |
+
assert L == H * W, "input feature has wrong size"
|
| 369 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
| 370 |
+
|
| 371 |
+
x = x.view(B, H, W, C)
|
| 372 |
+
|
| 373 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 374 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 375 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 376 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 377 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 378 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 379 |
+
|
| 380 |
+
x = self.norm(x)
|
| 381 |
+
x = self.reduction(x)
|
| 382 |
+
|
| 383 |
+
return x
|
| 384 |
+
|
| 385 |
+
def extra_repr(self) -> str:
|
| 386 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
| 387 |
+
|
| 388 |
+
def flops(self):
|
| 389 |
+
H, W = self.input_resolution
|
| 390 |
+
flops = H * W * self.dim
|
| 391 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
| 392 |
+
return flops
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class BasicLayer(nn.Module):
|
| 396 |
+
""" A basic Swin Transformer layer for one stage.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
dim (int): Number of input channels.
|
| 400 |
+
input_resolution (tuple[int]): Input resolution.
|
| 401 |
+
depth (int): Number of blocks.
|
| 402 |
+
num_heads (int): Number of attention heads.
|
| 403 |
+
window_size (int): Local window size.
|
| 404 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 405 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 406 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 407 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 408 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 409 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 410 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 411 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 412 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 413 |
+
"""
|
| 414 |
+
|
| 415 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
| 416 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
| 417 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
| 418 |
+
bottleneck=False):
|
| 419 |
+
|
| 420 |
+
super().__init__()
|
| 421 |
+
self.dim = dim
|
| 422 |
+
self.input_resolution = input_resolution
|
| 423 |
+
self.depth = depth
|
| 424 |
+
self.use_checkpoint = use_checkpoint
|
| 425 |
+
|
| 426 |
+
# build blocks
|
| 427 |
+
self.blocks = nn.ModuleList([
|
| 428 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
| 429 |
+
num_heads=num_heads, window_size=window_size,
|
| 430 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 431 |
+
mlp_ratio=mlp_ratio,
|
| 432 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 433 |
+
drop=drop, attn_drop=attn_drop,
|
| 434 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 435 |
+
norm_layer=norm_layer,
|
| 436 |
+
bottleneck=bottleneck if i == depth-1 else False,
|
| 437 |
+
use_checkpoint=use_checkpoint)
|
| 438 |
+
for i in range(depth)])
|
| 439 |
+
|
| 440 |
+
# patch merging layer
|
| 441 |
+
if downsample is not None:
|
| 442 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
| 443 |
+
else:
|
| 444 |
+
self.downsample = None
|
| 445 |
+
|
| 446 |
+
def forward(self, x):
|
| 447 |
+
for blk in self.blocks:
|
| 448 |
+
if self.use_checkpoint:
|
| 449 |
+
x = checkpoint.checkpoint(blk, x)
|
| 450 |
+
else:
|
| 451 |
+
x = blk(x)
|
| 452 |
+
if self.downsample is not None:
|
| 453 |
+
x = self.downsample(x)
|
| 454 |
+
return x
|
| 455 |
+
|
| 456 |
+
def extra_repr(self) -> str:
|
| 457 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
| 458 |
+
|
| 459 |
+
def flops(self):
|
| 460 |
+
flops = 0
|
| 461 |
+
for blk in self.blocks:
|
| 462 |
+
flops += blk.flops()
|
| 463 |
+
if self.downsample is not None:
|
| 464 |
+
flops += self.downsample.flops()
|
| 465 |
+
return flops
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
class PatchEmbed(nn.Module):
|
| 469 |
+
r""" Image to Patch Embedding
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
img_size (int): Image size. Default: 224.
|
| 473 |
+
patch_size (int): Patch token size. Default: 4.
|
| 474 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 475 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 476 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
def __init__(self, img_size=(224,224), patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 480 |
+
super().__init__()
|
| 481 |
+
#img_size = to_2tuple(img_size)
|
| 482 |
+
patch_size = to_2tuple(patch_size)
|
| 483 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
| 484 |
+
self.img_size = img_size
|
| 485 |
+
self.patch_size = patch_size
|
| 486 |
+
self.patches_resolution = patches_resolution
|
| 487 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
| 488 |
+
|
| 489 |
+
self.in_chans = in_chans
|
| 490 |
+
self.embed_dim = embed_dim
|
| 491 |
+
|
| 492 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 493 |
+
if norm_layer is not None:
|
| 494 |
+
self.norm = norm_layer(embed_dim)
|
| 495 |
+
else:
|
| 496 |
+
self.norm = None
|
| 497 |
+
|
| 498 |
+
def forward(self, x):
|
| 499 |
+
B, C, H, W = x.shape
|
| 500 |
+
# FIXME look at relaxing size constraints
|
| 501 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 502 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 503 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
|
| 504 |
+
if self.norm is not None:
|
| 505 |
+
x = self.norm(x)
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
def flops(self):
|
| 509 |
+
Ho, Wo = self.patches_resolution
|
| 510 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 511 |
+
if self.norm is not None:
|
| 512 |
+
flops += Ho * Wo * self.embed_dim
|
| 513 |
+
return flops
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class SwinTransformer(nn.Module):
|
| 517 |
+
r""" Swin Transformer
|
| 518 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 519 |
+
https://arxiv.org/pdf/2103.14030
|
| 520 |
+
Args:
|
| 521 |
+
img_size (int | tuple(int)): Input image size. Default 224
|
| 522 |
+
patch_size (int | tuple(int)): Patch size. Default: 4
|
| 523 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 524 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 525 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
| 526 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
| 527 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
| 528 |
+
window_size (int): Window size. Default: 7
|
| 529 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
| 530 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 531 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
| 532 |
+
drop_rate (float): Dropout rate. Default: 0
|
| 533 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
| 534 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
| 535 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 536 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
| 537 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
| 538 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
| 539 |
+
"""
|
| 540 |
+
|
| 541 |
+
def __init__(self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
| 542 |
+
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
|
| 543 |
+
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
| 544 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 545 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
| 546 |
+
use_checkpoint=False, thumbnail_rows=1, bottleneck=False, **kwargs):
|
| 547 |
+
super().__init__()
|
| 548 |
+
|
| 549 |
+
self.duration = duration
|
| 550 |
+
self.num_classes = num_classes
|
| 551 |
+
self.num_layers = len(depths)
|
| 552 |
+
self.embed_dim = embed_dim
|
| 553 |
+
self.ape = ape
|
| 554 |
+
self.patch_norm = patch_norm
|
| 555 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 556 |
+
self.mlp_ratio = mlp_ratio
|
| 557 |
+
self.thumbnail_rows = thumbnail_rows
|
| 558 |
+
|
| 559 |
+
self.img_size = img_size
|
| 560 |
+
self.window_size = [window_size for _ in depths] if not isinstance(window_size, list) else window_size
|
| 561 |
+
self.image_mode = True
|
| 562 |
+
|
| 563 |
+
self.frame_padding = self.duration % thumbnail_rows if self.image_mode is True else 0
|
| 564 |
+
if self.frame_padding != 0:
|
| 565 |
+
self.frame_padding = self.thumbnail_rows - self.frame_padding
|
| 566 |
+
self.duration += self.frame_padding
|
| 567 |
+
|
| 568 |
+
# split image into non-overlapping patches
|
| 569 |
+
if self.image_mode:
|
| 570 |
+
thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows)
|
| 571 |
+
thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1])
|
| 572 |
+
else:
|
| 573 |
+
thumbnail_size = (img_size, img_size)
|
| 574 |
+
|
| 575 |
+
print ('---------------------------------------')
|
| 576 |
+
print ('duration:', self.duration, 'frame padding:', self.frame_padding, 'image_size:', self.img_size, 'patch_size:', patch_size, 'thumbnail_size:', (thumbnail_rows, self.duration//thumbnail_rows), 'ape:', self.ape)
|
| 577 |
+
|
| 578 |
+
self.patch_embed = PatchEmbed(
|
| 579 |
+
img_size=(img_size, img_size), patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 580 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 581 |
+
num_patches = self.patch_embed.num_patches
|
| 582 |
+
patches_resolution = self.patch_embed.patches_resolution
|
| 583 |
+
self.patches_resolution = patches_resolution
|
| 584 |
+
|
| 585 |
+
# absolute position embedding
|
| 586 |
+
if self.ape:
|
| 587 |
+
# self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 588 |
+
# trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 589 |
+
self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim))
|
| 590 |
+
trunc_normal_(self.frame_pos_embed, std=.02)
|
| 591 |
+
|
| 592 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 593 |
+
|
| 594 |
+
# stochastic depth
|
| 595 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 596 |
+
|
| 597 |
+
# build layers
|
| 598 |
+
self.layers = nn.ModuleList()
|
| 599 |
+
for i_layer in range(self.num_layers):
|
| 600 |
+
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
| 601 |
+
input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
| 602 |
+
patches_resolution[1] // (2 ** i_layer)),
|
| 603 |
+
depth=depths[i_layer],
|
| 604 |
+
num_heads=num_heads[i_layer],
|
| 605 |
+
window_size=self.window_size[i_layer],
|
| 606 |
+
mlp_ratio=self.mlp_ratio,
|
| 607 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 608 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
| 609 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 610 |
+
norm_layer=norm_layer,
|
| 611 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 612 |
+
use_checkpoint=use_checkpoint,
|
| 613 |
+
bottleneck=bottleneck)
|
| 614 |
+
self.layers.append(layer)
|
| 615 |
+
|
| 616 |
+
self.norm = norm_layer(self.num_features)
|
| 617 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 618 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 619 |
+
|
| 620 |
+
self.apply(self._init_weights)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def _init_weights(self, m):
|
| 624 |
+
if isinstance(m, nn.Linear):
|
| 625 |
+
trunc_normal_(m.weight, std=.02)
|
| 626 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 627 |
+
nn.init.constant_(m.bias, 0)
|
| 628 |
+
elif isinstance(m, nn.LayerNorm):
|
| 629 |
+
nn.init.constant_(m.bias, 0)
|
| 630 |
+
nn.init.constant_(m.weight, 1.0)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
@torch.jit.ignore
|
| 635 |
+
def no_weight_decay(self):
|
| 636 |
+
return {'absolute_pos_embed', 'frame_pos_embed'}
|
| 637 |
+
|
| 638 |
+
@torch.jit.ignore
|
| 639 |
+
def no_weight_decay_keywords(self):
|
| 640 |
+
return {'relative_position_bias_table'}
|
| 641 |
+
|
| 642 |
+
def create_thumbnail(self, x):
|
| 643 |
+
# import pdb;pdb.set_trace()
|
| 644 |
+
input_size = x.shape[-2:]
|
| 645 |
+
if input_size != to_2tuple(self.img_size):
|
| 646 |
+
x = nn.functional.interpolate(x, size=self.img_size,mode='bilinear')
|
| 647 |
+
x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', th=self.thumbnail_rows, c=3)
|
| 648 |
+
return x
|
| 649 |
+
|
| 650 |
+
def pad_frames(self, x):
|
| 651 |
+
frame_num = self.duration - self.frame_padding
|
| 652 |
+
x = x.view((-1,3*frame_num)+x.size()[2:])
|
| 653 |
+
x_padding = torch.zeros((x.shape[0], 3*self.frame_padding) + x.size()[2:]).cuda()
|
| 654 |
+
x = torch.cat((x, x_padding), dim=1)
|
| 655 |
+
assert x.shape[1] == 3 * self.duration, 'frame number %d not the same as adjusted input size %d' % (x.shape[1], 3 * self.duration)
|
| 656 |
+
|
| 657 |
+
return x
|
| 658 |
+
|
| 659 |
+
# need to find a better way to do this, maybe torch.fold?
|
| 660 |
+
def create_image_pos_embed(self):
|
| 661 |
+
img_rows, img_cols = self.patches_resolution
|
| 662 |
+
_, _, T = self.frame_pos_embed.shape
|
| 663 |
+
rows = img_rows // self.thumbnail_rows
|
| 664 |
+
cols = img_cols // (self.duration // self.thumbnail_rows)
|
| 665 |
+
img_pos_embed = torch.zeros(img_rows, img_cols, T).cuda()
|
| 666 |
+
#print (self.duration, T, img_rows, img_cols, rows, cols)
|
| 667 |
+
for i in range(self.duration):
|
| 668 |
+
r_indx = (i // self.thumbnail_rows) * rows
|
| 669 |
+
c_indx = (i % self.thumbnail_rows) * cols
|
| 670 |
+
img_pos_embed[r_indx:r_indx+rows,c_indx:c_indx+cols] = self.frame_pos_embed[0, i]
|
| 671 |
+
#print (r_indx, r_indx+rows, c_indx, c_indx+cols)
|
| 672 |
+
return img_pos_embed.reshape(-1, T)
|
| 673 |
+
|
| 674 |
+
def forward_features(self, x):
|
| 675 |
+
# x = rearrange(x, 'b (t c) h w -> b c h (t w)', t=self.duration)
|
| 676 |
+
# in evaluation, it's Bx(num_crops*num_cips*num_frames*3)xHxW
|
| 677 |
+
# import pdb;pdb.set_trace()
|
| 678 |
+
b, t, _, h, w = x.shape
|
| 679 |
+
x = x.view(b, t*3, h, w)
|
| 680 |
+
if self.frame_padding > 0:
|
| 681 |
+
x = self.pad_frames(x)
|
| 682 |
+
else:
|
| 683 |
+
x = x.view((-1,3*self.duration)+x.size()[2:])
|
| 684 |
+
|
| 685 |
+
if self.image_mode:
|
| 686 |
+
x = self.create_thumbnail(x)
|
| 687 |
+
x = nn.functional.interpolate(x, size=self.img_size,mode='bilinear')
|
| 688 |
+
else:
|
| 689 |
+
x = rearrange(x, 'b (n t c) h w -> (b n t) c h w', t=self.duration, c=3)
|
| 690 |
+
|
| 691 |
+
x = self.patch_embed(x)
|
| 692 |
+
if self.ape:
|
| 693 |
+
# x = x + self.absolute_pos_embed
|
| 694 |
+
img_pos_embed = self.create_image_pos_embed()
|
| 695 |
+
x = x + img_pos_embed
|
| 696 |
+
x = self.pos_drop(x)
|
| 697 |
+
|
| 698 |
+
for layer in self.layers:
|
| 699 |
+
x = layer(x)
|
| 700 |
+
|
| 701 |
+
x = self.norm(x) # B L C
|
| 702 |
+
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
| 703 |
+
x = torch.flatten(x, 1)
|
| 704 |
+
return x
|
| 705 |
+
|
| 706 |
+
def forward(self, x):
|
| 707 |
+
x = self.forward_features(x)
|
| 708 |
+
x = self.head(x)
|
| 709 |
+
if not self.image_mode:
|
| 710 |
+
x = x.view(-1, self.duration, self.num_classes)
|
| 711 |
+
x = torch.mean(x, dim=1)
|
| 712 |
+
return x
|
| 713 |
+
|
| 714 |
+
def flops(self):
|
| 715 |
+
flops = 0
|
| 716 |
+
flops += self.patch_embed.flops()
|
| 717 |
+
for i, layer in enumerate(self.layers):
|
| 718 |
+
flops += layer.flops()
|
| 719 |
+
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
|
| 720 |
+
flops += self.num_features * self.num_classes
|
| 721 |
+
return flops
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_patches=196,
|
| 725 |
+
pretrained_window_size=7, pretrained_model="", strict=True):
|
| 726 |
+
if cfg is None:
|
| 727 |
+
cfg = getattr(model, 'default_cfg')
|
| 728 |
+
if cfg is None or 'url' not in cfg or not cfg['url']:
|
| 729 |
+
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
| 730 |
+
return
|
| 731 |
+
|
| 732 |
+
if len(pretrained_model) == 0:
|
| 733 |
+
# state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
| 734 |
+
# state_dict = load_state_dict_from_url(cfg['url'], progress=False, map_location='cpu')
|
| 735 |
+
state_dict = torch.load('/mnt/new_nas/yansan/models/swin_base_patch4_window7_224_22k.pth', map_location='cpu')
|
| 736 |
+
else:
|
| 737 |
+
try:
|
| 738 |
+
state_dict = load_state_dict(pretrained_model)['model']
|
| 739 |
+
except:
|
| 740 |
+
state_dict = load_state_dict(pretrained_model)
|
| 741 |
+
|
| 742 |
+
if filter_fn is not None:
|
| 743 |
+
state_dict = filter_fn(state_dict)
|
| 744 |
+
|
| 745 |
+
if in_chans == 1:
|
| 746 |
+
conv1_name = cfg['first_conv']
|
| 747 |
+
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
|
| 748 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
| 749 |
+
conv1_type = conv1_weight.dtype
|
| 750 |
+
conv1_weight = conv1_weight.float()
|
| 751 |
+
O, I, J, K = conv1_weight.shape
|
| 752 |
+
if I > 3:
|
| 753 |
+
assert conv1_weight.shape[1] % 3 == 0
|
| 754 |
+
# For models with space2depth stems
|
| 755 |
+
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
|
| 756 |
+
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
|
| 757 |
+
else:
|
| 758 |
+
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
|
| 759 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
| 760 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
| 761 |
+
elif in_chans != 3:
|
| 762 |
+
conv1_name = cfg['first_conv']
|
| 763 |
+
conv1_weight = state_dict[conv1_name + '.weight']
|
| 764 |
+
conv1_type = conv1_weight.dtype
|
| 765 |
+
conv1_weight = conv1_weight.float()
|
| 766 |
+
O, I, J, K = conv1_weight.shape
|
| 767 |
+
if I != 3:
|
| 768 |
+
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
|
| 769 |
+
del state_dict[conv1_name + '.weight']
|
| 770 |
+
strict = False
|
| 771 |
+
else:
|
| 772 |
+
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
|
| 773 |
+
repeat = int(math.ceil(in_chans / 3))
|
| 774 |
+
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
| 775 |
+
conv1_weight *= (3 / float(in_chans))
|
| 776 |
+
conv1_weight = conv1_weight.to(conv1_type)
|
| 777 |
+
state_dict[conv1_name + '.weight'] = conv1_weight
|
| 778 |
+
|
| 779 |
+
#for key, value in state_dict['model'].items():
|
| 780 |
+
# print (key)
|
| 781 |
+
|
| 782 |
+
classifier_name = cfg['classifier']
|
| 783 |
+
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
| 784 |
+
# special case for imagenet trained models with extra background class in pretrained weights
|
| 785 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
| 786 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
| 787 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
| 788 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
| 789 |
+
elif num_classes != cfg['num_classes']: # and len(pretrained_model) == 0:
|
| 790 |
+
# completely discard fully connected for all other differences between pretrained and created model
|
| 791 |
+
del state_dict['model'][classifier_name + '.weight']
|
| 792 |
+
del state_dict['model'][classifier_name + '.bias']
|
| 793 |
+
strict = False
|
| 794 |
+
'''
|
| 795 |
+
## Resizing the positional embeddings in case they don't match
|
| 796 |
+
if img_size != cfg['input_size'][1]:
|
| 797 |
+
pos_embed = state_dict['pos_embed']
|
| 798 |
+
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
|
| 799 |
+
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
|
| 800 |
+
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
|
| 801 |
+
new_pos_embed = new_pos_embed.transpose(1, 2)
|
| 802 |
+
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
|
| 803 |
+
state_dict['pos_embed'] = new_pos_embed
|
| 804 |
+
'''
|
| 805 |
+
|
| 806 |
+
# remove window_size related parameters
|
| 807 |
+
window_size = (model.window_size)[0]
|
| 808 |
+
print (pretrained_window_size, window_size)
|
| 809 |
+
|
| 810 |
+
new_state_dict = state_dict['model'].copy()
|
| 811 |
+
for key in state_dict['model']:
|
| 812 |
+
if 'attn_mask' in key:
|
| 813 |
+
del new_state_dict[key]
|
| 814 |
+
|
| 815 |
+
#if window_size != pretrained_window_size:
|
| 816 |
+
if 1:
|
| 817 |
+
if 'relative_position_index' in key:
|
| 818 |
+
del new_state_dict[key]
|
| 819 |
+
|
| 820 |
+
# resize it
|
| 821 |
+
if 'relative_position_bias_table' in key:
|
| 822 |
+
#print ('resizing relative_position_bias_table')
|
| 823 |
+
pretrained_table = state_dict['model'][key]
|
| 824 |
+
pretrained_table_size = int(math.sqrt(pretrained_table.shape[0]))
|
| 825 |
+
table_size = int(math.sqrt(model.state_dict()[key].shape[0]))
|
| 826 |
+
#print (pretrained_table_size, table_size)
|
| 827 |
+
if pretrained_table_size != table_size:
|
| 828 |
+
table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size)
|
| 829 |
+
table = nn.functional.interpolate(table, size=table_size, mode='bilinear')
|
| 830 |
+
table = table.view(-1, table_size*table_size).permute(1, 0)
|
| 831 |
+
new_state_dict[key] = table
|
| 832 |
+
|
| 833 |
+
for key in model.state_dict():
|
| 834 |
+
if 'bottleneck_norm' in key:
|
| 835 |
+
attn_key = key.replace('bottleneck_norm','norm1')
|
| 836 |
+
#print (key, attn_key)
|
| 837 |
+
new_state_dict[key] = new_state_dict[attn_key]
|
| 838 |
+
|
| 839 |
+
'''
|
| 840 |
+
for key in new_state_dict:
|
| 841 |
+
if key not in model.state_dict():
|
| 842 |
+
print ('----', key)
|
| 843 |
+
else:
|
| 844 |
+
print ('++++', key)
|
| 845 |
+
print ('====================')
|
| 846 |
+
for key in model.state_dict():
|
| 847 |
+
if key not in new_state_dict:
|
| 848 |
+
print ('----', key)
|
| 849 |
+
else:
|
| 850 |
+
print ('++++', key)
|
| 851 |
+
'''
|
| 852 |
+
|
| 853 |
+
print ('loading weights....')
|
| 854 |
+
## Loading the weights
|
| 855 |
+
model_dict = model.state_dict()
|
| 856 |
+
pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
|
| 857 |
+
model_dict.update(pretrained_dict)
|
| 858 |
+
model.load_state_dict(model_dict, strict=False)
|
| 859 |
+
|
| 860 |
+
def _conv_filter(state_dict, patch_size=4):
|
| 861 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
| 862 |
+
out_dict = {}
|
| 863 |
+
for k, v in state_dict.items():
|
| 864 |
+
if 'patch_embed.proj.weight' in k:
|
| 865 |
+
if v.shape[-1] != patch_size:
|
| 866 |
+
patch_size = v.shape[-1]
|
| 867 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
| 868 |
+
out_dict[k] = v
|
| 869 |
+
return out_dict
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def _create_vision_transformer(variant, pretrained=False, pretrained_window_size=7, **kwargs):
|
| 873 |
+
default_cfg = default_cfgs[variant]
|
| 874 |
+
default_num_classes = default_cfg['num_classes']
|
| 875 |
+
default_img_size = default_cfg['input_size'][-1]
|
| 876 |
+
|
| 877 |
+
num_classes = kwargs.pop('num_classes', default_num_classes)
|
| 878 |
+
img_size = kwargs.pop('img_size', default_img_size)
|
| 879 |
+
repr_size = kwargs.pop('representation_size', None)
|
| 880 |
+
|
| 881 |
+
model_cls = SwinTransformer
|
| 882 |
+
model = model_cls(img_size=img_size, num_classes=num_classes, **kwargs)
|
| 883 |
+
model.default_cfg = default_cfg
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
if pretrained:
|
| 887 |
+
load_pretrained(
|
| 888 |
+
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
|
| 889 |
+
filter_fn=_conv_filter,
|
| 890 |
+
img_size=img_size,
|
| 891 |
+
pretrained_window_size=pretrained_window_size,
|
| 892 |
+
pretrained_model=''
|
| 893 |
+
)
|
| 894 |
+
return model
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
@register_model
|
| 899 |
+
def TALL_SWIN(pretrained=False, **kwargs):
|
| 900 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
| 901 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
| 902 |
+
"""
|
| 903 |
+
temporal_module_name = kwargs.pop('temporal_module_name', None)
|
| 904 |
+
temporal_attention_only = kwargs.pop('temporal_attention_only', None)
|
| 905 |
+
temporal_heads_scale = kwargs.pop('temporal_heads_scale', 1.0)
|
| 906 |
+
temporal_mlp_scale = kwargs.pop('temporal_mlp_scale', 1.0)
|
| 907 |
+
rel_pos = kwargs.pop('rel_pos', False)
|
| 908 |
+
token_maks = kwargs.pop('token_mask', False)
|
| 909 |
+
frame_cls_tokens = kwargs.pop('frame_cls_tokens', 1)
|
| 910 |
+
kwargs.pop('hub_attention', '')
|
| 911 |
+
kwargs.pop('hub_aggregation', '')
|
| 912 |
+
kwargs.pop('spatial_hub_size', (-1, -1))
|
| 913 |
+
kwargs.pop('temporal_pooling', None)
|
| 914 |
+
kwargs.pop('window_size', -1)
|
| 915 |
+
|
| 916 |
+
embed_dim = 128
|
| 917 |
+
mlp_ratio = 4.
|
| 918 |
+
#drop_path_rate=0.5
|
| 919 |
+
patch_size=4
|
| 920 |
+
window_size=[14,14,14,7]
|
| 921 |
+
depths = [2, 2, 18, 2]
|
| 922 |
+
num_heads = [4, 8, 16, 32]
|
| 923 |
+
use_checkpoint=kwargs.pop('use_checkpoint', False)
|
| 924 |
+
ape = kwargs.pop('hpe_to_token', False)
|
| 925 |
+
bottleneck = True if kwargs.pop('bottleneck', None) is not None else False
|
| 926 |
+
model_kwargs = dict(patch_size=patch_size, window_size=window_size, embed_dim=embed_dim, depths=depths, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
| 927 |
+
use_checkpoint=use_checkpoint, ape=ape, bottleneck=bottleneck, **kwargs)
|
| 928 |
+
print(model_kwargs)
|
| 929 |
+
model = _create_vision_transformer('swin_base_patch4_window7_224_22k', pretrained=pretrained, pretrained_window_size=7, **model_kwargs)
|
| 930 |
+
return model
|
| 931 |
+
|
| 932 |
+
if __name__ == '__main__':
|
| 933 |
+
dummy_input = torch.randn(4,8,3,224,224)
|
| 934 |
+
model = TALL_SWIN(pretrained=True)
|
| 935 |
+
print(model(dummy_input))
|
models/VideoMAE.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision
|
| 5 |
+
import time
|
| 6 |
+
from .mamba_base import MambaConfig, ResidualBlock
|
| 7 |
+
|
| 8 |
+
def create_reorder_index(N, device):
|
| 9 |
+
new_order = []
|
| 10 |
+
for col in range(N):
|
| 11 |
+
if col % 2 == 0:
|
| 12 |
+
new_order.extend(range(col, N*N, N))
|
| 13 |
+
else:
|
| 14 |
+
new_order.extend(range(col + N*(N-1), col-1, -N))
|
| 15 |
+
return torch.tensor(new_order, device=device)
|
| 16 |
+
|
| 17 |
+
def reorder_data(data, N):
|
| 18 |
+
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
|
| 19 |
+
device = data.device
|
| 20 |
+
new_order = create_reorder_index(N, device)
|
| 21 |
+
B, t, _, _ = data.shape
|
| 22 |
+
index = new_order.repeat(B, t, 1).unsqueeze(-1)
|
| 23 |
+
reordered_data = torch.gather(data, 2, index.expand_as(data))
|
| 24 |
+
return reordered_data
|
| 25 |
+
|
| 26 |
+
class Videomae_Net(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self, channel_size=512, dropout=0.2, class_num=1
|
| 29 |
+
):
|
| 30 |
+
super(Videomae_Net, self).__init__()
|
| 31 |
+
self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae")
|
| 32 |
+
self.fc1 = nn.Linear(768, class_num)
|
| 33 |
+
self.bn1 = nn.BatchNorm1d(768)
|
| 34 |
+
|
| 35 |
+
self._init_params()
|
| 36 |
+
|
| 37 |
+
def _init_params(self):
|
| 38 |
+
nn.init.xavier_normal_(self.fc1.weight)
|
| 39 |
+
nn.init.constant_(self.fc1.bias, 0)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
x = self.model.videomae(x)
|
| 43 |
+
sequence_output = x[0]
|
| 44 |
+
print(sequence_output.shape)
|
| 45 |
+
if self.model.fc_norm is not None:
|
| 46 |
+
sequence_output = self.model.fc_norm(sequence_output.mean(1))
|
| 47 |
+
else:
|
| 48 |
+
sequence_output = sequence_output[:, 0]
|
| 49 |
+
x = self.bn1(sequence_output)
|
| 50 |
+
x = self.fc1(x)
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
|
| 57 |
+
model = Videomae_Net()
|
| 58 |
+
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
model = model.to(device)
|
| 61 |
+
model.eval()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
input_data = torch.randn(1, 16, 3, 224, 224).to(device)
|
| 65 |
+
|
| 66 |
+
model(input_data)
|
| 67 |
+
|
models/XCLIP.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import XCLIPVisionModel
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.nn.init as init
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
from transformers import XCLIPVisionModel
|
| 13 |
+
class XCLIP(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self, channel_size=512, dropout=0.2, class_num=1
|
| 16 |
+
):
|
| 17 |
+
super(XCLIP, self).__init__()
|
| 18 |
+
|
| 19 |
+
self.backbone = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
|
| 20 |
+
self.fc_norm = nn.LayerNorm(768)
|
| 21 |
+
self.head = nn.Linear(768, 1)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
b, t, _, h, w = x.shape
|
| 25 |
+
images = x.view(b * t, 3, h, w)
|
| 26 |
+
outputs = self.backbone(images, output_hidden_states=True)
|
| 27 |
+
sequence_output = outputs['pooler_output'].reshape(b, t, -1)
|
| 28 |
+
video_level_features = self.fc_norm(sequence_output.mean(1))
|
| 29 |
+
pred = self.head(video_level_features)
|
| 30 |
+
|
| 31 |
+
return pred
|
| 32 |
+
|
| 33 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .F3Net import Det_F3_Net
|
| 2 |
+
from .NPR import resnet50_npr
|
| 3 |
+
from .STIL import Det_STIL
|
| 4 |
+
from .DeMamba import XCLIP_DeMamba, CLIP_DeMamba
|
| 5 |
+
from .XCLIP import XCLIP
|
models/__pycache__/DeMamba.cpython-39.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
models/__pycache__/F3Net.cpython-39.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
models/__pycache__/NPR.cpython-39.pyc
ADDED
|
Binary file (7.58 kB). View file
|
|
|
models/__pycache__/STIL.cpython-39.pyc
ADDED
|
Binary file (17.7 kB). View file
|
|
|
models/__pycache__/XCLIP.cpython-39.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (396 Bytes). View file
|
|
|
models/__pycache__/mamba_base.cpython-39.pyc
ADDED
|
Binary file (7.53 kB). View file
|
|
|
models/__pycache__/pscan.cpython-39.pyc
ADDED
|
Binary file (5.62 kB). View file
|
|
|
models/clip/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .clip import *
|
models/clip/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (209 Bytes). View file
|
|
|
models/clip/__pycache__/clip.cpython-39.pyc
ADDED
|
Binary file (8.21 kB). View file
|
|
|
models/clip/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
models/clip/__pycache__/simple_tokenizer.cpython-39.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
models/clip/clip.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Any, Union, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from .model import build_model
|
| 13 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from torchvision.transforms import InterpolationMode
|
| 17 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 18 |
+
except ImportError:
|
| 19 |
+
BICUBIC = Image.BICUBIC
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if torch.__version__.split(".") < ["1", "7", "1"]:
|
| 23 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 27 |
+
_tokenizer = _Tokenizer()
|
| 28 |
+
|
| 29 |
+
_MODELS = {
|
| 30 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 31 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 32 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 33 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 34 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 35 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 36 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 37 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _download(url: str, root: str):
|
| 43 |
+
os.makedirs(root, exist_ok=True)
|
| 44 |
+
filename = os.path.basename(url)
|
| 45 |
+
|
| 46 |
+
expected_sha256 = url.split("/")[-2]
|
| 47 |
+
download_target = os.path.join(root, filename)
|
| 48 |
+
|
| 49 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 50 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 51 |
+
|
| 52 |
+
if os.path.isfile(download_target):
|
| 53 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 54 |
+
return download_target
|
| 55 |
+
else:
|
| 56 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 57 |
+
|
| 58 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 59 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 60 |
+
while True:
|
| 61 |
+
buffer = source.read(8192)
|
| 62 |
+
if not buffer:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
output.write(buffer)
|
| 66 |
+
loop.update(len(buffer))
|
| 67 |
+
|
| 68 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 69 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
| 70 |
+
|
| 71 |
+
return download_target
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _convert_image_to_rgb(image):
|
| 75 |
+
return image.convert("RGB")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _transform(n_px):
|
| 79 |
+
return Compose([
|
| 80 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 81 |
+
CenterCrop(n_px),
|
| 82 |
+
_convert_image_to_rgb,
|
| 83 |
+
ToTensor(),
|
| 84 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def available_models() -> List[str]:
|
| 89 |
+
"""Returns the names of available CLIP models"""
|
| 90 |
+
return list(_MODELS.keys())
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
| 94 |
+
"""Load a CLIP model
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
name : str
|
| 99 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 100 |
+
|
| 101 |
+
device : Union[str, torch.device]
|
| 102 |
+
The device to put the loaded model
|
| 103 |
+
|
| 104 |
+
jit : bool
|
| 105 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 106 |
+
|
| 107 |
+
download_root: str
|
| 108 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
| 109 |
+
|
| 110 |
+
Returns
|
| 111 |
+
-------
|
| 112 |
+
model : torch.nn.Module
|
| 113 |
+
The CLIP model
|
| 114 |
+
|
| 115 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 116 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 117 |
+
"""
|
| 118 |
+
# if name in _MODELS:
|
| 119 |
+
# model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 120 |
+
# elif os.path.isfile(name):
|
| 121 |
+
# model_path = name
|
| 122 |
+
# else:
|
| 123 |
+
# raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 124 |
+
|
| 125 |
+
model_path = 'weights/openclip/' + name +'.pt'
|
| 126 |
+
print(model_path)
|
| 127 |
+
try:
|
| 128 |
+
# loading JIT archive
|
| 129 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
| 130 |
+
state_dict = None
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
# loading saved state dict
|
| 133 |
+
if jit:
|
| 134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 135 |
+
jit = False
|
| 136 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 137 |
+
|
| 138 |
+
if not jit:
|
| 139 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
| 140 |
+
if str(device) == "cpu":
|
| 141 |
+
model.float()
|
| 142 |
+
return model, _transform(model.visual.input_resolution)
|
| 143 |
+
|
| 144 |
+
# patch the device names
|
| 145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 147 |
+
|
| 148 |
+
def patch_device(module):
|
| 149 |
+
try:
|
| 150 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 151 |
+
except RuntimeError:
|
| 152 |
+
graphs = []
|
| 153 |
+
|
| 154 |
+
if hasattr(module, "forward1"):
|
| 155 |
+
graphs.append(module.forward1.graph)
|
| 156 |
+
|
| 157 |
+
for graph in graphs:
|
| 158 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 159 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 160 |
+
node.copyAttributes(device_node)
|
| 161 |
+
|
| 162 |
+
model.apply(patch_device)
|
| 163 |
+
patch_device(model.encode_image)
|
| 164 |
+
patch_device(model.encode_text)
|
| 165 |
+
|
| 166 |
+
# patch dtype to float32 on CPU
|
| 167 |
+
if str(device) == "cpu":
|
| 168 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 169 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 170 |
+
float_node = float_input.node()
|
| 171 |
+
|
| 172 |
+
def patch_float(module):
|
| 173 |
+
try:
|
| 174 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 175 |
+
except RuntimeError:
|
| 176 |
+
graphs = []
|
| 177 |
+
|
| 178 |
+
if hasattr(module, "forward1"):
|
| 179 |
+
graphs.append(module.forward1.graph)
|
| 180 |
+
|
| 181 |
+
for graph in graphs:
|
| 182 |
+
for node in graph.findAllNodes("aten::to"):
|
| 183 |
+
inputs = list(node.inputs())
|
| 184 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 185 |
+
if inputs[i].node()["value"] == 5:
|
| 186 |
+
inputs[i].node().copyAttributes(float_node)
|
| 187 |
+
|
| 188 |
+
model.apply(patch_float)
|
| 189 |
+
patch_float(model.encode_image)
|
| 190 |
+
patch_float(model.encode_text)
|
| 191 |
+
|
| 192 |
+
model.float()
|
| 193 |
+
|
| 194 |
+
return model, _transform(model.input_resolution.item())
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
| 198 |
+
"""
|
| 199 |
+
Returns the tokenized representation of given input string(s)
|
| 200 |
+
|
| 201 |
+
Parameters
|
| 202 |
+
----------
|
| 203 |
+
texts : Union[str, List[str]]
|
| 204 |
+
An input string or a list of input strings to tokenize
|
| 205 |
+
|
| 206 |
+
context_length : int
|
| 207 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 208 |
+
|
| 209 |
+
truncate: bool
|
| 210 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
| 211 |
+
|
| 212 |
+
Returns
|
| 213 |
+
-------
|
| 214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
| 215 |
+
"""
|
| 216 |
+
if isinstance(texts, str):
|
| 217 |
+
texts = [texts]
|
| 218 |
+
|
| 219 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 220 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 221 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 222 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 223 |
+
|
| 224 |
+
for i, tokens in enumerate(all_tokens):
|
| 225 |
+
if len(tokens) > context_length:
|
| 226 |
+
if truncate:
|
| 227 |
+
tokens = tokens[:context_length]
|
| 228 |
+
tokens[-1] = eot_token
|
| 229 |
+
else:
|
| 230 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 231 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 232 |
+
|
| 233 |
+
return result
|
models/clip/model.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
expansion = 4
|
| 12 |
+
|
| 13 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
+
|
| 20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 22 |
+
|
| 23 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 24 |
+
|
| 25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 27 |
+
|
| 28 |
+
self.relu = nn.ReLU(inplace=True)
|
| 29 |
+
self.downsample = None
|
| 30 |
+
self.stride = stride
|
| 31 |
+
|
| 32 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 33 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 34 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 35 |
+
("-1", nn.AvgPool2d(stride)),
|
| 36 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 37 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 38 |
+
]))
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor):
|
| 41 |
+
identity = x
|
| 42 |
+
|
| 43 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 44 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 45 |
+
out = self.avgpool(out)
|
| 46 |
+
out = self.bn3(self.conv3(out))
|
| 47 |
+
|
| 48 |
+
if self.downsample is not None:
|
| 49 |
+
identity = self.downsample(x)
|
| 50 |
+
|
| 51 |
+
out += identity
|
| 52 |
+
out = self.relu(out)
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AttentionPool2d(nn.Module):
|
| 57 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 60 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 61 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 62 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 68 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 69 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 70 |
+
x, _ = F.multi_head_attention_forward(
|
| 71 |
+
query=x, key=x, value=x,
|
| 72 |
+
embed_dim_to_check=x.shape[-1],
|
| 73 |
+
num_heads=self.num_heads,
|
| 74 |
+
q_proj_weight=self.q_proj.weight,
|
| 75 |
+
k_proj_weight=self.k_proj.weight,
|
| 76 |
+
v_proj_weight=self.v_proj.weight,
|
| 77 |
+
in_proj_weight=None,
|
| 78 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 79 |
+
bias_k=None,
|
| 80 |
+
bias_v=None,
|
| 81 |
+
add_zero_attn=False,
|
| 82 |
+
dropout_p=0,
|
| 83 |
+
out_proj_weight=self.c_proj.weight,
|
| 84 |
+
out_proj_bias=self.c_proj.bias,
|
| 85 |
+
use_separate_proj_weight=True,
|
| 86 |
+
training=self.training,
|
| 87 |
+
need_weights=False
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return x[0]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ModifiedResNet(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.output_dim = output_dim
|
| 104 |
+
self.input_resolution = input_resolution
|
| 105 |
+
|
| 106 |
+
# the 3-layer stem
|
| 107 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 109 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 110 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 111 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 112 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 113 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 114 |
+
self.relu = nn.ReLU(inplace=True)
|
| 115 |
+
|
| 116 |
+
# residual layers
|
| 117 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 118 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 119 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 120 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 121 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 122 |
+
|
| 123 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 124 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 125 |
+
|
| 126 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 127 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 128 |
+
|
| 129 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 130 |
+
for _ in range(1, blocks):
|
| 131 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 132 |
+
|
| 133 |
+
return nn.Sequential(*layers)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
def stem(x):
|
| 137 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
|
| 138 |
+
x = self.relu(bn(conv(x)))
|
| 139 |
+
x = self.avgpool(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
x = x.type(self.conv1.weight.dtype)
|
| 143 |
+
x = stem(x)
|
| 144 |
+
x = self.layer1(x)
|
| 145 |
+
x = self.layer2(x)
|
| 146 |
+
x = self.layer3(x)
|
| 147 |
+
x = self.layer4(x)
|
| 148 |
+
x = self.attnpool(x)
|
| 149 |
+
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LayerNorm(nn.LayerNorm):
|
| 154 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 155 |
+
|
| 156 |
+
def forward(self, x: torch.Tensor):
|
| 157 |
+
orig_type = x.dtype
|
| 158 |
+
ret = super().forward(x.type(torch.float32))
|
| 159 |
+
return ret.type(orig_type)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class QuickGELU(nn.Module):
|
| 163 |
+
def forward(self, x: torch.Tensor):
|
| 164 |
+
return x * torch.sigmoid(1.702 * x)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class ResidualAttentionBlock(nn.Module):
|
| 168 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 172 |
+
self.ln_1 = LayerNorm(d_model)
|
| 173 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 174 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 175 |
+
("gelu", QuickGELU()),
|
| 176 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 177 |
+
]))
|
| 178 |
+
self.ln_2 = LayerNorm(d_model)
|
| 179 |
+
self.attn_mask = attn_mask
|
| 180 |
+
|
| 181 |
+
def attention(self, x: torch.Tensor):
|
| 182 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 183 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 184 |
+
|
| 185 |
+
def forward(self, x: torch.Tensor):
|
| 186 |
+
x = x + self.attention(self.ln_1(x))
|
| 187 |
+
x = x + self.mlp(self.ln_2(x))
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class Transformer(nn.Module):
|
| 192 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.width = width
|
| 195 |
+
self.layers = layers
|
| 196 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor):
|
| 199 |
+
return self.resblocks(x)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class VisionTransformer(nn.Module):
|
| 203 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.input_resolution = input_resolution
|
| 206 |
+
self.output_dim = output_dim
|
| 207 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 208 |
+
|
| 209 |
+
scale = width ** -0.5
|
| 210 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 211 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 212 |
+
self.ln_pre = LayerNorm(width)
|
| 213 |
+
|
| 214 |
+
self.transformer = Transformer(width, layers, heads)
|
| 215 |
+
|
| 216 |
+
self.ln_post = LayerNorm(width)
|
| 217 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 218 |
+
|
| 219 |
+
def forward(self, x: torch.Tensor):
|
| 220 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 221 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 222 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 223 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 224 |
+
x = x + self.positional_embedding.to(x.dtype)
|
| 225 |
+
x = self.ln_pre(x)
|
| 226 |
+
|
| 227 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 228 |
+
x = self.transformer(x)
|
| 229 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 230 |
+
|
| 231 |
+
x = self.ln_post(x[:, 1:, :])
|
| 232 |
+
|
| 233 |
+
if self.proj is not None:
|
| 234 |
+
x = x @ self.proj
|
| 235 |
+
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CLIP(nn.Module):
|
| 240 |
+
def __init__(self,
|
| 241 |
+
embed_dim: int,
|
| 242 |
+
# vision
|
| 243 |
+
image_resolution: int,
|
| 244 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 245 |
+
vision_width: int,
|
| 246 |
+
vision_patch_size: int,
|
| 247 |
+
# text
|
| 248 |
+
context_length: int,
|
| 249 |
+
vocab_size: int,
|
| 250 |
+
transformer_width: int,
|
| 251 |
+
transformer_heads: int,
|
| 252 |
+
transformer_layers: int
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.context_length = context_length
|
| 257 |
+
|
| 258 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 259 |
+
vision_heads = vision_width * 32 // 64
|
| 260 |
+
self.visual = ModifiedResNet(
|
| 261 |
+
layers=vision_layers,
|
| 262 |
+
output_dim=embed_dim,
|
| 263 |
+
heads=vision_heads,
|
| 264 |
+
input_resolution=image_resolution,
|
| 265 |
+
width=vision_width
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
vision_heads = vision_width // 64
|
| 269 |
+
self.visual = VisionTransformer(
|
| 270 |
+
input_resolution=image_resolution,
|
| 271 |
+
patch_size=vision_patch_size,
|
| 272 |
+
width=vision_width,
|
| 273 |
+
layers=vision_layers,
|
| 274 |
+
heads=vision_heads,
|
| 275 |
+
output_dim=embed_dim
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
self.transformer = Transformer(
|
| 279 |
+
width=transformer_width,
|
| 280 |
+
layers=transformer_layers,
|
| 281 |
+
heads=transformer_heads,
|
| 282 |
+
attn_mask=self.build_attention_mask()
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.vocab_size = vocab_size
|
| 286 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 287 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 288 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 289 |
+
|
| 290 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 291 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 292 |
+
|
| 293 |
+
self.initialize_parameters()
|
| 294 |
+
|
| 295 |
+
def initialize_parameters(self):
|
| 296 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 297 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 298 |
+
|
| 299 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 300 |
+
if self.visual.attnpool is not None:
|
| 301 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 302 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 303 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 304 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 305 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 306 |
+
|
| 307 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 308 |
+
for name, param in resnet_block.named_parameters():
|
| 309 |
+
if name.endswith("bn3.weight"):
|
| 310 |
+
nn.init.zeros_(param)
|
| 311 |
+
|
| 312 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 313 |
+
attn_std = self.transformer.width ** -0.5
|
| 314 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 315 |
+
for block in self.transformer.resblocks:
|
| 316 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 317 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 318 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 319 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 320 |
+
|
| 321 |
+
if self.text_projection is not None:
|
| 322 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 323 |
+
|
| 324 |
+
def build_attention_mask(self):
|
| 325 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 326 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 327 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 328 |
+
mask.fill_(float("-inf"))
|
| 329 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 330 |
+
return mask
|
| 331 |
+
|
| 332 |
+
@property
|
| 333 |
+
def dtype(self):
|
| 334 |
+
return self.visual.conv1.weight.dtype
|
| 335 |
+
|
| 336 |
+
def encode_image(self, image):
|
| 337 |
+
return self.visual(image.type(self.dtype))
|
| 338 |
+
|
| 339 |
+
def encode_text(self, text):
|
| 340 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 341 |
+
|
| 342 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 343 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 344 |
+
x = self.transformer(x)
|
| 345 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 346 |
+
x = self.ln_final(x).type(self.dtype)
|
| 347 |
+
|
| 348 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 349 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 350 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 351 |
+
|
| 352 |
+
return x
|
| 353 |
+
|
| 354 |
+
def forward(self, image, text):
|
| 355 |
+
image_features = self.encode_image(image)
|
| 356 |
+
text_features = self.encode_text(text)
|
| 357 |
+
|
| 358 |
+
# normalized features
|
| 359 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 360 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 361 |
+
|
| 362 |
+
# cosine similarity as logits
|
| 363 |
+
logit_scale = self.logit_scale.exp()
|
| 364 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 365 |
+
logits_per_text = logits_per_image.t()
|
| 366 |
+
|
| 367 |
+
# shape = [global_batch_size, global_batch_size]
|
| 368 |
+
return logits_per_image, logits_per_text
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def convert_weights(model: nn.Module):
|
| 372 |
+
"""Convert applicable model parameters to fp16"""
|
| 373 |
+
|
| 374 |
+
def _convert_weights_to_fp16(l):
|
| 375 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 376 |
+
l.weight.data = l.weight.data.half()
|
| 377 |
+
if l.bias is not None:
|
| 378 |
+
l.bias.data = l.bias.data.half()
|
| 379 |
+
|
| 380 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 381 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 382 |
+
tensor = getattr(l, attr)
|
| 383 |
+
if tensor is not None:
|
| 384 |
+
tensor.data = tensor.data.half()
|
| 385 |
+
|
| 386 |
+
for name in ["text_projection", "proj"]:
|
| 387 |
+
if hasattr(l, name):
|
| 388 |
+
attr = getattr(l, name)
|
| 389 |
+
if attr is not None:
|
| 390 |
+
attr.data = attr.data.half()
|
| 391 |
+
|
| 392 |
+
model.apply(_convert_weights_to_fp16)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def build_model(state_dict: dict):
|
| 396 |
+
vit = "visual.proj" in state_dict
|
| 397 |
+
|
| 398 |
+
if vit:
|
| 399 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 400 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 401 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 402 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 403 |
+
image_resolution = vision_patch_size * grid_size
|
| 404 |
+
else:
|
| 405 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 406 |
+
vision_layers = tuple(counts)
|
| 407 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 408 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 409 |
+
vision_patch_size = None
|
| 410 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 411 |
+
image_resolution = output_width * 32
|
| 412 |
+
|
| 413 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 414 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 415 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 416 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 417 |
+
transformer_heads = transformer_width // 64
|
| 418 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 419 |
+
|
| 420 |
+
model = CLIP(
|
| 421 |
+
embed_dim,
|
| 422 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 423 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 427 |
+
if key in state_dict:
|
| 428 |
+
del state_dict[key]
|
| 429 |
+
|
| 430 |
+
convert_weights(model)
|
| 431 |
+
model.load_state_dict(state_dict)
|
| 432 |
+
return model.eval()
|
models/clip/simple_tokenizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import html
|
| 3 |
+
import os
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
import ftfy
|
| 7 |
+
import regex as re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@lru_cache()
|
| 11 |
+
def default_bpe():
|
| 12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@lru_cache()
|
| 16 |
+
def bytes_to_unicode():
|
| 17 |
+
"""
|
| 18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
+
The reversible bpe codes work on unicode strings.
|
| 20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
+
"""
|
| 26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
| 27 |
+
cs = bs[:]
|
| 28 |
+
n = 0
|
| 29 |
+
for b in range(2**8):
|
| 30 |
+
if b not in bs:
|
| 31 |
+
bs.append(b)
|
| 32 |
+
cs.append(2**8+n)
|
| 33 |
+
n += 1
|
| 34 |
+
cs = [chr(n) for n in cs]
|
| 35 |
+
return dict(zip(bs, cs))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_pairs(word):
|
| 39 |
+
"""Return set of symbol pairs in a word.
|
| 40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
+
"""
|
| 42 |
+
pairs = set()
|
| 43 |
+
prev_char = word[0]
|
| 44 |
+
for char in word[1:]:
|
| 45 |
+
pairs.add((prev_char, char))
|
| 46 |
+
prev_char = char
|
| 47 |
+
return pairs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def basic_clean(text):
|
| 51 |
+
text = ftfy.fix_text(text)
|
| 52 |
+
text = html.unescape(html.unescape(text))
|
| 53 |
+
return text.strip()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def whitespace_clean(text):
|
| 57 |
+
text = re.sub(r'\s+', ' ', text)
|
| 58 |
+
text = text.strip()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class SimpleTokenizer(object):
|
| 63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
+
self.byte_encoder = bytes_to_unicode()
|
| 65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
+
merges = merges[1:49152-256-2+1]
|
| 68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
+
vocab = list(bytes_to_unicode().values())
|
| 70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
+
for merge in merges:
|
| 72 |
+
vocab.append(''.join(merge))
|
| 73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
+
|
| 80 |
+
def bpe(self, token):
|
| 81 |
+
if token in self.cache:
|
| 82 |
+
return self.cache[token]
|
| 83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
+
pairs = get_pairs(word)
|
| 85 |
+
|
| 86 |
+
if not pairs:
|
| 87 |
+
return token+'</w>'
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
+
if bigram not in self.bpe_ranks:
|
| 92 |
+
break
|
| 93 |
+
first, second = bigram
|
| 94 |
+
new_word = []
|
| 95 |
+
i = 0
|
| 96 |
+
while i < len(word):
|
| 97 |
+
try:
|
| 98 |
+
j = word.index(first, i)
|
| 99 |
+
new_word.extend(word[i:j])
|
| 100 |
+
i = j
|
| 101 |
+
except:
|
| 102 |
+
new_word.extend(word[i:])
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
+
new_word.append(first+second)
|
| 107 |
+
i += 2
|
| 108 |
+
else:
|
| 109 |
+
new_word.append(word[i])
|
| 110 |
+
i += 1
|
| 111 |
+
new_word = tuple(new_word)
|
| 112 |
+
word = new_word
|
| 113 |
+
if len(word) == 1:
|
| 114 |
+
break
|
| 115 |
+
else:
|
| 116 |
+
pairs = get_pairs(word)
|
| 117 |
+
word = ' '.join(word)
|
| 118 |
+
self.cache[token] = word
|
| 119 |
+
return word
|
| 120 |
+
|
| 121 |
+
def encode(self, text):
|
| 122 |
+
bpe_tokens = []
|
| 123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
+
for token in re.findall(self.pat, text):
|
| 125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
+
return bpe_tokens
|
| 128 |
+
|
| 129 |
+
def decode(self, tokens):
|
| 130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
+
return text
|
models/mamba_base.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
|
| 3 |
+
Copyright (c) Carnegie Mellon University.
|
| 4 |
+
Implemented by alxndrTL from https://github.com/alxndrTL/mamba.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from timm.models.layers import DropPath
|
| 16 |
+
from .pscan import pscan
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class MambaConfig:
|
| 20 |
+
d_model: 768 # D
|
| 21 |
+
dt_rank: Union[int, str] = 'auto'
|
| 22 |
+
d_state: int = 16 # N in paper/comments
|
| 23 |
+
expand_factor: int = 2 # E in paper/comments
|
| 24 |
+
d_conv: int = 4
|
| 25 |
+
|
| 26 |
+
dt_min: float = 0.001
|
| 27 |
+
dt_max: float = 0.1
|
| 28 |
+
dt_init: str = "random" # "random" or "constant"
|
| 29 |
+
dt_scale: float = 1.0
|
| 30 |
+
dt_init_floor = 1e-4
|
| 31 |
+
|
| 32 |
+
drop_prob: float = 0.1
|
| 33 |
+
|
| 34 |
+
bias: bool = False
|
| 35 |
+
conv_bias: bool = True
|
| 36 |
+
bimamba: bool = True
|
| 37 |
+
|
| 38 |
+
pscan: bool = True # use parallel scan mode or sequential mode when training
|
| 39 |
+
|
| 40 |
+
def __post_init__(self):
|
| 41 |
+
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
|
| 42 |
+
|
| 43 |
+
if self.dt_rank == 'auto':
|
| 44 |
+
self.dt_rank = math.ceil(self.d_model / 16)
|
| 45 |
+
|
| 46 |
+
class ResidualBlock(nn.Module):
|
| 47 |
+
def __init__(self, config: MambaConfig):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.mixer = MambaBlock(config)
|
| 51 |
+
self.norm = RMSNorm(config.d_model)
|
| 52 |
+
self.drop_path = DropPath(drop_prob=config.drop_prob)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
# x : (B, L, D)
|
| 56 |
+
|
| 57 |
+
# output : (B, L, D)
|
| 58 |
+
|
| 59 |
+
output = self.drop_path(self.mixer(self.norm(x))) + x
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
def step(self, x, cache):
|
| 63 |
+
# x : (B, D)
|
| 64 |
+
# cache : (h, inputs)
|
| 65 |
+
# h : (B, ED, N)
|
| 66 |
+
# inputs: (B, ED, d_conv-1)
|
| 67 |
+
|
| 68 |
+
# output : (B, D)
|
| 69 |
+
# cache : (h, inputs)
|
| 70 |
+
|
| 71 |
+
output, cache = self.mixer.step(self.norm(x), cache)
|
| 72 |
+
output = output + x
|
| 73 |
+
return output, cache
|
| 74 |
+
|
| 75 |
+
class MambaBlock(nn.Module):
|
| 76 |
+
def __init__(self, config: MambaConfig):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.config = config
|
| 80 |
+
|
| 81 |
+
# projects block input from D to 2*ED (two branches)
|
| 82 |
+
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
|
| 83 |
+
|
| 84 |
+
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
|
| 85 |
+
kernel_size=config.d_conv, bias=config.conv_bias,
|
| 86 |
+
groups=config.d_inner,
|
| 87 |
+
padding=config.d_conv - 1)
|
| 88 |
+
|
| 89 |
+
# projects x to input-dependent Δ, B, C
|
| 90 |
+
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
|
| 91 |
+
|
| 92 |
+
# projects Δ from dt_rank to d_inner
|
| 93 |
+
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
|
| 94 |
+
|
| 95 |
+
# dt initialization
|
| 96 |
+
# dt weights
|
| 97 |
+
dt_init_std = config.dt_rank**-0.5 * config.dt_scale
|
| 98 |
+
if config.dt_init == "constant":
|
| 99 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 100 |
+
elif config.dt_init == "random":
|
| 101 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 102 |
+
else:
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
|
| 105 |
+
# dt bias
|
| 106 |
+
dt = torch.exp(
|
| 107 |
+
torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
|
| 108 |
+
).clamp(min=config.dt_init_floor)
|
| 109 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 112 |
+
#self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 113 |
+
# todo : explain why removed
|
| 114 |
+
|
| 115 |
+
# S4D real initialization
|
| 116 |
+
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
|
| 117 |
+
self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
|
| 118 |
+
self.D = nn.Parameter(torch.ones(config.d_inner))
|
| 119 |
+
|
| 120 |
+
# projects block output from ED back to D
|
| 121 |
+
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
|
| 122 |
+
|
| 123 |
+
self.bimamba = config.bimamba
|
| 124 |
+
|
| 125 |
+
if self.bimamba:
|
| 126 |
+
A_b = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
|
| 127 |
+
self.A_b_log = nn.Parameter(torch.log(A_b))
|
| 128 |
+
|
| 129 |
+
self.conv1d_b = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
|
| 130 |
+
kernel_size=config.d_conv, bias=config.conv_bias,
|
| 131 |
+
groups=config.d_inner,
|
| 132 |
+
padding=config.d_conv - 1)
|
| 133 |
+
|
| 134 |
+
self.x_proj_b = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
|
| 135 |
+
self.dt_proj_b = nn.Linear(config.dt_rank, config.d_inner, bias=True)
|
| 136 |
+
self.D_b = nn.Parameter(torch.ones(config.d_inner))
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
# x : (B, L, D)
|
| 140 |
+
|
| 141 |
+
# y : (B, L, D)
|
| 142 |
+
|
| 143 |
+
_, L, _ = x.shape
|
| 144 |
+
|
| 145 |
+
xz = self.in_proj(x) # (B, L, 2*ED)
|
| 146 |
+
x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
|
| 147 |
+
|
| 148 |
+
# x branch
|
| 149 |
+
x = x.transpose(1, 2) # (B, ED, L)
|
| 150 |
+
x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
|
| 151 |
+
x = x.transpose(1, 2) # (B, L, ED)
|
| 152 |
+
|
| 153 |
+
x = F.silu(x)
|
| 154 |
+
y = self.ssm(x)
|
| 155 |
+
|
| 156 |
+
# z branch
|
| 157 |
+
z = F.silu(z)
|
| 158 |
+
|
| 159 |
+
output = y * z
|
| 160 |
+
output = self.out_proj(output) # (B, L, D)
|
| 161 |
+
|
| 162 |
+
return output
|
| 163 |
+
|
| 164 |
+
def ssm(self, x):
|
| 165 |
+
# x : (B, L, ED)
|
| 166 |
+
|
| 167 |
+
# y : (B, L, ED)
|
| 168 |
+
|
| 169 |
+
A = -torch.exp(self.A_log.float()) # (ED, N)
|
| 170 |
+
D = self.D.float()
|
| 171 |
+
# TODO remove .float()
|
| 172 |
+
|
| 173 |
+
deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
|
| 174 |
+
|
| 175 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
|
| 176 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
|
| 177 |
+
|
| 178 |
+
if self.config.pscan:
|
| 179 |
+
y = self.selective_scan(x, delta, A, B, C, D)
|
| 180 |
+
else:
|
| 181 |
+
y = self.selective_scan_seq(x, delta, A, B, C, D)
|
| 182 |
+
|
| 183 |
+
if self.bimamba:
|
| 184 |
+
x_b = x.flip([-1])
|
| 185 |
+
A_b = -torch.exp(self.A_b_log.float()) # (ED, N)
|
| 186 |
+
D_b = self.D_b.float()
|
| 187 |
+
deltaBC_b = self.x_proj_b(x_b)
|
| 188 |
+
delta_b, B_b, C_b = torch.split(deltaBC_b, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
|
| 189 |
+
delta_b = F.softplus(self.dt_proj_b(delta_b)) # (B, L, ED)
|
| 190 |
+
if self.config.pscan:
|
| 191 |
+
y_b = self.selective_scan(x_b, delta_b, A_b, B_b, C_b, D_b)
|
| 192 |
+
else:
|
| 193 |
+
y_b = self.selective_scan_seq(x_b, delta_b, A_b, B_b, C_b, D_b)
|
| 194 |
+
y_b = y_b.flip([-1])
|
| 195 |
+
y = y + y_b
|
| 196 |
+
return y
|
| 197 |
+
|
| 198 |
+
def selective_scan(self, x, delta, A, B, C, D):
|
| 199 |
+
# x : (B, L, ED)
|
| 200 |
+
# Δ : (B, L, ED)
|
| 201 |
+
# A : (ED, N)
|
| 202 |
+
# B : (B, L, N)
|
| 203 |
+
# C : (B, L, N)
|
| 204 |
+
# D : (ED)
|
| 205 |
+
|
| 206 |
+
# y : (B, L, ED)
|
| 207 |
+
|
| 208 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
| 209 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
| 210 |
+
|
| 211 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
| 212 |
+
|
| 213 |
+
hs = pscan(deltaA, BX)
|
| 214 |
+
|
| 215 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
| 216 |
+
|
| 217 |
+
y = y + D * x
|
| 218 |
+
|
| 219 |
+
return y
|
| 220 |
+
|
| 221 |
+
def selective_scan_seq(self, x, delta, A, B, C, D):
|
| 222 |
+
# x : (B, L, ED)
|
| 223 |
+
# Δ : (B, L, ED)
|
| 224 |
+
# A : (ED, N)
|
| 225 |
+
# B : (B, L, N)
|
| 226 |
+
# C : (B, L, N)
|
| 227 |
+
# D : (ED)
|
| 228 |
+
|
| 229 |
+
# y : (B, L, ED)
|
| 230 |
+
|
| 231 |
+
_, L, _ = x.shape
|
| 232 |
+
|
| 233 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
|
| 234 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
|
| 235 |
+
|
| 236 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
|
| 237 |
+
|
| 238 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
| 239 |
+
hs = []
|
| 240 |
+
|
| 241 |
+
for t in range(0, L):
|
| 242 |
+
h = deltaA[:, t] * h + BX[:, t]
|
| 243 |
+
hs.append(h)
|
| 244 |
+
|
| 245 |
+
hs = torch.stack(hs, dim=1) # (B, L, ED, N)
|
| 246 |
+
|
| 247 |
+
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
|
| 248 |
+
|
| 249 |
+
y = y + D * x
|
| 250 |
+
|
| 251 |
+
return y
|
| 252 |
+
|
| 253 |
+
# -------------------------- inference -------------------------- #
|
| 254 |
+
"""
|
| 255 |
+
Concerning auto-regressive inference
|
| 256 |
+
|
| 257 |
+
The cool part of using Mamba : inference is constant wrt to sequence length
|
| 258 |
+
We just have to keep in cache, for each layer, two things :
|
| 259 |
+
- the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
|
| 260 |
+
- the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
|
| 261 |
+
(d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
|
| 262 |
+
(and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
|
| 263 |
+
|
| 264 |
+
Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
|
| 265 |
+
h is (B, ED, N), and inputs is (B, ED, d_conv-1)
|
| 266 |
+
The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
|
| 267 |
+
|
| 268 |
+
The cache object is initialized as follows : (None, torch.zeros()).
|
| 269 |
+
When h is None, the selective scan function detects it and start with h=0.
|
| 270 |
+
The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
|
| 271 |
+
|
| 272 |
+
As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def step(self, x, cache):
|
| 276 |
+
# x : (B, D)
|
| 277 |
+
# cache : (h, inputs)
|
| 278 |
+
# h : (B, ED, N)
|
| 279 |
+
# inputs : (B, ED, d_conv-1)
|
| 280 |
+
|
| 281 |
+
# y : (B, D)
|
| 282 |
+
# cache : (h, inputs)
|
| 283 |
+
|
| 284 |
+
h, inputs = cache
|
| 285 |
+
|
| 286 |
+
xz = self.in_proj(x) # (B, 2*ED)
|
| 287 |
+
x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
|
| 288 |
+
|
| 289 |
+
# x branch
|
| 290 |
+
x_cache = x.unsqueeze(2)
|
| 291 |
+
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
|
| 292 |
+
|
| 293 |
+
x = F.silu(x)
|
| 294 |
+
y, h = self.ssm_step(x, h)
|
| 295 |
+
|
| 296 |
+
# z branch
|
| 297 |
+
z = F.silu(z)
|
| 298 |
+
|
| 299 |
+
output = y * z
|
| 300 |
+
output = self.out_proj(output) # (B, D)
|
| 301 |
+
|
| 302 |
+
# prepare cache for next call
|
| 303 |
+
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
|
| 304 |
+
cache = (h, inputs)
|
| 305 |
+
|
| 306 |
+
return output, cache
|
| 307 |
+
|
| 308 |
+
def ssm_step(self, x, h):
|
| 309 |
+
# x : (B, ED)
|
| 310 |
+
# h : (B, ED, N)
|
| 311 |
+
|
| 312 |
+
# y : (B, ED)
|
| 313 |
+
# h : (B, ED, N)
|
| 314 |
+
|
| 315 |
+
A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
|
| 316 |
+
D = self.D.float()
|
| 317 |
+
# TODO remove .float()
|
| 318 |
+
|
| 319 |
+
deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
|
| 320 |
+
|
| 321 |
+
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
|
| 322 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, ED)
|
| 323 |
+
|
| 324 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
|
| 325 |
+
deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
|
| 326 |
+
|
| 327 |
+
BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
|
| 328 |
+
|
| 329 |
+
if h is None:
|
| 330 |
+
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
|
| 331 |
+
|
| 332 |
+
h = deltaA * h + BX # (B, ED, N)
|
| 333 |
+
|
| 334 |
+
y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
|
| 335 |
+
|
| 336 |
+
y = y + D * x
|
| 337 |
+
|
| 338 |
+
# todo : pq h.squeeze(1) ??
|
| 339 |
+
return y, h.squeeze(1)
|
| 340 |
+
|
| 341 |
+
# taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
|
| 342 |
+
class RMSNorm(nn.Module):
|
| 343 |
+
def __init__(self, d_model: int, eps: float = 1e-5):
|
| 344 |
+
super().__init__()
|
| 345 |
+
|
| 346 |
+
self.eps = eps
|
| 347 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 348 |
+
|
| 349 |
+
def forward(self, x):
|
| 350 |
+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
| 351 |
+
|
| 352 |
+
return output
|
models/pscan.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mamba: Linear-Time Sequence Modeling with Selective State Spaces
|
| 3 |
+
Copyright (c) Carnegie Mellon University.
|
| 4 |
+
Implemented by alxndrTL from https://github.com/alxndrTL/mamba.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
An implementation of the parallel scan operation in PyTorch (Blelloch version).
|
| 15 |
+
Please see docs/pscan.ipynb for a detailed explanation of what happens here.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def npo2(len):
|
| 20 |
+
"""
|
| 21 |
+
Returns the next power of 2 above len
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
return 2 ** math.ceil(math.log2(len))
|
| 25 |
+
|
| 26 |
+
def pad_npo2(X):
|
| 27 |
+
"""
|
| 28 |
+
Pads input length dim to the next power of 2
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
X : (B, L, D, N)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Y : (B, npo2(L), D, N)
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
len_npo2 = npo2(X.size(1))
|
| 38 |
+
pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
|
| 39 |
+
return F.pad(X, pad_tuple, "constant", 0)
|
| 40 |
+
|
| 41 |
+
class PScan(torch.autograd.Function):
|
| 42 |
+
@staticmethod
|
| 43 |
+
def pscan(A, X):
|
| 44 |
+
# A : (B, D, L, N)
|
| 45 |
+
# X : (B, D, L, N)
|
| 46 |
+
|
| 47 |
+
# modifies X in place by doing a parallel scan.
|
| 48 |
+
# more formally, X will be populated by these values :
|
| 49 |
+
# H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
| 50 |
+
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
|
| 51 |
+
|
| 52 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
| 53 |
+
|
| 54 |
+
B, D, L, _ = A.size()
|
| 55 |
+
num_steps = int(math.log2(L))
|
| 56 |
+
|
| 57 |
+
# up sweep (last 2 steps unfolded)
|
| 58 |
+
Aa = A
|
| 59 |
+
Xa = X
|
| 60 |
+
for _ in range(num_steps-2):
|
| 61 |
+
T = Xa.size(2)
|
| 62 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
| 63 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
| 64 |
+
|
| 65 |
+
Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
|
| 66 |
+
Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
|
| 67 |
+
|
| 68 |
+
Aa = Aa[:, :, :, 1]
|
| 69 |
+
Xa = Xa[:, :, :, 1]
|
| 70 |
+
|
| 71 |
+
# we have only 4, 2 or 1 nodes left
|
| 72 |
+
if Xa.size(2) == 4:
|
| 73 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
| 74 |
+
Aa[:, :, 1].mul_(Aa[:, :, 0])
|
| 75 |
+
|
| 76 |
+
Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
|
| 77 |
+
elif Xa.size(2) == 2:
|
| 78 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
|
| 79 |
+
return
|
| 80 |
+
else:
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
# down sweep (first 2 steps unfolded)
|
| 84 |
+
Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
| 85 |
+
Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
|
| 86 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
|
| 87 |
+
Aa[:, :, 2].mul_(Aa[:, :, 1])
|
| 88 |
+
|
| 89 |
+
for k in range(num_steps-3, -1, -1):
|
| 90 |
+
Aa = A[:, :, 2**k-1:L:2**k]
|
| 91 |
+
Xa = X[:, :, 2**k-1:L:2**k]
|
| 92 |
+
|
| 93 |
+
T = Xa.size(2)
|
| 94 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
| 95 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
| 96 |
+
|
| 97 |
+
Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
|
| 98 |
+
Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def pscan_rev(A, X):
|
| 102 |
+
# A : (B, D, L, N)
|
| 103 |
+
# X : (B, D, L, N)
|
| 104 |
+
|
| 105 |
+
# the same function as above, but in reverse
|
| 106 |
+
# (if you flip the input, call pscan, then flip the output, you get what this function outputs)
|
| 107 |
+
# it is used in the backward pass
|
| 108 |
+
|
| 109 |
+
# only supports L that is a power of two (mainly for a clearer code)
|
| 110 |
+
|
| 111 |
+
B, D, L, _ = A.size()
|
| 112 |
+
num_steps = int(math.log2(L))
|
| 113 |
+
|
| 114 |
+
# up sweep (last 2 steps unfolded)
|
| 115 |
+
Aa = A
|
| 116 |
+
Xa = X
|
| 117 |
+
for _ in range(num_steps-2):
|
| 118 |
+
T = Xa.size(2)
|
| 119 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
| 120 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
| 121 |
+
|
| 122 |
+
Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
|
| 123 |
+
Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
|
| 124 |
+
|
| 125 |
+
Aa = Aa[:, :, :, 0]
|
| 126 |
+
Xa = Xa[:, :, :, 0]
|
| 127 |
+
|
| 128 |
+
# we have only 4, 2 or 1 nodes left
|
| 129 |
+
if Xa.size(2) == 4:
|
| 130 |
+
Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
|
| 131 |
+
Aa[:, :, 2].mul_(Aa[:, :, 3])
|
| 132 |
+
|
| 133 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
|
| 134 |
+
elif Xa.size(2) == 2:
|
| 135 |
+
Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
|
| 136 |
+
return
|
| 137 |
+
else:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
# down sweep (first 2 steps unfolded)
|
| 141 |
+
Aa = A[:, :, 0:L:2**(num_steps-2)]
|
| 142 |
+
Xa = X[:, :, 0:L:2**(num_steps-2)]
|
| 143 |
+
Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
|
| 144 |
+
Aa[:, :, 1].mul_(Aa[:, :, 2])
|
| 145 |
+
|
| 146 |
+
for k in range(num_steps-3, -1, -1):
|
| 147 |
+
Aa = A[:, :, 0:L:2**k]
|
| 148 |
+
Xa = X[:, :, 0:L:2**k]
|
| 149 |
+
|
| 150 |
+
T = Xa.size(2)
|
| 151 |
+
Aa = Aa.view(B, D, T//2, 2, -1)
|
| 152 |
+
Xa = Xa.view(B, D, T//2, 2, -1)
|
| 153 |
+
|
| 154 |
+
Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
|
| 155 |
+
Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def forward(ctx, A_in, X_in):
|
| 159 |
+
"""
|
| 160 |
+
Applies the parallel scan operation, as defined above. Returns a new tensor.
|
| 161 |
+
If you can, privilege sequence lengths that are powers of two.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
A_in : (B, L, D, N)
|
| 165 |
+
X_in : (B, L, D, N)
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
H : (B, L, D, N)
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
L = X_in.size(1)
|
| 172 |
+
|
| 173 |
+
# cloning is requiered because of the in-place ops
|
| 174 |
+
if L == npo2(L):
|
| 175 |
+
A = A_in.clone()
|
| 176 |
+
X = X_in.clone()
|
| 177 |
+
else:
|
| 178 |
+
# pad tensors (and clone btw)
|
| 179 |
+
A = pad_npo2(A_in) # (B, npo2(L), D, N)
|
| 180 |
+
X = pad_npo2(X_in) # (B, npo2(L), D, N)
|
| 181 |
+
|
| 182 |
+
# prepare tensors
|
| 183 |
+
A = A.transpose(2, 1) # (B, D, npo2(L), N)
|
| 184 |
+
X = X.transpose(2, 1) # (B, D, npo2(L), N)
|
| 185 |
+
|
| 186 |
+
# parallel scan (modifies X in-place)
|
| 187 |
+
PScan.pscan(A, X)
|
| 188 |
+
|
| 189 |
+
ctx.save_for_backward(A_in, X)
|
| 190 |
+
|
| 191 |
+
# slice [:, :L] (cut if there was padding)
|
| 192 |
+
return X.transpose(2, 1)[:, :L]
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def backward(ctx, grad_output_in):
|
| 196 |
+
"""
|
| 197 |
+
Flows the gradient from the output to the input. Returns two new tensors.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
ctx : A_in : (B, L, D, N), X : (B, D, L, N)
|
| 201 |
+
grad_output_in : (B, L, D, N)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
gradA : (B, L, D, N), gradX : (B, L, D, N)
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
A_in, X = ctx.saved_tensors
|
| 208 |
+
|
| 209 |
+
L = grad_output_in.size(1)
|
| 210 |
+
|
| 211 |
+
# cloning is requiered because of the in-place ops
|
| 212 |
+
if L == npo2(L):
|
| 213 |
+
grad_output = grad_output_in.clone()
|
| 214 |
+
# the next padding will clone A_in
|
| 215 |
+
else:
|
| 216 |
+
grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
|
| 217 |
+
A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
|
| 218 |
+
|
| 219 |
+
# prepare tensors
|
| 220 |
+
grad_output = grad_output.transpose(2, 1)
|
| 221 |
+
A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
|
| 222 |
+
A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
|
| 223 |
+
|
| 224 |
+
# reverse parallel scan (modifies grad_output in-place)
|
| 225 |
+
PScan.pscan_rev(A, grad_output)
|
| 226 |
+
|
| 227 |
+
Q = torch.zeros_like(X)
|
| 228 |
+
Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
|
| 229 |
+
|
| 230 |
+
return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
|
| 231 |
+
|
| 232 |
+
pscan = PScan.apply
|
models/time_transformer
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021
|
| 3 |
+
Copyright (c) Xiamen University and its affiliates.
|
| 4 |
+
Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, einsum
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from einops import rearrange, repeat
|
| 12 |
+
from einops.layers.torch import Rearrange
|
| 13 |
+
|
| 14 |
+
class Residual(nn.Module):
|
| 15 |
+
def __init__(self, fn):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.fn = fn
|
| 18 |
+
def forward(self, x, **kwargs):
|
| 19 |
+
return self.fn(x, **kwargs) + x
|
| 20 |
+
|
| 21 |
+
class PreNorm(nn.Module):
|
| 22 |
+
def __init__(self, dim, fn):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.norm = nn.LayerNorm(dim)
|
| 25 |
+
self.fn = fn
|
| 26 |
+
def forward(self, x, **kwargs):
|
| 27 |
+
return self.fn(self.norm(x), **kwargs)
|
| 28 |
+
|
| 29 |
+
class FeedForward(nn.Module):
|
| 30 |
+
def __init__(self, dim, hidden_dim, dropout = 0.):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.net = nn.Sequential(
|
| 33 |
+
nn.Linear(dim, hidden_dim),
|
| 34 |
+
nn.GELU(),
|
| 35 |
+
nn.Dropout(dropout),
|
| 36 |
+
nn.Linear(hidden_dim, dim),
|
| 37 |
+
nn.Dropout(dropout)
|
| 38 |
+
)
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
return self.net(x)
|
| 41 |
+
|
| 42 |
+
class Attention(nn.Module):
|
| 43 |
+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
| 44 |
+
super().__init__()
|
| 45 |
+
inner_dim = dim_head * heads
|
| 46 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 47 |
+
|
| 48 |
+
self.heads = heads
|
| 49 |
+
self.scale = dim_head ** -0.5
|
| 50 |
+
|
| 51 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
| 52 |
+
|
| 53 |
+
self.to_out = nn.Sequential(
|
| 54 |
+
nn.Linear(inner_dim, dim),
|
| 55 |
+
nn.Dropout(dropout)
|
| 56 |
+
) if project_out else nn.Identity()
|
| 57 |
+
|
| 58 |
+
def forward(self, x, mask = None):
|
| 59 |
+
b, n, _, h = *x.shape, self.heads
|
| 60 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
| 61 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
| 62 |
+
|
| 63 |
+
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
| 64 |
+
mask_value = -torch.finfo(dots.dtype).max
|
| 65 |
+
|
| 66 |
+
if mask is not None:
|
| 67 |
+
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
| 68 |
+
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
| 69 |
+
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
|
| 70 |
+
dots.masked_fill_(~mask, mask_value)
|
| 71 |
+
del mask
|
| 72 |
+
|
| 73 |
+
attn = dots.softmax(dim=-1)
|
| 74 |
+
|
| 75 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
| 76 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 77 |
+
out = self.to_out(out)
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
class Transformer(nn.Module):
|
| 81 |
+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.layers = nn.ModuleList([])
|
| 84 |
+
for _ in range(depth):
|
| 85 |
+
self.layers.append(nn.ModuleList([
|
| 86 |
+
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
|
| 87 |
+
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
| 88 |
+
]))
|
| 89 |
+
def forward(self, x, mask = None):
|
| 90 |
+
for attn, ff in self.layers:
|
| 91 |
+
x = attn(x, mask = mask)
|
| 92 |
+
x = ff(x)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
class ViT(nn.Module):
|
| 96 |
+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
| 97 |
+
super().__init__()
|
| 98 |
+
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
| 99 |
+
num_patches = (image_size // patch_size) ** 2
|
| 100 |
+
patch_dim = channels * patch_size ** 2
|
| 101 |
+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
| 102 |
+
|
| 103 |
+
self.to_patch_embedding = nn.Sequential(
|
| 104 |
+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
| 105 |
+
nn.Linear(patch_dim, dim),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
| 109 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
| 110 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 111 |
+
|
| 112 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
| 113 |
+
|
| 114 |
+
self.pool = pool
|
| 115 |
+
self.to_latent = nn.Identity()
|
| 116 |
+
|
| 117 |
+
self.mlp_head = nn.Sequential(
|
| 118 |
+
nn.LayerNorm(dim),
|
| 119 |
+
nn.Linear(dim, num_classes)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def forward(self, img, mask = None):
|
| 123 |
+
x = self.to_patch_embedding(img)
|
| 124 |
+
b, n, _ = x.shape #batch,num_patches,channels #
|
| 125 |
+
|
| 126 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
| 127 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 128 |
+
x += self.pos_embedding[:, :(n + 1)]
|
| 129 |
+
x = self.dropout(x)
|
| 130 |
+
|
| 131 |
+
x = self.transformer(x, mask)
|
| 132 |
+
|
| 133 |
+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
| 134 |
+
|
| 135 |
+
x = self.to_latent(x)
|
| 136 |
+
return self.mlp_head(x)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def valid_idx(idx, h):
|
| 141 |
+
i = idx // h
|
| 142 |
+
j = idx % h
|
| 143 |
+
pad = h // 7
|
| 144 |
+
if j < pad or i >= h - pad or j >= h - pad:
|
| 145 |
+
return False
|
| 146 |
+
else:
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
import random
|
| 150 |
+
from math import sqrt
|
| 151 |
+
class RandomSelect(nn.Module):
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
# batch,7x7
|
| 157 |
+
size=x.shape[1]
|
| 158 |
+
h=int(sqrt(size))
|
| 159 |
+
candidates = list(range(size))
|
| 160 |
+
candidates = [idx for idx in candidates if valid_idx(idx, h)]
|
| 161 |
+
max_k = len(candidates)
|
| 162 |
+
if self.training:
|
| 163 |
+
k = 8
|
| 164 |
+
if k==-1:
|
| 165 |
+
k=max_k
|
| 166 |
+
else:
|
| 167 |
+
k = max_k
|
| 168 |
+
candidates = random.sample(candidates, k)
|
| 169 |
+
x = x[:,candidates]
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
class VideoiT(nn.Module):
|
| 173 |
+
def __init__(self, *, image_size, patch_size, num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
| 174 |
+
super().__init__()
|
| 175 |
+
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
| 176 |
+
patch_dim = channels * patch_size ** 2
|
| 177 |
+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
| 178 |
+
|
| 179 |
+
self.to_patch = Rearrange('b c t (h p1) (w p2) -> b (h w) t (p1 p2 c)', p1 = patch_size, p2 = patch_size)
|
| 180 |
+
self.patch_to_embedding=nn.Linear(patch_dim, dim)
|
| 181 |
+
self.num_patches=num_patches
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
| 185 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
| 186 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 187 |
+
|
| 188 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
| 189 |
+
|
| 190 |
+
self.pool = pool
|
| 191 |
+
self.random_select=RandomSelect()
|
| 192 |
+
self.to_latent = nn.Identity()
|
| 193 |
+
|
| 194 |
+
self.mlp_head = nn.Sequential(
|
| 195 |
+
nn.LayerNorm(dim),
|
| 196 |
+
nn.Linear(dim, num_classes)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, img, mask = None):
|
| 200 |
+
real_b=img.shape[0]
|
| 201 |
+
x = self.to_patch(img)
|
| 202 |
+
x = self.random_select(x)
|
| 203 |
+
n=x.shape[1]
|
| 204 |
+
x=x.reshape(real_b*n,self.num_patches,-1)
|
| 205 |
+
x = self.patch_to_embedding(x)
|
| 206 |
+
b, n, _ = x.shape #batch,num_patches,channels #
|
| 207 |
+
|
| 208 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
| 209 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 210 |
+
x += self.pos_embedding[:, :(n + 1)]
|
| 211 |
+
x = self.dropout(x)
|
| 212 |
+
|
| 213 |
+
x = self.transformer(x, mask)
|
| 214 |
+
|
| 215 |
+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
| 216 |
+
|
| 217 |
+
x = self.to_latent(x)
|
| 218 |
+
x = self.mlp_head(x)
|
| 219 |
+
x = x.reshape(real_b,-1)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class TimeTransformer(nn.Module):
|
| 224 |
+
def __init__(self,num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.):
|
| 225 |
+
super().__init__()
|
| 226 |
+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
| 227 |
+
|
| 228 |
+
self.num_patches=num_patches
|
| 229 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
| 230 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
| 231 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 232 |
+
|
| 233 |
+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
| 234 |
+
|
| 235 |
+
self.pool = pool
|
| 236 |
+
self.to_latent = nn.Identity()
|
| 237 |
+
|
| 238 |
+
self.mlp_head = nn.Sequential(
|
| 239 |
+
nn.LayerNorm(dim),
|
| 240 |
+
nn.Linear(dim, num_classes)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, x):
|
| 244 |
+
b, n, _ = x.shape #batch,num_patches,channels #
|
| 245 |
+
|
| 246 |
+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
| 247 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 248 |
+
x += self.pos_embedding[:, :(n + 1)]
|
| 249 |
+
x = self.dropout(x)
|
| 250 |
+
|
| 251 |
+
x = self.transformer(x, mask=None)
|
| 252 |
+
|
| 253 |
+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
| 254 |
+
|
| 255 |
+
x = self.to_latent(x)
|
| 256 |
+
return self.mlp_head(x)
|
requirements.txt
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.3.1
|
| 2 |
+
albucore==0.0.24
|
| 3 |
+
albumentations==2.0.8
|
| 4 |
+
annotated-types==0.7.0
|
| 5 |
+
certifi==2025.10.5
|
| 6 |
+
charset-normalizer==3.4.3
|
| 7 |
+
eval-type-backport==0.2.2
|
| 8 |
+
filelock==3.19.1
|
| 9 |
+
fsspec==2025.9.0
|
| 10 |
+
ftfy==6.3.1
|
| 11 |
+
grpcio==1.75.1
|
| 12 |
+
hf-xet==1.1.10
|
| 13 |
+
huggingface-hub==0.35.3
|
| 14 |
+
idna==3.10
|
| 15 |
+
importlib-metadata==8.7.0
|
| 16 |
+
jinja2==3.1.6
|
| 17 |
+
joblib==1.5.2
|
| 18 |
+
markdown==3.9
|
| 19 |
+
markupsafe==3.0.3
|
| 20 |
+
mpmath==1.3.0
|
| 21 |
+
networkx==3.2.1
|
| 22 |
+
numpy==2.0.2
|
| 23 |
+
opencv-python==4.12.0.88
|
| 24 |
+
opencv-python-headless==4.12.0.88
|
| 25 |
+
packaging==25.0
|
| 26 |
+
pandas==2.3.3
|
| 27 |
+
pillow==11.3.0
|
| 28 |
+
protobuf==6.32.1
|
| 29 |
+
pydantic==2.11.10
|
| 30 |
+
pydantic-core==2.33.2
|
| 31 |
+
python-dateutil==2.9.0.post0
|
| 32 |
+
pytz==2025.2
|
| 33 |
+
pyyaml==6.0.3
|
| 34 |
+
regex==2025.9.18
|
| 35 |
+
requests==2.32.5
|
| 36 |
+
safetensors==0.6.2
|
| 37 |
+
scikit-learn==1.6.1
|
| 38 |
+
scipy==1.13.1
|
| 39 |
+
simsimd==6.5.3
|
| 40 |
+
six==1.17.0
|
| 41 |
+
stringzilla==4.1.0
|
| 42 |
+
sympy==1.14.0
|
| 43 |
+
tensorboard==2.20.0
|
| 44 |
+
tensorboard-data-server==0.7.2
|
| 45 |
+
threadpoolctl==3.6.0
|
| 46 |
+
timm==1.0.20
|
| 47 |
+
tokenizers==0.22.1
|
| 48 |
+
torch==2.8.0
|
| 49 |
+
torchvision==0.23.0
|
| 50 |
+
tqdm==4.67.1
|
| 51 |
+
transformers==4.57.0
|
| 52 |
+
triton==3.4.0
|
| 53 |
+
typing-extensions==4.15.0
|
| 54 |
+
typing-inspection==0.4.2
|
| 55 |
+
tzdata==2025.2
|
| 56 |
+
urllib3==2.5.0
|
| 57 |
+
wcwidth==0.2.14
|
| 58 |
+
werkzeug==3.1.3
|
| 59 |
+
zipp==3.23.0
|
results.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
file_name,predicted_prob,predicted_label
|
| 2 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f000.jpg,0.8246555924415588,1
|
| 3 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f001.jpg,0.8141521215438843,1
|
| 4 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f002.jpg,0.8013387322425842,1
|
| 5 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f003.jpg,0.7972325086593628,1
|
| 6 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f004.jpg,0.8086235523223877,1
|
script.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
|
| 3 |
+
scripts = [
|
| 4 |
+
"extract_frames.py",
|
| 5 |
+
"eval2.py",
|
| 6 |
+
"create_submission.py"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
def run_script(script):
|
| 10 |
+
print(f"\n🚀 Running {script}...")
|
| 11 |
+
result = subprocess.run(["python", script], capture_output=True, text=True)
|
| 12 |
+
if result.returncode == 0:
|
| 13 |
+
print(f"✅ {script} completed successfully.\n")
|
| 14 |
+
print(result.stdout)
|
| 15 |
+
else:
|
| 16 |
+
print(f"❌ {script} failed with error:\n{result.stderr}")
|
| 17 |
+
exit(1)
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
for script in scripts:
|
| 21 |
+
run_script(script)
|
| 22 |
+
print("🎉 All scripts completed successfully!")
|
submission.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
id,pred,score
|
| 2 |
+
kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6,generated,0.8092005014419555
|