dangthr commited on
Commit
2f849ec
·
verified ·
1 Parent(s): d167b3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -179
app.py CHANGED
@@ -20,13 +20,17 @@ import datetime
20
  import random
21
  import math
22
  import subprocess
 
23
  from huggingface_hub import snapshot_download
24
- import requests
25
  import shutil
26
 
27
- # --- 全域設定 ---
 
 
 
 
28
  if torch.cuda.is_available():
29
- device = "cuda"
30
  if torch.cuda.get_device_capability()[0] >= 8:
31
  dtype = torch.bfloat16
32
  else:
@@ -36,236 +40,304 @@ else:
36
  dtype = torch.float32
37
 
38
  def filter_kwargs(cls, kwargs):
39
- """過濾掉不屬於類別建構函式的關鍵字參數"""
40
  import inspect
41
  sig = inspect.signature(cls.__init__)
42
  valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
43
  filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
44
  return filtered_kwargs
45
 
46
- # <<< 核心修正:加入 Git LFS 指標檔的偵測與處理 >>>
47
- def is_lfs_pointer(file_path):
48
- """檢查一個檔案是否為 Git LFS 指標檔。"""
49
- try:
50
- # 指標檔通常很小 (< 2KB)
51
- if os.path.getsize(file_path) > 2048:
52
- return False
53
- with open(file_path, 'r') as f:
54
- first_line = f.readline().strip()
55
- # 指標檔的第一行通常是 'version https://git-lfs.github.com/spec/v1'
56
- if 'git-lfs' in first_line:
57
- return True
58
- except (OSError, UnicodeDecodeError):
59
- # 如果檔案無法讀取或不是文字檔,那它就不是指標檔
60
- return False
61
- return False
62
-
63
- def resolve_path(user_path, repo_root):
64
  """
65
- 以正確的優先級解析檔案路徑,並處理 Git LFS 指標檔問題。
66
- """
67
- # 檢查本地路徑是否存在
68
- if os.path.exists(user_path):
69
- # 檢查它是否為一個無效的 LFS 指標檔
70
- if is_lfs_pointer(user_path):
71
- print(f"警告:本地檔案 '{user_path}' 是一個 Git LFS 指標檔。將嘗試從 Hugging Face 快取中尋找完整檔案。")
72
- # 如果是指標檔,則忽略它,並在下一步從 HF 快取中尋找
73
- else:
74
- # 如果是個正常檔案,直接使用
75
- print(f"找到本地檔案: {os.path.abspath(user_path)}")
76
- return os.path.abspath(user_path)
77
 
78
- # 如果本地檔案不存在或是 LFS 指標檔,則從 HF 快取目錄中尋找
79
- potential_repo_path = os.path.join(repo_root, user_path)
80
- if os.path.exists(potential_repo_path):
81
- print(f"在 Hugging Face 快取目錄中找到檔案: {potential_repo_path}")
82
- return potential_repo_path
83
-
84
- return None
85
- # <<< 修正結束 >>>
86
-
87
- def setup_models(repo_root, model_version):
88
- """載入所有必要的模型和設定"""
89
- pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
90
- pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
91
-
92
- config_relative_path = "deepspeed_config/wan2.1/wan_civitai.yaml"
93
- config_path = resolve_path(config_relative_path, repo_root)
94
- if not config_path:
95
- raise FileNotFoundError(f"設定檔 '{config_relative_path}' 在當前目錄或 HF 快取中都找不到。")
96
-
97
- print(f"正在從 {config_path} 載入設定...")
98
- config = OmegaConf.load(config_path)
99
- sampler_name = "Flow"
100
-
101
- print("正在載入 Tokenizer...")
102
- tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')))
103
-
104
- print("正在載入 Text Encoder...")
105
- text_encoder = WanT5EncoderModel.from_pretrained(
106
- os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
107
- additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
108
- low_cpu_mem_usage=True,
109
- torch_dtype=dtype,
110
- ).eval()
111
-
112
- print("正在載入 VAE...")
113
- vae = AutoencoderKLWan.from_pretrained(
114
- os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
115
- additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
116
- )
117
-
118
- print("正在載入 Wav2Vec...")
119
- wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
120
- wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
121
 
122
- print("正在載入 CLIP Image Encoder...")
123
- clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))).eval()
 
 
124
 
125
- print("正在載入 Transformer 3D 基礎模型...")
126
- transformer3d = WanTransformer3DFantasyModel.from_pretrained(
127
- os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
128
- transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
129
- low_cpu_mem_usage=False,
130
- torch_dtype=dtype,
131
- )
132
-
133
  if model_version == "square":
134
  transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
135
- else: # rec_vec
136
  transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
137
-
 
 
 
 
 
138
  if os.path.exists(transformer_path):
139
- print(f"正在從 {transformer_path} 載入 StableAvatar 權重...")
140
  state_dict = torch.load(transformer_path, map_location="cpu")
141
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
142
  m, u = transformer3d.load_state_dict(state_dict, strict=False)
143
- print(f"StableAvatar 權重載入成功。 Missing keys: {len(m)}; Unexpected keys: {len(u)}")
 
 
144
  else:
145
- raise FileNotFoundError(f"找不到 StableAvatar 權重檔案:{transformer_path}。請確保模型已完整下載。")
146
-
147
- scheduler_class = { "Flow": FlowMatchEulerDiscreteScheduler }[sampler_name]
148
- scheduler = scheduler_class(**filter_kwargs(scheduler_class, OmegaConf.to_container(config['scheduler_kwargs'])))
149
-
150
- print("正在建立 Pipeline...")
151
- pipeline = WanI2VTalkingInferenceLongPipeline(
152
- tokenizer=tokenizer, text_encoder=text_encoder, vae=vae,
153
- transformer=transformer3d, clip_image_encoder=clip_image_encoder,
154
- scheduler=scheduler, wav2vec_processor=wav2vec_processor, wav2vec=wav2vec,
155
- )
156
-
157
- return pipeline, transformer3d, vae
158
 
159
- def run_inference(
160
- pipeline, transformer3d, vae, image_path, audio_path, prompt,
161
- negative_prompt, seed, output_filename, gpu_memory_mode="model_cpu_offload",
162
- width=512, height=512, num_inference_steps=50, fps=25, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  ):
164
- """執行推理以生成影片。"""
165
- if seed < 0:
 
 
166
  seed = random.randint(0, np.iinfo(np.int32).max)
167
- print(f"使用的種子: {seed}")
 
168
 
169
- if gpu_memory_mode == "sequential_cpu_offload":
 
 
 
 
 
 
 
170
  pipeline.enable_sequential_cpu_offload(device=device)
171
- elif gpu_memory_mode == "model_cpu_offload":
 
 
 
 
172
  pipeline.enable_model_cpu_offload(device=device)
173
  else:
174
  pipeline.to(device=device)
 
 
 
 
 
 
 
 
 
175
 
176
  with torch.no_grad():
177
- print("正在準備輸入資料...")
178
- video_length = 81
179
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
180
-
181
  sr = 16000
182
- vocal_input, _ = librosa.load(audio_path, sr=sr)
183
 
184
- print("Pipeline 執行中... 這可能需要一些時間。")
185
  sample = pipeline(
186
- prompt, num_frames=video_length, negative_prompt=negative_prompt,
187
- width=width, height=height, guidance_scale=6.0,
188
- generator=torch.Generator().manual_seed(seed), num_inference_steps=num_inference_steps,
189
- video=input_video, mask_video=input_video_mask, clip_image=clip_image,
190
- text_guide_scale=3.0, audio_guide_scale=5.0, vocal_input_values=vocal_input,
191
- motion_frame=25, fps=fps, sr=sr, cond_file_path=image_path,
192
- overlap_window_length=10, seed=seed, overlapping_weight_scheme="uniform",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  ).videos
194
 
195
- print("正在儲存影片...")
196
  os.makedirs("outputs", exist_ok=True)
197
- video_path = os.path.join("outputs", f"{output_filename}.mp4")
198
  save_videos_grid(sample, video_path, fps=fps)
 
199
 
200
- output_video_with_audio = os.path.join("outputs", f"{output_filename}_audio.mp4")
201
-
202
- print("正在將音訊合併到影片中...")
203
  subprocess.run([
204
- "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
205
- "-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
206
  output_video_with_audio
207
  ], check=True)
208
 
209
- os.remove(video_path)
210
-
211
- print(f"✅ 生成完成!影片已儲存至: {output_video_with_audio}")
212
  return output_video_with_audio, seed
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def main():
215
- parser = argparse.ArgumentParser(description="StableAvatar 命令列推理工具")
216
- parser.add_argument('--prompt', type=str, default="a beautiful woman is talking, masterpiece, best quality", help='正面提示詞')
217
- parser.add_argument('--input_image', type=str, default="example_case/case-6/reference.png", help='輸入圖片的路徑')
218
- parser.add_argument('--input_audio', type=str, default="example_case/case-6/audio.wav", help='輸入音訊的路徑')
219
- parser.add_argument('--seed', type=int, default=42, help='隨機種子,-1 表示隨機')
220
- parser.add_argument('--negative_prompt', type=str, default="vivid color, static, blur details, text, style, painting, picture, still, gray, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, malformed, deformed, bad anatomy, fused fingers, still image, messy background, many people in the background, walking backwards", help='負面提示詞')
221
- parser.add_argument('--width', type=int, default=512, help='影片寬度')
222
- parser.add_argument('--height', type=int, default=512, help='影片高度')
223
- parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步數')
224
- parser.add_argument('--fps', type=int, default=25, help='影片幀率')
225
- parser.add_argument('--gpu_memory_mode', type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload"], help='GPU 記憶體優化模式')
226
- parser.add_argument('--model_version', type=str, default="square", choices=["square", "rec_vec"], help='StableAvatar 模型版本')
227
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- print("--- 步驟 1: 正在檢查並下載模型與設定檔 ---")
 
 
230
  repo_root = snapshot_download(
231
- repo_id="FrancisRing/StableAvatar",
232
  allow_patterns=[
233
  "StableAvatar-1.3B/*",
234
  "Wan2.1-Fun-V1.1-1.3B-InP/*",
235
  "wav2vec2-base-960h/*",
236
- "deepspeed_config/**",
237
- "example_case/**"
238
  ],
239
  )
240
- print("模型檔案已準備就緒。")
241
-
242
- print("\n--- 步驟 2: 正在解析輸入檔案路徑 ---")
243
- final_image_path = resolve_path(args.input_image, repo_root)
244
- if not final_image_path:
245
- print(f"錯誤:無法找到圖片檔案 {args.input_image}")
246
- return
247
 
248
- final_audio_path = resolve_path(args.input_audio, repo_root)
249
- if not final_audio_path:
250
- print(f"錯誤:無法找到音訊檔案 {args.input_audio}")
251
- return
252
 
253
- print("\n--- 步驟 3: 正在載入模型 ---")
254
- pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
255
- print("模型載入完成。")
256
 
257
- print("\n--- 步驟 4: 開始執行推理 ---")
258
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
259
- run_inference(
260
- pipeline=pipeline, transformer3d=transformer3d, vae=vae,
261
- image_path=final_image_path, audio_path=final_audio_path,
262
- prompt=args.prompt, negative_prompt=args.negative_prompt,
263
- seed=args.seed, output_filename=f"output_{timestamp}",
264
- gpu_memory_mode=args.gpu_memory_mode, width=args.width,
265
- height=args.height, num_inference_steps=args.num_inference_steps,
266
- fps=args.fps
 
 
 
 
 
 
 
 
 
 
 
 
267
  )
268
 
269
- if __name__ == "__main__":
270
- main()
271
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import random
21
  import math
22
  import subprocess
23
+ from moviepy.editor import VideoFileClip
24
  from huggingface_hub import snapshot_download
 
25
  import shutil
26
 
27
+ try:
28
+ from audio_separator.separator import Separator
29
+ except:
30
+ print("Unable to use vocal separation feature. Please install audio-separator[gpu].")
31
+
32
  if torch.cuda.is_available():
33
+ device = "cuda"
34
  if torch.cuda.get_device_capability()[0] >= 8:
35
  dtype = torch.bfloat16
36
  else:
 
40
  dtype = torch.float32
41
 
42
  def filter_kwargs(cls, kwargs):
 
43
  import inspect
44
  sig = inspect.signature(cls.__init__)
45
  valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
46
  filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
47
  return filtered_kwargs
48
 
49
+ def load_transformer_model(model_version):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
+ 根据选择的模型版本加载对应的transformer模型
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ Args:
54
+ model_version (str): 模型版本,"square" 或 "rec_vec"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ Returns:
57
+ WanTransformer3DFantasyModel: 加载的transformer模型
58
+ """
59
+ global transformer3d
60
 
 
 
 
 
 
 
 
 
61
  if model_version == "square":
62
  transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
63
+ elif model_version == "rec_vec":
64
  transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
65
+ else:
66
+ # 默认使用square版本
67
+ transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
68
+
69
+ print(f"正在加载模型: {transformer_path}")
70
+
71
  if os.path.exists(transformer_path):
 
72
  state_dict = torch.load(transformer_path, map_location="cpu")
73
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
74
  m, u = transformer3d.load_state_dict(state_dict, strict=False)
75
+ print(f"模型加载成功: {transformer_path}")
76
+ print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
77
+ return transformer3d
78
  else:
79
+ print(f"错误:模型文件不存在: {transformer_path}")
80
+ return None
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ def generate_video(
83
+ GPU_memory_mode="model_cpu_offload",
84
+ teacache_threshold=0.0,
85
+ num_skip_start_steps=5,
86
+ image_path=None,
87
+ audio_path=None,
88
+ prompt="",
89
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
90
+ width=512,
91
+ height=512,
92
+ guidance_scale=6.0,
93
+ num_inference_steps=50,
94
+ text_guide_scale=3.0,
95
+ audio_guide_scale=5.0,
96
+ motion_frame=25,
97
+ fps=25,
98
+ overlap_window_length=10,
99
+ seed_param=42,
100
+ overlapping_weight_scheme="uniform",
101
  ):
102
+ global pipeline, transformer3d
103
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
104
+
105
+ if seed_param < 0:
106
  seed = random.randint(0, np.iinfo(np.int32).max)
107
+ else:
108
+ seed = seed_param
109
 
110
+ print(f"使用种子: {seed}")
111
+ print(f"输入图片: {image_path}")
112
+ print(f"输入音频: {audio_path}")
113
+ print(f"提示词: {prompt}")
114
+
115
+ if GPU_memory_mode == "sequential_cpu_offload":
116
+ replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
117
+ transformer3d.freqs = transformer3d.freqs.to(device=device)
118
  pipeline.enable_sequential_cpu_offload(device=device)
119
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
120
+ convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
121
+ convert_weight_dtype_wrapper(transformer3d, dtype)
122
+ pipeline.enable_model_cpu_offload(device=device)
123
+ elif GPU_memory_mode == "model_cpu_offload":
124
  pipeline.enable_model_cpu_offload(device=device)
125
  else:
126
  pipeline.to(device=device)
127
+
128
+ if teacache_threshold > 0:
129
+ coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
130
+ pipeline.transformer.enable_teacache(
131
+ coefficients,
132
+ num_inference_steps,
133
+ teacache_threshold,
134
+ num_skip_start_steps=num_skip_start_steps,
135
+ )
136
 
137
  with torch.no_grad():
138
+ video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
 
139
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
 
140
  sr = 16000
141
+ vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
142
 
143
+ print("开始生成视频...")
144
  sample = pipeline(
145
+ prompt,
146
+ num_frames=video_length,
147
+ negative_prompt=negative_prompt,
148
+ width=width,
149
+ height=height,
150
+ guidance_scale=guidance_scale,
151
+ generator=torch.Generator().manual_seed(seed),
152
+ num_inference_steps=num_inference_steps,
153
+ video=input_video,
154
+ mask_video=input_video_mask,
155
+ clip_image=clip_image,
156
+ text_guide_scale=text_guide_scale,
157
+ audio_guide_scale=audio_guide_scale,
158
+ vocal_input_values=vocal_input,
159
+ motion_frame=motion_frame,
160
+ fps=fps,
161
+ sr=sr,
162
+ cond_file_path=image_path,
163
+ overlap_window_length=overlap_window_length,
164
+ seed=seed,
165
+ overlapping_weight_scheme=overlapping_weight_scheme,
166
  ).videos
167
 
 
168
  os.makedirs("outputs", exist_ok=True)
169
+ video_path = os.path.join("outputs", f"{timestamp}.mp4")
170
  save_videos_grid(sample, video_path, fps=fps)
171
+ output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
172
 
173
+ print("合并音频到视频...")
 
 
174
  subprocess.run([
175
+ "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
176
+ "-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
177
  output_video_with_audio
178
  ], check=True)
179
 
180
+ print(f"生成完成! 输出文件: {output_video_with_audio}")
 
 
181
  return output_video_with_audio, seed
182
 
183
+ def parse_args():
184
+ parser = argparse.ArgumentParser(description="StableAvatar Video Generation")
185
+ parser.add_argument("--prompt", type=str, default="", help="文本提示词")
186
+ parser.add_argument("--input_image", type=str, required=True, help="输入图片路径或URL")
187
+ parser.add_argument("--input_audio", type=str, required=True, help="输入音频路径或URL")
188
+ parser.add_argument("--seed", type=int, default=42, help="随机种子,-1为随机")
189
+ parser.add_argument("--negative_prompt", type=str,
190
+ default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
191
+ help="负面提示词")
192
+ parser.add_argument("--width", type=int, default=512, help="视频宽度")
193
+ parser.add_argument("--height", type=int, default=512, help="视频高度")
194
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="引导权重")
195
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="推理步数")
196
+ parser.add_argument("--text_guide_scale", type=float, default=3.0, help="文本引导权重")
197
+ parser.add_argument("--audio_guide_scale", type=float, default=5.0, help="音频引导权重")
198
+ parser.add_argument("--motion_frame", type=int, default=25, help="运动帧数")
199
+ parser.add_argument("--fps", type=int, default=25, help="视频帧率")
200
+ parser.add_argument("--overlap_window_length", type=int, default=10, help="重叠窗口长度")
201
+ parser.add_argument("--overlapping_weight_scheme", type=str, default="uniform",
202
+ choices=["uniform", "log"], help="重叠权重方案")
203
+ parser.add_argument("--GPU_memory_mode", type=str, default="model_cpu_offload",
204
+ choices=["Normal", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"],
205
+ help="GPU内存模式")
206
+ parser.add_argument("--teacache_threshold", type=float, default=0.0, help="TeaCache阈值")
207
+ parser.add_argument("--num_skip_start_steps", type=int, default=5, help="跳过开始步数")
208
+ parser.add_argument("--model_version", type=str, default="square",
209
+ choices=["square", "rec_vec"], help="模型版本")
210
+
211
+ return parser.parse_args()
212
+
213
+ def download_file(url, local_path):
214
+ """下载远程文件到本地"""
215
+ import urllib.request
216
+ print(f"正在下载 {url} 到 {local_path}")
217
+ urllib.request.urlretrieve(url, local_path)
218
+ print(f"下载完成: {local_path}")
219
+ return local_path
220
+
221
  def main():
222
+ args = parse_args()
223
+
224
+ # 处理输入文件(支持URL或本地路径)
225
+ image_path = args.input_image
226
+ audio_path = args.input_audio
227
+
228
+ # 如果是URL,下载到临时文件
229
+ if image_path.startswith('http'):
230
+ os.makedirs("temp", exist_ok=True)
231
+ local_image_path = f"temp/temp_image_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
232
+ image_path = download_file(image_path, local_image_path)
233
+
234
+ if audio_path.startswith('http'):
235
+ os.makedirs("temp", exist_ok=True)
236
+ audio_ext = os.path.splitext(audio_path)[1] or '.wav'
237
+ local_audio_path = f"temp/temp_audio_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}{audio_ext}"
238
+ audio_path = download_file(audio_path, local_audio_path)
239
+
240
+ # 检查文件是否存在
241
+ if not os.path.exists(image_path):
242
+ print(f"错误:图片文件不存在: {image_path}")
243
+ return
244
+
245
+ if not os.path.exists(audio_path):
246
+ print(f"错误:音频文件不存在: {audio_path}")
247
+ return
248
+
249
+ # 加载指定版本的模型
250
+ load_transformer_model(args.model_version)
251
+
252
+ # 生成视频
253
+ output_path, used_seed = generate_video(
254
+ GPU_memory_mode=args.GPU_memory_mode,
255
+ teacache_threshold=args.teacache_threshold,
256
+ num_skip_start_steps=args.num_skip_start_steps,
257
+ image_path=image_path,
258
+ audio_path=audio_path,
259
+ prompt=args.prompt,
260
+ negative_prompt=args.negative_prompt,
261
+ width=args.width,
262
+ height=args.height,
263
+ guidance_scale=args.guidance_scale,
264
+ num_inference_steps=args.num_inference_steps,
265
+ text_guide_scale=args.text_guide_scale,
266
+ audio_guide_scale=args.audio_guide_scale,
267
+ motion_frame=args.motion_frame,
268
+ fps=args.fps,
269
+ overlap_window_length=args.overlap_window_length,
270
+ seed_param=args.seed,
271
+ overlapping_weight_scheme=args.overlapping_weight_scheme,
272
+ )
273
+
274
+ print(f"\n=== 生成完成 ===")
275
+ print(f"输出文件: {output_path}")
276
+ print(f"使用种子: {used_seed}")
277
 
278
+ if __name__ == "__main__":
279
+ # 初始化模型和配置
280
+ REPO_ID = "FrancisRing/StableAvatar"
281
  repo_root = snapshot_download(
282
+ repo_id=REPO_ID,
283
  allow_patterns=[
284
  "StableAvatar-1.3B/*",
285
  "Wan2.1-Fun-V1.1-1.3B-InP/*",
286
  "wav2vec2-base-960h/*",
287
+ "assets/**",
288
+ "Kim_Vocal_2.onnx",
289
  ],
290
  )
291
+ pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
292
+ pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
 
 
 
 
 
293
 
294
+ # 人声分离 onnx
295
+ audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
 
 
296
 
297
+ config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
298
+ sampler_name = "Flow"
299
+ clip_sample_n_frames = 81
300
 
301
+ print("正在初始化模型...")
302
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), )
303
+ text_encoder = WanT5EncoderModel.from_pretrained(
304
+ os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
305
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
306
+ low_cpu_mem_usage=True,
307
+ torch_dtype=dtype,
308
+ )
309
+ text_encoder = text_encoder.eval()
310
+ vae = AutoencoderKLWan.from_pretrained(
311
+ os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
312
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
313
+ )
314
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
315
+ wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
316
+ clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), )
317
+ clip_image_encoder = clip_image_encoder.eval()
318
+ transformer3d = WanTransformer3DFantasyModel.from_pretrained(
319
+ os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
320
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
321
+ low_cpu_mem_usage=False,
322
+ torch_dtype=dtype,
323
  )
324
 
325
+ Choosen_Scheduler = scheduler_dict = {
326
+ "Flow": FlowMatchEulerDiscreteScheduler,
327
+ }[sampler_name]
328
+ scheduler = Choosen_Scheduler(
329
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
330
+ )
331
+ pipeline = WanI2VTalkingInferenceLongPipeline(
332
+ tokenizer=tokenizer,
333
+ text_encoder=text_encoder,
334
+ vae=vae,
335
+ transformer=transformer3d,
336
+ clip_image_encoder=clip_image_encoder,
337
+ scheduler=scheduler,
338
+ wav2vec_processor=wav2vec_processor,
339
+ wav2vec=wav2vec,
340
+ )
341
+
342
+ print("模型初始化完成!")
343
+ main()