Kalpit commited on
Commit
d39b279
·
0 Parent(s):

feat: Add model files with LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +2 -0
  3. LICENSE +201 -0
  4. Preprocess/compress.py +38 -0
  5. Preprocess/folder2csv.py +72 -0
  6. Preprocess/video2frame.py +57 -0
  7. README.md +88 -0
  8. __pycache__/dataloader2.cpython-39.pyc +0 -0
  9. __pycache__/util.cpython-39.pyc +0 -0
  10. commands.md +3 -0
  11. create_csv.py +44 -0
  12. create_submission.py +35 -0
  13. dataloader.py +281 -0
  14. dataloader2.py +246 -0
  15. eval.py +48 -0
  16. eval2.py +110 -0
  17. extract_frames.py +74 -0
  18. models/DeMamba.py +176 -0
  19. models/F3Net.py +434 -0
  20. models/FTCN.py +143 -0
  21. models/MINTIME +269 -0
  22. models/NPR.py +284 -0
  23. models/STIL.py +641 -0
  24. models/TALL.py +935 -0
  25. models/VideoMAE.py +67 -0
  26. models/XCLIP.py +33 -0
  27. models/__init__.py +5 -0
  28. models/__pycache__/DeMamba.cpython-39.pyc +0 -0
  29. models/__pycache__/F3Net.cpython-39.pyc +0 -0
  30. models/__pycache__/NPR.cpython-39.pyc +0 -0
  31. models/__pycache__/STIL.cpython-39.pyc +0 -0
  32. models/__pycache__/XCLIP.cpython-39.pyc +0 -0
  33. models/__pycache__/__init__.cpython-39.pyc +0 -0
  34. models/__pycache__/mamba_base.cpython-39.pyc +0 -0
  35. models/__pycache__/pscan.cpython-39.pyc +0 -0
  36. models/clip/__init__.py +1 -0
  37. models/clip/__pycache__/__init__.cpython-39.pyc +0 -0
  38. models/clip/__pycache__/clip.cpython-39.pyc +0 -0
  39. models/clip/__pycache__/model.cpython-39.pyc +0 -0
  40. models/clip/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
  41. models/clip/clip.py +233 -0
  42. models/clip/model.py +432 -0
  43. models/clip/simple_tokenizer.py +132 -0
  44. models/mamba_base.py +352 -0
  45. models/pscan.py +232 -0
  46. models/time_transformer +256 -0
  47. requirements.txt +59 -0
  48. results.csv +6 -0
  49. script.py +22 -0
  50. submission.csv +2 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ results/kling_9k_9k/best_acc.pth
2
+ models/clip/bpe_simple_vocab_16e6.txt.gz
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Preprocess/compress.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from glob import glob
3
+
4
+ def convert_and_compress_video(input_video_path, output_video_path, crf=23, preset='medium'):
5
+ """
6
+ 使用ffmpeg转换并压缩视频
7
+ :param input_video_path: 输入视频文件路径
8
+ :param output_video_path: 输出压缩后视频文件路径
9
+ :param crf: Constant Rate Factor,值越小质量越好,但文件也越大,默认为23是一个平衡点
10
+ :param preset: 预设值,影响压缩速度和文件大小,如'ultrafast', 'fast', 'medium', 'slow', 'veryslow'等,默认'medium'
11
+ """
12
+ # 检查文件是否为GIF
13
+ if input_video_path.lower().endswith('.gif'):
14
+ # 构建ffmpeg命令
15
+ # 首先将GIF转换为MP4
16
+ convert_cmd = f'ffmpeg -i "{input_video_path}" -vf "palettegen" -y palette.png'
17
+ subprocess.run(convert_cmd, shell=True, check=True)
18
+ convert_cmd = f'ffmpeg -i "{input_video_path}" -i palette.png -lavfi "paletteuse" -c:v libx264 -preset {preset} -crf {crf} -c:a copy "{output_video_path}"'
19
+ else:
20
+ # 直接压缩视频
21
+ convert_cmd = f'ffmpeg -i "{input_video_path}" -c:v libx264 -preset {preset} -crf {crf} -c:a copy "{output_video_path}"'
22
+
23
+ # 执行ffmpeg命令
24
+ try:
25
+ subprocess.run(convert_cmd, shell=True, check=True)
26
+ print(f"视频压缩完成,输出文件:{output_video_path}")
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"视频压缩失败:{e}")
29
+
30
+ video_paths = '/home3/Transformer'
31
+ all_dirs = glob(video_paths+'/*')
32
+ output_dirs = '/home3/robust/compress/Transformer_'
33
+
34
+ for path in all_dirs:
35
+ name = path.split('/')[-1]
36
+ out_path = output_dirs + name.replace('.gif', '.mp4') if path.lower().endswith('.gif') else output_dirs + name
37
+ print(name)
38
+ convert_and_compress_video(path, out_path, crf=28)
Preprocess/folder2csv.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import pandas as pd
4
+ from pandas import Series, DataFrame
5
+ from glob import glob
6
+ import os
7
+
8
+ def count_images_in_folder(folder_path):
9
+ image_count = 0
10
+ image_names = []
11
+ for file_name in os.listdir(folder_path):
12
+ if file_name.endswith('.png') or file_name.endswith('.jpg') or file_name.endswith('.jpeg'):
13
+ image_count += 1
14
+ image_names.append(int(file_name.split('.')[0]))
15
+ image_names.sort()
16
+ return image_count, image_names
17
+
18
+
19
+ folder_path = './SD_frames'
20
+ all_dirs = []
21
+
22
+ for root, dirs, files in os.walk(folder_path):
23
+ for dir in dirs:
24
+ all_dirs.append(os.path.join(root, dir))
25
+
26
+ label = list()
27
+ save_path = list()
28
+ frame_counts = list()
29
+ frame_seq_counts = list()
30
+ content_paths = list()
31
+ chinese_labels = list()
32
+
33
+
34
+ for video_path in all_dirs:
35
+ frame_paths = glob(video_path + '/*')
36
+ temp_frame_count, temp_frame_seqs = count_images_in_folder(video_path)
37
+ if temp_frame_count == 0:
38
+ continue
39
+
40
+ for frame in frame_paths:
41
+ content_path = frame.split('/')[1:-1]
42
+ content_path = '/'.join(content_path)
43
+ # input your own path
44
+ content_path = '/home/AIGC_Video_Det/SD/' + content_path
45
+
46
+ frame_path = frame.split('/')[1:]
47
+ frame_path = '/'.join(frame_path)
48
+ frame_path = '/home/AIGC_Video_Det/SD/' + frame_path
49
+
50
+ print(content_path, frame_path)
51
+ label.append(str(1))
52
+ frame_counts.append(int(temp_frame_count))
53
+ frame_seq_counts.append(temp_frame_seqs)
54
+ save_path.append(frame_path)
55
+ content_paths.append(content_path)
56
+ chinese_labels.append('AIGC视频')
57
+ # chinese_labels.append('真实视频')
58
+ break
59
+
60
+ dic={
61
+ 'content_path': Series(data=content_paths),
62
+ 'image_path': Series(data=save_path),
63
+ 'type_id': Series(data=chinese_labels),
64
+ 'label': Series(data=label),
65
+ 'frame_len': Series(data=frame_counts),
66
+ 'frame_seq': Series(data=frame_seq_counts)
67
+ }
68
+
69
+ print(dic)
70
+ pd.DataFrame(dic).to_csv('SD.csv', encoding='utf-8', index=False)
71
+
72
+
Preprocess/video2frame.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from moviepy.editor import VideoFileClip
4
+ import multiprocessing
5
+ import cv2, math
6
+
7
+ def get_video_length(file_path):
8
+ video = VideoFileClip(file_path)
9
+ return video.duration
10
+
11
+ def process_video(video_path):
12
+ video_name = video_path.split('/')[-1]
13
+ video_name = video_name.split('.')[:-1]
14
+ video_name = '.'.join(video_name)
15
+
16
+ path = video_path.split('/')[1:-1]
17
+ path = '/'.join(path)
18
+
19
+ image_path = './SD_frames/'+path+'/'+ video_name+'/'
20
+ print(video_name)
21
+ if os.path.exists(image_path):
22
+ print("路径存在")
23
+ else:
24
+ print(video_name, "路径不存在")
25
+ try:
26
+ try:
27
+ video_length = get_video_length(video_path)
28
+ print(video_name, f"视频长度为:{video_length} 秒")
29
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
30
+
31
+ if video_length >= 4 :
32
+ inter_val = 2
33
+ os.system(f"cd {image_path} | ffmpeg -loglevel quiet -i {video_path} -r {inter_val} {image_path}%d.jpg")
34
+ else:
35
+ inter_val = math.ceil(8 / video_length)
36
+ os.system(f"cd {image_path} | ffmpeg -loglevel quiet -i {video_path} -r {inter_val} {image_path}%d.jpg")
37
+
38
+ except Exception as e:
39
+ print("发生异常:", str(e))
40
+ except:
41
+ print("Skip")
42
+
43
+ if __name__ == '__main__':
44
+ print("Getting frames!!")
45
+ video_paths = './SD'
46
+ all_dirs = []
47
+ all_dirs = glob(video_paths+'/*')
48
+
49
+ print(all_dirs)
50
+
51
+ pool = multiprocessing.Pool(processes=8)
52
+ pool.map(process_video, all_dirs)
53
+ pool.close()
54
+ pool.join()
55
+
56
+
57
+
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is the official code of paper 'DeMamba: AI-Generated Video Detection on Million-Scale GenVideo Benchmark'.
2
+
3
+ ## :boom: News!!!
4
+ - (25.09.24) We have released a lightweight version of GenVideo, named [GenVideo-100K](https://modelscope.cn/datasets/cccnju/GenVideo-100K), with 10,000 samples for each forgery category.
5
+ - (24.12.11) We are pleased to announce that our AI-generated content (AIGC) video detection model, developed based on GenVideo, has passed the evaluation by the China Academy of Information and Communications Technology (CAICT) and achieved the highest rating, making us the first organization in China to be officially registered and approved.[[certificate](https://github.com/chenhaoxing/DeMamba/blob/main/figs/xty.jpg)][[report](https://mp.weixin.qq.com/s/OoW7EI1QoSrQ3FIftfbudg)]
6
+
7
+
8
+ ## :dart: Todo
9
+
10
+ - [x] Release source code.
11
+ - [x] Release dataset.
12
+
13
+
14
+ ## :file_folder: Dataset download
15
+ ![](figs/tab_fig.jpg)
16
+
17
+ ### Data preparation process
18
+ - Download the original videos.
19
+
20
+ - Generated videos: all generated videos can download at [ModelScope](https://modelscope.cn/collections/Gen-Video-7de46cd6846f4e).
21
+
22
+ - Real videos: The data from the MSRVTT dataset is contained within the GenVideo-Val.zip file. We also provided the selected Youku videos in previous link . For Kinetics-400, you will need to download it yourself at [https://github.com/cvdfoundation/kinetics-dataset](https://github.com/cvdfoundation/kinetics-dataset).
23
+
24
+ - Preprocess the video and get the data list csv file.
25
+
26
+ Statistics of real and generated videos in the GenVideo dataset:
27
+ | **Video Source** | **Type** | **Task** | **Time** | **Resolution** | **FPS** | **Length** | **Training Set** | **Testing Set** |
28
+ |-----------------------------------------------------|----------|----------|----------|----------------|---------|------------|------------------|----------------|
29
+ | Kinetics-400 | Real | - | 17.05 | 224-340 | - | 5-10s | 260,232 | - |
30
+ | Youku-mPLUG | Real | - | 23.07 | - | - | 10-120s | 953,279 | - |
31
+ | MSR-VTT | Real | - | 16.05 | - | - | 10-30s | - | 10,000 |
32
+ | ZeroScope | Fake | T2V | 23.07 | 1024×576 | 8 | 3s | 133,169 | - |
33
+ | I2VGen-XL | Fake | I2V | 23.12 | 1280×720 | 8 | 2s | 61,975 | - |
34
+ | SVD | Fake | I2V | 23.12 | 1024×576 | 8 | 4s | 149,026 | - |
35
+ | VideoCrafte | Fake | T2V | 24.01 | 1024×576 | 8 | 2s | 39,485 | - |
36
+ | Pika | Fake | T2V&I2V | 24.02 | 1088×640 | 24 | 3s | 98,377 | |
37
+ | DynamiCrafter | Fake | I2V | 24.03 | 1024×576 | 8 | 3s | 46,205 | - |
38
+ | SD | Fake | T2V&I2V | 23-24 | 512-1024 | 8 | 2-6s | 200,720 | - |
39
+ | SEINE | Fake | I2V | 24.04 | 1024×576 | 8 | 2-4s | 24,737 | - |
40
+ | Latte | Fake | T2V | 24.03 | 512×512 | 8 | 2s | 149,979 | - |
41
+ | OpenSora | Fake | T2V | 24.03 | 512×512 | 8 | 2s | 177,410 | - |
42
+ | ModelScope | Fake | T2V | 23.03 | 256×256 | 8 | 4s | - | 700 |
43
+ | MorphStudio | Fake | T2V | 23.08 | 1280×720 | 8 | 2s | - | 700 |
44
+ | MoonValley | Fake | T2V | 24.01 | 1024×576 | 16 | 3s | - | 626 |
45
+ | HotShot | Fake | T2V | 23.10 | 672×384 | 8 | 1s | - | 700 |
46
+ | Show_1 | Fake | T2V | 23.10 | 576×320 | 8 | 4s | - | 700 |
47
+ | Gen2 | Fake | I2V&T2V | 23.09 | 896×512 | 24 | 4s | - | 1,380 |
48
+ | Crafter | Fake | T2V | 23.04 | 256×256 | 8 | 4s | - | 1,400 |
49
+ | Lavie | Fake | T2V | 23.09 | 1280×2048 | 8 | 2s | - | 1,400 |
50
+ | Sora | Fake | T2V | 24.02 | - | - | -60s | - | 56 |
51
+ | WildScrape | Fake | T2V&I2V | 24 | 512-1024 | 8-16 | 2-6s | - | 926 |
52
+ | **Total Count** | - | - | - | - | - | - | 2,294,594 | 19,588 |
53
+
54
+ ## :snake: Detail Mamba (DeMamba)
55
+
56
+ ![](figs/logo.png)
57
+ <p align="center"><em>In memory of Kobe Bryant (generated by GPT-4o)</em></p>
58
+
59
+ > "Determination wins games, but Detail wins championships." — *Kobe Bryant, in his Show Detail, 2018*
60
+
61
+ ![](figs/VFOD.png)
62
+ <p align="center"><em>The overall framework of our Detail Mamba (DeMamba)</em></p>
63
+
64
+
65
+ ## :space_invader: Citing GenVideo&DeMamba
66
+ If you use GenVideo or DeMamba in your research or use the codebase here, please use the following BibTeX entry.
67
+
68
+ ```BibTeX
69
+ @article{DeMamba,
70
+ title={DeMamba: AI-Generated Video Detection on Million-Scale GenVideo Benchmark},
71
+ author={Haoxing Chen and Yan Hong and Zizheng Huang and Zhuoer Xu and Zhangxuan Gu and Yaohui Li and Jun Lan and Huijia Zhu and Jianfu Zhang and Weiqiang Wang and Huaxiong Li},
72
+ journal={arXiv preprint arXiv:2405.19707},
73
+ year={2024}
74
+ }
75
+ ```
76
+
77
+ ## Star History
78
+
79
+ [![Star History Chart](https://api.star-history.com/svg?repos=chenhaoxing/DeMamba&type=Date)](https://star-history.com/#chenhaoxing/DeMamba&Date)
80
+
81
+ ## Acknowledgement
82
+ Many thanks to the nice work of [STIL](https://github.com/wizyoung/STIL-DeepFake-Video-Detection), [CLIP](https://github.com/openai/CLIP), [XCLIP](https://github.com/microsoft/VideoX/tree/master/X-CLIP), [NPR](https://github.com/chuangchuangtan/NPR-DeepfakeDetection/tree/main) and [VideoMAE](https://github.com/MCG-NJU/VideoMAE-Action-Detection).
83
+
84
+ ## :email: Contact
85
+ If you have any questions, feel free to contact us: hx.chen@hotmail.com.
86
+
87
+
88
+
__pycache__/dataloader2.cpython-39.pyc ADDED
Binary file (6.62 kB). View file
 
__pycache__/util.cpython-39.pyc ADDED
Binary file (3.91 kB). View file
 
commands.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ python train.py --config /home/kalpit/workspace/aigc/repos/DeMamba/configs/my_config.yaml
2
+
3
+ python train.py --config /home/kalpit/workspace/aigc/repos/DeMamba/configs/vid_grand_v1_train.yaml
create_csv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+
4
+ # Source directories
5
+ real_dir = "/home/kalpit/workspace/aigc/data/ShareVeo3/test/0_real"
6
+ fake_dir = "/home/kalpit/workspace/aigc/data/ShareVeo3/test/1_fake"
7
+
8
+ # Output CSV path
9
+ output_csv = "/home/kalpit/workspace/aigc/repos/DeMamba/veo_test.csv"
10
+
11
+ # Function to get all image paths from a directory
12
+ def get_image_paths(directory, label):
13
+ image_paths = []
14
+ for root, _, files in os.walk(directory):
15
+ for file in files:
16
+ if file.lower().endswith(('.png', '.jpg', '.jpeg')):
17
+ full_path = os.path.join(root, file)
18
+ image_paths.append({
19
+ "content_path": full_path,
20
+ "frame_seq": [full_path], # list containing the single frame path
21
+ "label": label
22
+ })
23
+ return image_paths
24
+
25
+ # Collect all images
26
+ data = []
27
+ data.extend(get_image_paths(real_dir, 0))
28
+ data.extend(get_image_paths(fake_dir, 1))
29
+
30
+ # Write to CSV
31
+ with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
32
+ fieldnames = ["content_path", "frame_seq", "label"]
33
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
34
+
35
+ writer.writeheader()
36
+ for row in data:
37
+ # Convert frame_seq list to string as shown in your example
38
+ writer.writerow({
39
+ "content_path": row["content_path"],
40
+ "frame_seq": str(row["frame_seq"]),
41
+ "label": row["label"]
42
+ })
43
+
44
+ print(f"CSV saved at: {output_csv}")
create_submission.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import re
3
+
4
+ # Input and output files
5
+ input_csv = "results.csv"
6
+ output_csv = "submission.csv"
7
+
8
+ # Load frame-level results
9
+ df = pd.read_csv(input_csv)
10
+
11
+ # Extract base video ID (remove _f###.jpg)
12
+ def extract_video_id(filename):
13
+ return re.sub(r"_f\d+\.\w+$", "", filename)
14
+
15
+ df["video_id"] = df["file_name"].apply(extract_video_id)
16
+
17
+ # Aggregate probabilities per video (mean of frame scores)
18
+ video_scores = (
19
+ df.groupby("video_id")["predicted_prob"]
20
+ .mean()
21
+ .reset_index()
22
+ .rename(columns={"video_id": "id", "predicted_prob": "score"})
23
+ )
24
+
25
+ # Assign label: generated if >0.5 else real
26
+ video_scores["pred"] = video_scores["score"].apply(
27
+ lambda x: "generated" if x > 0.5 else "real"
28
+ )
29
+
30
+ # Reorder columns
31
+ video_scores = video_scores[["id", "pred", "score"]]
32
+
33
+ # Save submission file
34
+ video_scores.to_csv(output_csv, index=False)
35
+ print(f"Saved submission file to {output_csv}")
dataloader.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import pandas as pd
4
+ import torch
5
+ import albumentations
6
+ import random
7
+ import os
8
+ import numpy as np
9
+ import cv2
10
+ import math
11
+ import warnings
12
+
13
+
14
+ def crop_center_by_percentage(image, percentage):
15
+ height, width = image.shape[:2]
16
+
17
+ if width > height:
18
+ left_pixels = int(width * percentage)
19
+ right_pixels = int(width * percentage)
20
+ start_x = left_pixels
21
+ end_x = width - right_pixels
22
+ cropped_image = image[:, start_x:end_x]
23
+ else:
24
+ up_pixels = int(height * percentage)
25
+ down_pixels = int(height * percentage)
26
+ start_y = up_pixels
27
+ end_y = height - down_pixels
28
+ cropped_image = image[start_y:end_y, :]
29
+
30
+ return cropped_image
31
+
32
+ class Ours_Dataset_train(Dataset):
33
+ def __init__(self, index_list=None, df=None):
34
+ self.index_list = index_list
35
+ self.df = df
36
+ self.positive_indices = df[df['label'] == 1].index.tolist()
37
+ self.negative_indices = df[df['label'] == 0].index.tolist()
38
+ self.balanced_indices = []
39
+ self.resample()
40
+
41
+ def resample(self):
42
+ # Ensure each epoch uses a balanced dataset
43
+ min_samples = min(len(self.positive_indices), len(self.negative_indices))
44
+ self.balanced_indices.clear()
45
+ self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
46
+ self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
47
+ random.shuffle(self.balanced_indices) # Shuffle to mix positive and negative samples
48
+
49
+ def __getitem__(self, idx):
50
+ real_idx = self.balanced_indices[idx]
51
+ row = self.df.iloc[real_idx]
52
+ video_id = row['content_path']
53
+ label = row['label']
54
+ frame_list = eval(row['frame_seq'])
55
+ label_onehot = [0]*2
56
+ select_frame_nums = 8
57
+
58
+ aug_list = [
59
+ albumentations.Resize(224, 224)
60
+ ]
61
+
62
+ if random.random() < 0.5:
63
+ aug_list.append(albumentations.HorizontalFlip(p=1.0))
64
+ if random.random() < 0.5:
65
+ quality_score = random.randint(50, 100)
66
+ aug_list.append(albumentations.ImageCompression(quality_lower=quality_score, quality_upper=quality_score))
67
+ if random.random() < 0.3:
68
+ aug_list.append(albumentations.GaussNoise(p=1.0))
69
+ if random.random() < 0.3:
70
+ aug_list.append(albumentations.GaussianBlur(blur_limit=(3, 5), p=1.0))
71
+ if random.random() < 0.001:
72
+ aug_list.append(albumentations.ToGray(p=1.0))
73
+
74
+ aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
75
+ trans = albumentations.Compose(aug_list)
76
+
77
+ if len(frame_list) >= select_frame_nums:
78
+ start_frame = random.randint(0, len(frame_list)-select_frame_nums)
79
+ select_frames = frame_list[start_frame:start_frame+select_frame_nums]
80
+ frames = []
81
+ for x in frame_list[start_frame:start_frame+select_frame_nums]:
82
+ while True:
83
+ try:
84
+ temp_image_path = video_id+'/'+str(x)+'.jpg'
85
+ image = download_oss_file('GenVideo/'+ temp_image_path)
86
+ if video_id.startswith("real/youku"):
87
+ image = crop_center_by_percentage(image, 0.15)
88
+ break
89
+ except Exception as e:
90
+ if x+1 < len(frame_list):
91
+ x = x + 1
92
+ elif x - 1 >=0 :
93
+ x = x - 1
94
+ augmented = trans(image=image)
95
+ image = augmented["image"]
96
+ frames.append(image.transpose(2,0,1)[np.newaxis,:])
97
+ else:
98
+ pad_num = select_frame_nums-len(frame_list)
99
+ frames = []
100
+ for x in frame_list:
101
+ temp_image_path = video_id+'/'+str(x)+'.jpg'
102
+ image = download_oss_file('GenVideo/'+temp_image_path)
103
+ if video_id.startswith("real/youku"):
104
+ image = crop_center_by_percentage(image, 0.15)
105
+ augmented = trans(image=image)
106
+ image = augmented["image"]
107
+ frames.append(image.transpose(2,0,1)[np.newaxis,:])
108
+ for i in range(pad_num):
109
+ frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
110
+
111
+ label_onehot[int(label)] = 1
112
+ frames = np.concatenate(frames, 0)
113
+ frames = torch.tensor(frames[np.newaxis,:])
114
+ label_onehot = torch.FloatTensor(label_onehot)
115
+ binary_label = torch.FloatTensor([int(label)])
116
+
117
+ return self.index_list[idx], frames, label_onehot, binary_label
118
+
119
+ def __len__(self):
120
+ return len(self.balanced_indices)
121
+
122
+
123
+ class Ours_Dataset_val(data.Dataset):
124
+ def __init__(self, cfg, index_list=None, df=None):
125
+ self.index_list = index_list
126
+ self.cfg = cfg
127
+ self.df = df
128
+ self.frame_dir = df['image_path'].tolist()
129
+
130
+ def __getitem__(self, idx):
131
+ aug_list = [
132
+ albumentations.Resize(224, 224),
133
+ ]
134
+
135
+ if self.cfg['task'] == 'JPEG_Compress_Attack':
136
+ aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35,p=1.0))
137
+ if self.cfg['task'] == 'FLIP_Attack':
138
+ if random.random() < 0.5:
139
+ aug_list.append(albumentations.HorizontalFlip(p=1.0))
140
+ else:
141
+ aug_list.append(albumentations.VerticalFlip(p=1.0))
142
+ if self.cfg['task'] == 'CROP_Attack':
143
+ random_crop_x = random.randint(0, 16)
144
+ random_crop_y = random.randint(0, 16)
145
+ crop_width = random.randint(160, 208)
146
+ crop_height = random.randint(160, 208)
147
+ aug_list.append(albumentations.Crop(x_min=random_crop_x, y_min=random_crop_y, x_max=random_crop_x+crop_width, y_max=random_crop_y+crop_height))
148
+ aug_list.append(albumentations.Resize(224, 224))
149
+
150
+ if self.cfg['task'] == 'Color_Attack':
151
+ index = random.choice([i for i in range(4)])
152
+ dicts = {0:[0.5,0,0,0],1:[0,0.5,0,0],2:[0,0,0.5,0],3:[0,0,0,0.5]}
153
+ brightness,contrast,saturation,hue = dicts[index]
154
+ aug_list.append(albumentations.ColorJitter(
155
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue))
156
+
157
+ if self.cfg['task'] == 'Gaussian_Attack':
158
+ aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
159
+
160
+ aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
161
+ trans = albumentations.Compose(aug_list)
162
+
163
+
164
+ df_v = self.df.loc[self.index_list[idx]]
165
+ video_id = df_v['content_path']
166
+ activity_id = df_v['activity_id']
167
+ label = df_v['label']
168
+ label_onehot = [0]*2
169
+ frame_list = eval(df_v['frame_seq'])
170
+
171
+ select_frame_nums = 8
172
+
173
+ if len(frame_list) >= select_frame_nums:
174
+ start_frame = random.randint(0, len(frame_list)-select_frame_nums)
175
+ select_frames = frame_list[start_frame:start_frame+select_frame_nums]
176
+ frames = []
177
+ for x in frame_list[start_frame:start_frame+select_frame_nums]:
178
+ while True:
179
+ try:
180
+ temp_image_path = video_id+'/'+str(x)+'.jpg'
181
+ image = download_oss_file('GenVideo/'+ temp_image_path)
182
+ image = crop_center_by_percentage(image, 0.1)
183
+ break
184
+ except Exception as e:
185
+ if x+1 < len(frame_list):
186
+ x = x + 1
187
+ elif x - 1 >=0 :
188
+ x = x - 1
189
+ augmented = trans(image=image)
190
+ image = augmented["image"]
191
+ frames.append(image.transpose(2,0,1)[np.newaxis,:])
192
+ else:
193
+ pad_num = select_frame_nums-len(frame_list)
194
+ frames = []
195
+ for x in frame_list:
196
+ temp_image_path = video_id+'/'+str(x)+'.jpg'
197
+ image = download_oss_file('GenVideo/'+temp_image_path)
198
+ image = crop_center_by_percentage(image, 0.1)
199
+ augmented = trans(image=image)
200
+ image = augmented["image"]
201
+ frames.append(image.transpose(2,0,1)[np.newaxis,:])
202
+ for i in range(pad_num):
203
+ frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
204
+
205
+ label_onehot[int(label)] = 1
206
+ frames = np.concatenate(frames, 0)
207
+ frames = torch.tensor(frames[np.newaxis,:])
208
+ label_onehot = torch.FloatTensor(label_onehot)
209
+ binary_label = torch.FloatTensor([int(label)])
210
+ return self.index_list[idx], frames, label_onehot, binary_label, video_id
211
+
212
+
213
+ def __len__(self):
214
+ return len(self.index_list)
215
+
216
+
217
+
218
+ def generate_dataset_loader(cfg):
219
+ df_train = pd.read_csv('GenVideo/datasets/train.csv')
220
+
221
+ if cfg['task'] == 'normal':
222
+ df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
223
+ elif cfg['task'] == 'robust_compress':
224
+ df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
225
+ elif cfg['task'] == 'Image_Water_Attack':
226
+ df_val = pd.read_csv('GenVideo/datasets/imgwater.csv')
227
+ elif cfg['task'] == 'Text_Water_Attack':
228
+ df_val = pd.read_csv('GenVideo/datasets/textwater.csv')
229
+ elif cfg['task'] == 'one2many':
230
+ df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
231
+ if cfg['train_sub_set'] == 'pika':
232
+ prefixes = ["fake/pika", "real"]
233
+ video_condition = df_train['content_path'].str.startswith(prefixes[0])
234
+ for prefix in prefixes[1:]:
235
+ video_condition |= df_train['content_path'].str.startswith(prefix)
236
+ df_train = df_train[video_condition]
237
+ elif cfg['train_sub_set'] == 'SEINE':
238
+ prefixes = ["fake/SEINE", "real"]
239
+ video_condition = df_train['content_path'].str.startswith(prefixes[0])
240
+ for prefix in prefixes[1:]:
241
+ video_condition |= df_train['content_path'].str.startswith(prefix)
242
+ df_train = df_train[video_condition]
243
+ elif cfg['train_sub_set'] == 'OpenSora':
244
+ prefixes = ["fake/OpenSora", "real"]
245
+ video_condition = df_train['content_path'].str.startswith(prefixes[0])
246
+ for prefix in prefixes[1:]:
247
+ video_condition |= df_train['content_path'].str.startswith(prefix)
248
+ df_train = df_train[video_condition]
249
+ elif cfg['train_sub_set'] == 'Latte':
250
+ prefixes = ["fake/Latte", "real"]
251
+ video_condition = df_train['content_path'].str.startswith(prefixes[0])
252
+ for prefix in prefixes[1:]:
253
+ video_condition |= df_train['content_path'].str.startswith(prefix)
254
+ df_train = df_train[video_condition]
255
+ else:
256
+ df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
257
+
258
+ df_train.reset_index(drop=True, inplace=True)
259
+ df_val.reset_index(drop=True, inplace=True)
260
+
261
+ index_val = df_val.index.tolist()
262
+ index_val = index_val[:]
263
+
264
+ val_dataset = Ours_Dataset_val(cfg, index_val, df_val)
265
+ val_loader = torch.utils.data.DataLoader(
266
+ val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
267
+ )
268
+
269
+ index_train = df_train.index.tolist()
270
+ index_train = index_train[:]
271
+ train_dataset = Ours_Dataset_train(index_train, df_train)
272
+ train_loader = torch.utils.data.DataLoader(
273
+ train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
274
+ )
275
+
276
+ print("******* Training Video IDs", str(len(index_train))," Training Batch size ", str(cfg['train_batch_size'])," *******")
277
+ print("******* Testing Video IDs", str(len(index_val)), " Testing Batch size ", str(cfg['val_batch_size'])," *******")
278
+
279
+ return train_loader, val_loader
280
+
281
+
dataloader2.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # new_dataloader.py
2
+
3
+ import torch
4
+ import cv2
5
+ import os
6
+ import random
7
+ import pandas as pd
8
+ import albumentations
9
+ import numpy as np
10
+ from torch.utils.data import Dataset, DataLoader
11
+
12
+ # =========== TRANSFORMATION HELPERS ===========
13
+
14
+ def get_train_transforms():
15
+ """Defines the probabilistic augmentations for training."""
16
+ return albumentations.Compose([
17
+ albumentations.Resize(224, 224),
18
+ albumentations.HorizontalFlip(p=0.5),
19
+ albumentations.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
20
+ albumentations.GaussNoise(p=0.3),
21
+ albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3),
22
+ albumentations.ToGray(p=0.01),
23
+ albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
24
+ ])
25
+
26
+ def get_val_transforms(cfg):
27
+ """Defines augmentations for validation, handling different attack tasks from the config."""
28
+ aug_list = [albumentations.Resize(224, 224)]
29
+
30
+ task = cfg.get('task', 'normal') # Use .get for safety
31
+
32
+ if task == 'JPEG_Compress_Attack':
33
+ aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35, p=1.0))
34
+ elif task == 'FLIP_Attack':
35
+ aug_list.append(albumentations.HorizontalFlip(p=0.5)) # Original had random choice, 50% HFlip is common
36
+ elif task == 'CROP_Attack':
37
+ aug_list.append(albumentations.RandomCrop(height=192, width=192, p=1.0))
38
+ aug_list.append(albumentations.Resize(224, 224))
39
+ elif task == 'Color_Attack':
40
+ aug_list.append(albumentations.ColorJitter(p=1.0))
41
+ elif task == 'Gaussian_Attack':
42
+ aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
43
+
44
+ aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
45
+ return albumentations.Compose(aug_list)
46
+
47
+ # =========== TRAINING DATASET ===========
48
+
49
+ class VideoDataset(Dataset):
50
+ """
51
+ A PyTorch Dataset for loading video frame sequences based on a DataFrame.
52
+ Handles class balancing for each epoch.
53
+ """
54
+ def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
55
+ self.df = df
56
+ self.index_list = index_list
57
+ self.base_data_path = base_data_path
58
+ self.transform = transform
59
+ self.select_frame_nums = select_frame_nums
60
+
61
+ self.positive_indices = self.df[self.df['label'] == 1].index.tolist()
62
+ self.negative_indices = self.df[self.df['label'] == 0].index.tolist()
63
+
64
+ self.balanced_indices = []
65
+ self.resample()
66
+
67
+ def resample(self):
68
+ min_samples = min(len(self.positive_indices), len(self.negative_indices))
69
+ self.balanced_indices.clear()
70
+ self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
71
+ self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
72
+ random.shuffle(self.balanced_indices)
73
+
74
+ def __len__(self):
75
+ return len(self.balanced_indices)
76
+
77
+ def __getitem__(self, idx):
78
+ real_idx = self.balanced_indices[idx]
79
+ row = self.df.iloc[real_idx]
80
+
81
+ video_id = row['content_path']
82
+ label = int(row['label'])
83
+ frame_list = eval(row['frame_seq'])
84
+
85
+ frames = []
86
+
87
+ if len(frame_list) >= self.select_frame_nums:
88
+ start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
89
+ selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
90
+ else:
91
+ selected_frames = frame_list
92
+
93
+ for frame_path in selected_frames:
94
+ try:
95
+ image = cv2.imread(frame_path)
96
+ if image is None:
97
+ raise ValueError("Failed to load")
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ except Exception:
100
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
101
+
102
+ if self.transform:
103
+ image = self.transform(image=image)['image']
104
+
105
+ frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
106
+
107
+
108
+ pad_num = self.select_frame_nums - len(frames)
109
+ if pad_num > 0:
110
+ for _ in range(pad_num):
111
+ frames.append(np.zeros((1, 3, 224, 224)))
112
+
113
+ frames_tensor = np.concatenate(frames, axis=0)
114
+ frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
115
+
116
+ label_onehot = torch.zeros(2)
117
+ label_onehot[label] = 1.0
118
+ binary_label = torch.FloatTensor([label])
119
+
120
+ original_index = self.index_list[idx]
121
+ return original_index, frames_tensor, label_onehot, binary_label
122
+
123
+ # =========== VALIDATION DATASET ===========
124
+
125
+ class VideoDatasetVal(Dataset):
126
+ """A compatible validation dataset loader."""
127
+ def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
128
+ self.df = df
129
+ self.index_list = index_list
130
+ self.base_data_path = base_data_path
131
+ self.transform = transform
132
+ self.select_frame_nums = select_frame_nums
133
+
134
+ def __len__(self):
135
+ return len(self.index_list)
136
+
137
+ def __getitem__(self, idx):
138
+ # Validation does not use balanced sampling, it uses the provided index directly
139
+ real_idx = self.index_list[idx]
140
+ row = self.df.iloc[real_idx]
141
+
142
+ video_id = row['content_path']
143
+ label = int(row['label'])
144
+ frame_list = eval(row['frame_seq'])
145
+
146
+ # This part is identical to the training dataset's __getitem__
147
+ frames = []
148
+ if len(frame_list) >= self.select_frame_nums:
149
+ start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
150
+ selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
151
+ else:
152
+ selected_frames = frame_list
153
+
154
+ for frame_path in selected_frames:
155
+ try:
156
+ image = cv2.imread(frame_path)
157
+ if image is None:
158
+ raise ValueError("Failed to load")
159
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
160
+ except Exception:
161
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
162
+
163
+ if self.transform:
164
+ image = self.transform(image=image)['image']
165
+
166
+ frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
167
+
168
+
169
+ pad_num = self.select_frame_nums - len(frames)
170
+ if pad_num > 0:
171
+ for _ in range(pad_num):
172
+ frames.append(np.zeros((1, 3, 224, 224)))
173
+
174
+ frames_tensor = np.concatenate(frames, axis=0)
175
+ frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
176
+
177
+ label_onehot = torch.zeros(2)
178
+ label_onehot[label] = 1.0
179
+ binary_label = torch.FloatTensor([label])
180
+
181
+ # The original validation loader returned video_id at the end
182
+ return self.index_list[idx], frames_tensor, label_onehot, binary_label, video_id
183
+
184
+ # =========== DATALOADER GENERATOR FUNCTION ===========
185
+
186
+ def generate_dataset_loader(cfg):
187
+ """
188
+ The main function to create train and validation dataloaders using the new classes.
189
+ """
190
+ df_train = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_train.csv')
191
+
192
+ # This logic for selecting different validation sets is preserved
193
+ task = cfg.get('task', 'normal')
194
+ if task == 'normal':
195
+ df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
196
+ elif task == 'robust_compress':
197
+ df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
198
+ # ... (add other elif conditions from your original script if needed) ...
199
+ else:
200
+ df_val = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_test.csv')
201
+
202
+ # This logic for subsetting the training data is also preserved
203
+ if cfg.get('train_sub_set'):
204
+ prefixes = [f"fake/{cfg['train_sub_set']}", "real"]
205
+ condition = df_train['content_path'].str.startswith(tuple(prefixes))
206
+ df_train = df_train[condition]
207
+
208
+ df_train.reset_index(drop=True, inplace=True)
209
+ df_val.reset_index(drop=True, inplace=True)
210
+
211
+ index_train = df_train.index.tolist()
212
+ index_val = df_val.index.tolist()
213
+
214
+ # --- Use the new VideoDataset classes ---
215
+ base_data_path = 'GenVideo'
216
+
217
+ train_dataset = VideoDataset(
218
+ df=df_train,
219
+ index_list=index_train,
220
+ base_data_path=base_data_path,
221
+ transform=get_train_transforms(),
222
+ select_frame_nums=8
223
+ )
224
+
225
+ val_dataset = VideoDatasetVal(
226
+ df=df_val,
227
+ index_list=index_val,
228
+ base_data_path=base_data_path,
229
+ transform=get_val_transforms(cfg),
230
+ select_frame_nums=8
231
+ )
232
+
233
+ train_loader = DataLoader(
234
+ train_dataset, batch_size=cfg['train_batch_size'], shuffle=True,
235
+ num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
236
+ )
237
+
238
+ val_loader = DataLoader(
239
+ val_dataset, batch_size=cfg['val_batch_size'], shuffle=False,
240
+ num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
241
+ )
242
+
243
+ print(f"******* Training Videos {len(index_train)}, Batch size {cfg['train_batch_size']} *******")
244
+ print(f"******* Testing Videos {len(index_val)}, Batch size {cfg['val_batch_size']} *******")
245
+
246
+ return train_loader, val_loader
eval.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import models
2
+ import time
3
+ import torch
4
+ import math
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, average_precision_score, roc_auc_score
9
+
10
+ def eval_model(cfg, model, val_loader, loss_ce, val_batch_size):
11
+ model.eval()
12
+ outpred_list = []
13
+ gt_label_list = []
14
+ video_list = []
15
+ valLoss = 0
16
+ lossTrainNorm = 0
17
+ print("******** Start Testing. ********")
18
+
19
+ with torch.no_grad(): # No need to track gradients during validation
20
+ for i, (_, input, target, binary_label, video_id) in enumerate(tqdm(val_loader, desc="Validation", total=len(val_loader))):
21
+ if i == 0:
22
+ ss_time = time.time()
23
+
24
+ input = input[:,0]
25
+ varInput = torch.autograd.Variable(input.float().cuda())
26
+ varTarget = torch.autograd.Variable(target.contiguous().cuda())
27
+ var_Binary_Target = torch.autograd.Variable(binary_label.contiguous().cuda())
28
+
29
+ logit = model(varInput)
30
+ lossvalue = loss_ce(logit, var_Binary_Target)
31
+
32
+ valLoss += lossvalue.item()
33
+ lossTrainNorm += 1
34
+ outpred_list.append(logit[:,0].sigmoid().cpu().detach().numpy())
35
+ gt_label_list.append(varTarget.cpu().detach().numpy())
36
+ video_list.append(video_id)
37
+
38
+ valLoss = valLoss / lossTrainNorm
39
+
40
+ outpred = np.concatenate(outpred_list, 0)
41
+ gt_label = np.concatenate(gt_label_list, 0)
42
+ video_list = np.concatenate(video_list, 0)
43
+ pred_labels = [1 if item > 0.5 else 0 for item in outpred]
44
+ true_labels = np.argmax(gt_label, axis=1)
45
+
46
+ pred_accuracy = accuracy_score(true_labels, pred_labels)
47
+
48
+ return pred_accuracy, video_list, pred_labels, true_labels, outpred
eval2.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+ from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score
10
+
11
+ import models # assuming models.py is in your PYTHONPATH or same dir
12
+
13
+ # -------- Build model as per your code ----------
14
+ def build_model(model_name):
15
+ if model_name == 'F3Net':
16
+ model = models.Det_F3_Net()
17
+ elif model_name == 'NPR':
18
+ model = models.resnet50_npr()
19
+ elif model_name == 'STIL':
20
+ model = models.Det_STIL()
21
+ elif model_name == 'XCLIP_DeMamba':
22
+ model = models.XCLIP_DeMamba()
23
+ elif model_name == 'CLIP_DeMamba':
24
+ model = models.CLIP_DeMamba()
25
+ elif model_name == 'XCLIP':
26
+ model = models.XCLIP()
27
+ elif model_name == 'CLIP':
28
+ model = models.CLIP_Base()
29
+ elif model_name == 'ViT_B_MINTIME':
30
+ model = models.ViT_B_MINTIME()
31
+ else:
32
+ raise ValueError(f"Unknown model: {model_name}")
33
+ return model
34
+
35
+
36
+ # --------- Evaluation loop -------------
37
+ def eval_on_frames(model, frames_dir, device):
38
+ model.eval()
39
+ transform = transforms.Compose([
40
+ transforms.Resize((224, 224)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
43
+ std=[0.229, 0.224, 0.225])
44
+ ])
45
+
46
+ frame_paths = []
47
+ for root, _, files in os.walk(frames_dir):
48
+ for f in files:
49
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
50
+ frame_paths.append(os.path.join(root, f))
51
+ frame_paths.sort()
52
+
53
+ results = []
54
+
55
+ with torch.no_grad():
56
+ for fp in tqdm(frame_paths, desc="Evaluating frames"):
57
+ img = Image.open(fp).convert("RGB")
58
+
59
+ # Transform and add batch dimension
60
+ x = transform(img).unsqueeze(0).to(device) # [1, C, H, W]
61
+
62
+ # Add temporal dimension expected by DeMamba/XCLIP (T=8)
63
+ x = x.unsqueeze(1).repeat(1, 8, 1, 1, 1) # [1, 8, C, H, W]
64
+
65
+ # Forward pass
66
+ logit = model(x)
67
+ prob = torch.sigmoid(logit[:, 0]).item()
68
+ pred_label = int(prob > 0.5)
69
+
70
+ results.append({
71
+ "file_name": os.path.basename(fp),
72
+ "predicted_prob": prob,
73
+ "predicted_label": pred_label
74
+ })
75
+
76
+ return pd.DataFrame(results)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ # ----- config -----
81
+ model_name = "XCLIP_DeMamba" # change if needed
82
+ model_path = "/home/kalpit/workspace/aigc/repos/SAFE_challenge/seetrails_aigvdet_v2.0.0/results/kling_9k_9k/best_acc.pth"
83
+ frames_dir = "./frames"
84
+ output_csv = "results.csv"
85
+
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+ print(f"Using device: {device}")
88
+
89
+ # ---- load model ----
90
+ model = build_model(model_name).to(device)
91
+ checkpoint = torch.load(model_path, map_location=device)
92
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
93
+ print(f"Loaded model weights from {model_path}")
94
+
95
+ # ---- evaluate ----
96
+ df_results = eval_on_frames(model, frames_dir, device)
97
+ df_results.to_csv(output_csv, index=False)
98
+ print(f"Saved framewise results to {output_csv}")
99
+
100
+ # ---- optional: basic metrics if you have GT ----
101
+ if "label" in df_results.columns:
102
+ y_true = df_results["label"].values
103
+ y_pred = df_results["predicted_label"].values
104
+ y_prob = df_results["predicted_prob"].values
105
+
106
+ acc = accuracy_score(y_true, y_pred)
107
+ auc = roc_auc_score(y_true, y_prob)
108
+ ap = average_precision_score(y_true, y_prob)
109
+
110
+ print(f"Accuracy: {acc:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}")
extract_frames.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+ import tqdm.auto as tqdm
6
+ import io
7
+ import torch
8
+ import av
9
+ from torchvision.utils import save_image
10
+
11
+
12
+ def preprocess_and_save_frames(
13
+ file_like: io.BytesIO,
14
+ video_id: str,
15
+ output_root: str = "./frames",
16
+ crop_size: int = -1,
17
+ every: int = 10,
18
+ max_memory: int = 50 * 1024 * 1024,
19
+ device: str = "cpu"
20
+ ):
21
+ """
22
+ Loads a video and saves frames as images in output_root/<video_id>/.
23
+ """
24
+
25
+ # Ensure the base frames directory exists
26
+ os.makedirs(output_root, exist_ok=True)
27
+
28
+ # Create subfolder for this specific video
29
+ video_dir = os.path.join(output_root, video_id)
30
+ os.makedirs(video_dir, exist_ok=True)
31
+
32
+ center_crop_transform = None
33
+ if crop_size > 0:
34
+ from torchvision import transforms
35
+ center_crop_transform = transforms.CenterCrop(crop_size)
36
+
37
+ file_like.seek(0)
38
+ container = av.open(file_like)
39
+ current_memory = 0
40
+
41
+ for i, frame in enumerate(container.decode(video=0)):
42
+ if i % every == 0:
43
+ frame_array = frame.to_ndarray(format="rgb24")
44
+ frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() / 255.0
45
+
46
+ if center_crop_transform is not None:
47
+ frame_tensor = center_crop_transform(frame_tensor)
48
+
49
+ frame_path = os.path.join(video_dir, f"frame_{i:05d}.png")
50
+ save_image(frame_tensor, frame_path)
51
+
52
+ frame_bytes = frame_tensor.numel() * 4
53
+ current_memory += frame_bytes
54
+ if current_memory >= max_memory:
55
+ break
56
+
57
+
58
+ # -------- Main section --------
59
+ if __name__ == "__main__":
60
+ DATASET_PATH = "/tmp/data"
61
+ dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True)
62
+
63
+ # Make sure ./frames exists before processing
64
+ os.makedirs("./frames", exist_ok=True)
65
+
66
+ for el in tqdm.tqdm(dataset_remote, desc="Extracting frames"):
67
+ try:
68
+ video_id = str(el["id"])
69
+ file_like = io.BytesIO(el["video"]["bytes"])
70
+ preprocess_and_save_frames(file_like, video_id)
71
+ except Exception as e:
72
+ print(f"⚠️ Failed {el['id']}: {e}")
73
+
74
+ print("✅ All frames saved under ./frames/<video_id>/")
models/DeMamba.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import XCLIPVisionModel
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from .mamba_base import MambaConfig, ResidualBlock
10
+ import torch.nn.init as init
11
+ from .clip import clip
12
+ import math
13
+ from transformers import XCLIPVisionConfig, XCLIPVisionModel
14
+
15
+ def create_reorder_index(N, device):
16
+ new_order = []
17
+ for col in range(N):
18
+ if col % 2 == 0:
19
+ new_order.extend(range(col, N*N, N))
20
+ else:
21
+ new_order.extend(range(col + N*(N-1), col-1, -N))
22
+ return torch.tensor(new_order, device=device)
23
+
24
+ def reorder_data(data, N):
25
+ assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
26
+ device = data.device
27
+ new_order = create_reorder_index(N, device)
28
+ B, t, _, _ = data.shape
29
+ index = new_order.repeat(B, t, 1).unsqueeze(-1)
30
+ reordered_data = torch.gather(data, 2, index.expand_as(data))
31
+ return reordered_data
32
+
33
+ class XCLIP_DeMamba(nn.Module):
34
+ def __init__(
35
+ self, channel_size=768, class_num=1
36
+ ):
37
+ super(XCLIP_DeMamba, self).__init__()
38
+ # self.encoder = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
39
+ # my code for training from scratch
40
+ config = XCLIPVisionConfig()
41
+ self.encoder = XCLIPVisionModel(config)
42
+
43
+ blocks = []
44
+ channel = 768
45
+ self.fusing_ratios = 1
46
+ self.patch_nums = (14//self.fusing_ratios)**2
47
+ self.mamba_configs = MambaConfig(d_model=channel)
48
+ self.mamba = ResidualBlock(config = self.mamba_configs)
49
+ # self.fc1 = nn.Linear((self.patch_nums+1)*channel, class_num)
50
+ self.fc1 = nn.Linear(38400, class_num) # my code
51
+ # self.fc_norm = nn.LayerNorm(self.patch_nums*channel)
52
+ self.fc_norm = None # my code
53
+ self.fc_norm2 = nn.LayerNorm(768)
54
+ self.initialize_weights(self.fc1)
55
+ self.dropout = nn.Dropout(p=0.0)
56
+
57
+ def initialize_weights(self, module):
58
+ for m in module.modules():
59
+ if isinstance(m, nn.Linear):
60
+ init.xavier_uniform_(m.weight)
61
+ if m.bias is not None:
62
+ init.constant_(m.bias, 0)
63
+ elif isinstance(m, nn.Conv2d):
64
+ init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
65
+ if m.bias is not None:
66
+ init.constant_(m.bias, 0)
67
+ elif isinstance(m, nn.BatchNorm2d):
68
+ init.constant_(m.weight, 1)
69
+ init.constant_(m.bias, 0)
70
+
71
+ def forward(self, x):
72
+ b, t, _, h, w = x.shape
73
+ images = x.view(b * t, 3, h, w)
74
+ outputs = self.encoder(images, output_hidden_states=True)
75
+ sequence_output = outputs['last_hidden_state'][:,1:,:]
76
+ _, _, c = sequence_output.shape
77
+
78
+ global_feat = outputs['pooler_output'].reshape(b, t, -1)
79
+ global_feat = global_feat.mean(1)
80
+ global_feat = self.fc_norm2(global_feat)
81
+
82
+ sequence_output = sequence_output.view(b, t, -1, c)
83
+ _, _, f_w, _ = sequence_output.shape
84
+ f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
85
+
86
+ s = f_h//self.fusing_ratios
87
+ sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
88
+ x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
89
+ b_l = b*s*s
90
+
91
+ x = reorder_data(x, self.fusing_ratios)
92
+ x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
93
+ res = self.mamba(x)
94
+
95
+ video_level_features = res.mean(1)
96
+ video_level_features = video_level_features.view(b, -1)
97
+
98
+ # my code
99
+ if self.fc_norm is None:
100
+ self.fc_norm = nn.LayerNorm(video_level_features.size(-1)).to(video_level_features.device)
101
+
102
+ video_level_features = self.fc_norm(video_level_features)
103
+ video_level_features = torch.cat((global_feat, video_level_features), dim=1)
104
+
105
+ pred = self.fc1(video_level_features)
106
+ pred = self.dropout(pred)
107
+
108
+ return pred
109
+
110
+
111
+
112
+ class CLIP_DeMamba(nn.Module):
113
+ def __init__(
114
+ self, channel_size=512, class_num=1
115
+ ):
116
+ super(CLIP_DeMamba, self).__init__()
117
+ self.clip_model, preprocess = clip.load('ViT-B-14')
118
+ self.clip_model = self.clip_model.float()
119
+ blocks = []
120
+ channel = 512
121
+ self.fusing_ratios = 2
122
+ self.patch_nums = (14//self.fusing_ratios)**2
123
+ self.mamba_configs = MambaConfig(d_model=channel)
124
+ self.mamba = ResidualBlock(config = self.mamba_configs)
125
+
126
+ self.fc1 = nn.Linear(channel*(self.patch_nums+1), class_num)
127
+
128
+ self.bn1 = nn.BatchNorm1d(channel)
129
+ self.initialize_weights(self.fc1)
130
+
131
+ def initialize_weights(self, module):
132
+ for m in module.modules():
133
+ if isinstance(m, nn.Linear):
134
+ init.xavier_uniform_(m.weight)
135
+ if m.bias is not None:
136
+ init.constant_(m.bias, 0)
137
+ elif isinstance(m, nn.Conv2d):
138
+ init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
139
+ if m.bias is not None:
140
+ init.constant_(m.bias, 0)
141
+ elif isinstance(m, nn.BatchNorm2d):
142
+ init.constant_(m.weight, 1)
143
+ init.constant_(m.bias, 0)
144
+
145
+ def forward(self, x):
146
+ b, t, _, h, w = x.shape
147
+ images = x.view(b * t, 3, h, w)
148
+ sequence_output = self.clip_model.encode_image(images)
149
+ _, _, c = sequence_output.shape
150
+ sequence_output = sequence_output.view(b, t, -1, c)
151
+
152
+ global_feat = sequence_output.reshape(b, -1, c)
153
+ global_feat = global_feat.mean(1)
154
+
155
+ _, _, f_w, _ = sequence_output.shape
156
+ f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w))
157
+
158
+ s = f_h//self.fusing_ratios
159
+ sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c)
160
+ x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c)
161
+ b_l = b*s*s
162
+
163
+ x = reorder_data(x, self.fusing_ratios)
164
+ x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c)
165
+ res = self.mamba(x)
166
+ video_level_features = res.mean(1)
167
+ video_level_features = video_level_features.view(b, -1)
168
+
169
+ video_level_features = torch.cat((global_feat, video_level_features), dim=1)
170
+ x = self.fc1(video_level_features)
171
+
172
+ return x
173
+
174
+ if __name__ == '__main__':
175
+ model = CLIP_DeMamba()
176
+ print(model)
models/F3Net.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ F3Net: Fusion, Feedback and Focus for Salient Object Detection @ AAAI'2020
3
+ Copyright (c) University of Chinese Academy of Sciences and its affiliates.
4
+ Modified by Jun Wei from https://github.com/weijun88/F3Net
5
+ """
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.model_zoo as model_zoo
14
+
15
+ pretrained_settings = {
16
+ 'xception': {
17
+ 'imagenet': {
18
+ 'dir': '/ossfs/workspace/aigc_video/weights/xception-b5690688.pth',
19
+ 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth',
20
+ 'input_space': 'RGB',
21
+ 'input_size': [3, 299, 299],
22
+ 'input_range': [0, 1],
23
+ 'mean': [0.5, 0.5, 0.5],
24
+ 'std': [0.5, 0.5, 0.5],
25
+ 'num_classes': 1000,
26
+ 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
27
+ }
28
+ }
29
+ }
30
+
31
+ class F3Net(nn.Module):
32
+ """
33
+ Implementation is mainly referenced from https://github.com/yyk-wew/F3Net
34
+ """
35
+ def __init__(self,
36
+ num_classes: int=2,
37
+ img_width: int=299,
38
+ img_height: int=299,
39
+ LFS_window_size: int=10,
40
+ LFS_M: int=6) -> None:
41
+ super(F3Net, self).__init__()
42
+ assert img_width == img_height
43
+ self.img_size = img_width
44
+ self.num_classes = num_classes
45
+ self._LFS_window_size = LFS_window_size
46
+ self._LFS_M = LFS_M
47
+
48
+
49
+ self.fad_head = FAD_Head(self.img_size)
50
+ self.lfs_head = LFS_Head(self.img_size, self._LFS_window_size, self._LFS_M)
51
+
52
+ self.fad_excep = self._init_xcep_fad()
53
+ self.lfs_excep = self._init_xcep_lfs()
54
+
55
+ self.mix_block7 = MixBlock(c_in=728, width=19, height=19)
56
+ self.mix_block12 = MixBlock(c_in=1024, width=10, height=10)
57
+ self.excep_forwards = ['conv1', 'bn1', 'relu', 'conv2', 'bn2', 'relu',
58
+ 'block1', 'block2', 'block3', 'block4', 'block5', 'block6',
59
+ 'block7', 'block8', 'block9', 'block10' , 'block11', 'block12',
60
+ 'conv3', 'bn3', 'relu', 'conv4', 'bn4']
61
+
62
+ # classifier
63
+ self.relu = nn.ReLU(inplace=True)
64
+ self.fc = nn.Linear(4096, num_classes)
65
+ self.dp = nn.Dropout(p=0.2)
66
+
67
+ def _init_xcep_fad(self):
68
+ fad_excep = return_pytorch04_xception(True)
69
+ conv1_data = fad_excep.conv1.weight.data
70
+ # let new conv1 use old param to balance the network
71
+ fad_excep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False)
72
+ for i in range(4):
73
+ fad_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / 4.0
74
+ return fad_excep
75
+
76
+ def _init_xcep_lfs(self):
77
+ lfs_excep = return_pytorch04_xception(True)
78
+ conv1_data = lfs_excep.conv1.weight.data
79
+ # let new conv1 use old param to balance the network
80
+ lfs_excep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False)
81
+ for i in range(int(self._LFS_M / 3)):
82
+ lfs_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / float(self._LFS_M / 3.0)
83
+ return lfs_excep
84
+
85
+ def _features(self, x_fad, x_fls):
86
+ for forward_func in self.excep_forwards:
87
+ x_fad = getattr(self.fad_excep, forward_func)(x_fad)
88
+ x_fls = getattr(self.lfs_excep, forward_func)(x_fls)
89
+ if forward_func == 'block7':
90
+ x_fad, x_fls = self.mix_block7(x_fad, x_fls)
91
+ if forward_func == 'block12':
92
+ x_fad, x_fls = self.mix_block12(x_fad, x_fls)
93
+ return x_fad, x_fls
94
+
95
+ def _norm_feature(self, x):
96
+ x = self.relu(x)
97
+ x = F.adaptive_avg_pool2d(x, (1,1))
98
+ x = x.view(x.size(0), -1)
99
+ return x
100
+
101
+ def forward(self, x):
102
+ fad_input = self.fad_head(x)
103
+ lfs_input = self.lfs_head(x)
104
+ x_fad, x_fls = self._features(fad_input, lfs_input)
105
+ x_fad = self._norm_feature(x_fad)
106
+ x_fls = self._norm_feature(x_fls)
107
+ x_cat = torch.cat((x_fad, x_fls), dim=1)
108
+ x_drop = self.dp(x_cat)
109
+ logit = self.fc(x_drop)
110
+ return logit
111
+
112
+ class SeparableConv2d(nn.Module):
113
+ def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
114
+ super(SeparableConv2d,self).__init__()
115
+
116
+ self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
117
+ self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
118
+
119
+ def forward(self,x):
120
+ x = self.conv1(x)
121
+ x = self.pointwise(x)
122
+ return x
123
+
124
+ class Block(nn.Module):
125
+ def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
126
+ super(Block, self).__init__()
127
+
128
+ if out_filters != in_filters or strides!=1:
129
+ self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
130
+ self.skipbn = nn.BatchNorm2d(out_filters)
131
+ else:
132
+ self.skip=None
133
+
134
+ self.relu = nn.ReLU(inplace=True)
135
+ rep=[]
136
+
137
+ filters=in_filters
138
+ if grow_first:
139
+ rep.append(self.relu)
140
+ rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
141
+ rep.append(nn.BatchNorm2d(out_filters))
142
+ filters = out_filters
143
+
144
+ for i in range(reps-1):
145
+ rep.append(self.relu)
146
+ rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
147
+ rep.append(nn.BatchNorm2d(filters))
148
+
149
+ if not grow_first:
150
+ rep.append(self.relu)
151
+ rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
152
+ rep.append(nn.BatchNorm2d(out_filters))
153
+
154
+ if not start_with_relu:
155
+ rep = rep[1:]
156
+ else:
157
+ rep[0] = nn.ReLU(inplace=False)
158
+
159
+ if strides != 1:
160
+ rep.append(nn.MaxPool2d(3,strides,1))
161
+ self.rep = nn.Sequential(*rep)
162
+
163
+ def forward(self,inp):
164
+ x = self.rep(inp)
165
+
166
+ if self.skip is not None:
167
+ skip = self.skip(inp)
168
+ skip = self.skipbn(skip)
169
+ else:
170
+ skip = inp
171
+
172
+ x+=skip
173
+ return x
174
+
175
+ class Xception(nn.Module):
176
+ """
177
+ Xception optimized for the ImageNet dataset, as specified in
178
+ https://arxiv.org/pdf/1610.02357.pdf
179
+ """
180
+ def __init__(self, num_classes=1000):
181
+ """ Constructor
182
+ Args:
183
+ num_classes: number of classes
184
+ """
185
+ super(Xception, self).__init__()
186
+ self.num_classes = num_classes
187
+
188
+ self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
189
+ self.bn1 = nn.BatchNorm2d(32)
190
+ self.relu = nn.ReLU(inplace=True)
191
+
192
+ self.conv2 = nn.Conv2d(32,64,3,bias=False)
193
+ self.bn2 = nn.BatchNorm2d(64)
194
+ #do relu here
195
+
196
+ self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
197
+ self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
198
+ self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)
199
+
200
+ self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
201
+ self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
202
+ self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
203
+ self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)
204
+
205
+ self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
206
+ self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
207
+ self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
208
+ self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)
209
+
210
+ self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
211
+
212
+ self.conv3 = SeparableConv2d(1024,1536,3,1,1)
213
+ self.bn3 = nn.BatchNorm2d(1536)
214
+
215
+ #do relu here
216
+ self.conv4 = SeparableConv2d(1536,2048,3,1,1)
217
+ self.bn4 = nn.BatchNorm2d(2048)
218
+
219
+ def xception(num_classes=1000, pretrained='imagenet'):
220
+ model = Xception(num_classes=num_classes)
221
+ if pretrained:
222
+ settings = pretrained_settings['xception'][pretrained]
223
+ assert num_classes == settings['num_classes'], \
224
+ "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
225
+
226
+ model = Xception(num_classes=num_classes)
227
+ model.load_state_dict(settings['dir'])
228
+
229
+ model.input_space = settings['input_space']
230
+ model.input_size = settings['input_size']
231
+ model.input_range = settings['input_range']
232
+ model.mean = settings['mean']
233
+ model.std = settings['std']
234
+
235
+ return model
236
+
237
+ def return_pytorch04_xception(pretrained=True):
238
+ model = xception(pretrained=False)
239
+ if pretrained:
240
+ state_dict = torch.load(
241
+ '/ossfs/workspace/GenVideo/weights/xception-b5690688.pth')
242
+ for name, weights in state_dict.items():
243
+ if 'pointwise' in name:
244
+ state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
245
+ model.load_state_dict(state_dict, strict=False)
246
+
247
+ return model
248
+
249
+ class Filter(nn.Module):
250
+ def __init__(self, size,
251
+ band_start,
252
+ band_end,
253
+ use_learnable=True,
254
+ norm=False):
255
+ super(Filter, self).__init__()
256
+ self.use_learnable = use_learnable
257
+
258
+ self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False)
259
+ if self.use_learnable:
260
+ self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True)
261
+ self.learnable.data.normal_(0., 0.1)
262
+ # Todo
263
+ # self.learnable = nn.Parameter(torch.rand((size, size)) * 0.2 - 0.1, requires_grad=True)
264
+
265
+ self.norm = norm
266
+ if norm:
267
+ self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False)
268
+
269
+
270
+ def forward(self, x):
271
+ if self.use_learnable:
272
+ filt = self.base + norm_sigma(self.learnable)
273
+ else:
274
+ filt = self.base
275
+
276
+ if self.norm:
277
+ y = x * filt / self.ft_num
278
+ else:
279
+ y = x * filt
280
+ return y
281
+
282
+ class FAD_Head(nn.Module):
283
+ def __init__(self, size):
284
+ super(FAD_Head, self).__init__()
285
+ # init DCT matrix
286
+ self._DCT_all = nn.Parameter(torch.tensor(DCT_mat(size)).float(), requires_grad=False)
287
+ self._DCT_all_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(size)).float(), 0, 1), requires_grad=False)
288
+
289
+ # define base filters and learnable
290
+ # 0 - 1/16 || 1/16 - 1/8 || 1/8 - 1
291
+ low_filter = Filter(size, 0, size // 16)
292
+ middle_filter = Filter(size, size // 16, size // 8)
293
+ high_filter = Filter(size, size // 8, size)
294
+ all_filter = Filter(size, 0, size * 2)
295
+
296
+ self.filters = nn.ModuleList([low_filter, middle_filter, high_filter, all_filter])
297
+
298
+ def forward(self, x):
299
+ # DCT
300
+ x_freq = self._DCT_all @ x @ self._DCT_all_T # [N, 3, 299, 299]
301
+
302
+ # 4 kernel
303
+ y_list = []
304
+ for i in range(4):
305
+ x_pass = self.filters[i](x_freq) # [N, 3, 299, 299]
306
+ y = self._DCT_all_T @ x_pass @ self._DCT_all # [N, 3, 299, 299]
307
+ y_list.append(y)
308
+ out = torch.cat(y_list, dim=1) # [N, 12, 299, 299]
309
+ return out
310
+
311
+
312
+ class LFS_Head(nn.Module):
313
+ def __init__(self, size, window_size, M):
314
+ super(LFS_Head, self).__init__()
315
+
316
+ self.window_size = window_size
317
+ self._M = M
318
+
319
+ # init DCT matrix
320
+ self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False)
321
+ self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False)
322
+
323
+ self.unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=2, padding=4)
324
+
325
+ # init filters
326
+ self.filters = nn.ModuleList([Filter(window_size, window_size * 2. / M * i, window_size * 2. / M * (i+1), norm=True) for i in range(M)])
327
+
328
+ def forward(self, x):
329
+ # turn RGB into Gray
330
+ x_gray = 0.299*x[:,0,:,:] + 0.587*x[:,1,:,:] + 0.114*x[:,2,:,:]
331
+ x = x_gray.unsqueeze(1)
332
+
333
+ # rescale to 0 - 255
334
+ x = (x + 1.) * 122.5
335
+
336
+ # calculate size
337
+ N, C, W, H = x.size()
338
+ S = self.window_size
339
+ size_after = int((W - S + 8)/2) + 1
340
+ assert size_after == 149
341
+
342
+ # sliding window unfold and DCT
343
+ x_unfold = self.unfold(x) # [N, C * S * S, L] L:block num
344
+ L = x_unfold.size()[2]
345
+ x_unfold = x_unfold.transpose(1, 2).reshape(N, L, C, S, S) # [N, L, C, S, S]
346
+ x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T
347
+
348
+ # M kernels filtering
349
+ y_list = []
350
+ for i in range(self._M):
351
+ # y = self.filters[i](x_dct) # [N, L, C, S, S]
352
+ # y = torch.abs(y)
353
+ # y = torch.sum(y, dim=[2,3,4]) # [N, L]
354
+ # y = torch.log10(y + 1e-15)
355
+ y = torch.abs(x_dct)
356
+ y = torch.log10(y + 1e-15)
357
+ y = self.filters[i](y)
358
+ y = torch.sum(y, dim=[2,3,4])
359
+ y = y.reshape(N, size_after, size_after).unsqueeze(dim=1) # [N, 1, 149, 149]
360
+ y_list.append(y)
361
+ out = torch.cat(y_list, dim=1) # [N, M, 149, 149]
362
+ return out
363
+
364
+ class MixBlock(nn.Module):
365
+
366
+ def __init__(self, c_in, width, height):
367
+ super(MixBlock, self).__init__()
368
+ self.FAD_query = nn.Conv2d(c_in, c_in, (1,1))
369
+ self.LFS_query = nn.Conv2d(c_in, c_in, (1,1))
370
+
371
+ self.FAD_key = nn.Conv2d(c_in, c_in, (1,1))
372
+ self.LFS_key = nn.Conv2d(c_in, c_in, (1,1))
373
+
374
+ self.softmax = nn.Softmax(dim=-1)
375
+ self.relu = nn.ReLU()
376
+
377
+ self.FAD_gamma = nn.Parameter(torch.zeros(1))
378
+ self.LFS_gamma = nn.Parameter(torch.zeros(1))
379
+
380
+ self.FAD_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
381
+ self.FAD_bn = nn.BatchNorm2d(c_in)
382
+ self.LFS_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in)
383
+ self.LFS_bn = nn.BatchNorm2d(c_in)
384
+
385
+ def forward(self, x_FAD, x_LFS):
386
+ B, C, W, H = x_FAD.size()
387
+ assert W == H
388
+
389
+ q_FAD = self.FAD_query(x_FAD).view(-1, W, H) # [BC, W, H]
390
+ q_LFS = self.LFS_query(x_LFS).view(-1, W, H)
391
+ M_query = torch.cat([q_FAD, q_LFS], dim=2) # [BC, W, 2H]
392
+
393
+ k_FAD = self.FAD_key(x_FAD).view(-1, W, H).transpose(1, 2) # [BC, H, W]
394
+ k_LFS = self.LFS_key(x_LFS).view(-1, W, H).transpose(1, 2)
395
+ M_key = torch.cat([k_FAD, k_LFS], dim=1) # [BC, 2H, W]
396
+
397
+ energy = torch.bmm(M_query, M_key) #[BC, W, W]
398
+ attention = self.softmax(energy).view(B, C, W, W)
399
+
400
+ att_LFS = x_LFS * attention * (torch.sigmoid(self.LFS_gamma) * 2.0 - 1.0)
401
+ y_FAD = x_FAD + self.FAD_bn(self.FAD_conv(att_LFS))
402
+
403
+ att_FAD = x_FAD * attention * (torch.sigmoid(self.FAD_gamma) * 2.0 - 1.0)
404
+ y_LFS = x_LFS + self.LFS_bn(self.LFS_conv(att_FAD))
405
+ return y_FAD, y_LFS
406
+
407
+ def DCT_mat(size):
408
+ m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)]
409
+ return m
410
+
411
+ def generate_filter(start, end, size):
412
+ return [[0. if i + j > end or i + j <= start else 1. for j in range(size)] for i in range(size)]
413
+
414
+ def norm_sigma(x):
415
+ return 2. * torch.sigmoid(x) - 1.
416
+
417
+
418
+ class Det_F3_Net(nn.Module):
419
+ def __init__(self):
420
+ super(Det_F3_Net, self).__init__()
421
+ self.f3net = F3Net(num_classes=1)
422
+
423
+ def forward(self, x):
424
+ b, t, _, h, w = x.shape
425
+ images = x.view(b * t, 3, h, w)
426
+ sequence_output = self.f3net(images)
427
+ sequence_output = sequence_output.view(b, t, -1)
428
+ sequence_output = sequence_output.mean(1)
429
+
430
+ return sequence_output
431
+
432
+ if __name__ == '__main__':
433
+ model = F3Net()
434
+ print(model)
models/FTCN.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021
3
+ Copyright (c) Xiamen University and its affiliates.
4
+ Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN
5
+ '''
6
+
7
+ import torch
8
+ from torch import nn
9
+ from .time_transformer import TimeTransformer
10
+ from .clip import clip
11
+
12
+ class RandomPatchPool(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def forward(self, x):
17
+ # batch,channel,16,7x7
18
+ b, c, t, h, w = x.shape
19
+ x = x.reshape(b, c, t, h * w)
20
+ if self.training and my_cfg.model.transformer.random_select:
21
+ while True:
22
+ idx = random.randint(0, h * w - 1)
23
+ i = idx // h
24
+ j = idx % h
25
+ if j == 0 or i == h - 1 or j == h - 1:
26
+ continue
27
+ else:
28
+ break
29
+ else:
30
+ idx = h * w // 2
31
+ x = x[..., idx]
32
+ return x
33
+
34
+
35
+ def valid_idx(idx, h):
36
+ i = idx // h
37
+ j = idx % h
38
+ if j == 0 or i == h - 1 or j == h - 1:
39
+ return False
40
+ else:
41
+ return True
42
+
43
+
44
+ class RandomAvgPool(nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+
48
+ def forward(self, x):
49
+ # batch,channel,16,7x7
50
+ b, c, t, h, w = x.shape
51
+ x = x.reshape(b, c, t, h * w)
52
+ candidates = list(range(h * w))
53
+ candidates = [idx for idx in candidates if valid_idx(idx, h)]
54
+ max_k = len(candidates)
55
+ if self.training and my_cfg.model.transformer.random_select:
56
+ k = my_cfg.model.transformer.k
57
+ else:
58
+ k = max_k
59
+ candidates = random.sample(candidates, k)
60
+ x = x[..., candidates].mean(-1)
61
+ return x
62
+
63
+ class TransformerHead(nn.Module):
64
+ def __init__(self, spatial_size=7, time_size=8, in_channels=2048):
65
+ super().__init__()
66
+ # if my_cfg.model.inco.no_time_pool:
67
+ # time_size = time_size * 2
68
+ patch_type = 'time'
69
+ if patch_type == "time":
70
+ self.pool = nn.AvgPool3d((1, spatial_size, spatial_size))
71
+ self.num_patches = time_size
72
+ elif patch_type == "spatial":
73
+ self.pool = nn.AvgPool3d((time_size, 1, 1))
74
+ self.num_patches = spatial_size ** 2
75
+ elif patch_type == "random":
76
+ self.pool = RandomPatchPool()
77
+ self.num_patches = time_size
78
+ elif patch_type == "random_avg":
79
+ self.pool = RandomAvgPool()
80
+ self.num_patches = time_size
81
+ elif patch_type == "all":
82
+ self.pool = nn.Identity()
83
+ self.num_patches = time_size * spatial_size * spatial_size
84
+ else:
85
+ raise NotImplementedError(patch_type)
86
+
87
+ self.dim = -1
88
+ if self.dim == -1:
89
+ self.dim = in_channels
90
+
91
+ self.in_channels = in_channels
92
+
93
+ if self.dim != self.in_channels:
94
+ self.fc = nn.Linear(self.in_channels, self.dim)
95
+
96
+ default_params = dict(
97
+ dim=self.dim, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1,
98
+ )
99
+
100
+ self.time_T = TimeTransformer(
101
+ num_patches=self.num_patches, num_classes=1, **default_params
102
+ )
103
+
104
+
105
+ def forward(self, x):
106
+ x = self.pool(x)
107
+ x = x.reshape(-1, self.in_channels, self.num_patches)
108
+ x = x.permute(0, 2, 1)
109
+ if self.dim != self.in_channels:
110
+ x = self.fc(x.reshape(-1, self.in_channels))
111
+ x = x.reshape(-1, self.num_patches, self.dim)
112
+ x = self.time_T(x)
113
+
114
+ return x
115
+
116
+
117
+ class ViT_B_FTCN(nn.Module):
118
+ def __init__(
119
+ self, channel_size=512, class_num=1
120
+ ):
121
+ super(ViT_B_FTCN, self).__init__()
122
+ self.clip_model, preprocess = clip.load('ViT-B-16')
123
+ self.clip_model = self.clip_model.float()
124
+
125
+ self.head = TransformerHead(spatial_size=14, time_size=8, in_channels=512)
126
+
127
+ def forward(self, x):
128
+ b, t, _, h, w = x.shape
129
+ images = x.view(b * t, 3, h, w)
130
+ sequence_output = self.clip_model.encode_image(images)
131
+ _, _, c = sequence_output.shape
132
+ sequence_output = sequence_output.view(b, t, 14, 14, c)
133
+ sequence_output = sequence_output.permute(0, 4, 1, 2, 3)
134
+
135
+ res = self.head(sequence_output)
136
+
137
+ return res
138
+ if __name__ == '__main__':
139
+ model = ViT_B_FTCN()
140
+ model = model.cuda()
141
+ dummy_input = torch.randn(4,8,3,224,224)
142
+ dummy_input = dummy_input.cuda()
143
+ model(dummy_input)
models/MINTIME ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MINTIME: Multi-Identity size-iNvariant TIMEsformer for Video Deepfake Detection@TIFS'2024
3
+ Copyright (c) ISTI-CNR and its affiliates.
4
+ Modified by Davide Alessandro Coccomini from https://github.com/davide-coccomini/MINTIME-Multi-Identity-size-iNvariant-TIMEsformer-for-Video-Deepfake-Detection
5
+ """
6
+
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from statistics import mean
12
+ from torch.nn.init import trunc_normal_
13
+ import cv2
14
+ import numpy as np
15
+ from random import random
16
+ from .clip import clip
17
+ from einops.layers.torch import Rearrange
18
+
19
+ # helpers
20
+ def exists(val):
21
+ return val is not None
22
+
23
+ # classes
24
+ class PreNorm(nn.Module):
25
+ def __init__(self, dim, fn):
26
+ super().__init__()
27
+ self.fn = fn
28
+ self.norm = nn.LayerNorm(dim)
29
+
30
+ def forward(self, x, *args, **kwargs):
31
+ x = self.norm(x)
32
+ return self.fn(x, *args, **kwargs)
33
+
34
+ # time token shift
35
+ def shift(t, amt):
36
+ if amt == 0:
37
+ return t
38
+ return F.pad(t, (0, 0, 0, 0, amt, -amt))
39
+
40
+ class PreTokenShift(nn.Module):
41
+ def __init__(self, frames, fn):
42
+ super().__init__()
43
+ self.frames = frames
44
+ self.fn = fn
45
+
46
+ def forward(self, x, *args, **kwargs):
47
+ f, dim = self.frames, x.shape[-1]
48
+ cls_x, x = x[:, :1], x[:, 1:]
49
+ x = rearrange(x, 'b (f n) d -> b f n d', f = f)
50
+
51
+ # shift along time frame before and after
52
+ dim_chunk = (dim // 3)
53
+ chunks = x.split(dim_chunk, dim = -1)
54
+ chunks_to_shift, rest = chunks[:3], chunks[3:]
55
+ shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
56
+ x = torch.cat((*shifted_chunks, *rest), dim = -1)
57
+
58
+ x = rearrange(x, 'b f n d -> b (f n) d')
59
+ x = torch.cat((cls_x, x), dim = 1)
60
+ return self.fn(x, *args, **kwargs)
61
+
62
+ # feedforward
63
+ class GEGLU(nn.Module):
64
+ def forward(self, x):
65
+ x, gates = x.chunk(2, dim = -1)
66
+ return x * F.gelu(gates)
67
+
68
+ class FeedForward(nn.Module):
69
+ def __init__(self, dim, mult = 4, dropout = 0.):
70
+ super().__init__()
71
+ self.net = nn.Sequential(
72
+ nn.Linear(dim, dim * mult * 2),
73
+ GEGLU(),
74
+ nn.Dropout(dropout),
75
+ nn.Linear(dim * mult, dim)
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.net(x)
80
+
81
+ # attention
82
+ def attn(q, k, v):
83
+ sim = torch.einsum('b i d, b j d -> b i j', q, k)
84
+ attn = sim.softmax(dim = -1)
85
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
86
+ return out, attn
87
+
88
+ class Attention(nn.Module):
89
+ def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.):
90
+ super().__init__()
91
+ self.heads = heads
92
+ self.scale = dim_head ** -0.5
93
+ inner_dim = dim_head * heads
94
+
95
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
96
+ self.to_out = nn.Sequential(
97
+ nn.Linear(inner_dim, dim),
98
+ nn.Dropout(dropout)
99
+ )
100
+
101
+ def forward(self, x, einops_from, einops_to, **einops_dims):
102
+ h = self.heads
103
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
104
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
105
+
106
+ q = q * self.scale
107
+
108
+ # splice out classification token at index 1
109
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))
110
+
111
+ # let classification token attend to key / values of all patches across time and space
112
+ cls_out, cls_attentions = attn(cls_q, k, v)
113
+
114
+ # rearrange across time or space
115
+ q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
116
+
117
+ # expand cls token keys and values across time or space and concat
118
+ r = q_.shape[0] // cls_k.shape[0]
119
+ cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))
120
+
121
+ k_ = torch.cat((cls_k, k_), dim = 1)
122
+ v_ = torch.cat((cls_v, v_), dim = 1)
123
+
124
+ # attention
125
+ out, attentions = attn(q_, k_, v_)
126
+
127
+ # merge back time or space
128
+ out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
129
+
130
+ # concat back the cls token
131
+ out = torch.cat((cls_out, out), dim = 1)
132
+
133
+ # merge back the heads
134
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
135
+
136
+ # combine heads out
137
+ return self.to_out(out), cls_attentions
138
+
139
+ class SizeInvariantTimeSformer(nn.Module):
140
+ def __init__(
141
+ self,
142
+ *,
143
+ require_attention = False
144
+ ):
145
+ super().__init__()
146
+ self.dim = 512
147
+ self.num_frames = 8
148
+ self.max_identities = 1
149
+ self.image_size = 224
150
+ self.num_classes = 1
151
+ self.patch_size = 1
152
+ self.num_patches = 196
153
+ self.channels = 512
154
+ self.depth = 9
155
+ self.heads = 8
156
+ self.dim_head = 64
157
+ self.attn_dropout = 0.
158
+ self.ff_dropout = 0.
159
+ self.shift_tokens = False
160
+ self.enable_size_emb = True
161
+ self.enable_pos_emb = True
162
+ self.require_attention = require_attention
163
+
164
+ num_positions = self.num_frames * self.channels
165
+ self.to_patch_embedding = nn.Linear(self.channels, self.dim)
166
+ self.cls_token = nn.Parameter(torch.randn(1, self.dim))
167
+ self.pos_emb = nn.Embedding(num_positions + 1, self.dim)
168
+
169
+ if self.enable_size_emb:
170
+ self.size_emb = nn.Embedding(num_positions + 1, self.dim)
171
+
172
+ self.layers = nn.ModuleList([])
173
+ for _ in range(self.depth):
174
+ ff = FeedForward(self.dim, dropout = self.ff_dropout)
175
+ time_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
176
+ spatial_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
177
+ if self.shift_tokens:
178
+ time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(self.num_frames, t), (time_attn, spatial_attn, ff))
179
+
180
+ time_attn, spatial_attn, ff = map(lambda t: PreNorm(self.dim, t), (time_attn, spatial_attn, ff))
181
+ self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))
182
+
183
+ self.to_out = nn.Sequential(
184
+ nn.LayerNorm(self.dim),
185
+ nn.Linear(self.dim, self.num_classes)
186
+ )
187
+
188
+ # Initialization
189
+ trunc_normal_(self.pos_emb.weight, std=.02)
190
+ trunc_normal_(self.cls_token, std=.02)
191
+ if self.enable_size_emb:
192
+ trunc_normal_(self.size_emb.weight, std=.02)
193
+ self.apply(self._init_weights)
194
+
195
+ def _init_weights(self, m):
196
+ if isinstance(m, nn.Linear):
197
+ trunc_normal_(m.weight, std=.02)
198
+ if isinstance(m, nn.Linear) and m.bias is not None:
199
+ nn.init.constant_(m.bias, 0)
200
+ elif isinstance(m, nn.LayerNorm):
201
+ nn.init.constant_(m.bias, 0)
202
+ nn.init.constant_(m.weight, 1.0)
203
+
204
+ @torch.jit.ignore
205
+ def no_weight_decay(self):
206
+ return {'pos_emb', 'cls_token'}
207
+
208
+ def forward(self, x):
209
+ b, f, c, h, w = x.shape
210
+ n = h * w
211
+ device = x.device
212
+
213
+ x = rearrange(x, 'b f c h w -> b (f h w) c') # B x F*P*P x C
214
+ tokens = self.to_patch_embedding(x) # B x 8*7*7 x dim
215
+
216
+ # Add cls token
217
+ cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
218
+ x = torch.cat((cls_token, tokens), dim = 1)
219
+
220
+ # Positional embedding
221
+ x += self.pos_emb(torch.arange(x.shape[1], device=device))
222
+
223
+ # Time and space attention
224
+ for (time_attn, spatial_attn, ff) in self.layers:
225
+ y, _ = time_attn(x, 'b (f n) d', '(b n) f d', n = n)
226
+ x = x + y
227
+ y, _ = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f)
228
+ x = x + y
229
+ x = ff(x) + x
230
+
231
+ cls_token = x[:, 0]
232
+
233
+ if self.require_attention:
234
+ return self.to_out(cls_token)
235
+ else:
236
+ return self.to_out(cls_token)
237
+
238
+
239
+ class ViT_B_MINTIME(nn.Module):
240
+ def __init__(
241
+ self, channel_size=512, class_num=1
242
+ ):
243
+ super(ViT_B_MINTIME, self).__init__()
244
+ self.clip_model, preprocess = clip.load('ViT-B-16')
245
+ self.clip_model = self.clip_model.float()
246
+ self.head = SizeInvariantTimeSformer()
247
+
248
+ def forward(self, x):
249
+ b, t, _, h, w = x.shape
250
+ images = x.view(b * t, 3, h, w)
251
+ sequence_output = self.clip_model.encode_image(images)
252
+ _, _, c = sequence_output.shape
253
+ sequence_output = sequence_output.view(b, t, 14, 14, c)
254
+ sequence_output = sequence_output.permute(0, 1, 4, 2, 3)
255
+
256
+
257
+ res = self.head(sequence_output)
258
+
259
+ return res
260
+
261
+
262
+ if __name__ == '__main__':
263
+
264
+ model = ViT_B_MINTIME()
265
+ model = model.cuda()
266
+ dummy_input = torch.randn(4,8,3,224,224)
267
+ dummy_input = dummy_input.cuda()
268
+ print(model(dummy_input))
269
+
models/NPR.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NPR: Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection@CVPR'2024
3
+ Copyright (c) Beijing Jiaotong University and its affiliates.
4
+ Modified by Chuangchuang Tan from https://github.com/chuangchuangtan/NPR-DeepfakeDetection
5
+ """
6
+
7
+ import torch.nn as nn
8
+ import torch
9
+ import torch.utils.model_zoo as model_zoo
10
+ from torch.nn import functional as F
11
+ from typing import Any, cast, Dict, List, Optional, Union
12
+ import numpy as np
13
+
14
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
15
+ 'resnet152']
16
+
17
+
18
+ model_urls = {
19
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
20
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
21
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
22
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
23
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
24
+ }
25
+
26
+
27
+ def conv3x3(in_planes, out_planes, stride=1):
28
+ """3x3 convolution with padding"""
29
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30
+ padding=1, bias=False)
31
+
32
+
33
+ def conv1x1(in_planes, out_planes, stride=1):
34
+ """1x1 convolution"""
35
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
42
+ super(BasicBlock, self).__init__()
43
+ self.conv1 = conv3x3(inplanes, planes, stride)
44
+ self.bn1 = nn.BatchNorm2d(planes)
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.conv2 = conv3x3(planes, planes)
47
+ self.bn2 = nn.BatchNorm2d(planes)
48
+ self.downsample = downsample
49
+ self.stride = stride
50
+
51
+ def forward(self, x):
52
+ identity = x
53
+
54
+ out = self.conv1(x)
55
+ out = self.bn1(out)
56
+ out = self.relu(out)
57
+
58
+ out = self.conv2(out)
59
+ out = self.bn2(out)
60
+
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+
64
+ out += identity
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class Bottleneck(nn.Module):
71
+ expansion = 4
72
+
73
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
74
+ super(Bottleneck, self).__init__()
75
+ self.conv1 = conv1x1(inplanes, planes)
76
+ self.bn1 = nn.BatchNorm2d(planes)
77
+ self.conv2 = conv3x3(planes, planes, stride)
78
+ self.bn2 = nn.BatchNorm2d(planes)
79
+ self.conv3 = conv1x1(planes, planes * self.expansion)
80
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
81
+ self.relu = nn.ReLU(inplace=True)
82
+ self.downsample = downsample
83
+ self.stride = stride
84
+
85
+ def forward(self, x):
86
+ identity = x
87
+
88
+ out = self.conv1(x)
89
+ out = self.bn1(out)
90
+ out = self.relu(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv3(out)
97
+ out = self.bn3(out)
98
+
99
+ if self.downsample is not None:
100
+ identity = self.downsample(x)
101
+
102
+ out += identity
103
+ out = self.relu(out)
104
+
105
+ return out
106
+
107
+
108
+ class ResNet(nn.Module):
109
+
110
+ def __init__(self, block, layers, num_classes=1, zero_init_residual=False):
111
+ super(ResNet, self).__init__()
112
+
113
+ self.unfoldSize = 2
114
+ self.unfoldIndex = 0
115
+ assert self.unfoldSize > 1
116
+ assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
117
+ self.inplanes = 64
118
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
119
+ self.bn1 = nn.BatchNorm2d(64)
120
+ self.relu = nn.ReLU(inplace=True)
121
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
122
+ self.layer1 = self._make_layer(block, 64 , layers[0])
123
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
124
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
125
+ # self.fc1 = nn.Linear(512 * block.expansion, 1)
126
+ self.fc1 = nn.Linear(512, num_classes)
127
+
128
+ for m in self.modules():
129
+ if isinstance(m, nn.Conv2d):
130
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
131
+ elif isinstance(m, nn.BatchNorm2d):
132
+ nn.init.constant_(m.weight, 1)
133
+ nn.init.constant_(m.bias, 0)
134
+
135
+ # Zero-initialize the last BN in each residual branch,
136
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
137
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
138
+ if zero_init_residual:
139
+ for m in self.modules():
140
+ if isinstance(m, Bottleneck):
141
+ nn.init.constant_(m.bn3.weight, 0)
142
+ elif isinstance(m, BasicBlock):
143
+ nn.init.constant_(m.bn2.weight, 0)
144
+
145
+ def _make_layer(self, block, planes, blocks, stride=1):
146
+ downsample = None
147
+ if stride != 1 or self.inplanes != planes * block.expansion:
148
+ downsample = nn.Sequential(
149
+ conv1x1(self.inplanes, planes * block.expansion, stride),
150
+ nn.BatchNorm2d(planes * block.expansion),
151
+ )
152
+
153
+ layers = []
154
+ layers.append(block(self.inplanes, planes, stride, downsample))
155
+ self.inplanes = planes * block.expansion
156
+ for _ in range(1, blocks):
157
+ layers.append(block(self.inplanes, planes))
158
+
159
+ return nn.Sequential(*layers)
160
+ def interpolate(self, img, factor):
161
+ return F.interpolate(F.interpolate(img, scale_factor=factor, mode='nearest', recompute_scale_factor=True), scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
162
+ def forward(self, x):
163
+ # n,c,w,h = x.shape
164
+ # if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
165
+ # if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
166
+ # factor = 0.5
167
+ # x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
168
+ # x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
169
+ # NPR = x - x_re
170
+ # n,c,w,h = x.shape
171
+ # if w%2 == 1 : x = x[:,:,:-1,:]
172
+ # if h%2 == 1 : x = x[:,:,:,:-1]
173
+ b, t, _, h, w = x.shape
174
+ x = x.view(b * t, 3, h, w)
175
+
176
+
177
+ NPR = x - self.interpolate(x, 0.5)
178
+
179
+ x = self.conv1(NPR*2.0/3.0)
180
+ x = self.bn1(x)
181
+ x = self.relu(x)
182
+ x = self.maxpool(x)
183
+
184
+ x = self.layer1(x)
185
+ x = self.layer2(x)
186
+
187
+ x = self.avgpool(x)
188
+ x = x.view(x.size(0), -1)
189
+ x = self.fc1(x)
190
+ x = x.view(b, t, -1)
191
+ x = x.mean(1)
192
+
193
+ return x
194
+
195
+ def infer(self, x):
196
+ # n,c,w,h = x.shape
197
+ # if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
198
+ # if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
199
+ # factor = 0.5
200
+ # x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
201
+ # x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
202
+ # NPR = x - x_re
203
+ # n,c,w,h = x.shape
204
+ # if w%2 == 1 : x = x[:,:,:-1,:]
205
+ # if h%2 == 1 : x = x[:,:,:,:-1]
206
+ b, t, _, h, w = x.shape
207
+ x = x.view(b * t, 3, h, w)
208
+
209
+
210
+ NPR = x - self.interpolate(x, 0.5)
211
+
212
+ x = self.conv1(NPR*2.0/3.0)
213
+ x = self.bn1(x)
214
+ x = self.relu(x)
215
+ x = self.maxpool(x)
216
+
217
+ x = self.layer1(x)
218
+ x = self.layer2(x)
219
+
220
+ x = self.avgpool(x)
221
+ x = x.view(x.size(0), -1)
222
+ x = self.fc1(x)
223
+ x = x.view(b, t, -1)
224
+ x = x.mean(1)
225
+ return x
226
+
227
+
228
+ def resnet18_npr(pretrained=False, **kwargs):
229
+ """Constructs a ResNet-18 model.
230
+ Args:
231
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
232
+ """
233
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
234
+ if pretrained:
235
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
236
+ return model
237
+
238
+
239
+ def resnet34_npr(pretrained=False, **kwargs):
240
+ """Constructs a ResNet-34 model.
241
+ Args:
242
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
243
+ """
244
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
245
+ if pretrained:
246
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
247
+ return model
248
+
249
+
250
+ def resnet50_npr(pretrained=False, **kwargs):
251
+ """Constructs a ResNet-50 model.
252
+ Args:
253
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
254
+ """
255
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
256
+ if pretrained:
257
+ model_state = torch.load('/ossfs/workspace/aigc_video/weights/resnet50-19c8e357.pth')
258
+ model.load_state_dict(model_state, strict=False)
259
+ return model
260
+
261
+
262
+ def resnet101_npr(pretrained=False, **kwargs):
263
+ """Constructs a ResNet-101 model.
264
+ Args:
265
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
266
+ """
267
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
268
+ if pretrained:
269
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
270
+ return model
271
+
272
+
273
+ def resnet152_npr(pretrained=False, **kwargs):
274
+ """Constructs a ResNet-152 model.
275
+ Args:
276
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
277
+ """
278
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
279
+ if pretrained:
280
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
281
+ return model
282
+
283
+
284
+
models/STIL.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ STIL: Spatiotemporal inconsistency learning for deepfake video detection @ ACM MM'2021
3
+ Copyright (c) Tencent Youtu Lab and its affiliates.
4
+ Modified by Zhiyuan Yan from https://github.com/Tencent/TFace?tab=readme-ov-file
5
+ '''
6
+
7
+ import os
8
+ import datetime
9
+ import logging
10
+ import numpy as np
11
+ from sklearn import metrics
12
+ from typing import Union
13
+ from collections import defaultdict
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+ import torch.utils.model_zoo as model_zoo
20
+ from torch.nn import DataParallel
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ class ISM_Module(nn.Module):
24
+ def __init__(self, k_size=3):
25
+ """The Information Supplement Module (ISM).
26
+
27
+ Args:
28
+ k_size (int, optional): Conv1d kernel_size . Defaults to 3.
29
+ """
30
+ super(ISM_Module, self).__init__()
31
+
32
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
33
+ self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False)
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+
37
+ def forward(self, x):
38
+ """
39
+ Args:
40
+ x (torch.tensor): Input tensor of shape (nt, c, h, w)
41
+ """
42
+ y = self.avg_pool(x)
43
+ y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1)
44
+ y = self.sigmoid(y)
45
+ return x * y.expand_as(x)
46
+
47
+
48
+ class TIM_Module(nn.Module):
49
+ def __init__(self, in_channels, reduction=16, n_segment=8, return_attn=False):
50
+ """The Temporal Inconsistency Module (TIM).
51
+
52
+ Args:
53
+ in_channels (int): Input channel number.
54
+ reduction (int, optional): Channel compression ratio r in the split operation.. Defaults to 16.
55
+ n_segment (int, optional): Number of input frames.. Defaults to 8.
56
+ return_attn (bool, optional): Whether to return the attention part. Defaults to False.
57
+
58
+ """
59
+ super(TIM_Module, self).__init__()
60
+ self.in_channels = in_channels
61
+ self.reduction = reduction
62
+ self.n_segment = n_segment
63
+ self.return_attn = return_attn
64
+
65
+ self.reduced_channels = self.in_channels // self.reduction
66
+
67
+ # first conv to shrink input channels
68
+ self.conv1 = nn.Conv2d(self.in_channels, self.reduced_channels, kernel_size=1, padding=0, bias=False)
69
+ self.bn1 = nn.BatchNorm2d(self.reduced_channels)
70
+
71
+ self.conv_ht = nn.Conv2d(self.reduced_channels, self.reduced_channels,
72
+ kernel_size=(3, 1), padding=(1, 0), groups=self.reduced_channels, bias=False)
73
+ self.conv_tw = nn.Conv2d(self.reduced_channels, self.reduced_channels,
74
+ kernel_size=(1, 3), padding=(0, 1), groups=self.reduced_channels, bias=False)
75
+
76
+ self.avg_pool_ht = nn.AvgPool2d((2, 1), (2, 1))
77
+ self.avg_pool_tw = nn.AvgPool2d((1, 2), (1, 2))
78
+
79
+ # HTIE in two directions
80
+ self.htie_conv1 = nn.Sequential(
81
+ nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False),
82
+ nn.BatchNorm2d(self.reduced_channels),
83
+ )
84
+ self.vtie_conv1 = nn.Sequential(
85
+ nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False),
86
+ nn.BatchNorm2d(self.reduced_channels),
87
+ )
88
+ self.htie_conv2 = nn.Sequential(
89
+ nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(3, 1), padding=(1, 0), bias=False),
90
+ nn.BatchNorm2d(self.reduced_channels),
91
+ )
92
+ self.vtie_conv2 = nn.Sequential(
93
+ nn.Conv2d(self.reduced_channels, self.reduced_channels, kernel_size=(1, 3), padding=(0, 1), bias=False),
94
+ nn.BatchNorm2d(self.reduced_channels),
95
+ )
96
+ self.ht_up_conv = nn.Sequential(
97
+ nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False),
98
+ nn.BatchNorm2d(self.in_channels)
99
+ )
100
+ self.tw_up_conv = nn.Sequential(
101
+ nn.Conv2d(self.reduced_channels, self.in_channels, kernel_size=1, bias=False),
102
+ nn.BatchNorm2d(self.in_channels)
103
+ )
104
+
105
+ self.sigmoid = nn.Sigmoid()
106
+
107
+
108
+ def feat_ht(self, feat):
109
+ """The H-T branch in the TIM module.
110
+
111
+ Args:
112
+ feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction)
113
+
114
+ """
115
+ n, t, c, h, w = feat.size()
116
+ # [n, t, c, h, w] -> [n, w, c, h, t] -> [nw, c, h, t]
117
+ feat_h = feat.permute(0, 4, 2, 3, 1).contiguous().view(-1, c, h, t)
118
+
119
+ # [nw, c, h, t-1]
120
+ feat_h_fwd, _ = feat_h.split([self.n_segment-1, 1], dim=3)
121
+ feat_h_conv = self.conv_ht(feat_h)
122
+ _, feat_h_conv_fwd = feat_h_conv.split([1, self.n_segment-1], dim=3)
123
+
124
+ diff_feat_fwd = feat_h_conv_fwd - feat_h_fwd
125
+ diff_feat_fwd = F.pad(diff_feat_fwd, [0, 1], value=0) # [nw, c, h, t]
126
+
127
+ # HTIE, down_up branch
128
+ diff_feat_fwd1 = self.avg_pool_ht(diff_feat_fwd) # [nw, c, h//2, t]
129
+ diff_feat_fwd1 = self.htie_conv1(diff_feat_fwd1) # [nw, c, h//2, t]
130
+ diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nw, c, h, t]
131
+ # HTIE, direct conv branch
132
+ diff_feat_fwd2 = self.htie_conv2(diff_feat_fwd) # [nw, c, h, t]
133
+
134
+ # [nw, C, h, t]
135
+ feat_ht_out = self.ht_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2)
136
+ feat_ht_out = self.sigmoid(feat_ht_out) - 0.5
137
+ # [nw, C, h, t] -> [n, w, C, h, t] -> [n, t, C, h, w]
138
+ feat_ht_out = feat_ht_out.view(n, w, self.in_channels, h, t).permute(0, 4, 2, 3, 1).contiguous()
139
+ # [n, t, C, h, w] -> [nt, C, h, w]
140
+ feat_ht_out = feat_ht_out.view(-1, self.in_channels, h, w)
141
+
142
+ return feat_ht_out
143
+
144
+
145
+ def feat_tw(self, feat):
146
+ """The T-W branch in the TIM module.
147
+
148
+ Args:
149
+ feat (torch.tensor): Input feature with shape [n, t, c, h, w] (c is in_channels // reduction)
150
+ """
151
+ n, t, c, h, w = feat.size()
152
+ # [n, t, c, h, w] -> [n, h, c, t, w] -> [nh, c, t, w]
153
+ feat_w = feat.permute(0, 3, 2, 1, 4).contiguous().view(-1, c, t, w)
154
+
155
+ # [nh, c, t-1, w]
156
+ feat_w_fwd, _ = feat_w.split([self.n_segment-1, 1], dim=2)
157
+ feat_w_conv = self.conv_tw(feat_w)
158
+ _, feat_w_conv_fwd = feat_w_conv.split([1, self.n_segment-1], dim=2)
159
+
160
+ diff_feat_fwd = feat_w_conv_fwd - feat_w_fwd
161
+ diff_feat_fwd = F.pad(diff_feat_fwd, [0, 0, 0, 1], value=0) # [nh, c, t, w]
162
+
163
+ # VTIE, down_up branch
164
+ diff_feat_fwd1 = self.avg_pool_tw(diff_feat_fwd) # [nh, c, t, w//2]
165
+ diff_feat_fwd1 = self.vtie_conv1(diff_feat_fwd1) # [nh, c, t, w//2]
166
+ diff_feat_fwd1 = F.interpolate(diff_feat_fwd1, diff_feat_fwd.size()[2:]) # [nh, c, t, w]
167
+ # VTIE, direct conv branch
168
+ diff_feat_fwd2 = self.vtie_conv2(diff_feat_fwd) # [nh, c, t, w]
169
+
170
+ # [nh, C, t, w]
171
+ feat_tw_out = self.tw_up_conv(1/3. * diff_feat_fwd + 1/3. * diff_feat_fwd1 + 1/3. * diff_feat_fwd2)
172
+ feat_tw_out = self.sigmoid(feat_tw_out) - 0.5
173
+ # [nh, C, t, w] -> [n, h, C, t, w] -> [n, t, C, h, W]
174
+ feat_tw_out = feat_tw_out.view(n, h, self.in_channels, t, w).permute(0, 3, 2, 1, 4).contiguous()
175
+ # [n, t, C, h, w] -> [nt, C, h, w]
176
+ feat_tw_out = feat_tw_out.view(-1, self.in_channels, h, w)
177
+
178
+ return feat_tw_out
179
+
180
+
181
+ def forward(self, x):
182
+ """
183
+ Args:
184
+ x (torch.tensor): Input with shape [nt, c, h, w]
185
+ """
186
+ # [nt, c, h, w] -> [nt, c//r, h, w]
187
+ bottleneck = self.conv1(x)
188
+ bottleneck = self.bn1(bottleneck)
189
+ # [nt, c//r, h, w] -> [n, t, c//r, h, w]
190
+ bottleneck = bottleneck.view((-1, self.n_segment) + bottleneck.size()[1:])
191
+
192
+ F_h = self.feat_ht(bottleneck) # [nt, c, h, w]
193
+ F_w = self.feat_tw(bottleneck) # [nt, c, h, w]
194
+
195
+ att = 0.5 * (F_h + F_w)
196
+
197
+ if self.return_attn:
198
+ return att
199
+
200
+ y2 = x + x * att
201
+
202
+ return y2
203
+
204
+
205
+ class ShiftModule(nn.Module):
206
+ def __init__(self, input_channels, n_segment=8, n_div=8, mode='shift'):
207
+ """A depth-wise conv on the segment level.
208
+
209
+ Args:
210
+ input_channels (int): Input channel number.
211
+ n_segment (int, optional): Number of input frames.. Defaults to 8.
212
+ n_div (int, optional): How many channels to group as a fold.. Defaults to 8.
213
+ mode (str, optional): One of "shift", "fixed", "norm". Defaults to 'shift'.
214
+ """
215
+ super(ShiftModule, self).__init__()
216
+ self.input_channels = input_channels
217
+ self.n_segment = n_segment
218
+ self.fold_div = n_div
219
+ self.fold = self.input_channels // self.fold_div
220
+ self.conv = nn.Conv1d(self.fold_div*self.fold, self.fold_div*self.fold,
221
+ kernel_size=3, padding=1, groups=self.fold_div*self.fold,
222
+ bias=False)
223
+
224
+ if mode == 'shift':
225
+ self.conv.weight.requires_grad = True
226
+ self.conv.weight.data.zero_()
227
+ # shift left
228
+ self.conv.weight.data[:self.fold, 0, 2] = 1
229
+ # shift right
230
+ self.conv.weight.data[self.fold: 2 * self.fold, 0, 0] = 1
231
+ if 2*self.fold < self.input_channels:
232
+ self.conv.weight.data[2 * self.fold:, 0, 1] = 1 # fixed
233
+ elif mode == 'fixed':
234
+ self.conv.weight.requires_grad = True
235
+ self.conv.weight.data.zero_()
236
+ self.conv.weight.data[:, 0, 1] = 1 # fixed
237
+ elif mode == 'norm':
238
+ self.conv.weight.requires_grad = True
239
+
240
+
241
+ def forward(self, x):
242
+ """
243
+ Args:
244
+ x (torch.tensor): Input with shape [nt, c, h, w]
245
+ """
246
+ nt, c, h, w = x.size()
247
+ n_batch = nt // self.n_segment
248
+ x = x.view(n_batch, self.n_segment, c, h, w)
249
+ # (n, h, w, c, t)
250
+ x = x.permute(0, 3, 4, 2, 1)
251
+ x = x.contiguous().view(n_batch*h*w, c, self.n_segment)
252
+ # (n*h*w, c, t)
253
+ x = self.conv(x)
254
+ x = x.view(n_batch, h, w, c, self.n_segment)
255
+ # (n, t, c, h, w)
256
+ x = x.permute(0, 4, 3, 1, 2)
257
+ x = x.contiguous().view(nt, c, h, w)
258
+ return x
259
+
260
+
261
+ class SCConv(nn.Module):
262
+ """
263
+ The spatial conv in SIM. Used in SCBottleneck
264
+ """
265
+ def __init__(self, inplanes, planes, stride, padding, dilation, groups, pooling_r, norm_layer):
266
+ super(SCConv, self).__init__()
267
+ self.f_w = nn.Sequential(
268
+ nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
269
+ nn.Conv2d(inplanes, planes, kernel_size=(1,3), stride=1,
270
+ padding=(0,padding), dilation=(1,dilation),
271
+ groups=groups, bias=False),
272
+ norm_layer(planes), nn.ReLU(inplace=True))
273
+ self.f_h = nn.Sequential(
274
+ # nn.AvgPool2d(kernel_size=(pooling_r,1), stride=(pooling_r,1)),
275
+ nn.Conv2d(inplanes, planes, kernel_size=(3,1), stride=1,
276
+ padding=(padding,0), dilation=(dilation,1),
277
+ groups=groups, bias=False),
278
+ norm_layer(planes),
279
+ )
280
+ self.k3 = nn.Sequential(
281
+ nn.Conv2d(inplanes, planes, kernel_size=3, stride=1,
282
+ padding=padding, dilation=dilation,
283
+ groups=groups, bias=False),
284
+ norm_layer(planes),
285
+ )
286
+ self.k4 = nn.Sequential(
287
+ nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
288
+ padding=padding, dilation=dilation,
289
+ groups=groups, bias=False),
290
+ norm_layer(planes),
291
+ )
292
+
293
+
294
+ def forward(self, x):
295
+ identity = x
296
+
297
+ # sigmoid(identity + k2)
298
+ out = torch.sigmoid(
299
+ torch.add(
300
+ identity,
301
+ F.interpolate(self.f_h(self.f_w(x)), identity.size()[2:])
302
+ )
303
+ )
304
+ out = torch.mul(self.k3(x), out) # k3 * sigmoid(identity + k2)
305
+ s2t_info = out
306
+ out = self.k4(out) # k4
307
+
308
+ return out, s2t_info
309
+
310
+
311
+ class SCBottleneck(nn.Module):
312
+ """
313
+ SCNet SCBottleneck. Variant for ResNet Bottlenect.
314
+ """
315
+ expansion = 4
316
+ pooling_r = 4 # down-sampling rate of the avg pooling layer in the K3 path of SC-Conv.
317
+
318
+ def __init__(self, num_segments, inplanes, planes, stride=1, downsample=None,
319
+ cardinality=1, bottleneck_width=32,
320
+ avd=False, dilation=1, is_first=False,
321
+ norm_layer=None):
322
+ super(SCBottleneck, self).__init__()
323
+ group_width = int(planes * (bottleneck_width / 64.)) * cardinality
324
+ self.conv1_a = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
325
+ self.bn1_a = norm_layer(group_width)
326
+ self.conv1_b = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
327
+ self.bn1_b = norm_layer(group_width)
328
+ self.avd = avd and (stride > 1 or is_first)
329
+ self.tim = TIM_Module(group_width, n_segment=num_segments)
330
+ self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift')
331
+ self.inplanes = inplanes
332
+ self.planes = planes
333
+ self.ism = ISM_Module()
334
+ self.shift = ShiftModule(group_width, n_segment=num_segments, n_div=8, mode='shift')
335
+
336
+ if self.avd:
337
+ self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
338
+ stride = 1
339
+
340
+ self.k1 = nn.Sequential(
341
+ nn.Conv2d(
342
+ group_width, group_width, kernel_size=3, stride=stride,
343
+ padding=dilation, dilation=dilation,
344
+ groups=cardinality, bias=False),
345
+ norm_layer(group_width),
346
+ )
347
+
348
+ self.scconv = SCConv(
349
+ group_width, group_width, stride=stride,
350
+ padding=dilation, dilation=dilation,
351
+ groups=cardinality, pooling_r=self.pooling_r, norm_layer=norm_layer)
352
+
353
+ self.conv3 = nn.Conv2d(
354
+ group_width * 2, planes * 4, kernel_size=1, bias=False)
355
+ self.bn3 = norm_layer(planes*4)
356
+
357
+ self.relu = nn.ReLU(inplace=True)
358
+ self.downsample = downsample
359
+ self.dilation = dilation
360
+
361
+
362
+ def forward(self, x):
363
+ """Forward func which splits the input into two branchs a and b.
364
+ a: trace features
365
+ b: spatial features
366
+ """
367
+ residual = x
368
+
369
+ out_a = self.relu(self.bn1_a(self.conv1_a(x)))
370
+ out_b = self.relu(self.bn1_b(self.conv1_b(x)))
371
+
372
+ # spatial representations
373
+ out_b, s2t_info = self.scconv(out_b)
374
+ out_b = self.relu(out_b)
375
+
376
+ # trace features
377
+ out_a = self.tim(out_a)
378
+ out_a = self.shift(out_a + self.ism(s2t_info))
379
+ out_a = self.relu(self.k1(out_a))
380
+
381
+ if self.avd:
382
+ out_a = self.avd_layer(out_a)
383
+ out_b = self.avd_layer(out_b)
384
+
385
+ out = self.conv3(torch.cat([out_a, out_b], dim=1))
386
+ out = self.bn3(out)
387
+
388
+ if self.downsample is not None:
389
+ residual = self.downsample(x)
390
+
391
+ out += residual
392
+ out = self.relu(out)
393
+
394
+ return out
395
+
396
+
397
+ class SCNet(nn.Module):
398
+ def __init__(self, num_segments, block, layers, groups=1, bottleneck_width=32,
399
+ num_classes=1000, dilated=False, dilation=1,
400
+ deep_stem=False, stem_width=64, avg_down=False,
401
+ avd=False, norm_layer=nn.BatchNorm2d):
402
+ """SCNet, a variant based on ResNet.
403
+
404
+ Args:
405
+ num_segments (int):
406
+ Number of input frames.
407
+ block (class):
408
+ Class for the residual block.
409
+ layers (list):
410
+ Number of layers in each block.
411
+ num_classes (int, optional):
412
+ Number of classification class.. Defaults to 1000.
413
+ dilated (bool, optional):
414
+ Whether to apply dilation conv. Defaults to False.
415
+ dilation (int, optional):
416
+ The dilation parameter in dilation conv. Defaults to 1.
417
+ deep_stem (bool, optional):
418
+ Whether to replace 7x7 conv in input stem with 3 3x3 conv. Defaults to False.
419
+ stem_width (int, optional):
420
+ Stem width in conv1 stem. Defaults to 64.
421
+ avg_down (bool, optional):
422
+ Whether to use AvgPool instead of stride conv when downsampling in the bottleneck. Defaults to False.
423
+ avd (bool, optional):
424
+ The avd parameter for the block Defaults to False.
425
+ norm_layer (class, optional):
426
+ Normalization layer. Defaults to nn.BatchNorm2d.
427
+ """
428
+ self.cardinality = groups
429
+ self.bottleneck_width = bottleneck_width
430
+ # ResNet-D params
431
+ self.inplanes = stem_width*2 if deep_stem else 64
432
+ self.avg_down = avg_down
433
+ self.avd = avd
434
+ self.num_segments = num_segments
435
+
436
+ super(SCNet, self).__init__()
437
+ conv_layer = nn.Conv2d
438
+ if deep_stem:
439
+ self.conv1 = nn.Sequential(
440
+ conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
441
+ norm_layer(stem_width),
442
+ nn.ReLU(inplace=True),
443
+ conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
444
+ norm_layer(stem_width),
445
+ nn.ReLU(inplace=True),
446
+ conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False),
447
+ )
448
+ else:
449
+ self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
450
+ bias=False)
451
+ self.bn1 = norm_layer(self.inplanes)
452
+ self.relu = nn.ReLU(inplace=True)
453
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
454
+ self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
455
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
456
+ if dilated or dilation == 4:
457
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
458
+ dilation=2, norm_layer=norm_layer)
459
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
460
+ dilation=4, norm_layer=norm_layer)
461
+ elif dilation==2:
462
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
463
+ dilation=1, norm_layer=norm_layer)
464
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
465
+ dilation=2, norm_layer=norm_layer)
466
+ else:
467
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
468
+ norm_layer=norm_layer)
469
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
470
+ norm_layer=norm_layer)
471
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
472
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
473
+
474
+ for m in self.modules():
475
+ if isinstance(m, nn.Conv2d):
476
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
477
+ elif isinstance(m, norm_layer):
478
+ nn.init.constant_(m.weight, 1)
479
+ nn.init.constant_(m.bias, 0)
480
+
481
+
482
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
483
+ is_first=True):
484
+ """
485
+ Core function to build layers.
486
+ """
487
+ downsample = None
488
+ if stride != 1 or self.inplanes != planes * block.expansion:
489
+ down_layers = []
490
+ if self.avg_down:
491
+ if dilation == 1:
492
+ down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
493
+ ceil_mode=True, count_include_pad=False))
494
+ else:
495
+ down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
496
+ ceil_mode=True, count_include_pad=False))
497
+ down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
498
+ kernel_size=1, stride=1, bias=False))
499
+ else:
500
+ down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
501
+ kernel_size=1, stride=stride, bias=False))
502
+ down_layers.append(norm_layer(planes * block.expansion))
503
+ downsample = nn.Sequential(*down_layers)
504
+
505
+ layers = []
506
+ if dilation == 1 or dilation == 2:
507
+ layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample,
508
+ cardinality=self.cardinality,
509
+ bottleneck_width=self.bottleneck_width,
510
+ avd=self.avd, dilation=1, is_first=is_first,
511
+ norm_layer=norm_layer))
512
+ elif dilation == 4:
513
+ layers.append(block(self.num_segments, self.inplanes, planes, stride, downsample=downsample,
514
+ cardinality=self.cardinality,
515
+ bottleneck_width=self.bottleneck_width,
516
+ avd=self.avd, dilation=2, is_first=is_first,
517
+ norm_layer=norm_layer))
518
+ else:
519
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
520
+
521
+ self.inplanes = planes * block.expansion
522
+ for i in range(1, blocks):
523
+ layers.append(block(self.num_segments, self.inplanes, planes,
524
+ cardinality=self.cardinality,
525
+ bottleneck_width=self.bottleneck_width,
526
+ avd=self.avd, dilation=dilation,
527
+ norm_layer=norm_layer))
528
+
529
+ return nn.Sequential(*layers)
530
+
531
+
532
+ def features(self, input):
533
+ x = self.conv1(input)
534
+ x = self.bn1(x)
535
+ x = self.relu(x)
536
+ x = self.maxpool(x)
537
+
538
+ x = self.layer1(x)
539
+ x = self.layer2(x)
540
+ x = self.layer3(x)
541
+ x = self.layer4(x)
542
+ return x
543
+
544
+
545
+ def logits(self, features):
546
+ x = self.avgpool(features)
547
+ x = x.view(x.size(0), -1)
548
+ x = self.fc(x)
549
+ return x
550
+
551
+
552
+ def forward(self, input):
553
+ x = self.features(input)
554
+ x = self.logits(x)
555
+ return x
556
+
557
+
558
+ def scnet50_v1d(num_segments, pretrained=False, **kwargs):
559
+ """
560
+ SCNet backbone, which is based on ResNet-50
561
+ Args:
562
+ num_segments (int):
563
+ Number of input frames.
564
+ pretrained (bool, optional):
565
+ Whether to load pretrained weights.
566
+ """
567
+ model = SCNet(num_segments, SCBottleneck, [3, 4, 6, 3],
568
+ deep_stem=True, stem_width=32, avg_down=True,
569
+ avd=True, **kwargs)
570
+ if pretrained:
571
+ model_state = torch.load('/ossfs/workspace/GenVideo/pretrained_weights/scnet50_v1d-4109d1e1.pth')
572
+ model.load_state_dict(model_state, strict=False)
573
+
574
+ return model
575
+
576
+ class STIL_Model(nn.Module):
577
+ def __init__(self,
578
+ num_class=1,
579
+ num_segment=8,
580
+ add_softmax=False,
581
+ **kwargs):
582
+ """ Model Builder for STIL model.
583
+ STIL: Spatiotemporal Inconsistency Learning for DeepFake Video Detection (https://arxiv.org/abs/2109.01860)
584
+
585
+ Args:
586
+ num_class (int, optional): Number of classes. Defaults to 2.
587
+ num_segment (int, optional): Number of segments (frames) fed to the model. Defaults to 8.
588
+ add_softmax (bool, optional): Whether to add softmax layer at the end. Defaults to False.
589
+ """
590
+ super().__init__()
591
+
592
+ self.num_class = num_class
593
+ self.num_segment = num_segment
594
+ self.build_model()
595
+
596
+ def build_model(self):
597
+ self.base_model = scnet50_v1d(self.num_segment, pretrained=True)
598
+
599
+ fc_feature_dim = self.base_model.fc.in_features
600
+ self.base_model.fc = nn.Linear(fc_feature_dim, self.num_class)
601
+
602
+ def forward(self, x):
603
+ """Forward pass of the model.
604
+
605
+ Args:
606
+ x (torch.tensor): input tensor of shape (n, t*c, h, w). n is the batch_size, t is num_segment
607
+ """
608
+ # img channel default to 3
609
+ img_channel = 3
610
+
611
+ # x: [n, tc, h, w] -> [nt, c, h, w]
612
+ # out: [nt, num_class]
613
+ out = self.base_model(
614
+ x.view((-1, img_channel) + x.size()[2:])
615
+ )
616
+
617
+ out = out.view(-1, self.num_segment, self.num_class) # [n, t, num_class]
618
+ out = out.mean(1, keepdim=False) # [n, num_class]
619
+ return out
620
+
621
+
622
+ def set_segment(self, num_segment):
623
+ """Change num_segment of the model.
624
+ Useful when the train and test want to feed different number of frames.
625
+
626
+ Args:
627
+ num_segment (int): New number of segments.
628
+ """
629
+ self.num_segment = num_segment
630
+
631
+
632
+ class Det_STIL(nn.Module):
633
+ def __init__(self):
634
+ super(Det_STIL, self).__init__()
635
+ self.model = STIL_Model()
636
+
637
+ def forward(self, x):
638
+ b, t, _, h, w = x.shape
639
+ images = x.view(b, t*3, h, w)
640
+ x = self.model(images)
641
+ return x
models/TALL.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+ from timm.models.registry import register_model
14
+ from torch.hub import load_state_dict_from_url
15
+ import logging
16
+ from einops import rearrange
17
+ import math
18
+
19
+ _logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _cfg(url='', **kwargs):
23
+ return {
24
+ 'url': url,
25
+ 'num_classes': 1, 'input_size': (3, 224, 224), 'pool_size': None,
26
+ 'crop_pct': .9, 'interpolation': 'bicubic',
27
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
28
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
29
+ **kwargs
30
+ }
31
+
32
+
33
+ default_cfgs = {
34
+ # patch models (my experiments)
35
+ 'swin_base_in1k_patch4_224': _cfg(
36
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',
37
+ ),
38
+ 'swin_base_patch4_window7_224_22k': _cfg(
39
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
40
+ )
41
+ }
42
+
43
+ class Mlp(nn.Module):
44
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
45
+ super().__init__()
46
+ out_features = out_features or in_features
47
+ hidden_features = hidden_features or in_features
48
+ self.fc1 = nn.Linear(in_features, hidden_features)
49
+ self.act = act_layer()
50
+ self.fc2 = nn.Linear(hidden_features, out_features)
51
+ self.drop = nn.Dropout(drop)
52
+
53
+ def forward(self, x):
54
+ x = self.fc1(x)
55
+ x = self.act(x)
56
+ x = self.drop(x)
57
+ x = self.fc2(x)
58
+ x = self.drop(x)
59
+ return x
60
+
61
+
62
+ def window_partition(x, window_size):
63
+ """
64
+ Args:
65
+ x: (B, H, W, C)
66
+ window_size (int): window size
67
+
68
+ Returns:
69
+ windows: (num_windows*B, window_size, window_size, C)
70
+ """
71
+ B, H, W, C = x.shape
72
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
73
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
74
+ return windows
75
+
76
+
77
+ def window_reverse(windows, window_size, H, W):
78
+ """
79
+ Args:
80
+ windows: (num_windows*B, window_size, window_size, C)
81
+ window_size (int): Window size
82
+ H (int): Height of image
83
+ W (int): Width of image
84
+
85
+ Returns:
86
+ x: (B, H, W, C)
87
+ """
88
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
89
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
90
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
91
+ return x
92
+
93
+ # adjust image size for the pyramid structure (i.e. must be integer of 32)
94
+ def create_new_image_size(img_size, thumbnail_dim, window_size):
95
+ h, w = img_size * thumbnail_dim[0], img_size * thumbnail_dim[1]
96
+ new_h, new_w = h, w
97
+ dim = 32 * window_size
98
+ if h % (32 * window_size) != 0:
99
+ new_h = (h // dim + 1) * dim
100
+ if w % (32 * window_size) != 0:
101
+ new_w = (w // dim + 1) * dim
102
+
103
+ return (new_h//thumbnail_dim[0], new_w//thumbnail_dim[1])
104
+
105
+
106
+ class WindowAttention(nn.Module):
107
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
108
+ It supports both of shifted and non-shifted window.
109
+
110
+ Args:
111
+ dim (int): Number of input channels.
112
+ window_size (tuple[int]): The height and width of the window.
113
+ num_heads (int): Number of attention heads.
114
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
115
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
116
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
117
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
118
+ """
119
+
120
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
121
+
122
+ super().__init__()
123
+ self.dim = dim
124
+ self.window_size = window_size # Wh, Ww
125
+ self.num_heads = num_heads
126
+ head_dim = dim // num_heads
127
+ self.scale = qk_scale or head_dim ** -0.5
128
+
129
+ # define a parameter table of relative position bias
130
+ self.relative_position_bias_table = nn.Parameter(
131
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
132
+ # get pair-wise relative position index for each token inside the window
133
+ coords_h = torch.arange(self.window_size[0])
134
+ coords_w = torch.arange(self.window_size[1])
135
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
136
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
137
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
138
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
139
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
140
+ relative_coords[:, :, 1] += self.window_size[1] - 1
141
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
142
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
143
+ self.register_buffer("relative_position_index", relative_position_index)
144
+
145
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
146
+ self.attn_drop = nn.Dropout(attn_drop)
147
+ self.proj = nn.Linear(dim, dim)
148
+ self.proj_drop = nn.Dropout(proj_drop)
149
+
150
+ trunc_normal_(self.relative_position_bias_table, std=.02)
151
+ self.softmax = nn.Softmax(dim=-1)
152
+
153
+ def forward(self, x, mask=None):
154
+ """
155
+ Args:
156
+ x: input features with shape of (num_windows*B, N, C)
157
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
158
+ """
159
+ B_, N, C = x.shape
160
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
161
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
162
+
163
+ q = q * self.scale
164
+ attn = (q @ k.transpose(-2, -1))
165
+
166
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
167
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
168
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
169
+ attn = attn + relative_position_bias.unsqueeze(0)
170
+
171
+ if mask is not None:
172
+ nW = mask.shape[0]
173
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
174
+ attn = attn.view(-1, self.num_heads, N, N)
175
+ attn = self.softmax(attn)
176
+ else:
177
+ attn = self.softmax(attn)
178
+
179
+ attn = self.attn_drop(attn)
180
+
181
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
182
+ x = self.proj(x)
183
+ x = self.proj_drop(x)
184
+ return x
185
+
186
+ def extra_repr(self) -> str:
187
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
188
+
189
+ def flops(self, N):
190
+ # calculate flops for 1 window with token length of N
191
+ flops = 0
192
+ # qkv = self.qkv(x)
193
+ flops += N * self.dim * 3 * self.dim
194
+ # attn = (q @ k.transpose(-2, -1))
195
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
196
+ # x = (attn @ v)
197
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
198
+ # x = self.proj(x)
199
+ flops += N * self.dim * self.dim
200
+ return flops
201
+
202
+
203
+ class SwinTransformerBlock(nn.Module):
204
+ r""" Swin Transformer Block.
205
+
206
+ Args:
207
+ dim (int): Number of input channels.
208
+ input_resolution (tuple[int]): Input resulotion.
209
+ num_heads (int): Number of attention heads.
210
+ window_size (int): Window size.
211
+ shift_size (int): Shift size for SW-MSA.
212
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
213
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
214
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
215
+ drop (float, optional): Dropout rate. Default: 0.0
216
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
217
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
218
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
219
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
220
+ """
221
+
222
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
223
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
224
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False):
225
+ super().__init__()
226
+ self.dim = dim
227
+ self.input_resolution = input_resolution
228
+ self.num_heads = num_heads
229
+ self.window_size = window_size
230
+ self.shift_size = shift_size
231
+ self.mlp_ratio = mlp_ratio
232
+ self.use_checkpoint = use_checkpoint
233
+
234
+ if min(self.input_resolution) <= self.window_size:
235
+ # if window size is larger than input resolution, we don't partition windows
236
+ self.shift_size = 0
237
+ self.window_size = min(self.input_resolution)
238
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
239
+
240
+ self.norm1 = norm_layer(dim)
241
+ self.attn = WindowAttention(
242
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
243
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
244
+
245
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
246
+ self.norm2 = norm_layer(dim)
247
+ mlp_hidden_dim = int(dim * mlp_ratio)
248
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
249
+
250
+ if self.shift_size > 0:
251
+ # calculate attention mask for SW-MSA
252
+ H, W = self.input_resolution
253
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
254
+ h_slices = (slice(0, -self.window_size),
255
+ slice(-self.window_size, -self.shift_size),
256
+ slice(-self.shift_size, None))
257
+ w_slices = (slice(0, -self.window_size),
258
+ slice(-self.window_size, -self.shift_size),
259
+ slice(-self.shift_size, None))
260
+ cnt = 0
261
+ for h in h_slices:
262
+ for w in w_slices:
263
+ img_mask[:, h, w, :] = cnt
264
+ cnt += 1
265
+
266
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
267
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
268
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
269
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
270
+ else:
271
+ attn_mask = None
272
+
273
+ self.register_buffer("attn_mask", attn_mask)
274
+
275
+ def forward_attn(self, x):
276
+ H, W = self.input_resolution
277
+ B, L, C = x.shape
278
+ assert L == H * W, "input feature has wrong size"
279
+
280
+ x = self.norm1(x)
281
+ x = x.view(B, H, W, C)
282
+
283
+ # cyclic shift
284
+ if self.shift_size > 0:
285
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
286
+ else:
287
+ shifted_x = x
288
+
289
+ # partition windows
290
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
291
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
292
+
293
+ # W-MSA/SW-MSA
294
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
295
+
296
+ # merge windows
297
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
298
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
299
+
300
+ # reverse cyclic shift
301
+ if self.shift_size > 0:
302
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
303
+ else:
304
+ x = shifted_x
305
+ x = x.view(B, H * W, C)
306
+
307
+ return x
308
+
309
+ def forward_mlp(self, x):
310
+ return self.drop_path(self.mlp(self.norm2(x)))
311
+
312
+ def forward(self, x):
313
+ shortcut = x
314
+ if self.use_checkpoint:
315
+ x = checkpoint.checkpoint(self.forward_attn, x)
316
+ else:
317
+ x = self.forward_attn(x)
318
+ x = shortcut + self.drop_path(x)
319
+
320
+ if self.use_checkpoint:
321
+ x = x + checkpoint.checkpoint(self.forward_mlp, x)
322
+ else:
323
+ x = x + self.forward_mlp(x)
324
+
325
+ return x
326
+
327
+ def extra_repr(self) -> str:
328
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
329
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
330
+
331
+ def flops(self):
332
+ flops = 0
333
+ H, W = self.input_resolution
334
+ # norm1
335
+ flops += self.dim * H * W
336
+ # W-MSA/SW-MSA
337
+ nW = H * W / self.window_size / self.window_size
338
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
339
+ # mlp
340
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
341
+ # norm2
342
+ flops += self.dim * H * W
343
+ return flops
344
+
345
+
346
+ class PatchMerging(nn.Module):
347
+ r""" Patch Merging Layer.
348
+
349
+ Args:
350
+ input_resolution (tuple[int]): Resolution of input feature.
351
+ dim (int): Number of input channels.
352
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
353
+ """
354
+
355
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
356
+ super().__init__()
357
+ self.input_resolution = input_resolution
358
+ self.dim = dim
359
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
360
+ self.norm = norm_layer(4 * dim)
361
+
362
+ def forward(self, x):
363
+ """
364
+ x: B, H*W, C
365
+ """
366
+ H, W = self.input_resolution
367
+ B, L, C = x.shape
368
+ assert L == H * W, "input feature has wrong size"
369
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
370
+
371
+ x = x.view(B, H, W, C)
372
+
373
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
374
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
375
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
376
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
377
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
378
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
379
+
380
+ x = self.norm(x)
381
+ x = self.reduction(x)
382
+
383
+ return x
384
+
385
+ def extra_repr(self) -> str:
386
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
387
+
388
+ def flops(self):
389
+ H, W = self.input_resolution
390
+ flops = H * W * self.dim
391
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
392
+ return flops
393
+
394
+
395
+ class BasicLayer(nn.Module):
396
+ """ A basic Swin Transformer layer for one stage.
397
+
398
+ Args:
399
+ dim (int): Number of input channels.
400
+ input_resolution (tuple[int]): Input resolution.
401
+ depth (int): Number of blocks.
402
+ num_heads (int): Number of attention heads.
403
+ window_size (int): Local window size.
404
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
405
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
406
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
407
+ drop (float, optional): Dropout rate. Default: 0.0
408
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
409
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
410
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
411
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
412
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
413
+ """
414
+
415
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
416
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
417
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
418
+ bottleneck=False):
419
+
420
+ super().__init__()
421
+ self.dim = dim
422
+ self.input_resolution = input_resolution
423
+ self.depth = depth
424
+ self.use_checkpoint = use_checkpoint
425
+
426
+ # build blocks
427
+ self.blocks = nn.ModuleList([
428
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
429
+ num_heads=num_heads, window_size=window_size,
430
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
431
+ mlp_ratio=mlp_ratio,
432
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
433
+ drop=drop, attn_drop=attn_drop,
434
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
435
+ norm_layer=norm_layer,
436
+ bottleneck=bottleneck if i == depth-1 else False,
437
+ use_checkpoint=use_checkpoint)
438
+ for i in range(depth)])
439
+
440
+ # patch merging layer
441
+ if downsample is not None:
442
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
443
+ else:
444
+ self.downsample = None
445
+
446
+ def forward(self, x):
447
+ for blk in self.blocks:
448
+ if self.use_checkpoint:
449
+ x = checkpoint.checkpoint(blk, x)
450
+ else:
451
+ x = blk(x)
452
+ if self.downsample is not None:
453
+ x = self.downsample(x)
454
+ return x
455
+
456
+ def extra_repr(self) -> str:
457
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
458
+
459
+ def flops(self):
460
+ flops = 0
461
+ for blk in self.blocks:
462
+ flops += blk.flops()
463
+ if self.downsample is not None:
464
+ flops += self.downsample.flops()
465
+ return flops
466
+
467
+
468
+ class PatchEmbed(nn.Module):
469
+ r""" Image to Patch Embedding
470
+
471
+ Args:
472
+ img_size (int): Image size. Default: 224.
473
+ patch_size (int): Patch token size. Default: 4.
474
+ in_chans (int): Number of input image channels. Default: 3.
475
+ embed_dim (int): Number of linear projection output channels. Default: 96.
476
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
477
+ """
478
+
479
+ def __init__(self, img_size=(224,224), patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
480
+ super().__init__()
481
+ #img_size = to_2tuple(img_size)
482
+ patch_size = to_2tuple(patch_size)
483
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
484
+ self.img_size = img_size
485
+ self.patch_size = patch_size
486
+ self.patches_resolution = patches_resolution
487
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
488
+
489
+ self.in_chans = in_chans
490
+ self.embed_dim = embed_dim
491
+
492
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
493
+ if norm_layer is not None:
494
+ self.norm = norm_layer(embed_dim)
495
+ else:
496
+ self.norm = None
497
+
498
+ def forward(self, x):
499
+ B, C, H, W = x.shape
500
+ # FIXME look at relaxing size constraints
501
+ assert H == self.img_size[0] and W == self.img_size[1], \
502
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
503
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
504
+ if self.norm is not None:
505
+ x = self.norm(x)
506
+ return x
507
+
508
+ def flops(self):
509
+ Ho, Wo = self.patches_resolution
510
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
511
+ if self.norm is not None:
512
+ flops += Ho * Wo * self.embed_dim
513
+ return flops
514
+
515
+
516
+ class SwinTransformer(nn.Module):
517
+ r""" Swin Transformer
518
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
519
+ https://arxiv.org/pdf/2103.14030
520
+ Args:
521
+ img_size (int | tuple(int)): Input image size. Default 224
522
+ patch_size (int | tuple(int)): Patch size. Default: 4
523
+ in_chans (int): Number of input image channels. Default: 3
524
+ num_classes (int): Number of classes for classification head. Default: 1000
525
+ embed_dim (int): Patch embedding dimension. Default: 96
526
+ depths (tuple(int)): Depth of each Swin Transformer layer.
527
+ num_heads (tuple(int)): Number of attention heads in different layers.
528
+ window_size (int): Window size. Default: 7
529
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
530
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
531
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
532
+ drop_rate (float): Dropout rate. Default: 0
533
+ attn_drop_rate (float): Attention dropout rate. Default: 0
534
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
535
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
536
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
537
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
538
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
539
+ """
540
+
541
+ def __init__(self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
542
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
543
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
544
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
545
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
546
+ use_checkpoint=False, thumbnail_rows=1, bottleneck=False, **kwargs):
547
+ super().__init__()
548
+
549
+ self.duration = duration
550
+ self.num_classes = num_classes
551
+ self.num_layers = len(depths)
552
+ self.embed_dim = embed_dim
553
+ self.ape = ape
554
+ self.patch_norm = patch_norm
555
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
556
+ self.mlp_ratio = mlp_ratio
557
+ self.thumbnail_rows = thumbnail_rows
558
+
559
+ self.img_size = img_size
560
+ self.window_size = [window_size for _ in depths] if not isinstance(window_size, list) else window_size
561
+ self.image_mode = True
562
+
563
+ self.frame_padding = self.duration % thumbnail_rows if self.image_mode is True else 0
564
+ if self.frame_padding != 0:
565
+ self.frame_padding = self.thumbnail_rows - self.frame_padding
566
+ self.duration += self.frame_padding
567
+
568
+ # split image into non-overlapping patches
569
+ if self.image_mode:
570
+ thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows)
571
+ thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1])
572
+ else:
573
+ thumbnail_size = (img_size, img_size)
574
+
575
+ print ('---------------------------------------')
576
+ print ('duration:', self.duration, 'frame padding:', self.frame_padding, 'image_size:', self.img_size, 'patch_size:', patch_size, 'thumbnail_size:', (thumbnail_rows, self.duration//thumbnail_rows), 'ape:', self.ape)
577
+
578
+ self.patch_embed = PatchEmbed(
579
+ img_size=(img_size, img_size), patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
580
+ norm_layer=norm_layer if self.patch_norm else None)
581
+ num_patches = self.patch_embed.num_patches
582
+ patches_resolution = self.patch_embed.patches_resolution
583
+ self.patches_resolution = patches_resolution
584
+
585
+ # absolute position embedding
586
+ if self.ape:
587
+ # self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
588
+ # trunc_normal_(self.absolute_pos_embed, std=.02)
589
+ self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim))
590
+ trunc_normal_(self.frame_pos_embed, std=.02)
591
+
592
+ self.pos_drop = nn.Dropout(p=drop_rate)
593
+
594
+ # stochastic depth
595
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
596
+
597
+ # build layers
598
+ self.layers = nn.ModuleList()
599
+ for i_layer in range(self.num_layers):
600
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
601
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
602
+ patches_resolution[1] // (2 ** i_layer)),
603
+ depth=depths[i_layer],
604
+ num_heads=num_heads[i_layer],
605
+ window_size=self.window_size[i_layer],
606
+ mlp_ratio=self.mlp_ratio,
607
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
608
+ drop=drop_rate, attn_drop=attn_drop_rate,
609
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
610
+ norm_layer=norm_layer,
611
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
612
+ use_checkpoint=use_checkpoint,
613
+ bottleneck=bottleneck)
614
+ self.layers.append(layer)
615
+
616
+ self.norm = norm_layer(self.num_features)
617
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
618
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
619
+
620
+ self.apply(self._init_weights)
621
+
622
+
623
+ def _init_weights(self, m):
624
+ if isinstance(m, nn.Linear):
625
+ trunc_normal_(m.weight, std=.02)
626
+ if isinstance(m, nn.Linear) and m.bias is not None:
627
+ nn.init.constant_(m.bias, 0)
628
+ elif isinstance(m, nn.LayerNorm):
629
+ nn.init.constant_(m.bias, 0)
630
+ nn.init.constant_(m.weight, 1.0)
631
+
632
+
633
+
634
+ @torch.jit.ignore
635
+ def no_weight_decay(self):
636
+ return {'absolute_pos_embed', 'frame_pos_embed'}
637
+
638
+ @torch.jit.ignore
639
+ def no_weight_decay_keywords(self):
640
+ return {'relative_position_bias_table'}
641
+
642
+ def create_thumbnail(self, x):
643
+ # import pdb;pdb.set_trace()
644
+ input_size = x.shape[-2:]
645
+ if input_size != to_2tuple(self.img_size):
646
+ x = nn.functional.interpolate(x, size=self.img_size,mode='bilinear')
647
+ x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)', th=self.thumbnail_rows, c=3)
648
+ return x
649
+
650
+ def pad_frames(self, x):
651
+ frame_num = self.duration - self.frame_padding
652
+ x = x.view((-1,3*frame_num)+x.size()[2:])
653
+ x_padding = torch.zeros((x.shape[0], 3*self.frame_padding) + x.size()[2:]).cuda()
654
+ x = torch.cat((x, x_padding), dim=1)
655
+ assert x.shape[1] == 3 * self.duration, 'frame number %d not the same as adjusted input size %d' % (x.shape[1], 3 * self.duration)
656
+
657
+ return x
658
+
659
+ # need to find a better way to do this, maybe torch.fold?
660
+ def create_image_pos_embed(self):
661
+ img_rows, img_cols = self.patches_resolution
662
+ _, _, T = self.frame_pos_embed.shape
663
+ rows = img_rows // self.thumbnail_rows
664
+ cols = img_cols // (self.duration // self.thumbnail_rows)
665
+ img_pos_embed = torch.zeros(img_rows, img_cols, T).cuda()
666
+ #print (self.duration, T, img_rows, img_cols, rows, cols)
667
+ for i in range(self.duration):
668
+ r_indx = (i // self.thumbnail_rows) * rows
669
+ c_indx = (i % self.thumbnail_rows) * cols
670
+ img_pos_embed[r_indx:r_indx+rows,c_indx:c_indx+cols] = self.frame_pos_embed[0, i]
671
+ #print (r_indx, r_indx+rows, c_indx, c_indx+cols)
672
+ return img_pos_embed.reshape(-1, T)
673
+
674
+ def forward_features(self, x):
675
+ # x = rearrange(x, 'b (t c) h w -> b c h (t w)', t=self.duration)
676
+ # in evaluation, it's Bx(num_crops*num_cips*num_frames*3)xHxW
677
+ # import pdb;pdb.set_trace()
678
+ b, t, _, h, w = x.shape
679
+ x = x.view(b, t*3, h, w)
680
+ if self.frame_padding > 0:
681
+ x = self.pad_frames(x)
682
+ else:
683
+ x = x.view((-1,3*self.duration)+x.size()[2:])
684
+
685
+ if self.image_mode:
686
+ x = self.create_thumbnail(x)
687
+ x = nn.functional.interpolate(x, size=self.img_size,mode='bilinear')
688
+ else:
689
+ x = rearrange(x, 'b (n t c) h w -> (b n t) c h w', t=self.duration, c=3)
690
+
691
+ x = self.patch_embed(x)
692
+ if self.ape:
693
+ # x = x + self.absolute_pos_embed
694
+ img_pos_embed = self.create_image_pos_embed()
695
+ x = x + img_pos_embed
696
+ x = self.pos_drop(x)
697
+
698
+ for layer in self.layers:
699
+ x = layer(x)
700
+
701
+ x = self.norm(x) # B L C
702
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
703
+ x = torch.flatten(x, 1)
704
+ return x
705
+
706
+ def forward(self, x):
707
+ x = self.forward_features(x)
708
+ x = self.head(x)
709
+ if not self.image_mode:
710
+ x = x.view(-1, self.duration, self.num_classes)
711
+ x = torch.mean(x, dim=1)
712
+ return x
713
+
714
+ def flops(self):
715
+ flops = 0
716
+ flops += self.patch_embed.flops()
717
+ for i, layer in enumerate(self.layers):
718
+ flops += layer.flops()
719
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
720
+ flops += self.num_features * self.num_classes
721
+ return flops
722
+
723
+
724
+ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_patches=196,
725
+ pretrained_window_size=7, pretrained_model="", strict=True):
726
+ if cfg is None:
727
+ cfg = getattr(model, 'default_cfg')
728
+ if cfg is None or 'url' not in cfg or not cfg['url']:
729
+ _logger.warning("Pretrained model URL is invalid, using random initialization.")
730
+ return
731
+
732
+ if len(pretrained_model) == 0:
733
+ # state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
734
+ # state_dict = load_state_dict_from_url(cfg['url'], progress=False, map_location='cpu')
735
+ state_dict = torch.load('/mnt/new_nas/yansan/models/swin_base_patch4_window7_224_22k.pth', map_location='cpu')
736
+ else:
737
+ try:
738
+ state_dict = load_state_dict(pretrained_model)['model']
739
+ except:
740
+ state_dict = load_state_dict(pretrained_model)
741
+
742
+ if filter_fn is not None:
743
+ state_dict = filter_fn(state_dict)
744
+
745
+ if in_chans == 1:
746
+ conv1_name = cfg['first_conv']
747
+ _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
748
+ conv1_weight = state_dict[conv1_name + '.weight']
749
+ conv1_type = conv1_weight.dtype
750
+ conv1_weight = conv1_weight.float()
751
+ O, I, J, K = conv1_weight.shape
752
+ if I > 3:
753
+ assert conv1_weight.shape[1] % 3 == 0
754
+ # For models with space2depth stems
755
+ conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
756
+ conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
757
+ else:
758
+ conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
759
+ conv1_weight = conv1_weight.to(conv1_type)
760
+ state_dict[conv1_name + '.weight'] = conv1_weight
761
+ elif in_chans != 3:
762
+ conv1_name = cfg['first_conv']
763
+ conv1_weight = state_dict[conv1_name + '.weight']
764
+ conv1_type = conv1_weight.dtype
765
+ conv1_weight = conv1_weight.float()
766
+ O, I, J, K = conv1_weight.shape
767
+ if I != 3:
768
+ _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
769
+ del state_dict[conv1_name + '.weight']
770
+ strict = False
771
+ else:
772
+ _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
773
+ repeat = int(math.ceil(in_chans / 3))
774
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
775
+ conv1_weight *= (3 / float(in_chans))
776
+ conv1_weight = conv1_weight.to(conv1_type)
777
+ state_dict[conv1_name + '.weight'] = conv1_weight
778
+
779
+ #for key, value in state_dict['model'].items():
780
+ # print (key)
781
+
782
+ classifier_name = cfg['classifier']
783
+ if num_classes == 1000 and cfg['num_classes'] == 1001:
784
+ # special case for imagenet trained models with extra background class in pretrained weights
785
+ classifier_weight = state_dict[classifier_name + '.weight']
786
+ state_dict[classifier_name + '.weight'] = classifier_weight[1:]
787
+ classifier_bias = state_dict[classifier_name + '.bias']
788
+ state_dict[classifier_name + '.bias'] = classifier_bias[1:]
789
+ elif num_classes != cfg['num_classes']: # and len(pretrained_model) == 0:
790
+ # completely discard fully connected for all other differences between pretrained and created model
791
+ del state_dict['model'][classifier_name + '.weight']
792
+ del state_dict['model'][classifier_name + '.bias']
793
+ strict = False
794
+ '''
795
+ ## Resizing the positional embeddings in case they don't match
796
+ if img_size != cfg['input_size'][1]:
797
+ pos_embed = state_dict['pos_embed']
798
+ cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
799
+ other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
800
+ new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
801
+ new_pos_embed = new_pos_embed.transpose(1, 2)
802
+ new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
803
+ state_dict['pos_embed'] = new_pos_embed
804
+ '''
805
+
806
+ # remove window_size related parameters
807
+ window_size = (model.window_size)[0]
808
+ print (pretrained_window_size, window_size)
809
+
810
+ new_state_dict = state_dict['model'].copy()
811
+ for key in state_dict['model']:
812
+ if 'attn_mask' in key:
813
+ del new_state_dict[key]
814
+
815
+ #if window_size != pretrained_window_size:
816
+ if 1:
817
+ if 'relative_position_index' in key:
818
+ del new_state_dict[key]
819
+
820
+ # resize it
821
+ if 'relative_position_bias_table' in key:
822
+ #print ('resizing relative_position_bias_table')
823
+ pretrained_table = state_dict['model'][key]
824
+ pretrained_table_size = int(math.sqrt(pretrained_table.shape[0]))
825
+ table_size = int(math.sqrt(model.state_dict()[key].shape[0]))
826
+ #print (pretrained_table_size, table_size)
827
+ if pretrained_table_size != table_size:
828
+ table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size)
829
+ table = nn.functional.interpolate(table, size=table_size, mode='bilinear')
830
+ table = table.view(-1, table_size*table_size).permute(1, 0)
831
+ new_state_dict[key] = table
832
+
833
+ for key in model.state_dict():
834
+ if 'bottleneck_norm' in key:
835
+ attn_key = key.replace('bottleneck_norm','norm1')
836
+ #print (key, attn_key)
837
+ new_state_dict[key] = new_state_dict[attn_key]
838
+
839
+ '''
840
+ for key in new_state_dict:
841
+ if key not in model.state_dict():
842
+ print ('----', key)
843
+ else:
844
+ print ('++++', key)
845
+ print ('====================')
846
+ for key in model.state_dict():
847
+ if key not in new_state_dict:
848
+ print ('----', key)
849
+ else:
850
+ print ('++++', key)
851
+ '''
852
+
853
+ print ('loading weights....')
854
+ ## Loading the weights
855
+ model_dict = model.state_dict()
856
+ pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
857
+ model_dict.update(pretrained_dict)
858
+ model.load_state_dict(model_dict, strict=False)
859
+
860
+ def _conv_filter(state_dict, patch_size=4):
861
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
862
+ out_dict = {}
863
+ for k, v in state_dict.items():
864
+ if 'patch_embed.proj.weight' in k:
865
+ if v.shape[-1] != patch_size:
866
+ patch_size = v.shape[-1]
867
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
868
+ out_dict[k] = v
869
+ return out_dict
870
+
871
+
872
+ def _create_vision_transformer(variant, pretrained=False, pretrained_window_size=7, **kwargs):
873
+ default_cfg = default_cfgs[variant]
874
+ default_num_classes = default_cfg['num_classes']
875
+ default_img_size = default_cfg['input_size'][-1]
876
+
877
+ num_classes = kwargs.pop('num_classes', default_num_classes)
878
+ img_size = kwargs.pop('img_size', default_img_size)
879
+ repr_size = kwargs.pop('representation_size', None)
880
+
881
+ model_cls = SwinTransformer
882
+ model = model_cls(img_size=img_size, num_classes=num_classes, **kwargs)
883
+ model.default_cfg = default_cfg
884
+
885
+
886
+ if pretrained:
887
+ load_pretrained(
888
+ model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
889
+ filter_fn=_conv_filter,
890
+ img_size=img_size,
891
+ pretrained_window_size=pretrained_window_size,
892
+ pretrained_model=''
893
+ )
894
+ return model
895
+
896
+
897
+
898
+ @register_model
899
+ def TALL_SWIN(pretrained=False, **kwargs):
900
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
901
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
902
+ """
903
+ temporal_module_name = kwargs.pop('temporal_module_name', None)
904
+ temporal_attention_only = kwargs.pop('temporal_attention_only', None)
905
+ temporal_heads_scale = kwargs.pop('temporal_heads_scale', 1.0)
906
+ temporal_mlp_scale = kwargs.pop('temporal_mlp_scale', 1.0)
907
+ rel_pos = kwargs.pop('rel_pos', False)
908
+ token_maks = kwargs.pop('token_mask', False)
909
+ frame_cls_tokens = kwargs.pop('frame_cls_tokens', 1)
910
+ kwargs.pop('hub_attention', '')
911
+ kwargs.pop('hub_aggregation', '')
912
+ kwargs.pop('spatial_hub_size', (-1, -1))
913
+ kwargs.pop('temporal_pooling', None)
914
+ kwargs.pop('window_size', -1)
915
+
916
+ embed_dim = 128
917
+ mlp_ratio = 4.
918
+ #drop_path_rate=0.5
919
+ patch_size=4
920
+ window_size=[14,14,14,7]
921
+ depths = [2, 2, 18, 2]
922
+ num_heads = [4, 8, 16, 32]
923
+ use_checkpoint=kwargs.pop('use_checkpoint', False)
924
+ ape = kwargs.pop('hpe_to_token', False)
925
+ bottleneck = True if kwargs.pop('bottleneck', None) is not None else False
926
+ model_kwargs = dict(patch_size=patch_size, window_size=window_size, embed_dim=embed_dim, depths=depths, num_heads=num_heads, mlp_ratio=mlp_ratio,
927
+ use_checkpoint=use_checkpoint, ape=ape, bottleneck=bottleneck, **kwargs)
928
+ print(model_kwargs)
929
+ model = _create_vision_transformer('swin_base_patch4_window7_224_22k', pretrained=pretrained, pretrained_window_size=7, **model_kwargs)
930
+ return model
931
+
932
+ if __name__ == '__main__':
933
+ dummy_input = torch.randn(4,8,3,224,224)
934
+ model = TALL_SWIN(pretrained=True)
935
+ print(model(dummy_input))
models/VideoMAE.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
3
+ import torch.nn as nn
4
+ import torchvision
5
+ import time
6
+ from .mamba_base import MambaConfig, ResidualBlock
7
+
8
+ def create_reorder_index(N, device):
9
+ new_order = []
10
+ for col in range(N):
11
+ if col % 2 == 0:
12
+ new_order.extend(range(col, N*N, N))
13
+ else:
14
+ new_order.extend(range(col + N*(N-1), col-1, -N))
15
+ return torch.tensor(new_order, device=device)
16
+
17
+ def reorder_data(data, N):
18
+ assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
19
+ device = data.device
20
+ new_order = create_reorder_index(N, device)
21
+ B, t, _, _ = data.shape
22
+ index = new_order.repeat(B, t, 1).unsqueeze(-1)
23
+ reordered_data = torch.gather(data, 2, index.expand_as(data))
24
+ return reordered_data
25
+
26
+ class Videomae_Net(nn.Module):
27
+ def __init__(
28
+ self, channel_size=512, dropout=0.2, class_num=1
29
+ ):
30
+ super(Videomae_Net, self).__init__()
31
+ self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae")
32
+ self.fc1 = nn.Linear(768, class_num)
33
+ self.bn1 = nn.BatchNorm1d(768)
34
+
35
+ self._init_params()
36
+
37
+ def _init_params(self):
38
+ nn.init.xavier_normal_(self.fc1.weight)
39
+ nn.init.constant_(self.fc1.bias, 0)
40
+
41
+ def forward(self, x):
42
+ x = self.model.videomae(x)
43
+ sequence_output = x[0]
44
+ print(sequence_output.shape)
45
+ if self.model.fc_norm is not None:
46
+ sequence_output = self.model.fc_norm(sequence_output.mean(1))
47
+ else:
48
+ sequence_output = sequence_output[:, 0]
49
+ x = self.bn1(sequence_output)
50
+ x = self.fc1(x)
51
+ return x
52
+
53
+
54
+
55
+ if __name__ == '__main__':
56
+
57
+ model = Videomae_Net()
58
+
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ model = model.to(device)
61
+ model.eval()
62
+
63
+
64
+ input_data = torch.randn(1, 16, 3, 224, 224).to(device)
65
+
66
+ model(input_data)
67
+
models/XCLIP.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import XCLIPVisionModel
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.nn.init as init
10
+ import math
11
+
12
+ from transformers import XCLIPVisionModel
13
+ class XCLIP(nn.Module):
14
+ def __init__(
15
+ self, channel_size=512, dropout=0.2, class_num=1
16
+ ):
17
+ super(XCLIP, self).__init__()
18
+
19
+ self.backbone = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip")
20
+ self.fc_norm = nn.LayerNorm(768)
21
+ self.head = nn.Linear(768, 1)
22
+
23
+ def forward(self, x):
24
+ b, t, _, h, w = x.shape
25
+ images = x.view(b * t, 3, h, w)
26
+ outputs = self.backbone(images, output_hidden_states=True)
27
+ sequence_output = outputs['pooler_output'].reshape(b, t, -1)
28
+ video_level_features = self.fc_norm(sequence_output.mean(1))
29
+ pred = self.head(video_level_features)
30
+
31
+ return pred
32
+
33
+
models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .F3Net import Det_F3_Net
2
+ from .NPR import resnet50_npr
3
+ from .STIL import Det_STIL
4
+ from .DeMamba import XCLIP_DeMamba, CLIP_DeMamba
5
+ from .XCLIP import XCLIP
models/__pycache__/DeMamba.cpython-39.pyc ADDED
Binary file (5.25 kB). View file
 
models/__pycache__/F3Net.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
models/__pycache__/NPR.cpython-39.pyc ADDED
Binary file (7.58 kB). View file
 
models/__pycache__/STIL.cpython-39.pyc ADDED
Binary file (17.7 kB). View file
 
models/__pycache__/XCLIP.cpython-39.pyc ADDED
Binary file (1.44 kB). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (396 Bytes). View file
 
models/__pycache__/mamba_base.cpython-39.pyc ADDED
Binary file (7.53 kB). View file
 
models/__pycache__/pscan.cpython-39.pyc ADDED
Binary file (5.62 kB). View file
 
models/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
models/clip/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (209 Bytes). View file
 
models/clip/__pycache__/clip.cpython-39.pyc ADDED
Binary file (8.21 kB). View file
 
models/clip/__pycache__/model.cpython-39.pyc ADDED
Binary file (14.9 kB). View file
 
models/clip/__pycache__/simple_tokenizer.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
models/clip/clip.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ try:
16
+ from torchvision.transforms import InterpolationMode
17
+ BICUBIC = InterpolationMode.BICUBIC
18
+ except ImportError:
19
+ BICUBIC = Image.BICUBIC
20
+
21
+
22
+ if torch.__version__.split(".") < ["1", "7", "1"]:
23
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24
+
25
+
26
+ __all__ = ["available_models", "load", "tokenize"]
27
+ _tokenizer = _Tokenizer()
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
37
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
38
+ }
39
+
40
+
41
+
42
+ def _download(url: str, root: str):
43
+ os.makedirs(root, exist_ok=True)
44
+ filename = os.path.basename(url)
45
+
46
+ expected_sha256 = url.split("/")[-2]
47
+ download_target = os.path.join(root, filename)
48
+
49
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
50
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
51
+
52
+ if os.path.isfile(download_target):
53
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54
+ return download_target
55
+ else:
56
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
60
+ while True:
61
+ buffer = source.read(8192)
62
+ if not buffer:
63
+ break
64
+
65
+ output.write(buffer)
66
+ loop.update(len(buffer))
67
+
68
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
70
+
71
+ return download_target
72
+
73
+
74
+ def _convert_image_to_rgb(image):
75
+ return image.convert("RGB")
76
+
77
+
78
+ def _transform(n_px):
79
+ return Compose([
80
+ Resize(n_px, interpolation=BICUBIC),
81
+ CenterCrop(n_px),
82
+ _convert_image_to_rgb,
83
+ ToTensor(),
84
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
85
+ ])
86
+
87
+
88
+ def available_models() -> List[str]:
89
+ """Returns the names of available CLIP models"""
90
+ return list(_MODELS.keys())
91
+
92
+
93
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
94
+ """Load a CLIP model
95
+
96
+ Parameters
97
+ ----------
98
+ name : str
99
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
100
+
101
+ device : Union[str, torch.device]
102
+ The device to put the loaded model
103
+
104
+ jit : bool
105
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
106
+
107
+ download_root: str
108
+ path to download the model files; by default, it uses "~/.cache/clip"
109
+
110
+ Returns
111
+ -------
112
+ model : torch.nn.Module
113
+ The CLIP model
114
+
115
+ preprocess : Callable[[PIL.Image], torch.Tensor]
116
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
117
+ """
118
+ # if name in _MODELS:
119
+ # model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
120
+ # elif os.path.isfile(name):
121
+ # model_path = name
122
+ # else:
123
+ # raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
124
+
125
+ model_path = 'weights/openclip/' + name +'.pt'
126
+ print(model_path)
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(model_path, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def patch_device(module):
149
+ try:
150
+ graphs = [module.graph] if hasattr(module, "graph") else []
151
+ except RuntimeError:
152
+ graphs = []
153
+
154
+ if hasattr(module, "forward1"):
155
+ graphs.append(module.forward1.graph)
156
+
157
+ for graph in graphs:
158
+ for node in graph.findAllNodes("prim::Constant"):
159
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160
+ node.copyAttributes(device_node)
161
+
162
+ model.apply(patch_device)
163
+ patch_device(model.encode_image)
164
+ patch_device(model.encode_text)
165
+
166
+ # patch dtype to float32 on CPU
167
+ if str(device) == "cpu":
168
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170
+ float_node = float_input.node()
171
+
172
+ def patch_float(module):
173
+ try:
174
+ graphs = [module.graph] if hasattr(module, "graph") else []
175
+ except RuntimeError:
176
+ graphs = []
177
+
178
+ if hasattr(module, "forward1"):
179
+ graphs.append(module.forward1.graph)
180
+
181
+ for graph in graphs:
182
+ for node in graph.findAllNodes("aten::to"):
183
+ inputs = list(node.inputs())
184
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185
+ if inputs[i].node()["value"] == 5:
186
+ inputs[i].node().copyAttributes(float_node)
187
+
188
+ model.apply(patch_float)
189
+ patch_float(model.encode_image)
190
+ patch_float(model.encode_text)
191
+
192
+ model.float()
193
+
194
+ return model, _transform(model.input_resolution.item())
195
+
196
+
197
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
198
+ """
199
+ Returns the tokenized representation of given input string(s)
200
+
201
+ Parameters
202
+ ----------
203
+ texts : Union[str, List[str]]
204
+ An input string or a list of input strings to tokenize
205
+
206
+ context_length : int
207
+ The context length to use; all CLIP models use 77 as the context length
208
+
209
+ truncate: bool
210
+ Whether to truncate the text in case its encoding is longer than the context length
211
+
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
215
+ """
216
+ if isinstance(texts, str):
217
+ texts = [texts]
218
+
219
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
220
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
221
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
222
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
223
+
224
+ for i, tokens in enumerate(all_tokens):
225
+ if len(tokens) > context_length:
226
+ if truncate:
227
+ tokens = tokens[:context_length]
228
+ tokens[-1] = eot_token
229
+ else:
230
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
231
+ result[i, :len(tokens)] = torch.tensor(tokens)
232
+
233
+ return result
models/clip/model.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(OrderedDict([
35
+ ("-1", nn.AvgPool2d(stride)),
36
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
+ ("1", nn.BatchNorm2d(planes * self.expansion))
38
+ ]))
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ identity = x
42
+
43
+ out = self.relu(self.bn1(self.conv1(x)))
44
+ out = self.relu(self.bn2(self.conv2(out)))
45
+ out = self.avgpool(out)
46
+ out = self.bn3(self.conv3(out))
47
+
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+ return out
54
+
55
+
56
+ class AttentionPool2d(nn.Module):
57
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
+ super().__init__()
59
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def forward(self, x):
67
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
+ self.bn2 = nn.BatchNorm2d(width // 2)
111
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
+ self.bn3 = nn.BatchNorm2d(width)
113
+ self.avgpool = nn.AvgPool2d(2)
114
+ self.relu = nn.ReLU(inplace=True)
115
+
116
+ # residual layers
117
+ self._inplanes = width # this is a *mutable* variable used during construction
118
+ self.layer1 = self._make_layer(width, layers[0])
119
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
+
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
+
126
+ def _make_layer(self, planes, blocks, stride=1):
127
+ layers = [Bottleneck(self._inplanes, planes, stride)]
128
+
129
+ self._inplanes = planes * Bottleneck.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(Bottleneck(self._inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ def stem(x):
137
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
+ x = self.relu(bn(conv(x)))
139
+ x = self.avgpool(x)
140
+ return x
141
+
142
+ x = x.type(self.conv1.weight.dtype)
143
+ x = stem(x)
144
+ x = self.layer1(x)
145
+ x = self.layer2(x)
146
+ x = self.layer3(x)
147
+ x = self.layer4(x)
148
+ x = self.attnpool(x)
149
+
150
+ return x
151
+
152
+
153
+ class LayerNorm(nn.LayerNorm):
154
+ """Subclass torch's LayerNorm to handle fp16."""
155
+
156
+ def forward(self, x: torch.Tensor):
157
+ orig_type = x.dtype
158
+ ret = super().forward(x.type(torch.float32))
159
+ return ret.type(orig_type)
160
+
161
+
162
+ class QuickGELU(nn.Module):
163
+ def forward(self, x: torch.Tensor):
164
+ return x * torch.sigmoid(1.702 * x)
165
+
166
+
167
+ class ResidualAttentionBlock(nn.Module):
168
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
+ super().__init__()
170
+
171
+ self.attn = nn.MultiheadAttention(d_model, n_head)
172
+ self.ln_1 = LayerNorm(d_model)
173
+ self.mlp = nn.Sequential(OrderedDict([
174
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
175
+ ("gelu", QuickGELU()),
176
+ ("c_proj", nn.Linear(d_model * 4, d_model))
177
+ ]))
178
+ self.ln_2 = LayerNorm(d_model)
179
+ self.attn_mask = attn_mask
180
+
181
+ def attention(self, x: torch.Tensor):
182
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ x = x + self.attention(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 1:, :])
232
+
233
+ if self.proj is not None:
234
+ x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisionTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return x
353
+
354
+ def forward(self, image, text):
355
+ image_features = self.encode_image(image)
356
+ text_features = self.encode_text(text)
357
+
358
+ # normalized features
359
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361
+
362
+ # cosine similarity as logits
363
+ logit_scale = self.logit_scale.exp()
364
+ logits_per_image = logit_scale * image_features @ text_features.t()
365
+ logits_per_text = logits_per_image.t()
366
+
367
+ # shape = [global_batch_size, global_batch_size]
368
+ return logits_per_image, logits_per_text
369
+
370
+
371
+ def convert_weights(model: nn.Module):
372
+ """Convert applicable model parameters to fp16"""
373
+
374
+ def _convert_weights_to_fp16(l):
375
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376
+ l.weight.data = l.weight.data.half()
377
+ if l.bias is not None:
378
+ l.bias.data = l.bias.data.half()
379
+
380
+ if isinstance(l, nn.MultiheadAttention):
381
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382
+ tensor = getattr(l, attr)
383
+ if tensor is not None:
384
+ tensor.data = tensor.data.half()
385
+
386
+ for name in ["text_projection", "proj"]:
387
+ if hasattr(l, name):
388
+ attr = getattr(l, name)
389
+ if attr is not None:
390
+ attr.data = attr.data.half()
391
+
392
+ model.apply(_convert_weights_to_fp16)
393
+
394
+
395
+ def build_model(state_dict: dict):
396
+ vit = "visual.proj" in state_dict
397
+
398
+ if vit:
399
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
400
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ image_resolution = vision_patch_size * grid_size
404
+ else:
405
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406
+ vision_layers = tuple(counts)
407
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409
+ vision_patch_size = None
410
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411
+ image_resolution = output_width * 32
412
+
413
+ embed_dim = state_dict["text_projection"].shape[1]
414
+ context_length = state_dict["positional_embedding"].shape[0]
415
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
416
+ transformer_width = state_dict["ln_final.weight"].shape[0]
417
+ transformer_heads = transformer_width // 64
418
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419
+
420
+ model = CLIP(
421
+ embed_dim,
422
+ image_resolution, vision_layers, vision_width, vision_patch_size,
423
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424
+ )
425
+
426
+ for key in ["input_resolution", "context_length", "vocab_size"]:
427
+ if key in state_dict:
428
+ del state_dict[key]
429
+
430
+ convert_weights(model)
431
+ model.load_state_dict(state_dict)
432
+ return model.eval()
models/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
models/mamba_base.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mamba: Linear-Time Sequence Modeling with Selective State Spaces
3
+ Copyright (c) Carnegie Mellon University.
4
+ Implemented by alxndrTL from https://github.com/alxndrTL/mamba.py
5
+ """
6
+
7
+
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from timm.models.layers import DropPath
16
+ from .pscan import pscan
17
+
18
+ @dataclass
19
+ class MambaConfig:
20
+ d_model: 768 # D
21
+ dt_rank: Union[int, str] = 'auto'
22
+ d_state: int = 16 # N in paper/comments
23
+ expand_factor: int = 2 # E in paper/comments
24
+ d_conv: int = 4
25
+
26
+ dt_min: float = 0.001
27
+ dt_max: float = 0.1
28
+ dt_init: str = "random" # "random" or "constant"
29
+ dt_scale: float = 1.0
30
+ dt_init_floor = 1e-4
31
+
32
+ drop_prob: float = 0.1
33
+
34
+ bias: bool = False
35
+ conv_bias: bool = True
36
+ bimamba: bool = True
37
+
38
+ pscan: bool = True # use parallel scan mode or sequential mode when training
39
+
40
+ def __post_init__(self):
41
+ self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
42
+
43
+ if self.dt_rank == 'auto':
44
+ self.dt_rank = math.ceil(self.d_model / 16)
45
+
46
+ class ResidualBlock(nn.Module):
47
+ def __init__(self, config: MambaConfig):
48
+ super().__init__()
49
+
50
+ self.mixer = MambaBlock(config)
51
+ self.norm = RMSNorm(config.d_model)
52
+ self.drop_path = DropPath(drop_prob=config.drop_prob)
53
+
54
+ def forward(self, x):
55
+ # x : (B, L, D)
56
+
57
+ # output : (B, L, D)
58
+
59
+ output = self.drop_path(self.mixer(self.norm(x))) + x
60
+ return output
61
+
62
+ def step(self, x, cache):
63
+ # x : (B, D)
64
+ # cache : (h, inputs)
65
+ # h : (B, ED, N)
66
+ # inputs: (B, ED, d_conv-1)
67
+
68
+ # output : (B, D)
69
+ # cache : (h, inputs)
70
+
71
+ output, cache = self.mixer.step(self.norm(x), cache)
72
+ output = output + x
73
+ return output, cache
74
+
75
+ class MambaBlock(nn.Module):
76
+ def __init__(self, config: MambaConfig):
77
+ super().__init__()
78
+
79
+ self.config = config
80
+
81
+ # projects block input from D to 2*ED (two branches)
82
+ self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
83
+
84
+ self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
85
+ kernel_size=config.d_conv, bias=config.conv_bias,
86
+ groups=config.d_inner,
87
+ padding=config.d_conv - 1)
88
+
89
+ # projects x to input-dependent Δ, B, C
90
+ self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
91
+
92
+ # projects Δ from dt_rank to d_inner
93
+ self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
94
+
95
+ # dt initialization
96
+ # dt weights
97
+ dt_init_std = config.dt_rank**-0.5 * config.dt_scale
98
+ if config.dt_init == "constant":
99
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
100
+ elif config.dt_init == "random":
101
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
102
+ else:
103
+ raise NotImplementedError
104
+
105
+ # dt bias
106
+ dt = torch.exp(
107
+ torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
108
+ ).clamp(min=config.dt_init_floor)
109
+ inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
110
+ with torch.no_grad():
111
+ self.dt_proj.bias.copy_(inv_dt)
112
+ #self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
113
+ # todo : explain why removed
114
+
115
+ # S4D real initialization
116
+ A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
117
+ self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
118
+ self.D = nn.Parameter(torch.ones(config.d_inner))
119
+
120
+ # projects block output from ED back to D
121
+ self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
122
+
123
+ self.bimamba = config.bimamba
124
+
125
+ if self.bimamba:
126
+ A_b = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
127
+ self.A_b_log = nn.Parameter(torch.log(A_b))
128
+
129
+ self.conv1d_b = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
130
+ kernel_size=config.d_conv, bias=config.conv_bias,
131
+ groups=config.d_inner,
132
+ padding=config.d_conv - 1)
133
+
134
+ self.x_proj_b = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
135
+ self.dt_proj_b = nn.Linear(config.dt_rank, config.d_inner, bias=True)
136
+ self.D_b = nn.Parameter(torch.ones(config.d_inner))
137
+
138
+ def forward(self, x):
139
+ # x : (B, L, D)
140
+
141
+ # y : (B, L, D)
142
+
143
+ _, L, _ = x.shape
144
+
145
+ xz = self.in_proj(x) # (B, L, 2*ED)
146
+ x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
147
+
148
+ # x branch
149
+ x = x.transpose(1, 2) # (B, ED, L)
150
+ x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
151
+ x = x.transpose(1, 2) # (B, L, ED)
152
+
153
+ x = F.silu(x)
154
+ y = self.ssm(x)
155
+
156
+ # z branch
157
+ z = F.silu(z)
158
+
159
+ output = y * z
160
+ output = self.out_proj(output) # (B, L, D)
161
+
162
+ return output
163
+
164
+ def ssm(self, x):
165
+ # x : (B, L, ED)
166
+
167
+ # y : (B, L, ED)
168
+
169
+ A = -torch.exp(self.A_log.float()) # (ED, N)
170
+ D = self.D.float()
171
+ # TODO remove .float()
172
+
173
+ deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
174
+
175
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
176
+ delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
177
+
178
+ if self.config.pscan:
179
+ y = self.selective_scan(x, delta, A, B, C, D)
180
+ else:
181
+ y = self.selective_scan_seq(x, delta, A, B, C, D)
182
+
183
+ if self.bimamba:
184
+ x_b = x.flip([-1])
185
+ A_b = -torch.exp(self.A_b_log.float()) # (ED, N)
186
+ D_b = self.D_b.float()
187
+ deltaBC_b = self.x_proj_b(x_b)
188
+ delta_b, B_b, C_b = torch.split(deltaBC_b, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
189
+ delta_b = F.softplus(self.dt_proj_b(delta_b)) # (B, L, ED)
190
+ if self.config.pscan:
191
+ y_b = self.selective_scan(x_b, delta_b, A_b, B_b, C_b, D_b)
192
+ else:
193
+ y_b = self.selective_scan_seq(x_b, delta_b, A_b, B_b, C_b, D_b)
194
+ y_b = y_b.flip([-1])
195
+ y = y + y_b
196
+ return y
197
+
198
+ def selective_scan(self, x, delta, A, B, C, D):
199
+ # x : (B, L, ED)
200
+ # Δ : (B, L, ED)
201
+ # A : (ED, N)
202
+ # B : (B, L, N)
203
+ # C : (B, L, N)
204
+ # D : (ED)
205
+
206
+ # y : (B, L, ED)
207
+
208
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
209
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
210
+
211
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
212
+
213
+ hs = pscan(deltaA, BX)
214
+
215
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
216
+
217
+ y = y + D * x
218
+
219
+ return y
220
+
221
+ def selective_scan_seq(self, x, delta, A, B, C, D):
222
+ # x : (B, L, ED)
223
+ # Δ : (B, L, ED)
224
+ # A : (ED, N)
225
+ # B : (B, L, N)
226
+ # C : (B, L, N)
227
+ # D : (ED)
228
+
229
+ # y : (B, L, ED)
230
+
231
+ _, L, _ = x.shape
232
+
233
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
234
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
235
+
236
+ BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
237
+
238
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
239
+ hs = []
240
+
241
+ for t in range(0, L):
242
+ h = deltaA[:, t] * h + BX[:, t]
243
+ hs.append(h)
244
+
245
+ hs = torch.stack(hs, dim=1) # (B, L, ED, N)
246
+
247
+ y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
248
+
249
+ y = y + D * x
250
+
251
+ return y
252
+
253
+ # -------------------------- inference -------------------------- #
254
+ """
255
+ Concerning auto-regressive inference
256
+
257
+ The cool part of using Mamba : inference is constant wrt to sequence length
258
+ We just have to keep in cache, for each layer, two things :
259
+ - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
260
+ - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
261
+ (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
262
+ (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
263
+
264
+ Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
265
+ h is (B, ED, N), and inputs is (B, ED, d_conv-1)
266
+ The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
267
+
268
+ The cache object is initialized as follows : (None, torch.zeros()).
269
+ When h is None, the selective scan function detects it and start with h=0.
270
+ The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
271
+
272
+ As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
273
+ """
274
+
275
+ def step(self, x, cache):
276
+ # x : (B, D)
277
+ # cache : (h, inputs)
278
+ # h : (B, ED, N)
279
+ # inputs : (B, ED, d_conv-1)
280
+
281
+ # y : (B, D)
282
+ # cache : (h, inputs)
283
+
284
+ h, inputs = cache
285
+
286
+ xz = self.in_proj(x) # (B, 2*ED)
287
+ x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
288
+
289
+ # x branch
290
+ x_cache = x.unsqueeze(2)
291
+ x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
292
+
293
+ x = F.silu(x)
294
+ y, h = self.ssm_step(x, h)
295
+
296
+ # z branch
297
+ z = F.silu(z)
298
+
299
+ output = y * z
300
+ output = self.out_proj(output) # (B, D)
301
+
302
+ # prepare cache for next call
303
+ inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
304
+ cache = (h, inputs)
305
+
306
+ return output, cache
307
+
308
+ def ssm_step(self, x, h):
309
+ # x : (B, ED)
310
+ # h : (B, ED, N)
311
+
312
+ # y : (B, ED)
313
+ # h : (B, ED, N)
314
+
315
+ A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
316
+ D = self.D.float()
317
+ # TODO remove .float()
318
+
319
+ deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
320
+
321
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
322
+ delta = F.softplus(self.dt_proj(delta)) # (B, ED)
323
+
324
+ deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
325
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
326
+
327
+ BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
328
+
329
+ if h is None:
330
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
331
+
332
+ h = deltaA * h + BX # (B, ED, N)
333
+
334
+ y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
335
+
336
+ y = y + D * x
337
+
338
+ # todo : pq h.squeeze(1) ??
339
+ return y, h.squeeze(1)
340
+
341
+ # taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
342
+ class RMSNorm(nn.Module):
343
+ def __init__(self, d_model: int, eps: float = 1e-5):
344
+ super().__init__()
345
+
346
+ self.eps = eps
347
+ self.weight = nn.Parameter(torch.ones(d_model))
348
+
349
+ def forward(self, x):
350
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
351
+
352
+ return output
models/pscan.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mamba: Linear-Time Sequence Modeling with Selective State Spaces
3
+ Copyright (c) Carnegie Mellon University.
4
+ Implemented by alxndrTL from https://github.com/alxndrTL/mamba.py
5
+ """
6
+
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ """
13
+
14
+ An implementation of the parallel scan operation in PyTorch (Blelloch version).
15
+ Please see docs/pscan.ipynb for a detailed explanation of what happens here.
16
+
17
+ """
18
+
19
+ def npo2(len):
20
+ """
21
+ Returns the next power of 2 above len
22
+ """
23
+
24
+ return 2 ** math.ceil(math.log2(len))
25
+
26
+ def pad_npo2(X):
27
+ """
28
+ Pads input length dim to the next power of 2
29
+
30
+ Args:
31
+ X : (B, L, D, N)
32
+
33
+ Returns:
34
+ Y : (B, npo2(L), D, N)
35
+ """
36
+
37
+ len_npo2 = npo2(X.size(1))
38
+ pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
39
+ return F.pad(X, pad_tuple, "constant", 0)
40
+
41
+ class PScan(torch.autograd.Function):
42
+ @staticmethod
43
+ def pscan(A, X):
44
+ # A : (B, D, L, N)
45
+ # X : (B, D, L, N)
46
+
47
+ # modifies X in place by doing a parallel scan.
48
+ # more formally, X will be populated by these values :
49
+ # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
50
+ # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
51
+
52
+ # only supports L that is a power of two (mainly for a clearer code)
53
+
54
+ B, D, L, _ = A.size()
55
+ num_steps = int(math.log2(L))
56
+
57
+ # up sweep (last 2 steps unfolded)
58
+ Aa = A
59
+ Xa = X
60
+ for _ in range(num_steps-2):
61
+ T = Xa.size(2)
62
+ Aa = Aa.view(B, D, T//2, 2, -1)
63
+ Xa = Xa.view(B, D, T//2, 2, -1)
64
+
65
+ Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
66
+ Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])
67
+
68
+ Aa = Aa[:, :, :, 1]
69
+ Xa = Xa[:, :, :, 1]
70
+
71
+ # we have only 4, 2 or 1 nodes left
72
+ if Xa.size(2) == 4:
73
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
74
+ Aa[:, :, 1].mul_(Aa[:, :, 0])
75
+
76
+ Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
77
+ elif Xa.size(2) == 2:
78
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
79
+ return
80
+ else:
81
+ return
82
+
83
+ # down sweep (first 2 steps unfolded)
84
+ Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
85
+ Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
86
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
87
+ Aa[:, :, 2].mul_(Aa[:, :, 1])
88
+
89
+ for k in range(num_steps-3, -1, -1):
90
+ Aa = A[:, :, 2**k-1:L:2**k]
91
+ Xa = X[:, :, 2**k-1:L:2**k]
92
+
93
+ T = Xa.size(2)
94
+ Aa = Aa.view(B, D, T//2, 2, -1)
95
+ Xa = Xa.view(B, D, T//2, 2, -1)
96
+
97
+ Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
98
+ Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])
99
+
100
+ @staticmethod
101
+ def pscan_rev(A, X):
102
+ # A : (B, D, L, N)
103
+ # X : (B, D, L, N)
104
+
105
+ # the same function as above, but in reverse
106
+ # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
107
+ # it is used in the backward pass
108
+
109
+ # only supports L that is a power of two (mainly for a clearer code)
110
+
111
+ B, D, L, _ = A.size()
112
+ num_steps = int(math.log2(L))
113
+
114
+ # up sweep (last 2 steps unfolded)
115
+ Aa = A
116
+ Xa = X
117
+ for _ in range(num_steps-2):
118
+ T = Xa.size(2)
119
+ Aa = Aa.view(B, D, T//2, 2, -1)
120
+ Xa = Xa.view(B, D, T//2, 2, -1)
121
+
122
+ Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
123
+ Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])
124
+
125
+ Aa = Aa[:, :, :, 0]
126
+ Xa = Xa[:, :, :, 0]
127
+
128
+ # we have only 4, 2 or 1 nodes left
129
+ if Xa.size(2) == 4:
130
+ Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
131
+ Aa[:, :, 2].mul_(Aa[:, :, 3])
132
+
133
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
134
+ elif Xa.size(2) == 2:
135
+ Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
136
+ return
137
+ else:
138
+ return
139
+
140
+ # down sweep (first 2 steps unfolded)
141
+ Aa = A[:, :, 0:L:2**(num_steps-2)]
142
+ Xa = X[:, :, 0:L:2**(num_steps-2)]
143
+ Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
144
+ Aa[:, :, 1].mul_(Aa[:, :, 2])
145
+
146
+ for k in range(num_steps-3, -1, -1):
147
+ Aa = A[:, :, 0:L:2**k]
148
+ Xa = X[:, :, 0:L:2**k]
149
+
150
+ T = Xa.size(2)
151
+ Aa = Aa.view(B, D, T//2, 2, -1)
152
+ Xa = Xa.view(B, D, T//2, 2, -1)
153
+
154
+ Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
155
+ Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])
156
+
157
+ @staticmethod
158
+ def forward(ctx, A_in, X_in):
159
+ """
160
+ Applies the parallel scan operation, as defined above. Returns a new tensor.
161
+ If you can, privilege sequence lengths that are powers of two.
162
+
163
+ Args:
164
+ A_in : (B, L, D, N)
165
+ X_in : (B, L, D, N)
166
+
167
+ Returns:
168
+ H : (B, L, D, N)
169
+ """
170
+
171
+ L = X_in.size(1)
172
+
173
+ # cloning is requiered because of the in-place ops
174
+ if L == npo2(L):
175
+ A = A_in.clone()
176
+ X = X_in.clone()
177
+ else:
178
+ # pad tensors (and clone btw)
179
+ A = pad_npo2(A_in) # (B, npo2(L), D, N)
180
+ X = pad_npo2(X_in) # (B, npo2(L), D, N)
181
+
182
+ # prepare tensors
183
+ A = A.transpose(2, 1) # (B, D, npo2(L), N)
184
+ X = X.transpose(2, 1) # (B, D, npo2(L), N)
185
+
186
+ # parallel scan (modifies X in-place)
187
+ PScan.pscan(A, X)
188
+
189
+ ctx.save_for_backward(A_in, X)
190
+
191
+ # slice [:, :L] (cut if there was padding)
192
+ return X.transpose(2, 1)[:, :L]
193
+
194
+ @staticmethod
195
+ def backward(ctx, grad_output_in):
196
+ """
197
+ Flows the gradient from the output to the input. Returns two new tensors.
198
+
199
+ Args:
200
+ ctx : A_in : (B, L, D, N), X : (B, D, L, N)
201
+ grad_output_in : (B, L, D, N)
202
+
203
+ Returns:
204
+ gradA : (B, L, D, N), gradX : (B, L, D, N)
205
+ """
206
+
207
+ A_in, X = ctx.saved_tensors
208
+
209
+ L = grad_output_in.size(1)
210
+
211
+ # cloning is requiered because of the in-place ops
212
+ if L == npo2(L):
213
+ grad_output = grad_output_in.clone()
214
+ # the next padding will clone A_in
215
+ else:
216
+ grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
217
+ A_in = pad_npo2(A_in) # (B, npo2(L), D, N)
218
+
219
+ # prepare tensors
220
+ grad_output = grad_output.transpose(2, 1)
221
+ A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
222
+ A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)
223
+
224
+ # reverse parallel scan (modifies grad_output in-place)
225
+ PScan.pscan_rev(A, grad_output)
226
+
227
+ Q = torch.zeros_like(X)
228
+ Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])
229
+
230
+ return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
231
+
232
+ pscan = PScan.apply
models/time_transformer ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021
3
+ Copyright (c) Xiamen University and its affiliates.
4
+ Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN
5
+ '''
6
+
7
+ import torch
8
+ from torch import nn, einsum
9
+ import torch.nn.functional as F
10
+
11
+ from einops import rearrange, repeat
12
+ from einops.layers.torch import Rearrange
13
+
14
+ class Residual(nn.Module):
15
+ def __init__(self, fn):
16
+ super().__init__()
17
+ self.fn = fn
18
+ def forward(self, x, **kwargs):
19
+ return self.fn(x, **kwargs) + x
20
+
21
+ class PreNorm(nn.Module):
22
+ def __init__(self, dim, fn):
23
+ super().__init__()
24
+ self.norm = nn.LayerNorm(dim)
25
+ self.fn = fn
26
+ def forward(self, x, **kwargs):
27
+ return self.fn(self.norm(x), **kwargs)
28
+
29
+ class FeedForward(nn.Module):
30
+ def __init__(self, dim, hidden_dim, dropout = 0.):
31
+ super().__init__()
32
+ self.net = nn.Sequential(
33
+ nn.Linear(dim, hidden_dim),
34
+ nn.GELU(),
35
+ nn.Dropout(dropout),
36
+ nn.Linear(hidden_dim, dim),
37
+ nn.Dropout(dropout)
38
+ )
39
+ def forward(self, x):
40
+ return self.net(x)
41
+
42
+ class Attention(nn.Module):
43
+ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
44
+ super().__init__()
45
+ inner_dim = dim_head * heads
46
+ project_out = not (heads == 1 and dim_head == dim)
47
+
48
+ self.heads = heads
49
+ self.scale = dim_head ** -0.5
50
+
51
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
52
+
53
+ self.to_out = nn.Sequential(
54
+ nn.Linear(inner_dim, dim),
55
+ nn.Dropout(dropout)
56
+ ) if project_out else nn.Identity()
57
+
58
+ def forward(self, x, mask = None):
59
+ b, n, _, h = *x.shape, self.heads
60
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
61
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
62
+
63
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
64
+ mask_value = -torch.finfo(dots.dtype).max
65
+
66
+ if mask is not None:
67
+ mask = F.pad(mask.flatten(1), (1, 0), value = True)
68
+ assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
69
+ mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
70
+ dots.masked_fill_(~mask, mask_value)
71
+ del mask
72
+
73
+ attn = dots.softmax(dim=-1)
74
+
75
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
76
+ out = rearrange(out, 'b h n d -> b n (h d)')
77
+ out = self.to_out(out)
78
+ return out
79
+
80
+ class Transformer(nn.Module):
81
+ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
82
+ super().__init__()
83
+ self.layers = nn.ModuleList([])
84
+ for _ in range(depth):
85
+ self.layers.append(nn.ModuleList([
86
+ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
87
+ Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
88
+ ]))
89
+ def forward(self, x, mask = None):
90
+ for attn, ff in self.layers:
91
+ x = attn(x, mask = mask)
92
+ x = ff(x)
93
+ return x
94
+
95
+ class ViT(nn.Module):
96
+ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
97
+ super().__init__()
98
+ assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
99
+ num_patches = (image_size // patch_size) ** 2
100
+ patch_dim = channels * patch_size ** 2
101
+ assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
102
+
103
+ self.to_patch_embedding = nn.Sequential(
104
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
105
+ nn.Linear(patch_dim, dim),
106
+ )
107
+
108
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
109
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
110
+ self.dropout = nn.Dropout(emb_dropout)
111
+
112
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
113
+
114
+ self.pool = pool
115
+ self.to_latent = nn.Identity()
116
+
117
+ self.mlp_head = nn.Sequential(
118
+ nn.LayerNorm(dim),
119
+ nn.Linear(dim, num_classes)
120
+ )
121
+
122
+ def forward(self, img, mask = None):
123
+ x = self.to_patch_embedding(img)
124
+ b, n, _ = x.shape #batch,num_patches,channels #
125
+
126
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
127
+ x = torch.cat((cls_tokens, x), dim=1)
128
+ x += self.pos_embedding[:, :(n + 1)]
129
+ x = self.dropout(x)
130
+
131
+ x = self.transformer(x, mask)
132
+
133
+ x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
134
+
135
+ x = self.to_latent(x)
136
+ return self.mlp_head(x)
137
+
138
+
139
+
140
+ def valid_idx(idx, h):
141
+ i = idx // h
142
+ j = idx % h
143
+ pad = h // 7
144
+ if j < pad or i >= h - pad or j >= h - pad:
145
+ return False
146
+ else:
147
+ return True
148
+
149
+ import random
150
+ from math import sqrt
151
+ class RandomSelect(nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+
155
+ def forward(self, x):
156
+ # batch,7x7
157
+ size=x.shape[1]
158
+ h=int(sqrt(size))
159
+ candidates = list(range(size))
160
+ candidates = [idx for idx in candidates if valid_idx(idx, h)]
161
+ max_k = len(candidates)
162
+ if self.training:
163
+ k = 8
164
+ if k==-1:
165
+ k=max_k
166
+ else:
167
+ k = max_k
168
+ candidates = random.sample(candidates, k)
169
+ x = x[:,candidates]
170
+ return x
171
+
172
+ class VideoiT(nn.Module):
173
+ def __init__(self, *, image_size, patch_size, num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
174
+ super().__init__()
175
+ assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
176
+ patch_dim = channels * patch_size ** 2
177
+ assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
178
+
179
+ self.to_patch = Rearrange('b c t (h p1) (w p2) -> b (h w) t (p1 p2 c)', p1 = patch_size, p2 = patch_size)
180
+ self.patch_to_embedding=nn.Linear(patch_dim, dim)
181
+ self.num_patches=num_patches
182
+
183
+
184
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
185
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
186
+ self.dropout = nn.Dropout(emb_dropout)
187
+
188
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
189
+
190
+ self.pool = pool
191
+ self.random_select=RandomSelect()
192
+ self.to_latent = nn.Identity()
193
+
194
+ self.mlp_head = nn.Sequential(
195
+ nn.LayerNorm(dim),
196
+ nn.Linear(dim, num_classes)
197
+ )
198
+
199
+ def forward(self, img, mask = None):
200
+ real_b=img.shape[0]
201
+ x = self.to_patch(img)
202
+ x = self.random_select(x)
203
+ n=x.shape[1]
204
+ x=x.reshape(real_b*n,self.num_patches,-1)
205
+ x = self.patch_to_embedding(x)
206
+ b, n, _ = x.shape #batch,num_patches,channels #
207
+
208
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
209
+ x = torch.cat((cls_tokens, x), dim=1)
210
+ x += self.pos_embedding[:, :(n + 1)]
211
+ x = self.dropout(x)
212
+
213
+ x = self.transformer(x, mask)
214
+
215
+ x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
216
+
217
+ x = self.to_latent(x)
218
+ x = self.mlp_head(x)
219
+ x = x.reshape(real_b,-1)
220
+ return x
221
+
222
+
223
+ class TimeTransformer(nn.Module):
224
+ def __init__(self,num_patches, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.):
225
+ super().__init__()
226
+ assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
227
+
228
+ self.num_patches=num_patches
229
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
230
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
231
+ self.dropout = nn.Dropout(emb_dropout)
232
+
233
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
234
+
235
+ self.pool = pool
236
+ self.to_latent = nn.Identity()
237
+
238
+ self.mlp_head = nn.Sequential(
239
+ nn.LayerNorm(dim),
240
+ nn.Linear(dim, num_classes)
241
+ )
242
+
243
+ def forward(self, x):
244
+ b, n, _ = x.shape #batch,num_patches,channels #
245
+
246
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
247
+ x = torch.cat((cls_tokens, x), dim=1)
248
+ x += self.pos_embedding[:, :(n + 1)]
249
+ x = self.dropout(x)
250
+
251
+ x = self.transformer(x, mask=None)
252
+
253
+ x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
254
+
255
+ x = self.to_latent(x)
256
+ return self.mlp_head(x)
requirements.txt ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.3.1
2
+ albucore==0.0.24
3
+ albumentations==2.0.8
4
+ annotated-types==0.7.0
5
+ certifi==2025.10.5
6
+ charset-normalizer==3.4.3
7
+ eval-type-backport==0.2.2
8
+ filelock==3.19.1
9
+ fsspec==2025.9.0
10
+ ftfy==6.3.1
11
+ grpcio==1.75.1
12
+ hf-xet==1.1.10
13
+ huggingface-hub==0.35.3
14
+ idna==3.10
15
+ importlib-metadata==8.7.0
16
+ jinja2==3.1.6
17
+ joblib==1.5.2
18
+ markdown==3.9
19
+ markupsafe==3.0.3
20
+ mpmath==1.3.0
21
+ networkx==3.2.1
22
+ numpy==2.0.2
23
+ opencv-python==4.12.0.88
24
+ opencv-python-headless==4.12.0.88
25
+ packaging==25.0
26
+ pandas==2.3.3
27
+ pillow==11.3.0
28
+ protobuf==6.32.1
29
+ pydantic==2.11.10
30
+ pydantic-core==2.33.2
31
+ python-dateutil==2.9.0.post0
32
+ pytz==2025.2
33
+ pyyaml==6.0.3
34
+ regex==2025.9.18
35
+ requests==2.32.5
36
+ safetensors==0.6.2
37
+ scikit-learn==1.6.1
38
+ scipy==1.13.1
39
+ simsimd==6.5.3
40
+ six==1.17.0
41
+ stringzilla==4.1.0
42
+ sympy==1.14.0
43
+ tensorboard==2.20.0
44
+ tensorboard-data-server==0.7.2
45
+ threadpoolctl==3.6.0
46
+ timm==1.0.20
47
+ tokenizers==0.22.1
48
+ torch==2.8.0
49
+ torchvision==0.23.0
50
+ tqdm==4.67.1
51
+ transformers==4.57.0
52
+ triton==3.4.0
53
+ typing-extensions==4.15.0
54
+ typing-inspection==0.4.2
55
+ tzdata==2025.2
56
+ urllib3==2.5.0
57
+ wcwidth==0.2.14
58
+ werkzeug==3.1.3
59
+ zipp==3.23.0
results.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ file_name,predicted_prob,predicted_label
2
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f000.jpg,0.8246555924415588,1
3
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f001.jpg,0.8141521215438843,1
4
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f002.jpg,0.8013387322425842,1
5
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f003.jpg,0.7972325086593628,1
6
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6_f004.jpg,0.8086235523223877,1
script.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ scripts = [
4
+ "extract_frames.py",
5
+ "eval2.py",
6
+ "create_submission.py"
7
+ ]
8
+
9
+ def run_script(script):
10
+ print(f"\n🚀 Running {script}...")
11
+ result = subprocess.run(["python", script], capture_output=True, text=True)
12
+ if result.returncode == 0:
13
+ print(f"✅ {script} completed successfully.\n")
14
+ print(result.stdout)
15
+ else:
16
+ print(f"❌ {script} failed with error:\n{result.stderr}")
17
+ exit(1)
18
+
19
+ if __name__ == "__main__":
20
+ for script in scripts:
21
+ run_script(script)
22
+ print("🎉 All scripts completed successfully!")
submission.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,pred,score
2
+ kling_kling_6ebbdf57a77769ab5a9cfb682249a9cfb9e389666d1b568bedcaf37c45c521f6,generated,0.8092005014419555