Spaces:
Runtime error
Runtime error
| import torch | |
| import psutil | |
| import argparse | |
| import gradio as gr | |
| import os | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import load_image | |
| from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor | |
| from omegaconf import OmegaConf | |
| from wan.models.cache_utils import get_teacache_coefficients | |
| from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel | |
| from wan.models.wan_text_encoder import WanT5EncoderModel | |
| from wan.models.wan_vae import AutoencoderKLWan | |
| from wan.models.wan_image_encoder import CLIPModel | |
| from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline | |
| from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8 | |
| from wan.utils.utils import get_image_to_video_latent, save_videos_grid | |
| import numpy as np | |
| import librosa | |
| import datetime | |
| import random | |
| import math | |
| import subprocess | |
| from moviepy.editor import VideoFileClip | |
| from huggingface_hub import snapshot_download | |
| import shutil | |
| import spaces | |
| try: | |
| from audio_separator.separator import Separator | |
| except: | |
| print("Unable to use vocal separation feature. Please install audio-separator[gpu].") | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| if torch.cuda.get_device_capability()[0] >= 8: | |
| dtype = torch.bfloat16 | |
| else: | |
| dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| def filter_kwargs(cls, kwargs): | |
| import inspect | |
| sig = inspect.signature(cls.__init__) | |
| valid_params = set(sig.parameters.keys()) - {'self', 'cls'} | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} | |
| return filtered_kwargs | |
| def load_transformer_model(model_version): | |
| """ | |
| 根据选择的模型版本加载对应的transformer模型 | |
| Args: | |
| model_version (str): 模型版本,"square" 或 "rec_vec" | |
| Returns: | |
| WanTransformer3DFantasyModel: 加载的transformer模型 | |
| """ | |
| global transformer3d | |
| if model_version == "square": | |
| transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt") | |
| elif model_version == "rec_vec": | |
| transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt") | |
| else: | |
| # 默认使用square版本 | |
| transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt") | |
| print(f"正在加载模型: {transformer_path}") | |
| if os.path.exists(transformer_path): | |
| state_dict = torch.load(transformer_path, map_location="cpu") | |
| state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict | |
| m, u = transformer3d.load_state_dict(state_dict, strict=False) | |
| print(f"模型加载成功: {transformer_path}") | |
| print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}") | |
| return transformer3d | |
| else: | |
| print(f"错误:模型文件不存在: {transformer_path}") | |
| return None | |
| REPO_ID = "FrancisRing/StableAvatar" | |
| repo_root = snapshot_download( | |
| repo_id=REPO_ID, | |
| allow_patterns=[ | |
| "StableAvatar-1.3B/*", | |
| "Wan2.1-Fun-V1.1-1.3B-InP/*", | |
| "wav2vec2-base-960h/*", | |
| "assets/**", | |
| "Kim_Vocal_2.onnx", | |
| ], | |
| ) | |
| pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP") | |
| pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h") | |
| # 人声分离 onnx | |
| audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx") | |
| # model_path = "/datadrive/stableavatar/checkpoints" | |
| # pretrained_model_name_or_path = f"{model_path}/Wan2.1-Fun-V1.1-1.3B-InP" | |
| # pretrained_wav2vec_path = f"{model_path}/wav2vec2-base-960h" | |
| # transformer_path = f"{model_path}/StableAvatar-1.3B/transformer3d-square.pt" | |
| config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml") | |
| sampler_name = "Flow" | |
| clip_sample_n_frames = 81 | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), ) | |
| text_encoder = WanT5EncoderModel.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), | |
| additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), | |
| low_cpu_mem_usage=True, | |
| torch_dtype=dtype, | |
| ) | |
| text_encoder = text_encoder.eval() | |
| vae = AutoencoderKLWan.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), | |
| additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), | |
| ) | |
| wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path) | |
| wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu") | |
| clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), ) | |
| clip_image_encoder = clip_image_encoder.eval() | |
| transformer3d = WanTransformer3DFantasyModel.from_pretrained( | |
| os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), | |
| transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), | |
| low_cpu_mem_usage=False, | |
| torch_dtype=dtype, | |
| ) | |
| # 默认加载square版本模型 | |
| load_transformer_model("square") | |
| Choosen_Scheduler = scheduler_dict = { | |
| "Flow": FlowMatchEulerDiscreteScheduler, | |
| }[sampler_name] | |
| scheduler = Choosen_Scheduler( | |
| **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) | |
| ) | |
| pipeline = WanI2VTalkingInferenceLongPipeline( | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| transformer=transformer3d, | |
| clip_image_encoder=clip_image_encoder, | |
| scheduler=scheduler, | |
| wav2vec_processor=wav2vec_processor, | |
| wav2vec=wav2vec, | |
| ) | |
| def generate( | |
| GPU_memory_mode, | |
| teacache_threshold, | |
| num_skip_start_steps, | |
| image_path, | |
| audio_path, | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| text_guide_scale, | |
| audio_guide_scale, | |
| motion_frame, | |
| fps, | |
| overlap_window_length, | |
| seed_param, | |
| overlapping_weight_scheme, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| global pipeline, transformer3d | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| if seed_param<0: | |
| seed = random.randint(0, np.iinfo(np.int32).max) | |
| else: | |
| seed = seed_param | |
| if GPU_memory_mode == "sequential_cpu_offload": | |
| replace_parameters_by_name(transformer3d, ["modulation", ], device=device) | |
| transformer3d.freqs = transformer3d.freqs.to(device=device) | |
| pipeline.enable_sequential_cpu_offload(device=device) | |
| elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": | |
| convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ]) | |
| convert_weight_dtype_wrapper(transformer3d, dtype) | |
| pipeline.enable_model_cpu_offload(device=device) | |
| elif GPU_memory_mode == "model_cpu_offload": | |
| pipeline.enable_model_cpu_offload(device=device) | |
| else: | |
| pipeline.to(device=device) | |
| if teacache_threshold > 0: | |
| coefficients = get_teacache_coefficients(pretrained_model_name_or_path) | |
| pipeline.transformer.enable_teacache( | |
| coefficients, | |
| num_inference_steps, | |
| teacache_threshold, | |
| num_skip_start_steps=num_skip_start_steps, | |
| ) | |
| with torch.no_grad(): | |
| video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1 | |
| input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width]) | |
| sr = 16000 | |
| vocal_input, sample_rate = librosa.load(audio_path, sr=sr) | |
| sample = pipeline( | |
| prompt, | |
| num_frames=video_length, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator().manual_seed(seed), | |
| num_inference_steps=num_inference_steps, | |
| video=input_video, | |
| mask_video=input_video_mask, | |
| clip_image=clip_image, | |
| text_guide_scale=text_guide_scale, | |
| audio_guide_scale=audio_guide_scale, | |
| vocal_input_values=vocal_input, | |
| motion_frame=motion_frame, | |
| fps=fps, | |
| sr=sr, | |
| cond_file_path=image_path, | |
| overlap_window_length=overlap_window_length, | |
| seed=seed, | |
| overlapping_weight_scheme=overlapping_weight_scheme, | |
| ).videos | |
| os.makedirs("outputs", exist_ok=True) | |
| video_path = os.path.join("outputs", f"{timestamp}.mp4") | |
| save_videos_grid(sample, video_path, fps=fps) | |
| output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4") | |
| subprocess.run([ | |
| "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path, | |
| "-c:v", "copy", "-c:a", "aac", "-strict", "experimental", | |
| output_video_with_audio | |
| ], check=True) | |
| return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4 / 已生成outputs/{timestamp}.mp4" | |
| def exchange_width_height(width, height): | |
| return height, width, "✅ Width and Height Swapped / 宽高交换完毕" | |
| def adjust_width_height(image): | |
| image = load_image(image) | |
| width, height = image.size | |
| original_area = width * height | |
| default_area = 512*512 | |
| ratio = math.sqrt(original_area / default_area) | |
| width = width / ratio // 16 * 16 | |
| height = height / ratio // 16 * 16 | |
| return int(width), int(height), "✅ Adjusted Size Based on Image / 根据图片调整宽高" | |
| def audio_extractor(video_path): | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("outputs", exist_ok=True) # 确保目录存在 | |
| out_wav = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav")) | |
| video = VideoFileClip(video_path) | |
| audio = video.audio | |
| audio.write_audiofile(out_wav, codec="pcm_s16le") | |
| return out_wav, f"Generated {out_wav} / 已生成 {out_wav}", out_wav # ← 第3个返回给 gr.File | |
| def vocal_separation(audio_path): | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| os.makedirs("outputs", exist_ok=True) | |
| # audio_separator_model_file = "checkpoints/Kim_Vocal_2.onnx" | |
| audio_separator = Separator( | |
| output_dir=os.path.abspath(os.path.join("outputs", timestamp)), | |
| output_single_stem="vocals", | |
| model_file_dir=os.path.dirname(audio_separator_model_file), | |
| ) | |
| audio_separator.load_model(os.path.basename(audio_separator_model_file)) | |
| assert audio_separator.model_instance is not None, "Fail to load audio separate model." | |
| outputs = audio_separator.separate(audio_path) | |
| vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0]) | |
| destination_file = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav")) | |
| shutil.copy(vocal_audio_file, destination_file) | |
| os.remove(vocal_audio_file) | |
| return destination_file, f"Generated {destination_file} / 已生成 {destination_file}", destination_file | |
| def update_language(language): | |
| if language == "English": | |
| return { | |
| GPU_memory_mode: gr.Dropdown(label="GPU Memory Mode", info="Normal uses 25G VRAM, model_cpu_offload uses 13G VRAM"), | |
| teacache_threshold: gr.Slider(label="TeaCache Threshold", info="Recommended 0.1, 0 disables TeaCache acceleration"), | |
| num_skip_start_steps: gr.Slider(label="Skip Start Steps", info="Recommended 5"), | |
| model_version: gr.Dropdown(label="Model Version", choices=["square", "rec_vec"], value="square"), | |
| image_path: gr.Image(label="Upload Image"), | |
| audio_path: gr.Audio(label="Upload Audio"), | |
| prompt: gr.Textbox(label="Prompt"), | |
| negative_prompt: gr.Textbox(label="Negative Prompt"), | |
| generate_button: gr.Button("🎬 Start Generation"), | |
| width: gr.Slider(label="Width"), | |
| height: gr.Slider(label="Height"), | |
| exchange_button: gr.Button("🔄 Swap Width/Height"), | |
| adjust_button: gr.Button("Adjust Size Based on Image"), | |
| guidance_scale: gr.Slider(label="Guidance Scale"), | |
| num_inference_steps: gr.Slider(label="Sampling Steps (Recommended 50)"), | |
| text_guide_scale: gr.Slider(label="Text Guidance Scale"), | |
| audio_guide_scale: gr.Slider(label="Audio Guidance Scale"), | |
| motion_frame: gr.Slider(label="Motion Frame"), | |
| fps: gr.Slider(label="FPS"), | |
| overlap_window_length: gr.Slider(label="Overlap Window Length"), | |
| seed_param: gr.Number(label="Seed (positive integer, -1 for random)"), | |
| overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"), | |
| info: gr.Textbox(label="Status"), | |
| video_output: gr.Video(label="Generated Result"), | |
| seed_output: gr.Textbox(label="Seed"), | |
| video_path: gr.Video(label="Upload Video"), | |
| extractor_button: gr.Button("🎬 Start Extraction"), | |
| info2: gr.Textbox(label="Status"), | |
| audio_output: gr.Audio(label="Generated Result"), | |
| audio_path3: gr.Audio(label="Upload Audio"), | |
| separation_button: gr.Button("🎬 Start Separation"), | |
| info3: gr.Textbox(label="Status"), | |
| audio_output3: gr.Audio(label="Generated Result"), | |
| example_title: gr.Markdown(value="### Select the following example cases for testing:"), | |
| example1_label: gr.Markdown(value="**Example 1**"), | |
| example2_label: gr.Markdown(value="**Example 2**"), | |
| example3_label: gr.Markdown(value="**Example 3**"), | |
| example4_label: gr.Markdown(value="**Example 4**"), | |
| example5_label: gr.Markdown(value="**Example 5**"), | |
| example1_btn: gr.Button("🚀 Use Example 1", variant="secondary"), | |
| example2_btn: gr.Button("🚀 Use Example 2", variant="secondary"), | |
| example3_btn: gr.Button("🚀 Use Example 3", variant="secondary"), | |
| example4_btn: gr.Button("🚀 Use Example 4", variant="secondary"), | |
| example5_btn: gr.Button("🚀 Use Example 5", variant="secondary"), | |
| parameter_settings_title: gr.Accordion(label="Parameter Settings", open=True), | |
| example_cases_title: gr.Accordion(label="Example Cases", open=True), | |
| stableavatar_title: gr.TabItem(label="StableAvatar"), | |
| audio_extraction_title: gr.TabItem(label="Audio Extraction"), | |
| vocal_separation_title: gr.TabItem(label="Vocal Separation") | |
| } | |
| else: | |
| return { | |
| GPU_memory_mode: gr.Dropdown(label="显存模式", info="Normal占用25G显存,model_cpu_offload占用13G显存"), | |
| teacache_threshold: gr.Slider(label="teacache threshold", info="推荐参数0.1,0为禁用teacache加速"), | |
| num_skip_start_steps: gr.Slider(label="跳过开始步数", info="推荐参数5"), | |
| model_version: gr.Dropdown(label="模型版本", choices=["square", "rec_vec"], value="square"), | |
| image_path: gr.Image(label="上传图片"), | |
| audio_path: gr.Audio(label="上传音频"), | |
| prompt: gr.Textbox(label="提示词"), | |
| negative_prompt: gr.Textbox(label="负面提示词"), | |
| generate_button: gr.Button("🎬 开始生成"), | |
| width: gr.Slider(label="宽度"), | |
| height: gr.Slider(label="高度"), | |
| exchange_button: gr.Button("🔄 交换宽高"), | |
| adjust_button: gr.Button("根据图片调整宽高"), | |
| guidance_scale: gr.Slider(label="guidance scale"), | |
| num_inference_steps: gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50), | |
| text_guide_scale: gr.Slider(label="text guidance scale"), | |
| audio_guide_scale: gr.Slider(label="audio guidance scale"), | |
| motion_frame: gr.Slider(label="motion frame"), | |
| fps: gr.Slider(label="帧率"), | |
| overlap_window_length: gr.Slider(label="overlap window length"), | |
| seed_param: gr.Number(label="种子,请输入正整数,-1为随机"), | |
| overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"), | |
| info: gr.Textbox(label="提示信息"), | |
| video_output: gr.Video(label="生成结果"), | |
| seed_output: gr.Textbox(label="种子"), | |
| video_path: gr.Video(label="上传视频"), | |
| extractor_button: gr.Button("🎬 开始提取"), | |
| info2: gr.Textbox(label="提示信息"), | |
| audio_output: gr.Audio(label="生成结果"), | |
| audio_path3: gr.Audio(label="上传音频"), | |
| separation_button: gr.Button("🎬 开始分离"), | |
| info3: gr.Textbox(label="提示信息"), | |
| audio_output3: gr.Audio(label="生成结果"), | |
| example_title: gr.Markdown(value="### 选择以下示例案例进行测试:"), | |
| example1_label: gr.Markdown(value="**示例 1**"), | |
| example2_label: gr.Markdown(value="**示例 2**"), | |
| example3_label: gr.Markdown(value="**示例 3**"), | |
| example4_label: gr.Markdown(value="**示例 4**"), | |
| example5_label: gr.Markdown(value="**示例 5**"), | |
| example1_btn: gr.Button("🚀 使用示例 1", variant="secondary"), | |
| example2_btn: gr.Button("🚀 使用示例 2", variant="secondary"), | |
| example3_btn: gr.Button("🚀 使用示例 3", variant="secondary"), | |
| example4_btn: gr.Button("🚀 使用示例 4", variant="secondary"), | |
| example5_btn: gr.Button("🚀 使用示例 5", variant="secondary"), | |
| parameter_settings_title: gr.Accordion(label="参数设置", open=True), | |
| example_cases_title: gr.Accordion(label="示例案例", open=True), | |
| stableavatar_title: gr.TabItem(label="StableAvatar"), | |
| audio_extraction_title: gr.TabItem(label="音频提取"), | |
| vocal_separation_title: gr.TabItem(label="人声分离") | |
| } | |
| BANNER_HTML = """ | |
| <div class="hero"> | |
| <div class="brand"> | |
| <!-- 如有项目 logo,可放到仓库并换成你的地址;没有就删这一行 --> | |
| <!-- <img src="https://raw.githubusercontent.com/Francis-Rings/StableAvatar/main/assets/logo.png" alt="StableAvatar Logo"> --> | |
| <span class="brand-text">STABLEAVATAR</span> | |
| </div> | |
| <div class="titles"> | |
| <h1>StableAvatar</h1> | |
| <div class="badges"> | |
| <a class="badge" href="https://arxiv.org/abs/2508.08248" target="_blank" rel="noopener"> | |
| <img src="https://img.shields.io/badge/arXiv-2508.08248-b31b1b"> | |
| </a> | |
| <a class="badge" href="https://francis-rings.github.io/StableAvatar/" target="_blank" rel="noopener"> | |
| <img src="https://img.shields.io/badge/Webpage-Visit-2266ee"> | |
| </a> | |
| <a class="badge" href="https://github.com/Francis-Rings/StableAvatar" target="_blank" rel="noopener"> | |
| <img src="https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white"> | |
| </a> | |
| <a class="badge" href="https://www.youtube.com/watch?v=6lhvmbzvv3Y" target="_blank" rel="noopener"> | |
| <img src="https://img.shields.io/badge/YouTube-Demo-ff0000?logo=youtube&logoColor=white"> | |
| </a> | |
| </div> | |
| </div> | |
| </div> | |
| <hr class="divider"> | |
| """ | |
| BANNER_CSS = """ | |
| .hero{display:flex;align-items:center;gap:18px;padding:18px;border-radius:14px; | |
| color:inherit;margin-bottom:12px} | |
| .brand-text{font-weight:800;letter-spacing:2px} | |
| .brand img{height:46px} | |
| .titles h1{font-size:28px;margin:0 0 6px 0} | |
| .badges{display:flex;gap:10px;flex-wrap:wrap} | |
| .badge img{height:22px} | |
| .divider{border:0;border-top:1px solid rgba(0,0,0,0.12);margin:6px 0 18px} | |
| """ | |
| # with gr.Blocks(theme=gr.themes.Base()) as demo: | |
| # gr.Markdown(""" | |
| # <div> | |
| # <h2 style="font-size: 30px;text-align: center;">StableAvatar</h2> | |
| # </div> | |
| # """) | |
| with gr.Blocks(theme=gr.themes.Base(), css=BANNER_CSS) as demo: | |
| gr.HTML(BANNER_HTML) | |
| language_radio = gr.Radio( | |
| choices=["English", "中文"], | |
| value="English", | |
| label="Language / 语言" | |
| ) | |
| with gr.Accordion("Model Settings / 模型设置", open=False): | |
| with gr.Row(): | |
| GPU_memory_mode = gr.Dropdown( | |
| label = "显存模式", | |
| info = "Normal占用25G显存,model_cpu_offload占用13G显存", | |
| choices = ["Normal", "model_cpu_offload", "model_cpu_offloadand_qfloat8", "sequential_cpu_offload"], | |
| value = "model_cpu_offload" | |
| ) | |
| teacache_threshold = gr.Slider(label="teacache threshold", info = "推荐参数0.1,0为禁用teacache加速", minimum=0, maximum=1, step=0.01, value=0) | |
| num_skip_start_steps = gr.Slider(label="跳过开始步数", info = "推荐参数5", minimum=0, maximum=100, step=1, value=5) | |
| with gr.Row(): | |
| model_version = gr.Dropdown( | |
| label = "模型版本", | |
| choices = ["square","rec_vec"], | |
| value = "square" | |
| ) | |
| stableavatar_title = gr.TabItem(label="StableAvatar") | |
| with stableavatar_title: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_path = gr.Image(label="上传图片", type="filepath", height=280) | |
| audio_path = gr.Audio(label="上传音频", type="filepath") | |
| prompt = gr.Textbox(label="提示词", value="") | |
| negative_prompt = gr.Textbox(label="负面提示词", value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") | |
| generate_button = gr.Button("🎬 开始生成", variant='primary') | |
| parameter_settings_title = gr.Accordion(label="参数设置", open=True) | |
| with parameter_settings_title: | |
| with gr.Row(): | |
| width = gr.Slider(label="宽度", minimum=256, maximum=2048, step=16, value=512) | |
| height = gr.Slider(label="高度", minimum=256, maximum=2048, step=16, value=512) | |
| with gr.Row(): | |
| exchange_button = gr.Button("🔄 交换宽高") | |
| adjust_button = gr.Button("根据图片调整宽高") | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=6.0) | |
| num_inference_steps = gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50) | |
| with gr.Row(): | |
| text_guide_scale = gr.Slider(label="text guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.0) | |
| audio_guide_scale = gr.Slider(label="audio guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0) | |
| with gr.Row(): | |
| motion_frame = gr.Slider(label="motion frame", minimum=1, maximum=50, step=1, value=25) | |
| fps = gr.Slider(label="帧率", minimum=1, maximum=60, step=1, value=25) | |
| with gr.Row(): | |
| overlap_window_length = gr.Slider(label="overlap window length", minimum=1, maximum=20, step=1, value=10) | |
| seed_param = gr.Number(label="种子,请输入正整数,-1为随机", value=42) | |
| with gr.Row(): | |
| overlapping_weight_scheme = gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform") | |
| with gr.Column(): | |
| info = gr.Textbox(label="提示信息", interactive=False) | |
| video_output = gr.Video(label="生成结果", interactive=False) | |
| seed_output = gr.Textbox(label="种子") | |
| # 示例案例部分移到StableAvatar标签页内部 | |
| example_cases_title = gr.Accordion(label="示例案例", open=True) | |
| with example_cases_title: | |
| example_title = gr.Markdown(value="### 选择以下示例案例进行测试:") | |
| with gr.Row(): | |
| with gr.Column(): | |
| example1_label = gr.Markdown(value="**示例 1**") | |
| example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False) | |
| example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False) | |
| example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm") | |
| with gr.Column(): | |
| example2_label = gr.Markdown(value="**示例 2**") | |
| example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False) | |
| example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False) | |
| example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm") | |
| with gr.Column(): | |
| example3_label = gr.Markdown(value="**示例 3**") | |
| example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False) | |
| example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False) | |
| example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm") | |
| with gr.Column(): | |
| example4_label = gr.Markdown(value="**示例 4**") | |
| example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False) | |
| example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False) | |
| example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm") | |
| with gr.Column(): | |
| example5_label = gr.Markdown(value="**示例 5**") | |
| example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False) | |
| example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False) | |
| example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm") | |
| audio_extraction_title = gr.TabItem(label="音频提取") | |
| with audio_extraction_title: | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_path = gr.Video(label="上传视频", height=500) | |
| extractor_button = gr.Button("🎬 开始提取", variant='primary') | |
| with gr.Column(): | |
| info2 = gr.Textbox(label="提示信息", interactive=False) | |
| audio_output = gr.Audio(label="生成结果", interactive=False) | |
| audio_file = gr.File(label="download audio file") | |
| vocal_separation_title = gr.TabItem(label="人声分离") | |
| with vocal_separation_title: | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_path3 = gr.Audio(label="上传音频", type="filepath") | |
| separation_button = gr.Button("🎬 开始分离", variant='primary') | |
| with gr.Column(): | |
| info3 = gr.Textbox(label="提示信息", interactive=False) | |
| audio_output3 = gr.Audio(label="生成结果", interactive=False) | |
| audio_file3 = gr.File(label="download audio file") | |
| # 示例案例部分移到末尾 | |
| # example_cases_title = gr.Accordion(label="示例案例", open=True) | |
| # with example_cases_title: | |
| # example_title = gr.Markdown(value="### 选择以下示例案例进行测试:") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # example1_label = gr.Markdown(value="**示例 1**") | |
| # example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False) | |
| # example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False) | |
| # example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm") | |
| # with gr.Column(): | |
| # example2_label = gr.Markdown(value="**示例 2**") | |
| # example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False) | |
| # example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False) | |
| # example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm") | |
| # with gr.Column(): | |
| # example3_label = gr.Markdown(value="**示例 3**") | |
| # example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False) | |
| # example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False) | |
| # example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm") | |
| # with gr.Column(): | |
| # example4_label = gr.Markdown(value="**示例 4**") | |
| # example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False) | |
| # example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False) | |
| # example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm") | |
| # with gr.Column(): | |
| # example5_label = gr.Markdown(value="**示例 5**") | |
| # example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False) | |
| # example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False) | |
| # example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm") | |
| all_components = [GPU_memory_mode, teacache_threshold, num_skip_start_steps, model_version, image_path, audio_path, prompt, negative_prompt, generate_button, width, height, exchange_button, adjust_button, guidance_scale, num_inference_steps, text_guide_scale, audio_guide_scale, motion_frame, fps, overlap_window_length, seed_param, overlapping_weight_scheme, info, video_output, seed_output, video_path, extractor_button, info2, audio_output, audio_path3, separation_button, info3, audio_output3, example_title, example1_label, example2_label, example3_label, example4_label, example1_btn, example2_btn, example3_btn, example4_btn, example5_label, example5_btn, parameter_settings_title, example_cases_title, stableavatar_title, audio_extraction_title, vocal_separation_title] | |
| language_radio.change( | |
| fn=update_language, | |
| inputs=[language_radio], | |
| outputs=all_components | |
| ) | |
| # 添加模型版本选择的事件处理 | |
| def on_model_version_change(model_version): | |
| """当模型版本改变时,重新加载对应的模型""" | |
| result = load_transformer_model(model_version) | |
| if result is not None: | |
| return f"✅ 模型已切换到 {model_version} 版本" | |
| else: | |
| return f"❌ 模型切换失败,请检查文件是否存在" | |
| model_version.change( | |
| fn=on_model_version_change, | |
| inputs=[model_version], | |
| outputs=[info] | |
| ) | |
| demo.load(fn=update_language, inputs=[language_radio], outputs=all_components) | |
| # 添加示例案例按钮的事件处理 | |
| def load_example1(): | |
| try: | |
| with open("example_case/case-1/prompt.txt", "r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| except: | |
| prompt_text = "" | |
| return "example_case/case-1/reference.png", "example_case/case-1/audio.wav", prompt_text | |
| def load_example2(): | |
| try: | |
| with open("example_case/case-2/prompt.txt", "r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| except: | |
| prompt_text = "" | |
| return "example_case/case-2/reference.png", "example_case/case-2/audio.wav", prompt_text | |
| def load_example3(): | |
| try: | |
| with open("example_case/case-6/prompt.txt", "r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| except: | |
| prompt_text = "" | |
| return "example_case/case-6/reference.png", "example_case/case-6/audio.wav", prompt_text | |
| def load_example4(): | |
| try: | |
| with open("example_case/case-45/prompt.txt", "r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| except: | |
| prompt_text = "" | |
| return "example_case/case-45/reference.png", "example_case/case-45/audio.wav", prompt_text | |
| def load_example5(): | |
| try: | |
| with open("example_case/case-3/prompt.txt", "r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| except: | |
| prompt_text = "" | |
| return "example_case/case-3/reference.jpg", "example_case/case-3/audio.wav", prompt_text | |
| example1_btn.click(fn=load_example1, outputs=[image_path, audio_path, prompt]) | |
| example2_btn.click(fn=load_example2, outputs=[image_path, audio_path, prompt]) | |
| example3_btn.click(fn=load_example3, outputs=[image_path, audio_path, prompt]) | |
| example4_btn.click(fn=load_example4, outputs=[image_path, audio_path, prompt]) | |
| example5_btn.click(fn=load_example5, outputs=[image_path, audio_path, prompt]) | |
| gr.on( | |
| triggers=[generate_button.click, prompt.submit, negative_prompt.submit], | |
| fn = generate, | |
| inputs = [ | |
| GPU_memory_mode, | |
| teacache_threshold, | |
| num_skip_start_steps, | |
| image_path, | |
| audio_path, | |
| prompt, | |
| negative_prompt, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| text_guide_scale, | |
| audio_guide_scale, | |
| motion_frame, | |
| fps, | |
| overlap_window_length, | |
| seed_param, | |
| overlapping_weight_scheme, | |
| ], | |
| outputs = [video_output, seed_output, info] | |
| ) | |
| exchange_button.click( | |
| fn=exchange_width_height, | |
| inputs=[width, height], | |
| outputs=[width, height, info] | |
| ) | |
| adjust_button.click( | |
| fn=adjust_width_height, | |
| inputs=[image_path], | |
| outputs=[width, height, info] | |
| ) | |
| extractor_button.click( | |
| fn=audio_extractor, | |
| inputs=[video_path], | |
| outputs=[audio_output, info2, audio_file] | |
| ) | |
| separation_button.click( | |
| fn=vocal_separation, | |
| inputs=[audio_path3], | |
| outputs=[audio_output3, info3, audio_file3] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", 7860)), | |
| share=False, | |
| inbrowser=False, | |
| ) | |