Spaces:
Running
Running
wangmengchao
commited on
Commit
·
282b272
1
Parent(s):
a893799
init
Browse files- app.py +312 -0
- diffsynth/__init__.py +5 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +650 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/video.py +173 -0
- diffsynth/pipelines/__init__.py +1 -0
- diffsynth/pipelines/base.py +127 -0
- diffsynth/pipelines/wan_video.py +290 -0
- diffsynth/prompters/__init__.py +1 -0
- diffsynth/prompters/base_prompter.py +70 -0
- diffsynth/prompters/wan_prompter.py +108 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +105 -0
- diffsynth/schedulers/flow_match.py +79 -0
- diffsynth/vram_management/__init__.py +1 -0
- diffsynth/vram_management/layers.py +95 -0
- infer.py +214 -0
- model.py +229 -0
- requirements.txt +14 -0
- utils.py +49 -0
app.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import argparse
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import librosa
|
| 6 |
+
from infer import load_models,main
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = None,None,None,None
|
| 10 |
+
|
| 11 |
+
def generate_video(
|
| 12 |
+
image_path,
|
| 13 |
+
audio_path,
|
| 14 |
+
prompt,
|
| 15 |
+
prompt_cfg_scale,
|
| 16 |
+
audio_cfg_scale,
|
| 17 |
+
audio_weight,
|
| 18 |
+
image_size,
|
| 19 |
+
max_num_frames,
|
| 20 |
+
inference_steps,
|
| 21 |
+
seed,
|
| 22 |
+
):
|
| 23 |
+
# Create the temp directory if it doesn't exist
|
| 24 |
+
output_dir = Path("./output")
|
| 25 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
|
| 27 |
+
# Convert paths to absolute Path objects and normalize them
|
| 28 |
+
print(image_path)
|
| 29 |
+
image_path = Path(image_path).absolute().as_posix()
|
| 30 |
+
audio_path = Path(audio_path).absolute().as_posix()
|
| 31 |
+
|
| 32 |
+
# Parse the arguments
|
| 33 |
+
|
| 34 |
+
args = create_args(
|
| 35 |
+
image_path=image_path,
|
| 36 |
+
audio_path=audio_path,
|
| 37 |
+
prompt=prompt,
|
| 38 |
+
output_dir=str(output_dir),
|
| 39 |
+
audio_weight=audio_weight,
|
| 40 |
+
prompt_cfg_scale=prompt_cfg_scale,
|
| 41 |
+
audio_cfg_scale=audio_cfg_scale,
|
| 42 |
+
image_size=image_size,
|
| 43 |
+
max_num_frames=max_num_frames,
|
| 44 |
+
inference_steps=inference_steps,
|
| 45 |
+
seed=seed,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
global pipe, fantasytalking, wav2vec_processor, wav2vec
|
| 50 |
+
if pipe is None:
|
| 51 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
|
| 52 |
+
output_path=main(
|
| 53 |
+
args,pipe,fantasytalking,wav2vec_processor,wav2vec
|
| 54 |
+
)
|
| 55 |
+
return output_path # Ensure the output path is returned
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"Error during processing: {str(e)}")
|
| 58 |
+
raise gr.Error(f"Error during processing: {str(e)}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def create_args(
|
| 62 |
+
image_path: str,
|
| 63 |
+
audio_path: str,
|
| 64 |
+
prompt: str,
|
| 65 |
+
output_dir: str,
|
| 66 |
+
audio_weight: float,
|
| 67 |
+
prompt_cfg_scale: float,
|
| 68 |
+
audio_cfg_scale: float,
|
| 69 |
+
image_size: int,
|
| 70 |
+
max_num_frames: int,
|
| 71 |
+
inference_steps: int,
|
| 72 |
+
seed: int,
|
| 73 |
+
) -> argparse.Namespace:
|
| 74 |
+
parser = argparse.ArgumentParser()
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--wan_model_dir",
|
| 77 |
+
type=str,
|
| 78 |
+
default="./models/Wan2.1-I2V-14B-720P",
|
| 79 |
+
required=False,
|
| 80 |
+
help="The dir of the Wan I2V 14B model.",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--fantasytalking_model_path",
|
| 84 |
+
type=str,
|
| 85 |
+
default="./models/fantasytalking_model.ckpt",
|
| 86 |
+
required=False,
|
| 87 |
+
help="The .ckpt path of fantasytalking model.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--wav2vec_model_dir",
|
| 91 |
+
type=str,
|
| 92 |
+
default="./models/wav2vec2-base-960h",
|
| 93 |
+
required=False,
|
| 94 |
+
help="The dir of wav2vec model.",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--image_path",
|
| 98 |
+
type=str,
|
| 99 |
+
default="./assets/images/woman.png",
|
| 100 |
+
required=False,
|
| 101 |
+
help="The path of the image.",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--audio_path",
|
| 105 |
+
type=str,
|
| 106 |
+
default="./assets/audios/woman.wav",
|
| 107 |
+
required=False,
|
| 108 |
+
help="The path of the audio.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--prompt",
|
| 112 |
+
type=str,
|
| 113 |
+
default="A woman is talking.",
|
| 114 |
+
required=False,
|
| 115 |
+
help="prompt.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--output_dir",
|
| 119 |
+
type=str,
|
| 120 |
+
default="./output",
|
| 121 |
+
help="Dir to save the video.",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--image_size",
|
| 125 |
+
type=int,
|
| 126 |
+
default=512,
|
| 127 |
+
help="The image will be resized proportionally to this size.",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--audio_scale",
|
| 131 |
+
type=float,
|
| 132 |
+
default=1.0,
|
| 133 |
+
help="Image width.",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--prompt_cfg_scale",
|
| 137 |
+
type=float,
|
| 138 |
+
default=5.0,
|
| 139 |
+
required=False,
|
| 140 |
+
help="prompt cfg scale",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--audio_cfg_scale",
|
| 144 |
+
type=float,
|
| 145 |
+
default=5.0,
|
| 146 |
+
required=False,
|
| 147 |
+
help="audio cfg scale",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--max_num_frames",
|
| 151 |
+
type=int,
|
| 152 |
+
default=81,
|
| 153 |
+
required=False,
|
| 154 |
+
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--inference_steps",
|
| 158 |
+
type=int,
|
| 159 |
+
default=20,
|
| 160 |
+
required=False,
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--fps",
|
| 164 |
+
type=int,
|
| 165 |
+
default=23,
|
| 166 |
+
required=False,
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--num_persistent_param_in_dit",
|
| 170 |
+
type=int,
|
| 171 |
+
default=None,
|
| 172 |
+
required=False,
|
| 173 |
+
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--seed",
|
| 177 |
+
type=int,
|
| 178 |
+
default=1111,
|
| 179 |
+
required=False,
|
| 180 |
+
)
|
| 181 |
+
args = parser.parse_args(
|
| 182 |
+
[
|
| 183 |
+
"--image_path",
|
| 184 |
+
image_path,
|
| 185 |
+
"--audio_path",
|
| 186 |
+
audio_path,
|
| 187 |
+
"--prompt",
|
| 188 |
+
prompt,
|
| 189 |
+
"--output_dir",
|
| 190 |
+
output_dir,
|
| 191 |
+
"--image_size",
|
| 192 |
+
str(image_size),
|
| 193 |
+
"--audio_scale",
|
| 194 |
+
str(audio_weight),
|
| 195 |
+
"--prompt_cfg_scale",
|
| 196 |
+
str(prompt_cfg_scale),
|
| 197 |
+
"--audio_cfg_scale",
|
| 198 |
+
str(audio_cfg_scale),
|
| 199 |
+
"--max_num_frames",
|
| 200 |
+
str(max_num_frames),
|
| 201 |
+
"--inference_steps",
|
| 202 |
+
str(inference_steps),
|
| 203 |
+
"--seed",
|
| 204 |
+
str(seed),
|
| 205 |
+
]
|
| 206 |
+
)
|
| 207 |
+
print(args)
|
| 208 |
+
return args
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Create Gradio interface
|
| 212 |
+
with gr.Blocks(title="FantasyTalking Video Generation") as demo:
|
| 213 |
+
gr.Markdown(
|
| 214 |
+
"""
|
| 215 |
+
# FantasyTalking: Realistic Talking Portrait Generation via Coherent Motion Synthesis
|
| 216 |
+
|
| 217 |
+
<div align="center">
|
| 218 |
+
<strong> Mengchao Wang1* Qiang Wang1* Fan Jiang1†
|
| 219 |
+
Yaqi Fan2 Yunpeng Zhang1,2 YongGang Qi2‡
|
| 220 |
+
Kun Zhao1. Mu Xu1 </strong>
|
| 221 |
+
</div>
|
| 222 |
+
|
| 223 |
+
<div align="center">
|
| 224 |
+
<strong>1AMAP,Alibaba Group 2Beijing University of Posts and Telecommunications</strong>
|
| 225 |
+
</div>
|
| 226 |
+
|
| 227 |
+
<div style="display:flex;justify-content:center;column-gap:4px;">
|
| 228 |
+
<a href="https://github.com/Fantasy-AMAP/fantasy-talking">
|
| 229 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
|
| 230 |
+
</a>
|
| 231 |
+
<a href="https://arxiv.org/abs/2504.04842">
|
| 232 |
+
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
|
| 233 |
+
</a>
|
| 234 |
+
</div>
|
| 235 |
+
"""
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
with gr.Row():
|
| 239 |
+
with gr.Column():
|
| 240 |
+
image_input = gr.Image(label="Input Image", type="filepath")
|
| 241 |
+
audio_input = gr.Audio(label="Input Audio", type="filepath")
|
| 242 |
+
prompt_input = gr.Text(label="Input Prompt")
|
| 243 |
+
with gr.Row():
|
| 244 |
+
prompt_cfg_scale = gr.Slider(
|
| 245 |
+
minimum=1.0,
|
| 246 |
+
maximum=9.0,
|
| 247 |
+
value=5.0,
|
| 248 |
+
step=0.5,
|
| 249 |
+
label="Prompt CFG Scale",
|
| 250 |
+
)
|
| 251 |
+
audio_cfg_scale = gr.Slider(
|
| 252 |
+
minimum=1.0,
|
| 253 |
+
maximum=9.0,
|
| 254 |
+
value=5.0,
|
| 255 |
+
step=0.5,
|
| 256 |
+
label="Audio CFG Scale",
|
| 257 |
+
)
|
| 258 |
+
audio_weight = gr.Slider(
|
| 259 |
+
minimum=0.1,
|
| 260 |
+
maximum=3.0,
|
| 261 |
+
value=1.0,
|
| 262 |
+
step=0.1,
|
| 263 |
+
label="Audio Weight",
|
| 264 |
+
)
|
| 265 |
+
with gr.Row():
|
| 266 |
+
image_size = gr.Number(
|
| 267 |
+
value=512, label="Width/Height Maxsize", precision=0
|
| 268 |
+
)
|
| 269 |
+
max_num_frames = gr.Number(
|
| 270 |
+
value=81, label="The Maximum Frames", precision=0
|
| 271 |
+
)
|
| 272 |
+
inference_steps = gr.Slider(
|
| 273 |
+
minimum=1, maximum=50, value=20, step=1, label="Inference Steps"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
with gr.Row():
|
| 277 |
+
seed = gr.Number(value=1247, label="Random Seed", precision=0)
|
| 278 |
+
|
| 279 |
+
process_btn = gr.Button("Generate Video")
|
| 280 |
+
|
| 281 |
+
with gr.Column():
|
| 282 |
+
video_output = gr.Video(label="Output Video")
|
| 283 |
+
|
| 284 |
+
gr.Examples(
|
| 285 |
+
examples=[
|
| 286 |
+
[
|
| 287 |
+
"/home/wangmengchao.wmc/code/fantasytalking/assets/images/woman.png",
|
| 288 |
+
"/home/wangmengchao.wmc/code/fantasytalking/assets/audios/woman.wav",
|
| 289 |
+
],
|
| 290 |
+
],
|
| 291 |
+
inputs=[image_input, audio_input],
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
process_btn.click(
|
| 295 |
+
fn=generate_video,
|
| 296 |
+
inputs=[
|
| 297 |
+
image_input,
|
| 298 |
+
audio_input,
|
| 299 |
+
prompt_input,
|
| 300 |
+
prompt_cfg_scale,
|
| 301 |
+
audio_cfg_scale,
|
| 302 |
+
audio_weight,
|
| 303 |
+
image_size,
|
| 304 |
+
max_num_frames,
|
| 305 |
+
inference_steps,
|
| 306 |
+
seed,
|
| 307 |
+
],
|
| 308 |
+
outputs=video_output,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
demo.launch(inbrowser=True, share=True)
|
diffsynth/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data import *
|
| 2 |
+
from .models import *
|
| 3 |
+
from .prompters import *
|
| 4 |
+
from .schedulers import *
|
| 5 |
+
from .pipelines import *
|
diffsynth/configs/__init__.py
ADDED
|
File without changes
|
diffsynth/configs/model_config.py
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import Literal, TypeAlias
|
| 2 |
+
|
| 3 |
+
from ..models.wan_video_dit import WanModel
|
| 4 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
| 5 |
+
from ..models.wan_video_image_encoder import WanImageEncoder
|
| 6 |
+
from ..models.wan_video_vae import WanVideoVAE
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
model_loader_configs = [
|
| 10 |
+
# These configs are provided for detecting model type automatically.
|
| 11 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
| 12 |
+
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
| 13 |
+
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
| 14 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
| 15 |
+
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
| 16 |
+
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
| 17 |
+
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 18 |
+
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 19 |
+
]
|
| 20 |
+
huggingface_model_loader_configs = [
|
| 21 |
+
# These configs are provided for detecting model type automatically.
|
| 22 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
| 23 |
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
| 24 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
| 25 |
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
| 26 |
+
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
| 27 |
+
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
| 28 |
+
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
| 29 |
+
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
| 30 |
+
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
| 31 |
+
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
| 32 |
+
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
| 33 |
+
]
|
| 34 |
+
patch_model_loader_configs = [
|
| 35 |
+
# These configs are provided for detecting model type automatically.
|
| 36 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
| 37 |
+
# ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
preset_models_on_huggingface = {
|
| 41 |
+
"HunyuanDiT": [
|
| 42 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
| 43 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
| 44 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
| 45 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
| 46 |
+
],
|
| 47 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 48 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
| 49 |
+
],
|
| 50 |
+
"ExVideo-SVD-128f-v1": [
|
| 51 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
| 52 |
+
],
|
| 53 |
+
# Stable Diffusion
|
| 54 |
+
"StableDiffusion_v15": [
|
| 55 |
+
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
| 56 |
+
],
|
| 57 |
+
"DreamShaper_8": [
|
| 58 |
+
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
| 59 |
+
],
|
| 60 |
+
# Textual Inversion
|
| 61 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 62 |
+
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
| 63 |
+
],
|
| 64 |
+
# Stable Diffusion XL
|
| 65 |
+
"StableDiffusionXL_v1": [
|
| 66 |
+
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
| 67 |
+
],
|
| 68 |
+
"BluePencilXL_v200": [
|
| 69 |
+
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
| 70 |
+
],
|
| 71 |
+
"StableDiffusionXL_Turbo": [
|
| 72 |
+
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
| 73 |
+
],
|
| 74 |
+
# Stable Diffusion 3
|
| 75 |
+
"StableDiffusion3": [
|
| 76 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
| 77 |
+
],
|
| 78 |
+
"StableDiffusion3_without_T5": [
|
| 79 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
| 80 |
+
],
|
| 81 |
+
# ControlNet
|
| 82 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 83 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
| 84 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 85 |
+
],
|
| 86 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 87 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
| 88 |
+
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
| 89 |
+
],
|
| 90 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 91 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
| 92 |
+
],
|
| 93 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 94 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
| 95 |
+
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
| 96 |
+
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
| 97 |
+
],
|
| 98 |
+
"ControlNet_union_sdxl_promax": [
|
| 99 |
+
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
| 100 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 101 |
+
],
|
| 102 |
+
# AnimateDiff
|
| 103 |
+
"AnimateDiff_v2": [
|
| 104 |
+
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
| 105 |
+
],
|
| 106 |
+
"AnimateDiff_xl_beta": [
|
| 107 |
+
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
| 108 |
+
],
|
| 109 |
+
|
| 110 |
+
# Qwen Prompt
|
| 111 |
+
"QwenPrompt": [
|
| 112 |
+
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 113 |
+
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 114 |
+
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 115 |
+
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 116 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 117 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 118 |
+
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 119 |
+
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 120 |
+
],
|
| 121 |
+
# Beautiful Prompt
|
| 122 |
+
"BeautifulPrompt": [
|
| 123 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 124 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 125 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 126 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 127 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 128 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 129 |
+
],
|
| 130 |
+
# Omost prompt
|
| 131 |
+
"OmostPrompt":[
|
| 132 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 133 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 134 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 135 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 136 |
+
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 137 |
+
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 138 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 139 |
+
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 140 |
+
],
|
| 141 |
+
# Translator
|
| 142 |
+
"opus-mt-zh-en": [
|
| 143 |
+
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
| 144 |
+
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
| 145 |
+
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
| 146 |
+
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
| 147 |
+
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 148 |
+
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 149 |
+
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
| 150 |
+
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 151 |
+
],
|
| 152 |
+
# IP-Adapter
|
| 153 |
+
"IP-Adapter-SD": [
|
| 154 |
+
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
| 155 |
+
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
| 156 |
+
],
|
| 157 |
+
"IP-Adapter-SDXL": [
|
| 158 |
+
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
| 159 |
+
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
| 160 |
+
],
|
| 161 |
+
"SDXL-vae-fp16-fix": [
|
| 162 |
+
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
| 163 |
+
],
|
| 164 |
+
# Kolors
|
| 165 |
+
"Kolors": [
|
| 166 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
| 167 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
| 168 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 169 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 170 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 171 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 172 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 173 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 174 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 175 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
| 176 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
| 177 |
+
],
|
| 178 |
+
# FLUX
|
| 179 |
+
"FLUX.1-dev": [
|
| 180 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 181 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 182 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 183 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 184 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 185 |
+
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 186 |
+
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 187 |
+
],
|
| 188 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 189 |
+
"file_list": [
|
| 190 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
| 191 |
+
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 192 |
+
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 193 |
+
],
|
| 194 |
+
"load_path": [
|
| 195 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 196 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 197 |
+
],
|
| 198 |
+
},
|
| 199 |
+
# RIFE
|
| 200 |
+
"RIFE": [
|
| 201 |
+
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
| 202 |
+
],
|
| 203 |
+
# CogVideo
|
| 204 |
+
"CogVideoX-5B": [
|
| 205 |
+
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 206 |
+
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 207 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 208 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 209 |
+
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 210 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 211 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 212 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 213 |
+
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
| 214 |
+
],
|
| 215 |
+
# Stable Diffusion 3.5
|
| 216 |
+
"StableDiffusion3.5-large": [
|
| 217 |
+
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
| 218 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 219 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 220 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 221 |
+
],
|
| 222 |
+
}
|
| 223 |
+
preset_models_on_modelscope = {
|
| 224 |
+
# Hunyuan DiT
|
| 225 |
+
"HunyuanDiT": [
|
| 226 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
| 227 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
| 228 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
| 229 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
| 230 |
+
],
|
| 231 |
+
# Stable Video Diffusion
|
| 232 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 233 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
| 234 |
+
],
|
| 235 |
+
# ExVideo
|
| 236 |
+
"ExVideo-SVD-128f-v1": [
|
| 237 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
| 238 |
+
],
|
| 239 |
+
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
| 240 |
+
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
| 241 |
+
],
|
| 242 |
+
# Stable Diffusion
|
| 243 |
+
"StableDiffusion_v15": [
|
| 244 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
| 245 |
+
],
|
| 246 |
+
"DreamShaper_8": [
|
| 247 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
| 248 |
+
],
|
| 249 |
+
"AingDiffusion_v12": [
|
| 250 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
| 251 |
+
],
|
| 252 |
+
"Flat2DAnimerge_v45Sharp": [
|
| 253 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
| 254 |
+
],
|
| 255 |
+
# Textual Inversion
|
| 256 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 257 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
| 258 |
+
],
|
| 259 |
+
# Stable Diffusion XL
|
| 260 |
+
"StableDiffusionXL_v1": [
|
| 261 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
| 262 |
+
],
|
| 263 |
+
"BluePencilXL_v200": [
|
| 264 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
| 265 |
+
],
|
| 266 |
+
"StableDiffusionXL_Turbo": [
|
| 267 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
| 268 |
+
],
|
| 269 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
| 270 |
+
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
| 271 |
+
],
|
| 272 |
+
# Stable Diffusion 3
|
| 273 |
+
"StableDiffusion3": [
|
| 274 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
| 275 |
+
],
|
| 276 |
+
"StableDiffusion3_without_T5": [
|
| 277 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
| 278 |
+
],
|
| 279 |
+
# ControlNet
|
| 280 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 281 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
| 282 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 283 |
+
],
|
| 284 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 285 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
| 286 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
| 287 |
+
],
|
| 288 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 289 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
| 290 |
+
],
|
| 291 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 292 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
| 293 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 294 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
| 295 |
+
],
|
| 296 |
+
"ControlNet_union_sdxl_promax": [
|
| 297 |
+
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
| 298 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 299 |
+
],
|
| 300 |
+
"Annotators:Depth": [
|
| 301 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 302 |
+
],
|
| 303 |
+
"Annotators:Softedge": [
|
| 304 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
| 305 |
+
],
|
| 306 |
+
"Annotators:Lineart": [
|
| 307 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 308 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
| 309 |
+
],
|
| 310 |
+
"Annotators:Normal": [
|
| 311 |
+
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
| 312 |
+
],
|
| 313 |
+
"Annotators:Openpose": [
|
| 314 |
+
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
| 315 |
+
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
| 316 |
+
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
| 317 |
+
],
|
| 318 |
+
# AnimateDiff
|
| 319 |
+
"AnimateDiff_v2": [
|
| 320 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
| 321 |
+
],
|
| 322 |
+
"AnimateDiff_xl_beta": [
|
| 323 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
| 324 |
+
],
|
| 325 |
+
# RIFE
|
| 326 |
+
"RIFE": [
|
| 327 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
| 328 |
+
],
|
| 329 |
+
# Qwen Prompt
|
| 330 |
+
"QwenPrompt": {
|
| 331 |
+
"file_list": [
|
| 332 |
+
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 333 |
+
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 334 |
+
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 335 |
+
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 336 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 337 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 338 |
+
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 339 |
+
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 340 |
+
],
|
| 341 |
+
"load_path": [
|
| 342 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 343 |
+
],
|
| 344 |
+
},
|
| 345 |
+
# Beautiful Prompt
|
| 346 |
+
"BeautifulPrompt": {
|
| 347 |
+
"file_list": [
|
| 348 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 349 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 350 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 351 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 352 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 353 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 354 |
+
],
|
| 355 |
+
"load_path": [
|
| 356 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 357 |
+
],
|
| 358 |
+
},
|
| 359 |
+
# Omost prompt
|
| 360 |
+
"OmostPrompt": {
|
| 361 |
+
"file_list": [
|
| 362 |
+
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 363 |
+
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 364 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 365 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 366 |
+
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 367 |
+
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 368 |
+
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 369 |
+
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 370 |
+
],
|
| 371 |
+
"load_path": [
|
| 372 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 373 |
+
],
|
| 374 |
+
},
|
| 375 |
+
# Translator
|
| 376 |
+
"opus-mt-zh-en": {
|
| 377 |
+
"file_list": [
|
| 378 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
| 379 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
| 380 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
| 381 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
| 382 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 383 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 384 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
| 385 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 386 |
+
],
|
| 387 |
+
"load_path": [
|
| 388 |
+
"models/translator/opus-mt-zh-en",
|
| 389 |
+
],
|
| 390 |
+
},
|
| 391 |
+
# IP-Adapter
|
| 392 |
+
"IP-Adapter-SD": [
|
| 393 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
| 394 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
| 395 |
+
],
|
| 396 |
+
"IP-Adapter-SDXL": [
|
| 397 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
| 398 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
| 399 |
+
],
|
| 400 |
+
# Kolors
|
| 401 |
+
"Kolors": {
|
| 402 |
+
"file_list": [
|
| 403 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
| 404 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
| 405 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 406 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 407 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 408 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 409 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 410 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 411 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 412 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
| 413 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
| 414 |
+
],
|
| 415 |
+
"load_path": [
|
| 416 |
+
"models/kolors/Kolors/text_encoder",
|
| 417 |
+
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
| 418 |
+
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
| 419 |
+
],
|
| 420 |
+
},
|
| 421 |
+
"SDXL-vae-fp16-fix": [
|
| 422 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
| 423 |
+
],
|
| 424 |
+
# FLUX
|
| 425 |
+
"FLUX.1-dev": {
|
| 426 |
+
"file_list": [
|
| 427 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 428 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 429 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 430 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 431 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 432 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 433 |
+
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 434 |
+
],
|
| 435 |
+
"load_path": [
|
| 436 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 437 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 438 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 439 |
+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
| 440 |
+
],
|
| 441 |
+
},
|
| 442 |
+
"FLUX.1-schnell": {
|
| 443 |
+
"file_list": [
|
| 444 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 445 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 446 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 447 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 448 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 449 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 450 |
+
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
| 451 |
+
],
|
| 452 |
+
"load_path": [
|
| 453 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 454 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 455 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 456 |
+
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
| 457 |
+
],
|
| 458 |
+
},
|
| 459 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
| 460 |
+
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
| 461 |
+
],
|
| 462 |
+
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
| 463 |
+
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
| 464 |
+
],
|
| 465 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
| 466 |
+
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
| 467 |
+
],
|
| 468 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
| 469 |
+
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
| 470 |
+
],
|
| 471 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
| 472 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
| 473 |
+
],
|
| 474 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
| 475 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
| 476 |
+
],
|
| 477 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
| 478 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
| 479 |
+
],
|
| 480 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
| 481 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
| 482 |
+
],
|
| 483 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 484 |
+
"file_list": [
|
| 485 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
| 486 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 487 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 488 |
+
],
|
| 489 |
+
"load_path": [
|
| 490 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 491 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 492 |
+
],
|
| 493 |
+
},
|
| 494 |
+
# ESRGAN
|
| 495 |
+
"ESRGAN_x4": [
|
| 496 |
+
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
| 497 |
+
],
|
| 498 |
+
# RIFE
|
| 499 |
+
"RIFE": [
|
| 500 |
+
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
| 501 |
+
],
|
| 502 |
+
# Omnigen
|
| 503 |
+
"OmniGen-v1": {
|
| 504 |
+
"file_list": [
|
| 505 |
+
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
| 506 |
+
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
| 507 |
+
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
| 508 |
+
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
| 509 |
+
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
| 510 |
+
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
| 511 |
+
],
|
| 512 |
+
"load_path": [
|
| 513 |
+
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
| 514 |
+
"models/OmniGen/OmniGen-v1/model.safetensors",
|
| 515 |
+
]
|
| 516 |
+
},
|
| 517 |
+
# CogVideo
|
| 518 |
+
"CogVideoX-5B": {
|
| 519 |
+
"file_list": [
|
| 520 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 521 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 522 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 523 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 524 |
+
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 525 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 526 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 527 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 528 |
+
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
| 529 |
+
],
|
| 530 |
+
"load_path": [
|
| 531 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 532 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 533 |
+
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
| 534 |
+
],
|
| 535 |
+
},
|
| 536 |
+
# Stable Diffusion 3.5
|
| 537 |
+
"StableDiffusion3.5-large": [
|
| 538 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
| 539 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 540 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 541 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 542 |
+
],
|
| 543 |
+
"StableDiffusion3.5-medium": [
|
| 544 |
+
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
| 545 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 546 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 547 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 548 |
+
],
|
| 549 |
+
"StableDiffusion3.5-large-turbo": [
|
| 550 |
+
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
| 551 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 552 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 553 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 554 |
+
],
|
| 555 |
+
"HunyuanVideo":{
|
| 556 |
+
"file_list": [
|
| 557 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
| 558 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 559 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 560 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 561 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 562 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
| 563 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
| 564 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
| 565 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
| 566 |
+
],
|
| 567 |
+
"load_path": [
|
| 568 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 569 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 570 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 571 |
+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
| 572 |
+
],
|
| 573 |
+
},
|
| 574 |
+
"HunyuanVideo-fp8":{
|
| 575 |
+
"file_list": [
|
| 576 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
| 577 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 578 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 579 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 580 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 581 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
| 582 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
| 583 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
| 584 |
+
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
| 585 |
+
],
|
| 586 |
+
"load_path": [
|
| 587 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 588 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 589 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 590 |
+
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
| 591 |
+
],
|
| 592 |
+
},
|
| 593 |
+
}
|
| 594 |
+
Preset_model_id: TypeAlias = Literal[
|
| 595 |
+
"HunyuanDiT",
|
| 596 |
+
"stable-video-diffusion-img2vid-xt",
|
| 597 |
+
"ExVideo-SVD-128f-v1",
|
| 598 |
+
"ExVideo-CogVideoX-LoRA-129f-v1",
|
| 599 |
+
"StableDiffusion_v15",
|
| 600 |
+
"DreamShaper_8",
|
| 601 |
+
"AingDiffusion_v12",
|
| 602 |
+
"Flat2DAnimerge_v45Sharp",
|
| 603 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
| 604 |
+
"StableDiffusionXL_v1",
|
| 605 |
+
"BluePencilXL_v200",
|
| 606 |
+
"StableDiffusionXL_Turbo",
|
| 607 |
+
"ControlNet_v11f1p_sd15_depth",
|
| 608 |
+
"ControlNet_v11p_sd15_softedge",
|
| 609 |
+
"ControlNet_v11f1e_sd15_tile",
|
| 610 |
+
"ControlNet_v11p_sd15_lineart",
|
| 611 |
+
"AnimateDiff_v2",
|
| 612 |
+
"AnimateDiff_xl_beta",
|
| 613 |
+
"RIFE",
|
| 614 |
+
"BeautifulPrompt",
|
| 615 |
+
"opus-mt-zh-en",
|
| 616 |
+
"IP-Adapter-SD",
|
| 617 |
+
"IP-Adapter-SDXL",
|
| 618 |
+
"StableDiffusion3",
|
| 619 |
+
"StableDiffusion3_without_T5",
|
| 620 |
+
"Kolors",
|
| 621 |
+
"SDXL-vae-fp16-fix",
|
| 622 |
+
"ControlNet_union_sdxl_promax",
|
| 623 |
+
"FLUX.1-dev",
|
| 624 |
+
"FLUX.1-schnell",
|
| 625 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
| 626 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
| 627 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
| 628 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
| 629 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
| 630 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
| 631 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
| 632 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 633 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
| 634 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
| 635 |
+
"QwenPrompt",
|
| 636 |
+
"OmostPrompt",
|
| 637 |
+
"ESRGAN_x4",
|
| 638 |
+
"RIFE",
|
| 639 |
+
"OmniGen-v1",
|
| 640 |
+
"CogVideoX-5B",
|
| 641 |
+
"Annotators:Depth",
|
| 642 |
+
"Annotators:Softedge",
|
| 643 |
+
"Annotators:Lineart",
|
| 644 |
+
"Annotators:Normal",
|
| 645 |
+
"Annotators:Openpose",
|
| 646 |
+
"StableDiffusion3.5-large",
|
| 647 |
+
"StableDiffusion3.5-medium",
|
| 648 |
+
"HunyuanVideo",
|
| 649 |
+
"HunyuanVideo-fp8",
|
| 650 |
+
]
|
diffsynth/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/video.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LowMemoryVideo:
|
| 8 |
+
def __init__(self, file_name):
|
| 9 |
+
self.reader = imageio.get_reader(file_name)
|
| 10 |
+
|
| 11 |
+
def __len__(self):
|
| 12 |
+
return self.reader.count_frames()
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, item):
|
| 15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
| 16 |
+
|
| 17 |
+
def __del__(self):
|
| 18 |
+
self.reader.close()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def split_file_name(file_name):
|
| 22 |
+
result = []
|
| 23 |
+
number = -1
|
| 24 |
+
for i in file_name:
|
| 25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
| 26 |
+
if number == -1:
|
| 27 |
+
number = 0
|
| 28 |
+
number = number*10 + ord(i) - ord("0")
|
| 29 |
+
else:
|
| 30 |
+
if number != -1:
|
| 31 |
+
result.append(number)
|
| 32 |
+
number = -1
|
| 33 |
+
result.append(i)
|
| 34 |
+
if number != -1:
|
| 35 |
+
result.append(number)
|
| 36 |
+
result = tuple(result)
|
| 37 |
+
return result
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def search_for_images(folder):
|
| 41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
| 42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
| 43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
| 44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
| 45 |
+
return file_list
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LowMemoryImageFolder:
|
| 49 |
+
def __init__(self, folder, file_list=None):
|
| 50 |
+
if file_list is None:
|
| 51 |
+
self.file_list = search_for_images(folder)
|
| 52 |
+
else:
|
| 53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.file_list)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, item):
|
| 59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
| 60 |
+
|
| 61 |
+
def __del__(self):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def crop_and_resize(image, height, width):
|
| 66 |
+
image = np.array(image)
|
| 67 |
+
image_height, image_width, _ = image.shape
|
| 68 |
+
if image_height / image_width < height / width:
|
| 69 |
+
croped_width = int(image_height / height * width)
|
| 70 |
+
left = (image_width - croped_width) // 2
|
| 71 |
+
image = image[:, left: left+croped_width]
|
| 72 |
+
image = Image.fromarray(image).resize((width, height))
|
| 73 |
+
else:
|
| 74 |
+
croped_height = int(image_width / width * height)
|
| 75 |
+
left = (image_height - croped_height) // 2
|
| 76 |
+
image = image[left: left+croped_height, :]
|
| 77 |
+
image = Image.fromarray(image).resize((width, height))
|
| 78 |
+
return image
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class VideoData:
|
| 82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
| 83 |
+
if video_file is not None:
|
| 84 |
+
self.data_type = "video"
|
| 85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
| 86 |
+
elif image_folder is not None:
|
| 87 |
+
self.data_type = "images"
|
| 88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError("Cannot open video or image folder")
|
| 91 |
+
self.length = None
|
| 92 |
+
self.set_shape(height, width)
|
| 93 |
+
|
| 94 |
+
def raw_data(self):
|
| 95 |
+
frames = []
|
| 96 |
+
for i in range(self.__len__()):
|
| 97 |
+
frames.append(self.__getitem__(i))
|
| 98 |
+
return frames
|
| 99 |
+
|
| 100 |
+
def set_length(self, length):
|
| 101 |
+
self.length = length
|
| 102 |
+
|
| 103 |
+
def set_shape(self, height, width):
|
| 104 |
+
self.height = height
|
| 105 |
+
self.width = width
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
if self.length is None:
|
| 109 |
+
return len(self.data)
|
| 110 |
+
else:
|
| 111 |
+
return self.length
|
| 112 |
+
|
| 113 |
+
def shape(self):
|
| 114 |
+
if self.height is not None and self.width is not None:
|
| 115 |
+
return self.height, self.width
|
| 116 |
+
else:
|
| 117 |
+
height, width, _ = self.__getitem__(0).shape
|
| 118 |
+
return height, width
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, item):
|
| 121 |
+
frame = self.data.__getitem__(item)
|
| 122 |
+
width, height = frame.size
|
| 123 |
+
if self.height is not None and self.width is not None:
|
| 124 |
+
if self.height != height or self.width != width:
|
| 125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
| 126 |
+
return frame
|
| 127 |
+
|
| 128 |
+
def __del__(self):
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
def save_images(self, folder):
|
| 132 |
+
os.makedirs(folder, exist_ok=True)
|
| 133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
| 134 |
+
frame = self.__getitem__(i)
|
| 135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
| 136 |
+
|
| 137 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
| 138 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
| 139 |
+
for frame in tqdm(frames, desc="Saving video"):
|
| 140 |
+
frame = np.array(frame)
|
| 141 |
+
writer.append_data(frame)
|
| 142 |
+
writer.close()
|
| 143 |
+
# def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
| 144 |
+
# writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=["-crf", "0", "-preset", "veryslow"])
|
| 145 |
+
# for frame in tqdm(frames, desc="Saving video"):
|
| 146 |
+
# frame = np.array(frame)
|
| 147 |
+
# writer.append_data(frame)
|
| 148 |
+
# writer.close()
|
| 149 |
+
|
| 150 |
+
# def save_video_h264(frames, save_path, fps, ffmpeg_params=None):
|
| 151 |
+
# import imageio.v3 as iio
|
| 152 |
+
# from tqdm import tqdm
|
| 153 |
+
# import numpy as np
|
| 154 |
+
|
| 155 |
+
# if ffmpeg_params is None:
|
| 156 |
+
# ffmpeg_params = ["-crf", "0", "-preset", "ultrafast"] # 无损 H.264
|
| 157 |
+
|
| 158 |
+
# writer = iio.get_writer(save_path, fps=fps, codec="libx264", ffmpeg_params=ffmpeg_params)
|
| 159 |
+
# for frame in tqdm(frames, desc="Saving video"):
|
| 160 |
+
# writer.append_data(np.array(frame))
|
| 161 |
+
# writer.close()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def save_frames(frames, save_path):
|
| 166 |
+
os.makedirs(save_path, exist_ok=True)
|
| 167 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
| 168 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__=='__main__':
|
| 172 |
+
frames = [Image.fromarray(np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8)) for i in range(81)]
|
| 173 |
+
save_video(frames,"haha.mp4",23,5)
|
diffsynth/pipelines/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .wan_video import WanVideoPipeline
|
diffsynth/pipelines/base.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision.transforms import GaussianBlur
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BasePipeline(torch.nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, height_division_factor=64, width_division_factor=64):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.device = device
|
| 13 |
+
self.torch_dtype = torch_dtype
|
| 14 |
+
self.height_division_factor = height_division_factor
|
| 15 |
+
self.width_division_factor = width_division_factor
|
| 16 |
+
self.cpu_offload = False
|
| 17 |
+
self.model_names = []
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def check_resize_height_width(self, height, width):
|
| 21 |
+
if height % self.height_division_factor != 0:
|
| 22 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
| 23 |
+
print(f"The height cannot be evenly divided by {self.height_division_factor}. We round it up to {height}.")
|
| 24 |
+
if width % self.width_division_factor != 0:
|
| 25 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
| 26 |
+
print(f"The width cannot be evenly divided by {self.width_division_factor}. We round it up to {width}.")
|
| 27 |
+
return height, width
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def preprocess_image(self, image):
|
| 31 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
| 32 |
+
return image
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess_images(self, images):
|
| 36 |
+
return [self.preprocess_image(image) for image in images]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def vae_output_to_image(self, vae_output):
|
| 40 |
+
image = vae_output[0].cpu().float().permute(1, 2, 0).numpy()
|
| 41 |
+
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def vae_output_to_video(self, vae_output):
|
| 46 |
+
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
| 47 |
+
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
| 48 |
+
return video
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def merge_latents(self, value, latents, masks, scales, blur_kernel_size=33, blur_sigma=10.0):
|
| 52 |
+
if len(latents) > 0:
|
| 53 |
+
blur = GaussianBlur(kernel_size=blur_kernel_size, sigma=blur_sigma)
|
| 54 |
+
height, width = value.shape[-2:]
|
| 55 |
+
weight = torch.ones_like(value)
|
| 56 |
+
for latent, mask, scale in zip(latents, masks, scales):
|
| 57 |
+
mask = self.preprocess_image(mask.resize((width, height))).mean(dim=1, keepdim=True) > 0
|
| 58 |
+
mask = mask.repeat(1, latent.shape[1], 1, 1).to(dtype=latent.dtype, device=latent.device)
|
| 59 |
+
mask = blur(mask)
|
| 60 |
+
value += latent * mask * scale
|
| 61 |
+
weight += mask * scale
|
| 62 |
+
value /= weight
|
| 63 |
+
return value
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None):
|
| 67 |
+
if special_kwargs is None:
|
| 68 |
+
noise_pred_global = inference_callback(prompt_emb_global)
|
| 69 |
+
else:
|
| 70 |
+
noise_pred_global = inference_callback(prompt_emb_global, special_kwargs)
|
| 71 |
+
if special_local_kwargs_list is None:
|
| 72 |
+
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
| 73 |
+
else:
|
| 74 |
+
noise_pred_locals = [inference_callback(prompt_emb_local, special_kwargs) for prompt_emb_local, special_kwargs in zip(prompt_emb_locals, special_local_kwargs_list)]
|
| 75 |
+
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
| 76 |
+
return noise_pred
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def extend_prompt(self, prompt, local_prompts, masks, mask_scales):
|
| 80 |
+
local_prompts = local_prompts or []
|
| 81 |
+
masks = masks or []
|
| 82 |
+
mask_scales = mask_scales or []
|
| 83 |
+
extended_prompt_dict = self.prompter.extend_prompt(prompt)
|
| 84 |
+
prompt = extended_prompt_dict.get("prompt", prompt)
|
| 85 |
+
local_prompts += extended_prompt_dict.get("prompts", [])
|
| 86 |
+
masks += extended_prompt_dict.get("masks", [])
|
| 87 |
+
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
|
| 88 |
+
return prompt, local_prompts, masks, mask_scales
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def enable_cpu_offload(self):
|
| 92 |
+
self.cpu_offload = True
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_models_to_device(self, loadmodel_names=[]):
|
| 96 |
+
# only load models to device if cpu_offload is enabled
|
| 97 |
+
if not self.cpu_offload:
|
| 98 |
+
return
|
| 99 |
+
# offload the unneeded models to cpu
|
| 100 |
+
for model_name in self.model_names:
|
| 101 |
+
if model_name not in loadmodel_names:
|
| 102 |
+
model = getattr(self, model_name)
|
| 103 |
+
if model is not None:
|
| 104 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 105 |
+
for module in model.modules():
|
| 106 |
+
if hasattr(module, "offload"):
|
| 107 |
+
module.offload()
|
| 108 |
+
else:
|
| 109 |
+
model.cpu()
|
| 110 |
+
# load the needed models to device
|
| 111 |
+
for model_name in loadmodel_names:
|
| 112 |
+
model = getattr(self, model_name)
|
| 113 |
+
if model is not None:
|
| 114 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 115 |
+
for module in model.modules():
|
| 116 |
+
if hasattr(module, "onload"):
|
| 117 |
+
module.onload()
|
| 118 |
+
else:
|
| 119 |
+
model.to(self.device)
|
| 120 |
+
# fresh the cuda cache
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
|
| 125 |
+
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
|
| 126 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 127 |
+
return noise
|
diffsynth/pipelines/wan_video.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..models import ModelManager
|
| 2 |
+
from ..models.wan_video_dit import WanModel
|
| 3 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
| 4 |
+
from ..models.wan_video_vae import WanVideoVAE
|
| 5 |
+
from ..models.wan_video_image_encoder import WanImageEncoder
|
| 6 |
+
from ..schedulers.flow_match import FlowMatchScheduler
|
| 7 |
+
from .base import BasePipeline
|
| 8 |
+
from ..prompters import WanPrompter
|
| 9 |
+
import torch, os
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
| 16 |
+
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
| 17 |
+
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
| 18 |
+
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class WanVideoPipeline(BasePipeline):
|
| 22 |
+
|
| 23 |
+
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None):
|
| 24 |
+
super().__init__(device=device, torch_dtype=torch_dtype)
|
| 25 |
+
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
| 26 |
+
self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
|
| 27 |
+
self.text_encoder: WanTextEncoder = None
|
| 28 |
+
self.image_encoder: WanImageEncoder = None
|
| 29 |
+
self.dit: WanModel = None
|
| 30 |
+
self.vae: WanVideoVAE = None
|
| 31 |
+
self.model_names = ['text_encoder', 'dit', 'vae']
|
| 32 |
+
self.height_division_factor = 16
|
| 33 |
+
self.width_division_factor = 16
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
| 37 |
+
dtype = next(iter(self.text_encoder.parameters())).dtype
|
| 38 |
+
enable_vram_management(
|
| 39 |
+
self.text_encoder,
|
| 40 |
+
module_map = {
|
| 41 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 42 |
+
torch.nn.Embedding: AutoWrappedModule,
|
| 43 |
+
T5RelativeEmbedding: AutoWrappedModule,
|
| 44 |
+
T5LayerNorm: AutoWrappedModule,
|
| 45 |
+
},
|
| 46 |
+
module_config = dict(
|
| 47 |
+
offload_dtype=dtype,
|
| 48 |
+
offload_device="cpu",
|
| 49 |
+
onload_dtype=dtype,
|
| 50 |
+
onload_device="cpu",
|
| 51 |
+
computation_dtype=self.torch_dtype,
|
| 52 |
+
computation_device=self.device,
|
| 53 |
+
),
|
| 54 |
+
)
|
| 55 |
+
dtype = next(iter(self.dit.parameters())).dtype
|
| 56 |
+
enable_vram_management(
|
| 57 |
+
self.dit,
|
| 58 |
+
module_map = {
|
| 59 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 60 |
+
torch.nn.Conv3d: AutoWrappedModule,
|
| 61 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 62 |
+
WanLayerNorm: AutoWrappedModule,
|
| 63 |
+
WanRMSNorm: AutoWrappedModule,
|
| 64 |
+
},
|
| 65 |
+
module_config = dict(
|
| 66 |
+
offload_dtype=dtype,
|
| 67 |
+
offload_device="cpu",
|
| 68 |
+
onload_dtype=dtype,
|
| 69 |
+
onload_device=self.device,
|
| 70 |
+
computation_dtype=self.torch_dtype,
|
| 71 |
+
computation_device=self.device,
|
| 72 |
+
),
|
| 73 |
+
max_num_param=num_persistent_param_in_dit,
|
| 74 |
+
overflow_module_config = dict(
|
| 75 |
+
offload_dtype=dtype,
|
| 76 |
+
offload_device="cpu",
|
| 77 |
+
onload_dtype=dtype,
|
| 78 |
+
onload_device="cpu",
|
| 79 |
+
computation_dtype=self.torch_dtype,
|
| 80 |
+
computation_device=self.device,
|
| 81 |
+
),
|
| 82 |
+
)
|
| 83 |
+
dtype = next(iter(self.vae.parameters())).dtype
|
| 84 |
+
enable_vram_management(
|
| 85 |
+
self.vae,
|
| 86 |
+
module_map = {
|
| 87 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 88 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 89 |
+
RMS_norm: AutoWrappedModule,
|
| 90 |
+
CausalConv3d: AutoWrappedModule,
|
| 91 |
+
Upsample: AutoWrappedModule,
|
| 92 |
+
torch.nn.SiLU: AutoWrappedModule,
|
| 93 |
+
torch.nn.Dropout: AutoWrappedModule,
|
| 94 |
+
},
|
| 95 |
+
module_config = dict(
|
| 96 |
+
offload_dtype=dtype,
|
| 97 |
+
offload_device="cpu",
|
| 98 |
+
onload_dtype=dtype,
|
| 99 |
+
onload_device=self.device,
|
| 100 |
+
computation_dtype=self.torch_dtype,
|
| 101 |
+
computation_device=self.device,
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
if self.image_encoder is not None:
|
| 105 |
+
dtype = next(iter(self.image_encoder.parameters())).dtype
|
| 106 |
+
enable_vram_management(
|
| 107 |
+
self.image_encoder,
|
| 108 |
+
module_map = {
|
| 109 |
+
torch.nn.Linear: AutoWrappedLinear,
|
| 110 |
+
torch.nn.Conv2d: AutoWrappedModule,
|
| 111 |
+
torch.nn.LayerNorm: AutoWrappedModule,
|
| 112 |
+
},
|
| 113 |
+
module_config = dict(
|
| 114 |
+
offload_dtype=dtype,
|
| 115 |
+
offload_device="cpu",
|
| 116 |
+
onload_dtype=dtype,
|
| 117 |
+
onload_device="cpu",
|
| 118 |
+
computation_dtype=self.torch_dtype,
|
| 119 |
+
computation_device=self.device,
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
self.enable_cpu_offload()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def fetch_models(self, model_manager: ModelManager):
|
| 126 |
+
text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True)
|
| 127 |
+
if text_encoder_model_and_path is not None:
|
| 128 |
+
self.text_encoder, tokenizer_path = text_encoder_model_and_path
|
| 129 |
+
self.prompter.fetch_models(self.text_encoder)
|
| 130 |
+
self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl"))
|
| 131 |
+
self.dit = model_manager.fetch_model("wan_video_dit")
|
| 132 |
+
self.vae = model_manager.fetch_model("wan_video_vae")
|
| 133 |
+
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
| 138 |
+
if device is None: device = model_manager.device
|
| 139 |
+
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 140 |
+
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
| 141 |
+
pipe.fetch_models(model_manager)
|
| 142 |
+
return pipe
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def denoising_model(self):
|
| 146 |
+
return self.dit
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def encode_prompt(self, prompt, positive=True):
|
| 150 |
+
prompt_emb = self.prompter.encode_prompt(prompt, positive=positive)
|
| 151 |
+
return {"context": prompt_emb}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def encode_image(self, image, num_frames, height, width):
|
| 155 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
| 156 |
+
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
| 157 |
+
clip_context = self.image_encoder.encode_image([image])
|
| 158 |
+
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
| 159 |
+
msk[:, 1:] = 0
|
| 160 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
| 161 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
| 162 |
+
msk = msk.transpose(1, 2)[0]
|
| 163 |
+
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
| 164 |
+
y = torch.concat([msk, y])
|
| 165 |
+
return {"clip_fea": clip_context, "y": [y]}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def tensor2video(self, frames):
|
| 169 |
+
frames = rearrange(frames, "C T H W -> T H W C")
|
| 170 |
+
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
| 171 |
+
frames = [Image.fromarray(frame) for frame in frames]
|
| 172 |
+
return frames
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def prepare_extra_input(self, latents=None):
|
| 176 |
+
return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 180 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
| 181 |
+
latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 182 |
+
return latents
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 186 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
| 187 |
+
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
| 188 |
+
return frames
|
| 189 |
+
|
| 190 |
+
def set_ip(self, local_path):
|
| 191 |
+
pass
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def __call__(
|
| 194 |
+
self,
|
| 195 |
+
prompt,
|
| 196 |
+
negative_prompt="",
|
| 197 |
+
input_image=None,
|
| 198 |
+
input_video=None,
|
| 199 |
+
denoising_strength=1.0,
|
| 200 |
+
seed=None,
|
| 201 |
+
rand_device="cpu",
|
| 202 |
+
height=480,
|
| 203 |
+
width=832,
|
| 204 |
+
num_frames=81,
|
| 205 |
+
cfg_scale=5.0,
|
| 206 |
+
audio_cfg_scale=None,
|
| 207 |
+
num_inference_steps=50,
|
| 208 |
+
sigma_shift=5.0,
|
| 209 |
+
tiled=True,
|
| 210 |
+
tile_size=(30, 52),
|
| 211 |
+
tile_stride=(15, 26),
|
| 212 |
+
progress_bar_cmd=tqdm,
|
| 213 |
+
progress_bar_st=None,
|
| 214 |
+
**kwargs,
|
| 215 |
+
):
|
| 216 |
+
# Parameter check
|
| 217 |
+
height, width = self.check_resize_height_width(height, width)
|
| 218 |
+
if num_frames % 4 != 1:
|
| 219 |
+
num_frames = (num_frames + 2) // 4 * 4 + 1
|
| 220 |
+
print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.")
|
| 221 |
+
|
| 222 |
+
# Tiler parameters
|
| 223 |
+
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 224 |
+
|
| 225 |
+
# Scheduler
|
| 226 |
+
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
| 227 |
+
|
| 228 |
+
# Initialize noise
|
| 229 |
+
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
| 230 |
+
if input_video is not None:
|
| 231 |
+
self.load_models_to_device(['vae'])
|
| 232 |
+
input_video = self.preprocess_images(input_video)
|
| 233 |
+
input_video = torch.stack(input_video, dim=2)
|
| 234 |
+
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device)
|
| 235 |
+
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
| 236 |
+
else:
|
| 237 |
+
latents = noise
|
| 238 |
+
|
| 239 |
+
# Encode prompts
|
| 240 |
+
self.load_models_to_device(["text_encoder"])
|
| 241 |
+
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
| 242 |
+
if cfg_scale != 1.0:
|
| 243 |
+
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
| 244 |
+
|
| 245 |
+
# Encode image
|
| 246 |
+
if input_image is not None and self.image_encoder is not None:
|
| 247 |
+
self.load_models_to_device(["image_encoder", "vae"])
|
| 248 |
+
image_emb = self.encode_image(input_image, num_frames, height, width)
|
| 249 |
+
else:
|
| 250 |
+
image_emb = {}
|
| 251 |
+
|
| 252 |
+
# Extra input
|
| 253 |
+
extra_input = self.prepare_extra_input(latents)
|
| 254 |
+
|
| 255 |
+
# Denoise
|
| 256 |
+
self.load_models_to_device(["dit"])
|
| 257 |
+
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
| 258 |
+
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
| 259 |
+
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
| 260 |
+
|
| 261 |
+
# Inference
|
| 262 |
+
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) # (zt,audio,prompt)
|
| 263 |
+
if audio_cfg_scale is not None:
|
| 264 |
+
audio_scale = kwargs['audio_scale']
|
| 265 |
+
kwargs['audio_scale'] = 0.0
|
| 266 |
+
noise_pred_noaudio = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **kwargs) #(zt,0,prompt)
|
| 267 |
+
# kwargs['ip_scale'] = ip_scale
|
| 268 |
+
if cfg_scale != 1.0: #prompt cfg
|
| 269 |
+
noise_pred_no_cond = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) # (zt,0,0)
|
| 270 |
+
noise_pred = noise_pred_no_cond + cfg_scale * (noise_pred_noaudio - noise_pred_no_cond) + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
|
| 271 |
+
else:
|
| 272 |
+
noise_pred = noise_pred_noaudio + audio_cfg_scale * (noise_pred_posi - noise_pred_noaudio)
|
| 273 |
+
kwargs['audio_scale'] = audio_scale
|
| 274 |
+
else:
|
| 275 |
+
if cfg_scale != 1.0:
|
| 276 |
+
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **kwargs) #(zt,audio,0)
|
| 277 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
| 278 |
+
else:
|
| 279 |
+
noise_pred = noise_pred_posi
|
| 280 |
+
|
| 281 |
+
# Scheduler
|
| 282 |
+
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
| 283 |
+
|
| 284 |
+
# Decode
|
| 285 |
+
self.load_models_to_device(['vae'])
|
| 286 |
+
frames = self.decode_video(latents, **tiler_kwargs)
|
| 287 |
+
self.load_models_to_device([])
|
| 288 |
+
frames = self.tensor2video(frames[0])
|
| 289 |
+
|
| 290 |
+
return frames
|
diffsynth/prompters/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .wan_prompter import WanPrompter
|
diffsynth/prompters/base_prompter.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..models.model_manager import ModelManager
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
|
| 7 |
+
# Get model_max_length from self.tokenizer
|
| 8 |
+
length = tokenizer.model_max_length if max_length is None else max_length
|
| 9 |
+
|
| 10 |
+
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
| 11 |
+
tokenizer.model_max_length = 99999999
|
| 12 |
+
|
| 13 |
+
# Tokenize it!
|
| 14 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
| 15 |
+
|
| 16 |
+
# Determine the real length.
|
| 17 |
+
max_length = (input_ids.shape[1] + length - 1) // length * length
|
| 18 |
+
|
| 19 |
+
# Restore tokenizer.model_max_length
|
| 20 |
+
tokenizer.model_max_length = length
|
| 21 |
+
|
| 22 |
+
# Tokenize it again with fixed length.
|
| 23 |
+
input_ids = tokenizer(
|
| 24 |
+
prompt,
|
| 25 |
+
return_tensors="pt",
|
| 26 |
+
padding="max_length",
|
| 27 |
+
max_length=max_length,
|
| 28 |
+
truncation=True
|
| 29 |
+
).input_ids
|
| 30 |
+
|
| 31 |
+
# Reshape input_ids to fit the text encoder.
|
| 32 |
+
num_sentence = input_ids.shape[1] // length
|
| 33 |
+
input_ids = input_ids.reshape((num_sentence, length))
|
| 34 |
+
|
| 35 |
+
return input_ids
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BasePrompter:
|
| 40 |
+
def __init__(self):
|
| 41 |
+
self.refiners = []
|
| 42 |
+
self.extenders = []
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]):
|
| 46 |
+
for refiner_class in refiner_classes:
|
| 47 |
+
refiner = refiner_class.from_model_manager(model_manager)
|
| 48 |
+
self.refiners.append(refiner)
|
| 49 |
+
|
| 50 |
+
def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]):
|
| 51 |
+
for extender_class in extender_classes:
|
| 52 |
+
extender = extender_class.from_model_manager(model_manager)
|
| 53 |
+
self.extenders.append(extender)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def process_prompt(self, prompt, positive=True):
|
| 58 |
+
if isinstance(prompt, list):
|
| 59 |
+
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
|
| 60 |
+
else:
|
| 61 |
+
for refiner in self.refiners:
|
| 62 |
+
prompt = refiner(prompt, positive=positive)
|
| 63 |
+
return prompt
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def extend_prompt(self, prompt:str, positive=True):
|
| 67 |
+
extended_prompt = dict(prompt=prompt)
|
| 68 |
+
for extender in self.extenders:
|
| 69 |
+
extended_prompt = extender(extended_prompt)
|
| 70 |
+
return extended_prompt
|
diffsynth/prompters/wan_prompter.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_prompter import BasePrompter
|
| 2 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
import os, torch
|
| 5 |
+
import ftfy
|
| 6 |
+
import html
|
| 7 |
+
import string
|
| 8 |
+
import regex as re
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def basic_clean(text):
|
| 12 |
+
text = ftfy.fix_text(text)
|
| 13 |
+
text = html.unescape(html.unescape(text))
|
| 14 |
+
return text.strip()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def whitespace_clean(text):
|
| 18 |
+
text = re.sub(r'\s+', ' ', text)
|
| 19 |
+
text = text.strip()
|
| 20 |
+
return text
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 24 |
+
text = text.replace('_', ' ')
|
| 25 |
+
if keep_punctuation_exact_string:
|
| 26 |
+
text = keep_punctuation_exact_string.join(
|
| 27 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
| 28 |
+
for part in text.split(keep_punctuation_exact_string))
|
| 29 |
+
else:
|
| 30 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
| 31 |
+
text = text.lower()
|
| 32 |
+
text = re.sub(r'\s+', ' ', text)
|
| 33 |
+
return text.strip()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class HuggingfaceTokenizer:
|
| 37 |
+
|
| 38 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 39 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
| 40 |
+
self.name = name
|
| 41 |
+
self.seq_len = seq_len
|
| 42 |
+
self.clean = clean
|
| 43 |
+
|
| 44 |
+
# init tokenizer
|
| 45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 46 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 47 |
+
|
| 48 |
+
def __call__(self, sequence, **kwargs):
|
| 49 |
+
return_mask = kwargs.pop('return_mask', False)
|
| 50 |
+
|
| 51 |
+
# arguments
|
| 52 |
+
_kwargs = {'return_tensors': 'pt'}
|
| 53 |
+
if self.seq_len is not None:
|
| 54 |
+
_kwargs.update({
|
| 55 |
+
'padding': 'max_length',
|
| 56 |
+
'truncation': True,
|
| 57 |
+
'max_length': self.seq_len
|
| 58 |
+
})
|
| 59 |
+
_kwargs.update(**kwargs)
|
| 60 |
+
|
| 61 |
+
# tokenization
|
| 62 |
+
if isinstance(sequence, str):
|
| 63 |
+
sequence = [sequence]
|
| 64 |
+
if self.clean:
|
| 65 |
+
sequence = [self._clean(u) for u in sequence]
|
| 66 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 67 |
+
|
| 68 |
+
# output
|
| 69 |
+
if return_mask:
|
| 70 |
+
return ids.input_ids, ids.attention_mask
|
| 71 |
+
else:
|
| 72 |
+
return ids.input_ids
|
| 73 |
+
|
| 74 |
+
def _clean(self, text):
|
| 75 |
+
if self.clean == 'whitespace':
|
| 76 |
+
text = whitespace_clean(basic_clean(text))
|
| 77 |
+
elif self.clean == 'lower':
|
| 78 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 79 |
+
elif self.clean == 'canonicalize':
|
| 80 |
+
text = canonicalize(basic_clean(text))
|
| 81 |
+
return text
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class WanPrompter(BasePrompter):
|
| 85 |
+
|
| 86 |
+
def __init__(self, tokenizer_path=None, text_len=512):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.text_len = text_len
|
| 89 |
+
self.text_encoder = None
|
| 90 |
+
self.fetch_tokenizer(tokenizer_path)
|
| 91 |
+
|
| 92 |
+
def fetch_tokenizer(self, tokenizer_path=None):
|
| 93 |
+
if tokenizer_path is not None:
|
| 94 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.text_len, clean='whitespace')
|
| 95 |
+
|
| 96 |
+
def fetch_models(self, text_encoder: WanTextEncoder = None):
|
| 97 |
+
self.text_encoder = text_encoder
|
| 98 |
+
|
| 99 |
+
def encode_prompt(self, prompt, positive=True, device="cuda"):
|
| 100 |
+
prompt = self.process_prompt(prompt, positive=positive)
|
| 101 |
+
|
| 102 |
+
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
| 103 |
+
ids = ids.to(device)
|
| 104 |
+
mask = mask.to(device)
|
| 105 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 106 |
+
prompt_emb = self.text_encoder(ids, mask)
|
| 107 |
+
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
| 108 |
+
return prompt_emb
|
diffsynth/schedulers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ddim import EnhancedDDIMScheduler
|
| 2 |
+
from .continuous_ode import ContinuousODEScheduler
|
| 3 |
+
from .flow_match import FlowMatchScheduler
|
diffsynth/schedulers/continuous_ode.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ContinuousODEScheduler():
|
| 5 |
+
|
| 6 |
+
def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
|
| 7 |
+
self.sigma_max = sigma_max
|
| 8 |
+
self.sigma_min = sigma_min
|
| 9 |
+
self.rho = rho
|
| 10 |
+
self.set_timesteps(num_inference_steps)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
|
| 14 |
+
ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
|
| 15 |
+
min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
|
| 16 |
+
max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
|
| 17 |
+
self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
|
| 18 |
+
self.timesteps = torch.log(self.sigmas) * 0.25
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 22 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 23 |
+
sigma = self.sigmas[timestep_id]
|
| 24 |
+
sample *= (sigma*sigma + 1).sqrt()
|
| 25 |
+
estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
|
| 26 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 27 |
+
prev_sample = estimated_sample
|
| 28 |
+
else:
|
| 29 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 30 |
+
derivative = 1 / sigma * (sample - estimated_sample)
|
| 31 |
+
prev_sample = sample + derivative * (sigma_ - sigma)
|
| 32 |
+
prev_sample /= (sigma_*sigma_ + 1).sqrt()
|
| 33 |
+
return prev_sample
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 37 |
+
# This scheduler doesn't support this function.
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 42 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 43 |
+
sigma = self.sigmas[timestep_id]
|
| 44 |
+
sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
|
| 45 |
+
return sample
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def training_target(self, sample, noise, timestep):
|
| 49 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 50 |
+
sigma = self.sigmas[timestep_id]
|
| 51 |
+
target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
|
| 52 |
+
return target
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def training_weight(self, timestep):
|
| 56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 57 |
+
sigma = self.sigmas[timestep_id]
|
| 58 |
+
weight = (1 + sigma*sigma).sqrt() / sigma
|
| 59 |
+
return weight
|
diffsynth/schedulers/ddim.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class EnhancedDDIMScheduler():
|
| 5 |
+
|
| 6 |
+
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
|
| 7 |
+
self.num_train_timesteps = num_train_timesteps
|
| 8 |
+
if beta_schedule == "scaled_linear":
|
| 9 |
+
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
| 10 |
+
elif beta_schedule == "linear":
|
| 11 |
+
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
| 12 |
+
else:
|
| 13 |
+
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
| 14 |
+
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
|
| 15 |
+
if rescale_zero_terminal_snr:
|
| 16 |
+
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
|
| 17 |
+
self.alphas_cumprod = self.alphas_cumprod.tolist()
|
| 18 |
+
self.set_timesteps(10)
|
| 19 |
+
self.prediction_type = prediction_type
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rescale_zero_terminal_snr(self, alphas_cumprod):
|
| 23 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
| 24 |
+
|
| 25 |
+
# Store old values.
|
| 26 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 27 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 28 |
+
|
| 29 |
+
# Shift so the last timestep is zero.
|
| 30 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 31 |
+
|
| 32 |
+
# Scale so the first timestep is back to the old value.
|
| 33 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 34 |
+
|
| 35 |
+
# Convert alphas_bar_sqrt to betas
|
| 36 |
+
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
|
| 37 |
+
|
| 38 |
+
return alphas_bar
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
|
| 42 |
+
# The timesteps are aligned to 999...0, which is different from other implementations,
|
| 43 |
+
# but I think this implementation is more reasonable in theory.
|
| 44 |
+
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
| 45 |
+
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
| 46 |
+
if num_inference_steps == 1:
|
| 47 |
+
self.timesteps = torch.Tensor([max_timestep])
|
| 48 |
+
else:
|
| 49 |
+
step_length = max_timestep / (num_inference_steps - 1)
|
| 50 |
+
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
| 54 |
+
if self.prediction_type == "epsilon":
|
| 55 |
+
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
| 56 |
+
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
| 57 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
| 58 |
+
elif self.prediction_type == "v_prediction":
|
| 59 |
+
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
| 60 |
+
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
| 61 |
+
prev_sample = sample * weight_x + model_output * weight_e
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
| 64 |
+
return prev_sample
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 68 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 69 |
+
if isinstance(timestep, torch.Tensor):
|
| 70 |
+
timestep = timestep.cpu()
|
| 71 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 72 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 73 |
+
alpha_prod_t_prev = 1.0
|
| 74 |
+
else:
|
| 75 |
+
timestep_prev = int(self.timesteps[timestep_id + 1])
|
| 76 |
+
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
| 77 |
+
|
| 78 |
+
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 82 |
+
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
| 83 |
+
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
| 84 |
+
return noise_pred
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 88 |
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
| 89 |
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
| 90 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 91 |
+
return noisy_samples
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def training_target(self, sample, noise, timestep):
|
| 95 |
+
if self.prediction_type == "epsilon":
|
| 96 |
+
return noise
|
| 97 |
+
else:
|
| 98 |
+
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
| 99 |
+
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
| 100 |
+
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
| 101 |
+
return target
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def training_weight(self, timestep):
|
| 105 |
+
return 1.0
|
diffsynth/schedulers/flow_match.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FlowMatchScheduler():
|
| 6 |
+
|
| 7 |
+
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
| 8 |
+
self.num_train_timesteps = num_train_timesteps
|
| 9 |
+
self.shift = shift
|
| 10 |
+
self.sigma_max = sigma_max
|
| 11 |
+
self.sigma_min = sigma_min
|
| 12 |
+
self.inverse_timesteps = inverse_timesteps
|
| 13 |
+
self.extra_one_step = extra_one_step
|
| 14 |
+
self.reverse_sigmas = reverse_sigmas
|
| 15 |
+
self.set_timesteps(num_inference_steps)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
|
| 19 |
+
if shift is not None:
|
| 20 |
+
self.shift = shift
|
| 21 |
+
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
| 22 |
+
if self.extra_one_step:
|
| 23 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
| 24 |
+
else:
|
| 25 |
+
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
| 26 |
+
if self.inverse_timesteps:
|
| 27 |
+
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
| 28 |
+
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
| 29 |
+
if self.reverse_sigmas:
|
| 30 |
+
self.sigmas = 1 - self.sigmas
|
| 31 |
+
self.timesteps = self.sigmas * self.num_train_timesteps
|
| 32 |
+
if training:
|
| 33 |
+
x = self.timesteps
|
| 34 |
+
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
| 35 |
+
y_shifted = y - y.min()
|
| 36 |
+
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
| 37 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def step(self, model_output, timestep, sample, to_final=False):
|
| 41 |
+
if isinstance(timestep, torch.Tensor):
|
| 42 |
+
timestep = timestep.cpu()
|
| 43 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 44 |
+
sigma = self.sigmas[timestep_id]
|
| 45 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 46 |
+
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
| 47 |
+
else:
|
| 48 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 49 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 50 |
+
return prev_sample
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 54 |
+
if isinstance(timestep, torch.Tensor):
|
| 55 |
+
timestep = timestep.cpu()
|
| 56 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 57 |
+
sigma = self.sigmas[timestep_id]
|
| 58 |
+
model_output = (sample - sample_stablized) / sigma
|
| 59 |
+
return model_output
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 63 |
+
if isinstance(timestep, torch.Tensor):
|
| 64 |
+
timestep = timestep.cpu()
|
| 65 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 66 |
+
sigma = self.sigmas[timestep_id]
|
| 67 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 68 |
+
return sample
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def training_target(self, sample, noise, timestep):
|
| 72 |
+
target = noise - sample
|
| 73 |
+
return target
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def training_weight(self, timestep):
|
| 77 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
| 78 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 79 |
+
return weights
|
diffsynth/vram_management/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .layers import *
|
diffsynth/vram_management/layers.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, copy
|
| 2 |
+
from ..models.utils import init_weights_on_device
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def cast_to(weight, dtype, device):
|
| 6 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 7 |
+
r.copy_(weight)
|
| 8 |
+
return r
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AutoWrappedModule(torch.nn.Module):
|
| 12 |
+
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
| 15 |
+
self.offload_dtype = offload_dtype
|
| 16 |
+
self.offload_device = offload_device
|
| 17 |
+
self.onload_dtype = onload_dtype
|
| 18 |
+
self.onload_device = onload_device
|
| 19 |
+
self.computation_dtype = computation_dtype
|
| 20 |
+
self.computation_device = computation_device
|
| 21 |
+
self.state = 0
|
| 22 |
+
|
| 23 |
+
def offload(self):
|
| 24 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 25 |
+
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 26 |
+
self.state = 0
|
| 27 |
+
|
| 28 |
+
def onload(self):
|
| 29 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 30 |
+
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 31 |
+
self.state = 1
|
| 32 |
+
|
| 33 |
+
def forward(self, *args, **kwargs):
|
| 34 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
| 35 |
+
module = self.module
|
| 36 |
+
else:
|
| 37 |
+
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
| 38 |
+
return module(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AutoWrappedLinear(torch.nn.Linear):
|
| 42 |
+
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
| 43 |
+
with init_weights_on_device(device=torch.device("meta")):
|
| 44 |
+
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
| 45 |
+
self.weight = module.weight
|
| 46 |
+
self.bias = module.bias
|
| 47 |
+
self.offload_dtype = offload_dtype
|
| 48 |
+
self.offload_device = offload_device
|
| 49 |
+
self.onload_dtype = onload_dtype
|
| 50 |
+
self.onload_device = onload_device
|
| 51 |
+
self.computation_dtype = computation_dtype
|
| 52 |
+
self.computation_device = computation_device
|
| 53 |
+
self.state = 0
|
| 54 |
+
|
| 55 |
+
def offload(self):
|
| 56 |
+
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 57 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 58 |
+
self.state = 0
|
| 59 |
+
|
| 60 |
+
def onload(self):
|
| 61 |
+
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
| 62 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 63 |
+
self.state = 1
|
| 64 |
+
|
| 65 |
+
def forward(self, x, *args, **kwargs):
|
| 66 |
+
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
| 67 |
+
weight, bias = self.weight, self.bias
|
| 68 |
+
else:
|
| 69 |
+
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
| 70 |
+
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
| 71 |
+
return torch.nn.functional.linear(x, weight, bias)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
| 75 |
+
for name, module in model.named_children():
|
| 76 |
+
for source_module, target_module in module_map.items():
|
| 77 |
+
if isinstance(module, source_module):
|
| 78 |
+
num_param = sum(p.numel() for p in module.parameters())
|
| 79 |
+
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
| 80 |
+
module_config_ = overflow_module_config
|
| 81 |
+
else:
|
| 82 |
+
module_config_ = module_config
|
| 83 |
+
module_ = target_module(module, **module_config_)
|
| 84 |
+
setattr(model, name, module_)
|
| 85 |
+
total_num_param += num_param
|
| 86 |
+
break
|
| 87 |
+
else:
|
| 88 |
+
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
| 89 |
+
return total_num_param
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
| 93 |
+
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
| 94 |
+
model.vram_management_enabled = True
|
| 95 |
+
|
infer.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffsynth import ModelManager, WanVideoPipeline
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import argparse
|
| 5 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2Model
|
| 6 |
+
import librosa
|
| 7 |
+
import os
|
| 8 |
+
import subprocess
|
| 9 |
+
import cv2
|
| 10 |
+
from model import FantasyTalkingAudioConditionModel
|
| 11 |
+
from utils import save_video, get_audio_features, resize_image_by_longest_edge
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 17 |
+
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--wan_model_dir",
|
| 20 |
+
type=str,
|
| 21 |
+
default="./models/Wan2.1-I2V-14B-720P",
|
| 22 |
+
required=False,
|
| 23 |
+
help="The dir of the Wan I2V 14B model.",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--fantasytalking_model_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default="./models/fantasytalking_model.ckpt",
|
| 29 |
+
required=False,
|
| 30 |
+
help="The .ckpt path of fantasytalking model.",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--wav2vec_model_dir",
|
| 34 |
+
type=str,
|
| 35 |
+
default="./models/wav2vec2-base-960h",
|
| 36 |
+
required=False,
|
| 37 |
+
help="The dir of wav2vec model.",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--image_path",
|
| 42 |
+
type=str,
|
| 43 |
+
default="./assets/images/woman.png",
|
| 44 |
+
required=False,
|
| 45 |
+
help="The path of the image.",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--audio_path",
|
| 50 |
+
type=str,
|
| 51 |
+
default="./assets/audios/woman.wav",
|
| 52 |
+
required=False,
|
| 53 |
+
help="The path of the audio.",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--prompt",
|
| 57 |
+
type=str,
|
| 58 |
+
default="A woman is talking.",
|
| 59 |
+
required=False,
|
| 60 |
+
help="prompt.",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--output_dir",
|
| 64 |
+
type=str,
|
| 65 |
+
default="./output",
|
| 66 |
+
help="Dir to save the model.",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--image_size",
|
| 70 |
+
type=int,
|
| 71 |
+
default=512,
|
| 72 |
+
help="The image will be resized proportionally to this size.",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--audio_scale",
|
| 76 |
+
type=float,
|
| 77 |
+
default=1.0,
|
| 78 |
+
help="Audio condition injection weight",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--prompt_cfg_scale",
|
| 82 |
+
type=float,
|
| 83 |
+
default=5.0,
|
| 84 |
+
required=False,
|
| 85 |
+
help="Prompt cfg scale",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--audio_cfg_scale",
|
| 89 |
+
type=float,
|
| 90 |
+
default=5.0,
|
| 91 |
+
required=False,
|
| 92 |
+
help="Audio cfg scale",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--max_num_frames",
|
| 96 |
+
type=int,
|
| 97 |
+
default=81,
|
| 98 |
+
required=False,
|
| 99 |
+
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated."
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--fps",
|
| 103 |
+
type=int,
|
| 104 |
+
default=23,
|
| 105 |
+
required=False,
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--num_persistent_param_in_dit",
|
| 109 |
+
type=int,
|
| 110 |
+
default=None,
|
| 111 |
+
required=False,
|
| 112 |
+
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required"
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--seed",
|
| 116 |
+
type=int,
|
| 117 |
+
default=1111,
|
| 118 |
+
required=False,
|
| 119 |
+
)
|
| 120 |
+
args = parser.parse_args()
|
| 121 |
+
return args
|
| 122 |
+
|
| 123 |
+
def load_models(args):
|
| 124 |
+
# Load Wan I2V models
|
| 125 |
+
model_manager = ModelManager(device="cpu")
|
| 126 |
+
model_manager.load_models(
|
| 127 |
+
[
|
| 128 |
+
[
|
| 129 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
|
| 130 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
|
| 131 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
|
| 132 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
|
| 133 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
|
| 134 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
|
| 135 |
+
f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
|
| 136 |
+
],
|
| 137 |
+
f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
| 138 |
+
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
|
| 139 |
+
f"{args.wan_model_dir}/Wan2.1_VAE.pth",
|
| 140 |
+
],
|
| 141 |
+
# torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
| 142 |
+
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
| 143 |
+
)
|
| 144 |
+
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
| 145 |
+
|
| 146 |
+
# Load FantasyTalking weights
|
| 147 |
+
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
| 148 |
+
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
|
| 149 |
+
|
| 150 |
+
# You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
| 151 |
+
pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
|
| 152 |
+
|
| 153 |
+
# Load wav2vec models
|
| 154 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
|
| 155 |
+
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
|
| 156 |
+
|
| 157 |
+
return pipe,fantasytalking,wav2vec_processor,wav2vec
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def main(args,pipe,fantasytalking,wav2vec_processor,wav2vec):
|
| 162 |
+
os.makedirs(args.output_dir,exist_ok=True)
|
| 163 |
+
|
| 164 |
+
duration = librosa.get_duration(filename=args.audio_path)
|
| 165 |
+
num_frames = min(int(args.fps*duration//4)*4+5,args.max_num_frames)
|
| 166 |
+
|
| 167 |
+
audio_wav2vec_fea = get_audio_features(wav2vec,wav2vec_processor,args.audio_path,args.fps,num_frames)
|
| 168 |
+
image = resize_image_by_longest_edge(args.image_path,args.image_size)
|
| 169 |
+
width, height = image.size
|
| 170 |
+
|
| 171 |
+
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
|
| 172 |
+
pos_idx_ranges = fantasytalking.split_audio_sequence(audio_proj_fea.size(1),num_frames=num_frames)
|
| 173 |
+
audio_proj_split,audio_context_lens = fantasytalking.split_tensor_with_padding(audio_proj_fea,pos_idx_ranges,expand_length=4) # [b,21,9+8,768]
|
| 174 |
+
|
| 175 |
+
# Image-to-video
|
| 176 |
+
video_audio = pipe(
|
| 177 |
+
prompt=args.prompt,
|
| 178 |
+
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 179 |
+
input_image=image,
|
| 180 |
+
width=width,
|
| 181 |
+
height=height,
|
| 182 |
+
num_frames=num_frames,
|
| 183 |
+
num_inference_steps=30,
|
| 184 |
+
seed=args.seed, tiled=True,
|
| 185 |
+
audio_scale=args.audio_scale,
|
| 186 |
+
cfg_scale = args.prompt_cfg_scale,
|
| 187 |
+
audio_cfg_scale=args.audio_cfg_scale,
|
| 188 |
+
audio_proj=audio_proj_split,
|
| 189 |
+
audio_context_lens=audio_context_lens,
|
| 190 |
+
latents_num_frames=(num_frames-1)//4+1
|
| 191 |
+
)
|
| 192 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 193 |
+
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
| 194 |
+
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
|
| 195 |
+
|
| 196 |
+
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
| 197 |
+
final_command = [
|
| 198 |
+
"ffmpeg", "-y",
|
| 199 |
+
"-i", save_path_tmp,
|
| 200 |
+
"-i", args.audio_path,
|
| 201 |
+
"-c:v", "libx264",
|
| 202 |
+
"-c:a", "aac",
|
| 203 |
+
"-shortest",
|
| 204 |
+
save_path
|
| 205 |
+
]
|
| 206 |
+
subprocess.run(final_command, check=True)
|
| 207 |
+
os.remove(save_path_tmp)
|
| 208 |
+
return save_path
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
args = parse_args()
|
| 212 |
+
pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args)
|
| 213 |
+
|
| 214 |
+
main(args,pipe,fantasytalking,wav2vec_processor,wav2vec)
|
model.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffsynth.models.wan_video_dit import flash_attention, WanModel
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from safetensors import safe_open
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AudioProjModel(nn.Module):
|
| 10 |
+
def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.cross_attention_dim = cross_attention_dim
|
| 13 |
+
self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
|
| 14 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 15 |
+
|
| 16 |
+
def forward(self, audio_embeds):
|
| 17 |
+
context_tokens = self.proj(audio_embeds)
|
| 18 |
+
context_tokens = self.norm(context_tokens)
|
| 19 |
+
return context_tokens # [B,L,C]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class WanCrossAttentionProcessor(nn.Module):
|
| 23 |
+
def __init__(self, context_dim, hidden_dim):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.context_dim = context_dim
|
| 27 |
+
self.hidden_dim = hidden_dim
|
| 28 |
+
|
| 29 |
+
self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
| 30 |
+
self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
|
| 31 |
+
|
| 32 |
+
nn.init.zeros_(self.k_proj.weight)
|
| 33 |
+
nn.init.zeros_(self.v_proj.weight)
|
| 34 |
+
|
| 35 |
+
def __call__(
|
| 36 |
+
self,
|
| 37 |
+
attn: nn.Module,
|
| 38 |
+
x: torch.Tensor,
|
| 39 |
+
context: torch.Tensor,
|
| 40 |
+
context_lens: torch.Tensor,
|
| 41 |
+
audio_proj: torch.Tensor,
|
| 42 |
+
audio_context_lens: torch.Tensor,
|
| 43 |
+
latents_num_frames: int = 21,
|
| 44 |
+
audio_scale: float = 1.0,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
x: [B, L1, C].
|
| 48 |
+
context: [B, L2, C].
|
| 49 |
+
context_lens: [B].
|
| 50 |
+
audio_proj: [B, 21, L3, C]
|
| 51 |
+
audio_context_lens: [B*21].
|
| 52 |
+
"""
|
| 53 |
+
context_img = context[:, :257]
|
| 54 |
+
context = context[:, 257:]
|
| 55 |
+
b, n, d = x.size(0), attn.num_heads, attn.head_dim
|
| 56 |
+
|
| 57 |
+
# compute query, key, value
|
| 58 |
+
q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
|
| 59 |
+
k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
|
| 60 |
+
v = attn.v(context).view(b, -1, n, d)
|
| 61 |
+
k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
|
| 62 |
+
v_img = attn.v_img(context_img).view(b, -1, n, d)
|
| 63 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| 64 |
+
# compute attention
|
| 65 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 66 |
+
x = x.flatten(2)
|
| 67 |
+
img_x = img_x.flatten(2)
|
| 68 |
+
|
| 69 |
+
if len(audio_proj.shape) == 4:
|
| 70 |
+
audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
|
| 71 |
+
ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
| 72 |
+
ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
|
| 73 |
+
audio_x = flash_attention(
|
| 74 |
+
audio_q, ip_key, ip_value, k_lens=audio_context_lens
|
| 75 |
+
)
|
| 76 |
+
audio_x = audio_x.view(b, q.size(1), n, d)
|
| 77 |
+
audio_x = audio_x.flatten(2)
|
| 78 |
+
elif len(audio_proj.shape) == 3:
|
| 79 |
+
ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
|
| 80 |
+
ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
|
| 81 |
+
audio_x = flash_attention(q, ip_key, ip_value, k_lens=audio_context_lens)
|
| 82 |
+
audio_x = audio_x.flatten(2)
|
| 83 |
+
# output
|
| 84 |
+
x = x + img_x + audio_x * audio_scale
|
| 85 |
+
x = attn.o(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class FantasyTalkingAudioConditionModel(nn.Module):
|
| 90 |
+
def __init__(self, wan_dit: WanModel, audio_in_dim: int, audio_proj_dim: int):
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.audio_in_dim = audio_in_dim
|
| 94 |
+
self.audio_proj_dim = audio_proj_dim
|
| 95 |
+
|
| 96 |
+
# audio proj model
|
| 97 |
+
self.proj_model = self.init_proj(self.audio_proj_dim)
|
| 98 |
+
self.set_audio_processor(wan_dit)
|
| 99 |
+
|
| 100 |
+
def init_proj(self, cross_attention_dim=5120):
|
| 101 |
+
proj_model = AudioProjModel(
|
| 102 |
+
audio_in_dim=self.audio_in_dim, cross_attention_dim=cross_attention_dim
|
| 103 |
+
)
|
| 104 |
+
return proj_model
|
| 105 |
+
|
| 106 |
+
def set_audio_processor(self, wan_dit):
|
| 107 |
+
attn_procs = {}
|
| 108 |
+
for name in wan_dit.attn_processors.keys():
|
| 109 |
+
attn_procs[name] = WanCrossAttentionProcessor(
|
| 110 |
+
context_dim=self.audio_proj_dim, hidden_dim=wan_dit.dim
|
| 111 |
+
)
|
| 112 |
+
wan_dit.set_attn_processor(attn_procs)
|
| 113 |
+
|
| 114 |
+
def load_audio_processor(self, ip_ckpt: str, wan_dit):
|
| 115 |
+
if os.path.splitext(ip_ckpt)[-1] == ".safetensors":
|
| 116 |
+
state_dict = {"proj_model": {}, "audio_processor": {}}
|
| 117 |
+
with safe_open(ip_ckpt, framework="pt", device="cpu") as f:
|
| 118 |
+
for key in f.keys():
|
| 119 |
+
if key.startswith("proj_model."):
|
| 120 |
+
state_dict["proj_model"][key.replace("proj_model.", "")] = (
|
| 121 |
+
f.get_tensor(key)
|
| 122 |
+
)
|
| 123 |
+
elif key.startswith("audio_processor."):
|
| 124 |
+
state_dict["audio_processor"][
|
| 125 |
+
key.replace("audio_processor.", "")
|
| 126 |
+
] = f.get_tensor(key)
|
| 127 |
+
else:
|
| 128 |
+
state_dict = torch.load(ip_ckpt, map_location="cpu")
|
| 129 |
+
self.proj_model.load_state_dict(state_dict["proj_model"])
|
| 130 |
+
wan_dit.load_state_dict(state_dict["audio_processor"], strict=False)
|
| 131 |
+
|
| 132 |
+
def get_proj_fea(self, audio_fea=None):
|
| 133 |
+
|
| 134 |
+
return self.proj_model(audio_fea) if audio_fea is not None else None
|
| 135 |
+
|
| 136 |
+
def split_audio_sequence(self, audio_proj_length, num_frames=81):
|
| 137 |
+
"""
|
| 138 |
+
Map the audio feature sequence to corresponding latent frame slices.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
audio_proj_length (int): The total length of the audio feature sequence
|
| 142 |
+
(e.g., 173 in audio_proj[1, 173, 768]).
|
| 143 |
+
num_frames (int): The number of video frames in the training data (default: 81).
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
|
| 147 |
+
(within the audio feature sequence) corresponding to a latent frame.
|
| 148 |
+
"""
|
| 149 |
+
# Average number of tokens per original video frame
|
| 150 |
+
tokens_per_frame = audio_proj_length / num_frames
|
| 151 |
+
|
| 152 |
+
# Each latent frame covers 4 video frames, and we want the center
|
| 153 |
+
tokens_per_latent_frame = tokens_per_frame * 4
|
| 154 |
+
half_tokens = int(tokens_per_latent_frame / 2)
|
| 155 |
+
|
| 156 |
+
pos_indices = []
|
| 157 |
+
for i in range(int((num_frames - 1) / 4) + 1):
|
| 158 |
+
if i == 0:
|
| 159 |
+
pos_indices.append(0)
|
| 160 |
+
else:
|
| 161 |
+
start_token = tokens_per_frame * ((i - 1) * 4 + 1)
|
| 162 |
+
end_token = tokens_per_frame * (i * 4 + 1)
|
| 163 |
+
center_token = int((start_token + end_token) / 2) - 1
|
| 164 |
+
pos_indices.append(center_token)
|
| 165 |
+
|
| 166 |
+
# Build index ranges centered around each position
|
| 167 |
+
pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
|
| 168 |
+
|
| 169 |
+
# Adjust the first range to avoid negative start index
|
| 170 |
+
pos_idx_ranges[0] = [
|
| 171 |
+
-(half_tokens * 2 - pos_idx_ranges[1][0]),
|
| 172 |
+
pos_idx_ranges[1][0],
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
return pos_idx_ranges
|
| 176 |
+
|
| 177 |
+
def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
|
| 178 |
+
"""
|
| 179 |
+
Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
|
| 180 |
+
if the range exceeds the input boundaries.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
|
| 184 |
+
pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
|
| 185 |
+
expand_length (int): Number of tokens to expand on both sides of each subsequence.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
|
| 189 |
+
Each element is a padded subsequence.
|
| 190 |
+
k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
|
| 191 |
+
Useful for ignoring padding tokens in attention masks.
|
| 192 |
+
"""
|
| 193 |
+
pos_idx_ranges = [
|
| 194 |
+
[idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
|
| 195 |
+
]
|
| 196 |
+
sub_sequences = []
|
| 197 |
+
seq_len = input_tensor.size(1) # 173
|
| 198 |
+
max_valid_idx = seq_len - 1 # 172
|
| 199 |
+
k_lens_list = []
|
| 200 |
+
for start, end in pos_idx_ranges:
|
| 201 |
+
# Calculate the fill amount
|
| 202 |
+
pad_front = max(-start, 0)
|
| 203 |
+
pad_back = max(end - max_valid_idx, 0)
|
| 204 |
+
|
| 205 |
+
# Calculate the start and end indices of the valid part
|
| 206 |
+
valid_start = max(start, 0)
|
| 207 |
+
valid_end = min(end, max_valid_idx)
|
| 208 |
+
|
| 209 |
+
# Extract the valid part
|
| 210 |
+
if valid_start <= valid_end:
|
| 211 |
+
valid_part = input_tensor[:, valid_start : valid_end + 1, :]
|
| 212 |
+
else:
|
| 213 |
+
valid_part = input_tensor.new_zeros(
|
| 214 |
+
(1, 0, input_tensor.size(2))
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# In the sequence dimension (the 1st dimension) perform padding
|
| 218 |
+
padded_subseq = F.pad(
|
| 219 |
+
valid_part,
|
| 220 |
+
(0, 0, 0, pad_back + pad_front, 0, 0),
|
| 221 |
+
mode="constant",
|
| 222 |
+
value=0,
|
| 223 |
+
)
|
| 224 |
+
k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
|
| 225 |
+
|
| 226 |
+
sub_sequences.append(padded_subseq)
|
| 227 |
+
return torch.stack(sub_sequences, dim=1), torch.tensor(
|
| 228 |
+
k_lens_list, dtype=torch.long
|
| 229 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision
|
| 3 |
+
cupy-cuda12x
|
| 4 |
+
transformers==4.46.2
|
| 5 |
+
controlnet-aux==0.0.7
|
| 6 |
+
imageio
|
| 7 |
+
imageio[ffmpeg]
|
| 8 |
+
safetensors
|
| 9 |
+
einops
|
| 10 |
+
sentencepiece
|
| 11 |
+
protobuf
|
| 12 |
+
modelscope
|
| 13 |
+
ftfy
|
| 14 |
+
librosa
|
utils.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, librosa
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def resize_image_by_longest_edge(image_path, target_size):
|
| 9 |
+
image = Image.open(image_path).convert("RGB")
|
| 10 |
+
width, height = image.size
|
| 11 |
+
scale = target_size / max(width, height)
|
| 12 |
+
new_size = (int(width * scale), int(height * scale))
|
| 13 |
+
return image.resize(new_size, Image.LANCZOS)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
| 17 |
+
writer = imageio.get_writer(
|
| 18 |
+
save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
|
| 19 |
+
)
|
| 20 |
+
for frame in tqdm(frames, desc="Saving video"):
|
| 21 |
+
frame = np.array(frame)
|
| 22 |
+
writer.append_data(frame)
|
| 23 |
+
writer.close()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
|
| 27 |
+
sr = 16000
|
| 28 |
+
audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
|
| 29 |
+
|
| 30 |
+
start_time = 0
|
| 31 |
+
# end_time = (0 + (num_frames - 1) * 1) / fps
|
| 32 |
+
end_time = num_frames / fps
|
| 33 |
+
|
| 34 |
+
start_sample = int(start_time * sr)
|
| 35 |
+
end_sample = int(end_time * sr)
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
audio_segment = audio_input[start_sample:end_sample]
|
| 39 |
+
except:
|
| 40 |
+
audio_segment = audio_input
|
| 41 |
+
|
| 42 |
+
input_values = audio_processor(
|
| 43 |
+
audio_segment, sampling_rate=sample_rate, return_tensors="pt"
|
| 44 |
+
).input_values.to("cuda")
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
fea = wav2vec(input_values).last_hidden_state
|
| 48 |
+
|
| 49 |
+
return fea
|