Spaces:
Runtime error
Runtime error
| # Copyright Alibaba Inc. All Rights Reserved. | |
| import argparse | |
| import os | |
| import subprocess | |
| from datetime import datetime | |
| from pathlib import Path | |
| import cv2 | |
| import librosa | |
| import torch | |
| from PIL import Image | |
| from transformers import Wav2Vec2Model, Wav2Vec2Processor | |
| # 注意:以下导入在完整版本中需要 | |
| # from diffsynth import ModelManager, WanVideoPipeline | |
| from model import FantasyTalkingAudioConditionModel | |
| from utils import get_audio_features, resize_image_by_longest_edge, save_video | |
| def parse_args(): | |
| """解析命令行参数""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--wan_model_dir", | |
| type=str, | |
| default="./models/Wan2.1-I2V-14B-720P", | |
| help="Wan I2V 14B模型目录" | |
| ) | |
| parser.add_argument( | |
| "--fantasytalking_model_path", | |
| type=str, | |
| default="./models/fantasytalking_model.ckpt", | |
| help="FantasyTalking模型路径" | |
| ) | |
| parser.add_argument( | |
| "--wav2vec_model_dir", | |
| type=str, | |
| default="./models/wav2vec2-base-960h", | |
| help="Wav2Vec模型目录" | |
| ) | |
| parser.add_argument( | |
| "--image_path", | |
| type=str, | |
| default="./assets/images/woman.png", | |
| help="输入图像路径" | |
| ) | |
| parser.add_argument( | |
| "--audio_path", | |
| type=str, | |
| default="./assets/audios/woman.wav", | |
| help="输入音频路径" | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="A woman is talking.", | |
| help="提示词" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./output", | |
| help="输出目录" | |
| ) | |
| parser.add_argument( | |
| "--image_size", | |
| type=int, | |
| default=512, | |
| help="图像尺寸" | |
| ) | |
| parser.add_argument( | |
| "--audio_scale", | |
| type=float, | |
| default=1.0, | |
| help="音频条件注入权重" | |
| ) | |
| parser.add_argument( | |
| "--prompt_cfg_scale", | |
| type=float, | |
| default=5.0, | |
| help="提示词CFG比例" | |
| ) | |
| parser.add_argument( | |
| "--audio_cfg_scale", | |
| type=float, | |
| default=5.0, | |
| help="音频CFG比例" | |
| ) | |
| parser.add_argument( | |
| "--max_num_frames", | |
| type=int, | |
| default=81, | |
| help="最大帧数" | |
| ) | |
| parser.add_argument( | |
| "--num_inference_steps", | |
| type=int, | |
| default=30, | |
| help="推理步数" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=1247, | |
| help="随机种子" | |
| ) | |
| parser.add_argument( | |
| "--fps", | |
| type=int, | |
| default=24, | |
| help="帧率" | |
| ) | |
| parser.add_argument( | |
| "--num_persistent_param_in_dit", | |
| type=int, | |
| default=3_000_000_000, # 16GB GPU优化 | |
| help="DiT中持久参数数量,用于VRAM管理" | |
| ) | |
| return parser.parse_args() | |
| def load_models(args): | |
| """加载模型""" | |
| print("正在加载模型...") | |
| # 在完整版本中,这里会加载实际的模型 | |
| # model_manager = ModelManager(device="cpu") | |
| # model_manager.load_models([...]) | |
| # pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") | |
| # 模拟模型加载 | |
| pipe = None | |
| fantasytalking = None | |
| wav2vec_processor = None | |
| wav2vec = None | |
| print("模型加载完成(演示模式)") | |
| return pipe, fantasytalking, wav2vec_processor, wav2vec | |
| def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec): | |
| """主推理函数""" | |
| print(f"输入图像: {args.image_path}") | |
| print(f"输入音频: {args.audio_path}") | |
| print(f"提示词: {args.prompt}") | |
| # 创建输出目录 | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # 在完整版本中,这里会执行实际的推理 | |
| print("开始推理...") | |
| # 模拟输出路径 | |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_path = f"{args.output_dir}/output_{current_time}.mp4" | |
| print(f"输出将保存到: {output_path}") | |
| print("推理完成(演示模式)") | |
| return output_path | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args) | |
| main(args, pipe, fantasytalking, wav2vec_processor, wav2vec) | |