Spaces:
Paused
Paused
| import gradio as gr | |
| import subprocess | |
| import os | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DEFAULT_CONFIG_PATH = "configs/inference.yaml" | |
| DEFAULT_INPUT_FILE = "examples/infer_samples.txt" | |
| OUTPUT_DIR = Path("demo_out/gradio_outputs") | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| def generate_avatar_video( | |
| reference_image, | |
| audio_file, | |
| text_prompt, | |
| seed=42, | |
| num_steps=50, | |
| guidance_scale=4.5, | |
| audio_scale=None, | |
| overlap_frames=13, | |
| fps=25, | |
| silence_duration=0.3, | |
| resolution="720p", | |
| progress=gr.Progress() | |
| ): | |
| """Generate an avatar video using OmniAvatar | |
| Args: | |
| reference_image: Path to reference avatar image | |
| audio_file: Path to audio file for lip sync | |
| text_prompt: Text description of the video to generate | |
| seed: Random seed for generation | |
| num_steps: Number of inference steps | |
| guidance_scale: Classifier-free guidance scale | |
| audio_scale: Audio guidance scale (uses guidance_scale if None) | |
| overlap_frames: Number of overlapping frames between chunks | |
| fps: Frames per second | |
| silence_duration: Duration of silence to add before/after audio | |
| resolution: Output resolution ("480p" or "720p") | |
| progress: Gradio progress callback | |
| Returns: | |
| str: Path to generated video file | |
| """ | |
| try: | |
| progress(0.1, desc="Preparing inputs") | |
| # Create temporary directory for this generation | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = Path(temp_dir) | |
| # Copy input files to temp directory | |
| temp_image = temp_path / "input_image.jpeg" | |
| temp_audio = temp_path / "input_audio.mp3" | |
| shutil.copy(reference_image, temp_image) | |
| shutil.copy(audio_file, temp_audio) | |
| # Create input file for inference script | |
| input_file = temp_path / "input.txt" | |
| # Format: prompt@@image_path@@audio_path | |
| with open(input_file, 'w') as f: | |
| f.write(f"{text_prompt}@@{temp_image}@@{temp_audio}\n") | |
| progress(0.2, desc="Configuring generation parameters") | |
| # Determine max_hw based on resolution | |
| max_hw = 720 if resolution == "480p" else 1280 | |
| # Build command to run inference script | |
| cmd = [ | |
| "torchrun", | |
| "--nproc_per_node=1", | |
| "scripts/inference.py", | |
| "--config", DEFAULT_CONFIG_PATH, | |
| "--input_file", str(input_file), | |
| "-hp", f"seed={seed},num_steps={num_steps},guidance_scale={guidance_scale}," | |
| f"overlap_frame={overlap_frames},fps={fps},silence_duration_s={silence_duration}," | |
| f"max_hw={max_hw},use_audio=True,i2v=True" | |
| ] | |
| # Add audio scale if specified | |
| if audio_scale is not None: | |
| cmd[-1] += f",audio_scale={audio_scale}" | |
| progress(0.3, desc="Running OmniAvatar generation") | |
| logger.info(f"Running command: {' '.join(cmd)}") | |
| # Run the inference script | |
| env = os.environ.copy() | |
| env['CUDA_VISIBLE_DEVICES'] = '0' # Use first GPU | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| env=env | |
| ) | |
| # Monitor progress (simplified - in reality you'd parse the output) | |
| stdout_lines = [] | |
| stderr_lines = [] | |
| while True: | |
| output = process.stdout.readline() | |
| if output: | |
| stdout_lines.append(output.strip()) | |
| logger.info(output.strip()) | |
| # Update progress based on output | |
| if "Starting video generation" in output: | |
| progress(0.5, desc="Generating video frames") | |
| elif "[1/" in output: # First chunk | |
| progress(0.6, desc="Processing video chunks") | |
| elif "Saving video" in output: | |
| progress(0.9, desc="Finalizing video") | |
| if process.poll() is not None: | |
| break | |
| # Get any remaining output | |
| remaining_stdout, remaining_stderr = process.communicate() | |
| if remaining_stdout: | |
| stdout_lines.extend(remaining_stdout.strip().split('\n')) | |
| if remaining_stderr: | |
| stderr_lines.extend(remaining_stderr.strip().split('\n')) | |
| if process.returncode != 0: | |
| error_msg = '\n'.join(stderr_lines) | |
| logger.error(f"Inference failed with return code {process.returncode}") | |
| logger.error(f"Error output: {error_msg}") | |
| raise gr.Error(f"Video generation failed: {error_msg}") | |
| progress(0.95, desc="Retrieving generated video") | |
| # Find the generated video file | |
| # The inference script saves to demo_out/{exp_name}/res_{input_file_name}_... | |
| # We need to find the most recent video file | |
| generated_videos = list(Path("demo_out").rglob("result_000.mp4")) | |
| if not generated_videos: | |
| raise gr.Error("No video file was generated") | |
| # Get the most recent video | |
| latest_video = max(generated_videos, key=lambda p: p.stat().st_mtime) | |
| # Copy to output directory with unique name | |
| output_filename = f"avatar_video_{os.getpid()}_{torch.randint(1000, 9999, (1,)).item()}.mp4" | |
| output_path = OUTPUT_DIR / output_filename | |
| shutil.copy(latest_video, output_path) | |
| progress(1.0, desc="Generation complete") | |
| logger.info(f"Video saved to: {output_path}") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Error generating video: {str(e)}") | |
| raise gr.Error(f"Error generating video: {str(e)}") | |
| # Create the Gradio interface | |
| with gr.Blocks(title="OmniAvatar - Lipsynced Avatar Video Generation") as app: | |
| gr.Markdown(""" | |
| # π OmniAvatar - Lipsynced Avatar Video Generation | |
| Generate videos with lipsynced avatars using a reference image and audio file. | |
| Based on Wan2.1 with OmniAvatar enhancements for audio-driven avatar animation. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input components | |
| reference_image = gr.Image( | |
| label="Reference Avatar Image", | |
| type="filepath", | |
| elem_id="reference_image" | |
| ) | |
| audio_file = gr.Audio( | |
| label="Speech Audio File", | |
| type="filepath", | |
| elem_id="audio_file" | |
| ) | |
| text_prompt = gr.Textbox( | |
| label="Video Description", | |
| placeholder="Describe the video scene and actions...", | |
| lines=3, | |
| value="A person speaking naturally with subtle facial expressions" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| value=42 | |
| ) | |
| resolution = gr.Radio( | |
| label="Resolution", | |
| choices=["480p", "720p"], | |
| value="720p" | |
| ) | |
| with gr.Row(): | |
| num_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=10, | |
| maximum=100, | |
| step=5, | |
| value=50 | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=4.5 | |
| ) | |
| with gr.Row(): | |
| audio_scale = gr.Slider( | |
| label="Audio Scale (leave 0 to use guidance scale)", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=0.0 | |
| ) | |
| overlap_frames = gr.Slider( | |
| label="Overlap Frames", | |
| minimum=1, | |
| maximum=25, | |
| step=4, | |
| value=13, | |
| info="Must be 1 + 4*n" | |
| ) | |
| with gr.Row(): | |
| fps = gr.Slider( | |
| label="FPS", | |
| minimum=10, | |
| maximum=30, | |
| step=1, | |
| value=25 | |
| ) | |
| silence_duration = gr.Slider( | |
| label="Silence Duration (s)", | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=0.3 | |
| ) | |
| generate_btn = gr.Button( | |
| "π¬ Generate Avatar Video", | |
| variant="primary" | |
| ) | |
| with gr.Column(scale=1): | |
| # Output component | |
| output_video = gr.Video( | |
| label="Generated Avatar Video", | |
| elem_id="output_video" | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/images/0000.jpeg", | |
| "examples/audios/0000.MP3", | |
| "A professional woman giving a presentation with confident gestures" | |
| ], | |
| ], | |
| inputs=[reference_image, audio_file, text_prompt], | |
| label="Example Inputs" | |
| ) | |
| # Connect the generate button | |
| generate_btn.click( | |
| fn=generate_avatar_video, | |
| inputs=[ | |
| reference_image, | |
| audio_file, | |
| text_prompt, | |
| seed, | |
| num_steps, | |
| guidance_scale, | |
| audio_scale, | |
| overlap_frames, | |
| fps, | |
| silence_duration, | |
| resolution | |
| ], | |
| outputs=output_video | |
| ) | |
| gr.Markdown(""" | |
| ## π Notes | |
| - The reference image should be a clear frontal view of the person | |
| - Audio should be clear speech without background music | |
| - Generation may take several minutes depending on video length | |
| - For best results, use high-quality input images and audio | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(share=True) |