dangthr commited on
Commit
160e694
·
verified ·
1 Parent(s): 7e8f8ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -227
app.py CHANGED
@@ -23,14 +23,12 @@ import subprocess
23
  from moviepy.editor import VideoFileClip
24
  from huggingface_hub import snapshot_download
25
  import shutil
 
 
26
 
27
- try:
28
- from audio_separator.separator import Separator
29
- except:
30
- print("Unable to use vocal separation feature. Please install audio-separator[gpu].")
31
-
32
  if torch.cuda.is_available():
33
- device = "cuda"
34
  if torch.cuda.get_device_capability()[0] >= 8:
35
  dtype = torch.bfloat16
36
  else:
@@ -47,37 +45,112 @@ def filter_kwargs(cls, kwargs):
47
  return filtered_kwargs
48
 
49
  def load_transformer_model(model_version, repo_root):
50
- """
51
- 根据选择的模型版本加载对应的transformer模型
52
-
53
- Args:
54
- model_version (str): 模型版本,"square" 或 "rec_vec"
55
- repo_root (str): 模型根目录
56
-
57
- Returns:
58
- WanTransformer3DFantasyModel: 加载的transformer模型
59
- """
60
- if model_version == "square":
61
- transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
62
- elif model_version == "rec_vec":
63
- transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
64
- else:
65
- # 默认使用square版本
66
- transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
67
-
68
- print(f"正在加载模型: {transformer_path}")
69
-
70
  if os.path.exists(transformer_path):
71
  state_dict = torch.load(transformer_path, map_location="cpu")
72
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
73
- return transformer_path, state_dict
 
 
 
74
  else:
75
- raise FileNotFoundError(f"模型文件不存在: {transformer_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- def generate_video(
78
- transformer3d,
79
- pipeline,
80
- repo_root,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  GPU_memory_mode="model_cpu_offload",
82
  teacache_threshold=0,
83
  num_skip_start_steps=5,
@@ -95,32 +168,37 @@ def generate_video(
95
  fps=25,
96
  overlap_window_length=10,
97
  seed_param=42,
98
- overlapping_weight_scheme="uniform",
99
  ):
 
100
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
101
-
102
  if seed_param < 0:
103
  seed = random.randint(0, np.iinfo(np.int32).max)
104
  else:
105
  seed = seed_param
106
-
107
- print(f"使用种子: {seed}")
108
 
 
 
 
 
 
 
 
109
  if GPU_memory_mode == "sequential_cpu_offload":
110
- replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
111
  transformer3d.freqs = transformer3d.freqs.to(device=device)
112
  pipeline.enable_sequential_cpu_offload(device=device)
113
  elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
114
- convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
115
  convert_weight_dtype_wrapper(transformer3d, dtype)
116
  pipeline.enable_model_cpu_offload(device=device)
117
  elif GPU_memory_mode == "model_cpu_offload":
118
  pipeline.enable_model_cpu_offload(device=device)
119
  else:
120
  pipeline.to(device=device)
121
-
 
122
  if teacache_threshold > 0:
123
- pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
124
  coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
125
  pipeline.transformer.enable_teacache(
126
  coefficients,
@@ -129,15 +207,12 @@ def generate_video(
129
  num_skip_start_steps=num_skip_start_steps,
130
  )
131
 
 
132
  with torch.no_grad():
133
- clip_sample_n_frames = 81
134
- vae = pipeline.vae
135
  video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
136
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
137
  sr = 16000
138
  vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
139
-
140
- print("开始生成视频...")
141
  sample = pipeline(
142
  prompt,
143
  num_frames=video_length,
@@ -161,195 +236,43 @@ def generate_video(
161
  seed=seed,
162
  overlapping_weight_scheme=overlapping_weight_scheme,
163
  ).videos
164
-
165
  os.makedirs("outputs", exist_ok=True)
166
  video_path = os.path.join("outputs", f"{timestamp}.mp4")
167
  save_videos_grid(sample, video_path, fps=fps)
168
  output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
169
-
170
- print("合成音频...")
171
  subprocess.run([
172
- "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
173
- "-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
174
  output_video_with_audio
175
  ], check=True)
176
-
177
- return output_video_with_audio, seed
178
-
179
- def audio_extractor(video_path, output_dir="outputs"):
180
- """从视频中提取音频"""
181
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
182
- os.makedirs(output_dir, exist_ok=True)
183
- out_wav = os.path.abspath(os.path.join(output_dir, f"{timestamp}.wav"))
184
- video = VideoFileClip(video_path)
185
- audio = video.audio
186
- audio.write_audiofile(out_wav, codec="pcm_s16le")
187
- return out_wav
188
 
189
- def vocal_separation(audio_path, audio_separator_model_file, output_dir="outputs"):
190
- """人声分离"""
191
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
192
- os.makedirs(output_dir, exist_ok=True)
193
-
194
- audio_separator = Separator(
195
- output_dir=os.path.abspath(os.path.join(output_dir, timestamp)),
196
- output_single_stem="vocals",
197
- model_file_dir=os.path.dirname(audio_separator_model_file),
198
- )
199
- audio_separator.load_model(os.path.basename(audio_separator_model_file))
200
- assert audio_separator.model_instance is not None, "Fail to load audio separate model."
201
- outputs = audio_separator.separate(audio_path)
202
- vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0])
203
- destination_file = os.path.abspath(os.path.join(output_dir, f"{timestamp}.wav"))
204
- shutil.copy(vocal_audio_file, destination_file)
205
- os.remove(vocal_audio_file)
206
- return destination_file
207
 
208
  def main():
209
- parser = argparse.ArgumentParser(description="StableAvatar 命令行推理工具")
210
-
211
- # 主要参数
212
- parser.add_argument("--prompt", type=str, default="", help="提示词")
213
- parser.add_argument("--input_image", type=str, required=True, help="输入图片路径或URL")
214
- parser.add_argument("--input_audio", type=str, required=True, help="输入音频路径或URL")
215
- parser.add_argument("--seed", type=int, default=42, help="随机种子,-1为随机")
216
-
217
- # 模型参数
218
- parser.add_argument("--model_version", type=str, default="square", choices=["square", "rec_vec"], help="模型版本")
219
- parser.add_argument("--gpu_memory_mode", type=str, default="model_cpu_offload",
220
- choices=["Normal", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"],
221
- help="GPU内存模式")
222
- parser.add_argument("--teacache_threshold", type=float, default=0, help="TeaCache阈值")
223
- parser.add_argument("--num_skip_start_steps", type=int, default=5, help="跳过开始步数")
224
-
225
- # 生成参数
226
- parser.add_argument("--negative_prompt", type=str,
227
- default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
228
- help="负面提示词")
229
- parser.add_argument("--width", type=int, default=512, help="视频宽度")
230
- parser.add_argument("--height", type=int, default=512, help="视频高度")
231
- parser.add_argument("--guidance_scale", type=float, default=6.0, help="引导尺度")
232
- parser.add_argument("--num_inference_steps", type=int, default=50, help="推理步数")
233
- parser.add_argument("--text_guide_scale", type=float, default=3.0, help="文本引导尺度")
234
- parser.add_argument("--audio_guide_scale", type=float, default=5.0, help="音频引导尺度")
235
- parser.add_argument("--motion_frame", type=int, default=25, help="运动帧数")
236
- parser.add_argument("--fps", type=int, default=25, help="帧率")
237
- parser.add_argument("--overlap_window_length", type=int, default=10, help="重叠窗口长度")
238
- parser.add_argument("--overlapping_weight_scheme", type=str, default="uniform", choices=["uniform", "log"], help="重叠权重方案")
239
-
240
- # 工具功能
241
- parser.add_argument("--extract_audio", type=str, help="从视频提取音频,提供视频路径")
242
- parser.add_argument("--separate_vocal", type=str, help="人声分离,提供音频路径")
243
- parser.add_argument("--output_dir", type=str, default="outputs", help="输出目录")
244
-
245
  args = parser.parse_args()
246
-
247
- # 下载模型
248
- print("正在下载模型...")
249
- REPO_ID = "FrancisRing/StableAvatar"
250
- repo_root = snapshot_download(
251
- repo_id=REPO_ID,
252
- allow_patterns=[
253
- "StableAvatar-1.3B/*",
254
- "Wan2.1-Fun-V1.1-1.3B-InP/*",
255
- "wav2vec2-base-960h/*",
256
- "assets/**",
257
- "Kim_Vocal_2.onnx",
258
- ],
259
- )
260
-
261
- # 工具功能
262
- if args.extract_audio:
263
- print(f"从视频提取音频: {args.extract_audio}")
264
- output_audio = audio_extractor(args.extract_audio, args.output_dir)
265
- print(f"音频已保存到: {output_audio}")
266
- return
267
-
268
- if args.separate_vocal:
269
- print(f"人声分离: {args.separate_vocal}")
270
- audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
271
- output_audio = vocal_separation(args.separate_vocal, audio_separator_model_file, args.output_dir)
272
- print(f"分离后的人声已保存到: {output_audio}")
273
- return
274
-
275
- # 检查必要参数
276
- if not args.input_image or not args.input_audio:
277
- print("错误: 必须提供 --input_image 和 --input_audio 参数")
278
- return
279
-
280
- # 初始化模型
281
- print("正在初始化模型...")
282
- pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
283
- pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
284
-
285
- config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
286
- sampler_name = "Flow"
287
-
288
- # 加载各个组件
289
- tokenizer = AutoTokenizer.from_pretrained(
290
- os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer'))
291
- )
292
-
293
- text_encoder = WanT5EncoderModel.from_pretrained(
294
- os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
295
- additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
296
- low_cpu_mem_usage=True,
297
- torch_dtype=dtype,
298
- )
299
- text_encoder = text_encoder.eval()
300
-
301
- vae = AutoencoderKLWan.from_pretrained(
302
- os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
303
- additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
304
- )
305
-
306
- wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
307
- wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
308
-
309
- clip_image_encoder = CLIPModel.from_pretrained(
310
- os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))
311
- )
312
- clip_image_encoder = clip_image_encoder.eval()
313
-
314
- transformer3d = WanTransformer3DFantasyModel.from_pretrained(
315
- os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
316
- transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
317
- low_cpu_mem_usage=False,
318
- torch_dtype=dtype,
319
- )
320
-
321
- # 加载指定版本的transformer模型
322
- transformer_path, state_dict = load_transformer_model(args.model_version, repo_root)
323
- m, u = transformer3d.load_state_dict(state_dict, strict=False)
324
- print(f"模型加载成功: {transformer_path}")
325
- print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
326
-
327
- Choosen_Scheduler = {
328
- "Flow": FlowMatchEulerDiscreteScheduler,
329
- }[sampler_name]
330
-
331
- scheduler = Choosen_Scheduler(
332
- **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
333
- )
334
-
335
- pipeline = WanI2VTalkingInferenceLongPipeline(
336
- tokenizer=tokenizer,
337
- text_encoder=text_encoder,
338
- vae=vae,
339
- transformer=transformer3d,
340
- clip_image_encoder=clip_image_encoder,
341
- scheduler=scheduler,
342
- wav2vec_processor=wav2vec_processor,
343
- wav2vec=wav2vec,
344
- )
345
-
346
- # 生成视频
347
- print("开始生成...")
348
- output_video, used_seed = generate_video(
349
- transformer3d=transformer3d,
350
- pipeline=pipeline,
351
- repo_root=repo_root,
352
- GPU_memory_mode=args.gpu_memory_mode,
353
  teacache_threshold=args.teacache_threshold,
354
  num_skip_start_steps=args.num_skip_start_steps,
355
  image_path=args.input_image,
@@ -366,12 +289,13 @@ def main():
366
  fps=args.fps,
367
  overlap_window_length=args.overlap_window_length,
368
  seed_param=args.seed,
369
- overlapping_weight_scheme=args.overlapping_weight_scheme,
370
  )
371
-
372
- print(f"生成完成!")
373
- print(f"输出视频: {output_video}")
374
- print(f"使用种子: {used_seed}")
 
375
 
376
  if __name__ == "__main__":
377
  main()
 
23
  from moviepy.editor import VideoFileClip
24
  from huggingface_hub import snapshot_download
25
  import shutil
26
+ import requests
27
+ import uuid
28
 
29
+ # Device and dtype setup
 
 
 
 
30
  if torch.cuda.is_available():
31
+ device = "cuda"
32
  if torch.cuda.get_device_capability()[0] >= 8:
33
  dtype = torch.bfloat16
34
  else:
 
45
  return filtered_kwargs
46
 
47
  def load_transformer_model(model_version, repo_root):
48
+ transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", f"transformer3d-{model_version}.pt")
49
+ print(f"Loading model: {transformer_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if os.path.exists(transformer_path):
51
  state_dict = torch.load(transformer_path, map_location="cpu")
52
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
53
+ m, u = transformer3d.load_state_dict(state_dict, strict=False)
54
+ print(f"Model loaded successfully: {transformer_path}")
55
+ print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
56
+ return transformer3d
57
  else:
58
+ print(f"Error: Model file does not exist: {transformer_path}")
59
+ return None
60
+
61
+ def download_file(url, local_path):
62
+ """Download file from URL to local path"""
63
+ try:
64
+ response = requests.get(url, stream=True)
65
+ response.raise_for_status()
66
+ with open(local_path, 'wb') as f:
67
+ for chunk in response.iter_content(chunk_size=8192):
68
+ f.write(chunk)
69
+ return local_path
70
+ except Exception as e:
71
+ print(f"Error downloading file from {url}: {e}")
72
+ return None
73
+
74
+ def prepare_input_file(input_path, file_type="image"):
75
+ """Handle local or remote file inputs"""
76
+ if input_path.startswith("http://") or input_path.startswith("https://"):
77
+ ext = ".png" if file_type == "image" else ".wav"
78
+ local_path = os.path.join("temp", f"{uuid.uuid4()}{ext}")
79
+ os.makedirs("temp", exist_ok=True)
80
+ return download_file(input_path, local_path)
81
+ elif os.path.exists(input_path):
82
+ return input_path
83
+ else:
84
+ print(f"Error: {file_type.capitalize()} file {input_path} does not exist")
85
+ return None
86
+
87
+ # Initialize model paths
88
+ REPO_ID = "FrancisRing/StableAvatar"
89
+ repo_root = snapshot_download(
90
+ repo_id=REPO_ID,
91
+ allow_patterns=[
92
+ "StableAvatar-1.3B/*",
93
+ "Wan2.1-Fun-V1.1-1.3B-InP/*",
94
+ "wav2vec2-base-960h/*",
95
+ "assets/**",
96
+ "Kim_Vocal_2.onnx",
97
+ ],
98
+ )
99
+ pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
100
+ pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
101
+ audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
102
+
103
+ # Load configuration and models
104
+ config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
105
+ sampler_name = "Flow"
106
+ clip_sample_n_frames = 81
107
 
108
+ tokenizer = AutoTokenizer.from_pretrained(
109
+ os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer'))
110
+ )
111
+ text_encoder = WanT5EncoderModel.from_pretrained(
112
+ os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
113
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
114
+ low_cpu_mem_usage=True,
115
+ torch_dtype=dtype,
116
+ ).eval()
117
+ vae = AutoencoderKLWan.from_pretrained(
118
+ os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
119
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
120
+ )
121
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
122
+ wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
123
+ clip_image_encoder = CLIPModel.from_pretrained(
124
+ os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))
125
+ ).eval()
126
+ transformer3d = WanTransformer3DFantasyModel.from_pretrained(
127
+ os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
128
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
129
+ low_cpu_mem_usage=False,
130
+ torch_dtype=dtype,
131
+ )
132
+
133
+ # Load default transformer model
134
+ load_transformer_model("square", repo_root)
135
+
136
+ # Initialize scheduler and pipeline
137
+ scheduler_dict = {"Flow": FlowMatchEulerDiscreteScheduler}
138
+ Choosen_Scheduler = scheduler_dict[sampler_name]
139
+ scheduler = Choosen_Scheduler(
140
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
141
+ )
142
+ pipeline = WanI2VTalkingInferenceLongPipeline(
143
+ tokenizer=tokenizer,
144
+ text_encoder=text_encoder,
145
+ vae=vae,
146
+ transformer=transformer3d,
147
+ clip_image_encoder=clip_image_encoder,
148
+ scheduler=scheduler,
149
+ wav2vec_processor=wav2vec_processor,
150
+ wav2vec=wav2vec,
151
+ )
152
+
153
+ def generate(
154
  GPU_memory_mode="model_cpu_offload",
155
  teacache_threshold=0,
156
  num_skip_start_steps=5,
 
168
  fps=25,
169
  overlap_window_length=10,
170
  seed_param=42,
171
+ overlapping_weight_scheme="uniform"
172
  ):
173
+ global pipeline, transformer3d
174
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
 
175
  if seed_param < 0:
176
  seed = random.randint(0, np.iinfo(np.int32).max)
177
  else:
178
  seed = seed_param
 
 
179
 
180
+ # Handle input files
181
+ image_path = prepare_input_file(image_path, "image")
182
+ audio_path = prepare_input_file(audio_path, "audio")
183
+ if not image_path or not audio_path:
184
+ return None, None, "Error: Invalid input file paths"
185
+
186
+ # Configure pipeline based on GPU memory mode
187
  if GPU_memory_mode == "sequential_cpu_offload":
188
+ replace_parameters_by_name(transformer3d, ["modulation"], device=device)
189
  transformer3d.freqs = transformer3d.freqs.to(device=device)
190
  pipeline.enable_sequential_cpu_offload(device=device)
191
  elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
192
+ convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation"])
193
  convert_weight_dtype_wrapper(transformer3d, dtype)
194
  pipeline.enable_model_cpu_offload(device=device)
195
  elif GPU_memory_mode == "model_cpu_offload":
196
  pipeline.enable_model_cpu_offload(device=device)
197
  else:
198
  pipeline.to(device=device)
199
+
200
+ # Enable TeaCache if specified
201
  if teacache_threshold > 0:
 
202
  coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
203
  pipeline.transformer.enable_teacache(
204
  coefficients,
 
207
  num_skip_start_steps=num_skip_start_steps,
208
  )
209
 
210
+ # Perform inference
211
  with torch.no_grad():
 
 
212
  video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
213
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
214
  sr = 16000
215
  vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
 
 
216
  sample = pipeline(
217
  prompt,
218
  num_frames=video_length,
 
236
  seed=seed,
237
  overlapping_weight_scheme=overlapping_weight_scheme,
238
  ).videos
 
239
  os.makedirs("outputs", exist_ok=True)
240
  video_path = os.path.join("outputs", f"{timestamp}.mp4")
241
  save_videos_grid(sample, video_path, fps=fps)
242
  output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
 
 
243
  subprocess.run([
244
+ "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
245
+ "-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
246
  output_video_with_audio
247
  ], check=True)
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  def main():
252
+ parser = argparse.ArgumentParser(description="StableAvatar Inference Script")
253
+ parser.add_argument("--prompt", type=str, default="", help="Text prompt for generation")
254
+ parser.add_argument("--seed", type=int, default=42, help="Random seed, -1 for random")
255
+ parser.add_argument("--input_image", type=str, required=True, help="Path or URL to input image (e.g., ./image.png or https://example.com/image.png)")
256
+ parser.add_argument("--input_audio", type=str, required=True, help="Path or URL to input audio (e.g., ./audio.wav or https://example.com/audio.wav)")
257
+ parser.add_argument("--GPU_memory_mode", type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload", "model_cpu_offload_and_qfloat8", "sequential_cpu_offload"], help="GPU memory mode")
258
+ parser.add_argument("--teacache_threshold", type=float, default=0, help="TeaCache threshold, 0 to disable")
259
+ parser.add_argument("--num_skip_start_steps", type=int, default=5, help="Number of start steps to skip")
260
+ parser.add_argument("--negative_prompt", type=str, default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", help="Negative prompt")
261
+ parser.add_argument("--width", type=int, default=512, help="Output video width")
262
+ parser.add_argument("--height", type=int, default=512, help="Output video height")
263
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="Guidance scale")
264
+ parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
265
+ parser.add_argument("--text_guide_scale", type=float, default=3.0, help="Text guidance scale")
266
+ parser.add_argument("--audio_guide_scale", type=float, default=5.0, help="Audio guidance scale")
267
+ parser.add_argument("--motion_frame", type=int, default=25, help="Motion frame")
268
+ parser.add_argument("--fps", type=int, default=25, help="Frames per second")
269
+ parser.add_argument("--overlap_window_length", type=int, default=10, help="Overlap window length")
270
+ parser.add_argument("--overlapping_weight_scheme", type=str, default="uniform", choices=["uniform", "log"], help="Overlapping weight scheme")
271
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  args = parser.parse_args()
273
+
274
+ video_path, seed, message = generate(
275
+ GPU_memory_mode=args.GPU_memory_mode,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  teacache_threshold=args.teacache_threshold,
277
  num_skip_start_steps=args.num_skip_start_steps,
278
  image_path=args.input_image,
 
289
  fps=args.fps,
290
  overlap_window_length=args.overlap_window_length,
291
  seed_param=args.seed,
292
+ overlapping_weight_scheme=args.overlapping_weight_scheme
293
  )
294
+
295
+ if video_path:
296
+ print(f"{message}\nSeed: {seed}")
297
+ else:
298
+ print("Generation failed.")
299
 
300
  if __name__ == "__main__":
301
  main()