Spaces:
Running
Running
| ο»Ώ""" | |
| Enhanced OmniAvatar-14B Integration Module | |
| Provides complete avatar video generation with adaptive body animation | |
| """ | |
| import os | |
| import torch | |
| import subprocess | |
| import tempfile | |
| import yaml | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Dict, Any | |
| import json | |
| logger = logging.getLogger(__name__) | |
| class OmniAvatarEngine: | |
| """ | |
| Complete OmniAvatar-14B integration for avatar video generation | |
| with adaptive body animation using audio-driven synthesis. | |
| """ | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.models_loaded = False | |
| self.model_paths = { | |
| "base_model": "./pretrained_models/Wan2.1-T2V-14B", | |
| "omni_model": "./pretrained_models/OmniAvatar-14B", | |
| "wav2vec": "./pretrained_models/wav2vec2-base-960h" | |
| } | |
| # Default configuration from OmniAvatar documentation | |
| self.default_config = { | |
| "guidance_scale": 4.5, | |
| "audio_scale": 3.0, | |
| "num_steps": 25, | |
| "max_tokens": 30000, | |
| "overlap_frame": 13, | |
| "tea_cache_l1_thresh": 0.14, | |
| "use_fsdp": False, | |
| "sp_size": 1, | |
| "resolution": "480p" | |
| } | |
| logger.info(f"OmniAvatar Engine initialized on {self.device}") | |
| def check_models_available(self) -> Dict[str, bool]: | |
| """ | |
| Check which OmniAvatar models are available | |
| Returns dictionary with model availability status | |
| """ | |
| status = {} | |
| for name, path in self.model_paths.items(): | |
| model_path = Path(path) | |
| if model_path.exists() and any(model_path.iterdir()): | |
| status[name] = True | |
| logger.info(f"β {name} model found at {path}") | |
| else: | |
| status[name] = False | |
| logger.warning(f"β {name} model not found at {path}") | |
| self.models_loaded = all(status.values()) | |
| if self.models_loaded: | |
| logger.info("π All OmniAvatar-14B models available!") | |
| else: | |
| missing = [name for name, available in status.items() if not available] | |
| logger.warning(f"β οΈ Missing models: {', '.join(missing)}") | |
| return status | |
| def load_models(self) -> bool: | |
| """ | |
| Load the OmniAvatar models into memory | |
| """ | |
| try: | |
| model_status = self.check_models_available() | |
| if not all(model_status.values()): | |
| logger.error("Cannot load models - some models are missing") | |
| return False | |
| # TODO: Implement actual model loading | |
| # This would require the full OmniAvatar implementation | |
| logger.info("π Model loading logic would be implemented here") | |
| logger.info("π‘ For full implementation, integrate with official OmniAvatar codebase") | |
| self.models_loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| return False | |
| def create_inference_input(self, prompt: str, image_path: Optional[str], | |
| audio_path: str) -> str: | |
| """ | |
| Create the input file format required by OmniAvatar inference | |
| Format: [prompt]@@[img_path]@@[audio_path] | |
| """ | |
| if image_path: | |
| input_line = f"{prompt}@@{image_path}@@{audio_path}" | |
| else: | |
| input_line = f"{prompt}@@@@{audio_path}" | |
| # Create temporary input file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
| f.write(input_line) | |
| temp_input_file = f.name | |
| logger.info(f"Created inference input: {input_line}") | |
| return temp_input_file | |
| def generate_video(self, prompt: str, audio_path: str, | |
| image_path: Optional[str] = None, | |
| **config_overrides) -> Tuple[str, float]: | |
| """ | |
| Generate avatar video using OmniAvatar-14B | |
| Args: | |
| prompt: Text description of character and behavior | |
| audio_path: Path to audio file for lip-sync | |
| image_path: Optional reference image path | |
| **config_overrides: Override default configuration | |
| Returns: | |
| Tuple of (output_video_path, processing_time) | |
| """ | |
| import time | |
| start_time = time.time() | |
| if not self.models_loaded: | |
| if not self.check_models_available() or not all(self.check_models_available().values()): | |
| raise RuntimeError("OmniAvatar models not available. Run setup_omniavatar.py first.") | |
| try: | |
| # Merge configuration with overrides | |
| config = {**self.default_config, **config_overrides} | |
| # Create inference input file | |
| temp_input_file = self.create_inference_input(prompt, image_path, audio_path) | |
| # Prepare inference command based on OmniAvatar documentation | |
| cmd = [ | |
| "python", "-m", "torch.distributed.run", | |
| "--standalone", f"--nproc_per_node={config['sp_size']}", | |
| "scripts/inference.py", | |
| "--config", "configs/inference.yaml", | |
| "--input_file", temp_input_file | |
| ] | |
| # Add hyperparameters | |
| hp_params = [ | |
| f"sp_size={config['sp_size']}", | |
| f"max_tokens={config['max_tokens']}", | |
| f"guidance_scale={config['guidance_scale']}", | |
| f"overlap_frame={config['overlap_frame']}", | |
| f"num_steps={config['num_steps']}" | |
| ] | |
| if config.get('use_fsdp'): | |
| hp_params.append("use_fsdp=True") | |
| if config.get('tea_cache_l1_thresh'): | |
| hp_params.append(f"tea_cache_l1_thresh={config['tea_cache_l1_thresh']}") | |
| if config.get('audio_scale') != self.default_config['audio_scale']: | |
| hp_params.append(f"audio_scale={config['audio_scale']}") | |
| cmd.extend(["--hp", ",".join(hp_params)]) | |
| logger.info(f"π Running OmniAvatar inference:") | |
| logger.info(f"Command: {' '.join(cmd)}") | |
| # Run inference | |
| result = subprocess.run(cmd, capture_output=True, text=True, cwd=Path.cwd()) | |
| # Clean up temporary files | |
| if os.path.exists(temp_input_file): | |
| os.unlink(temp_input_file) | |
| if result.returncode != 0: | |
| logger.error(f"OmniAvatar inference failed: {result.stderr}") | |
| raise RuntimeError(f"Inference failed: {result.stderr}") | |
| # Find output video file | |
| output_dir = Path("./outputs") | |
| if output_dir.exists(): | |
| video_files = list(output_dir.glob("*.mp4")) + list(output_dir.glob("*.avi")) | |
| if video_files: | |
| # Return the most recent video file | |
| latest_video = max(video_files, key=lambda x: x.stat().st_mtime) | |
| processing_time = time.time() - start_time | |
| logger.info(f"β Video generated successfully: {latest_video}") | |
| logger.info(f"β±οΈ Processing time: {processing_time:.1f}s") | |
| return str(latest_video), processing_time | |
| raise RuntimeError("No output video generated") | |
| except Exception as e: | |
| # Clean up temporary files in case of error | |
| if 'temp_input_file' in locals() and os.path.exists(temp_input_file): | |
| os.unlink(temp_input_file) | |
| logger.error(f"OmniAvatar generation error: {e}") | |
| raise | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get detailed information about the OmniAvatar setup""" | |
| model_status = self.check_models_available() | |
| info = { | |
| "engine": "OmniAvatar-14B", | |
| "version": "1.0.0", | |
| "device": self.device, | |
| "cuda_available": torch.cuda.is_available(), | |
| "models_loaded": self.models_loaded, | |
| "model_status": model_status, | |
| "all_models_available": all(model_status.values()), | |
| "supported_features": [ | |
| "Audio-driven avatar generation", | |
| "Adaptive body animation", | |
| "Lip-sync synthesis", | |
| "Reference image support", | |
| "Text prompt control", | |
| "480p video output", | |
| "TeaCache acceleration", | |
| "Multi-GPU support" | |
| ], | |
| "model_requirements": { | |
| "Wan2.1-T2V-14B": "~28GB - Base text-to-video model", | |
| "OmniAvatar-14B": "~2GB - LoRA and audio conditioning weights", | |
| "wav2vec2-base-960h": "~360MB - Audio encoder" | |
| }, | |
| "configuration": self.default_config | |
| } | |
| return info | |
| def optimize_for_hardware(self) -> Dict[str, Any]: | |
| """ | |
| Suggest optimal configuration based on available hardware | |
| Based on OmniAvatar documentation performance table | |
| """ | |
| if not torch.cuda.is_available(): | |
| return { | |
| "recommendation": "CPU mode - very slow, not recommended", | |
| "suggested_config": { | |
| "num_steps": 10, # Reduce steps for CPU | |
| "max_tokens": 10000, # Reduce tokens | |
| "use_fsdp": False | |
| }, | |
| "expected_speed": "Very slow (minutes per video)" | |
| } | |
| gpu_count = torch.cuda.device_count() | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # GB | |
| recommendations = { | |
| 1: { # Single GPU | |
| "high_memory": { # >32GB VRAM | |
| "config": { | |
| "sp_size": 1, | |
| "use_fsdp": False, | |
| "num_persistent_param_in_dit": None, | |
| "max_tokens": 60000 | |
| }, | |
| "expected_speed": "~16s/iteration", | |
| "required_vram": "36GB" | |
| }, | |
| "medium_memory": { # 16-32GB VRAM | |
| "config": { | |
| "sp_size": 1, | |
| "use_fsdp": False, | |
| "num_persistent_param_in_dit": 7000000000, | |
| "max_tokens": 30000 | |
| }, | |
| "expected_speed": "~19s/iteration", | |
| "required_vram": "21GB" | |
| }, | |
| "low_memory": { # 8-16GB VRAM | |
| "config": { | |
| "sp_size": 1, | |
| "use_fsdp": False, | |
| "num_persistent_param_in_dit": 0, | |
| "max_tokens": 15000, | |
| "num_steps": 20 | |
| }, | |
| "expected_speed": "~22s/iteration", | |
| "required_vram": "8GB" | |
| } | |
| }, | |
| 4: { # 4 GPUs | |
| "config": { | |
| "sp_size": 4, | |
| "use_fsdp": True, | |
| "max_tokens": 60000 | |
| }, | |
| "expected_speed": "~4.8s/iteration", | |
| "required_vram": "14.3GB per GPU" | |
| } | |
| } | |
| # Select recommendation based on hardware | |
| if gpu_count >= 4: | |
| return { | |
| "recommendation": "Multi-GPU setup - optimal performance", | |
| "hardware": f"{gpu_count} GPUs, {gpu_memory:.1f}GB VRAM each", | |
| **recommendations[4] | |
| } | |
| elif gpu_memory > 32: | |
| return { | |
| "recommendation": "High-memory single GPU - excellent performance", | |
| "hardware": f"1 GPU, {gpu_memory:.1f}GB VRAM", | |
| **recommendations[1]["high_memory"] | |
| } | |
| elif gpu_memory > 16: | |
| return { | |
| "recommendation": "Medium-memory single GPU - good performance", | |
| "hardware": f"1 GPU, {gpu_memory:.1f}GB VRAM", | |
| **recommendations[1]["medium_memory"] | |
| } | |
| else: | |
| return { | |
| "recommendation": "Low-memory single GPU - basic performance", | |
| "hardware": f"1 GPU, {gpu_memory:.1f}GB VRAM", | |
| **recommendations[1]["low_memory"] | |
| } | |
| # Global instance | |
| omni_engine = OmniAvatarEngine() | |