Spaces:
Runtime error
Runtime error
| from huggingface_hub import snapshot_download | |
| # Download All Required Models using `snapshot_download` | |
| # Download Wan2.1-I2V-14B-480P model | |
| wan_model_path = snapshot_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| local_dir="./weights/Wan2.1-I2V-14B-480P", | |
| #local_dir_use_symlinks=False | |
| ) | |
| # Download Chinese wav2vec2 model | |
| wav2vec_path = snapshot_download( | |
| repo_id="TencentGameMate/chinese-wav2vec2-base", | |
| local_dir="./weights/chinese-wav2vec2-base", | |
| #local_dir_use_symlinks=False | |
| ) | |
| # Download MeiGen MultiTalk weights | |
| multitalk_path = snapshot_download( | |
| repo_id="MeiGen-AI/MeiGen-MultiTalk", | |
| local_dir="./weights/MeiGen-MultiTalk", | |
| #local_dir_use_symlinks=False | |
| ) | |
| import os | |
| import shutil | |
| # Define paths | |
| base_model_dir = "./weights/Wan2.1-I2V-14B-480P" | |
| multitalk_dir = "./weights/MeiGen-MultiTalk" | |
| # File to rename | |
| original_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json") | |
| backup_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json_old") | |
| # Rename the original index file | |
| if os.path.exists(original_index): | |
| os.rename(original_index, backup_index) | |
| print("Renamed original index file to .json_old") | |
| # Copy updated index file from MultiTalk | |
| shutil.copy2( | |
| os.path.join(multitalk_dir, "diffusion_pytorch_model.safetensors.index.json"), | |
| base_model_dir | |
| ) | |
| # Copy MultiTalk model weights | |
| shutil.copy2( | |
| os.path.join(multitalk_dir, "multitalk.safetensors"), | |
| base_model_dir | |
| ) | |
| print("Copied MultiTalk files into base model directory.") | |
| import torch | |
| # Check if CUDA-compatible GPU is available | |
| if torch.cuda.is_available(): | |
| # Get current GPU name | |
| gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
| print(f"Current GPU: {gpu_name}") | |
| # Enforce GPU requirement | |
| if "A100" not in gpu_name and "L4" not in gpu_name: | |
| raise RuntimeError(f"This notebook requires an A100 or L4 GPU. Found: {gpu_name}") | |
| elif "L4" in gpu_name: | |
| print("Warning: L4 is supported, but A100 is recommended for faster inference.") | |
| else: | |
| raise RuntimeError("No CUDA-compatible GPU found. An A100 or L4 GPU is required.") | |
| GPU_TO_VRAM_PARAMS = { | |
| "NVIDIA A100": 11000000000, | |
| "NVIDIA A100-SXM4-40GB": 11000000000, | |
| "NVIDIA A100-SXM4-80GB": 22000000000, | |
| "NVIDIA L4": 5000000000 | |
| } | |
| USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name] | |
| print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit") | |
| import subprocess | |
| import json | |
| import tempfile | |
| #import os | |
| def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str: | |
| """ | |
| Create a temporary JSON file with the user-provided prompt, image, and audio paths. | |
| Returns the path to the temporary JSON file. | |
| """ | |
| # Structure based on your original JSON format | |
| data = { | |
| "prompt": prompt, | |
| "cond_image": cond_image_path, | |
| "cond_audio": { | |
| "person1": cond_audio_path | |
| } | |
| } | |
| # Create a temp file | |
| temp_json = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w', encoding='utf-8') | |
| json.dump(data, temp_json, indent=4) | |
| temp_json_path = temp_json.name | |
| temp_json.close() | |
| print(f"Temporary input JSON saved to: {temp_json_path}") | |
| return temp_json_path | |
| def infer(prompt, cond_image_path, cond_audio_path): | |
| # Example usage (from user input) | |
| prompt = "A woman sings passionately in a dimly lit studio." | |
| cond_image_path = "examples/single/single1.png" # Assume uploaded via Gradio | |
| cond_audio_path = "examples/single/1.wav" # Assume uploaded via Gradio | |
| input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path) | |
| cmd = [ | |
| "python3", "generate_multitalk.py", | |
| "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P", | |
| "--wav2vec_dir", "weights/chinese-wav2vec2-base", | |
| "--input_json", "./examples/single_example_1.json", | |
| "--sample_steps", "20", | |
| "--num_persistent_param_in_dit", str(USED_VRAM_PARAMS), | |
| "--mode", "streaming", | |
| "--use_teacache", | |
| "--save_file", "multi_long_mediumvram_exp" | |
| ] | |
| subprocess.run(cmd, check=True) | |
| return "multi_long_mediumvra_exp.mp4" | |
| import gradio as gr | |
| with gr.Blocks(title="MultiTalk Inference") as demo: | |
| gr.Markdown("## 🎤 MultiTalk Inference Demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="Describe the scene...", | |
| lines=4 | |
| ) | |
| image_input = gr.Image( | |
| type="filepath", | |
| label="Conditioning Image" | |
| ) | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="Conditioning Audio (.wav)" | |
| ) | |
| submit_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[prompt_input, image_input, audio_input], | |
| outputs=output_video | |
| ) | |
| demo.launch() | |