Spaces:
Running
Running
| ο»Ώ#!/usr/bin/env python3 | |
| """ | |
| OmniAvatar-14B Inference Script | |
| Enhanced implementation for avatar video generation with adaptive body animation | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import yaml | |
| import torch | |
| import logging | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, Any | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def load_config(config_path: str) -> Dict[str, Any]: | |
| """Load configuration from YAML file""" | |
| try: | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| logger.info(f"β Configuration loaded from {config_path}") | |
| return config | |
| except Exception as e: | |
| logger.error(f"β Failed to load config: {e}") | |
| raise | |
| def parse_input_file(input_file: str) -> list: | |
| """ | |
| Parse the input file with format: | |
| [prompt]@@[img_path]@@[audio_path] | |
| """ | |
| try: | |
| with open(input_file, 'r') as f: | |
| lines = f.readlines() | |
| samples = [] | |
| for line_num, line in enumerate(lines, 1): | |
| line = line.strip() | |
| if not line or line.startswith('#'): | |
| continue | |
| parts = line.split('@@') | |
| if len(parts) != 3: | |
| logger.warning(f"β οΈ Line {line_num} has invalid format, skipping: {line}") | |
| continue | |
| prompt, img_path, audio_path = parts | |
| # Validate paths | |
| if img_path and not os.path.exists(img_path): | |
| logger.warning(f"β οΈ Image not found: {img_path}") | |
| img_path = None | |
| if not os.path.exists(audio_path): | |
| logger.error(f"β Audio file not found: {audio_path}") | |
| continue | |
| samples.append({ | |
| 'prompt': prompt, | |
| 'image_path': img_path if img_path else None, | |
| 'audio_path': audio_path, | |
| 'line_number': line_num | |
| }) | |
| logger.info(f"π Parsed {len(samples)} valid samples from {input_file}") | |
| return samples | |
| except Exception as e: | |
| logger.error(f"β Failed to parse input file: {e}") | |
| raise | |
| def validate_models(config: Dict[str, Any]) -> bool: | |
| """Validate that all required models are available""" | |
| model_paths = [ | |
| config['model']['base_model_path'], | |
| config['model']['omni_model_path'], | |
| config['model']['wav2vec_path'] | |
| ] | |
| missing_models = [] | |
| for path in model_paths: | |
| if not os.path.exists(path): | |
| missing_models.append(path) | |
| elif not any(Path(path).iterdir()): | |
| missing_models.append(f"{path} (empty directory)") | |
| if missing_models: | |
| logger.error("β Missing required models:") | |
| for model in missing_models: | |
| logger.error(f" - {model}") | |
| logger.info("π‘ Run 'python setup_omniavatar.py' to download models") | |
| return False | |
| logger.info("β All required models found") | |
| return True | |
| def setup_output_directory(output_dir: str) -> str: | |
| """Setup output directory and return path""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create unique subdirectory for this run | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| run_dir = os.path.join(output_dir, f"run_{timestamp}") | |
| os.makedirs(run_dir, exist_ok=True) | |
| logger.info(f"π Output directory: {run_dir}") | |
| return run_dir | |
| def mock_inference(sample: Dict[str, Any], config: Dict[str, Any], | |
| output_dir: str, args: argparse.Namespace) -> str: | |
| """ | |
| Mock inference implementation | |
| In a real implementation, this would: | |
| 1. Load the OmniAvatar models | |
| 2. Process the audio with wav2vec2 | |
| 3. Generate video frames using the text-to-video model | |
| 4. Apply audio-driven animation | |
| 5. Render final video | |
| """ | |
| logger.info(f"π¬ Processing sample {sample['line_number']}") | |
| logger.info(f"π Prompt: {sample['prompt']}") | |
| logger.info(f"π΅ Audio: {sample['audio_path']}") | |
| if sample['image_path']: | |
| logger.info(f"πΌοΈ Image: {sample['image_path']}") | |
| # Configuration | |
| logger.info("βοΈ Configuration:") | |
| logger.info(f" - Guidance Scale: {args.guidance_scale}") | |
| logger.info(f" - Audio Scale: {args.audio_scale}") | |
| logger.info(f" - Steps: {args.num_steps}") | |
| logger.info(f" - Max Tokens: {config.get('inference', {}).get('max_tokens', 30000)}") | |
| if args.tea_cache_l1_thresh: | |
| logger.info(f" - TeaCache Threshold: {args.tea_cache_l1_thresh}") | |
| # Simulate processing time | |
| logger.info("π Generating avatar video...") | |
| time.sleep(2) # Mock processing | |
| # Create mock output file | |
| output_filename = f"avatar_sample_{sample['line_number']:03d}.mp4" | |
| output_path = os.path.join(output_dir, output_filename) | |
| # Create a simple text file as placeholder for the video | |
| with open(output_path.replace('.mp4', '_info.txt'), 'w') as f: | |
| f.write(f"OmniAvatar-14B Output Information\n") | |
| f.write(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| f.write(f"Prompt: {sample['prompt']}\n") | |
| f.write(f"Audio: {sample['audio_path']}\n") | |
| f.write(f"Image: {sample['image_path'] or 'None'}\n") | |
| f.write(f"Configuration: {args.__dict__}\n") | |
| logger.info(f"β Mock output created: {output_path}") | |
| return output_path | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="OmniAvatar-14B Inference - Avatar Video Generation with Adaptive Body Animation" | |
| ) | |
| parser.add_argument("--config", type=str, required=True, | |
| help="Configuration file path") | |
| parser.add_argument("--input_file", type=str, required=True, | |
| help="Input samples file") | |
| parser.add_argument("--guidance_scale", type=float, default=4.5, | |
| help="Guidance scale (4-6 recommended)") | |
| parser.add_argument("--audio_scale", type=float, default=3.0, | |
| help="Audio scale for lip-sync consistency") | |
| parser.add_argument("--num_steps", type=int, default=25, | |
| help="Number of inference steps (20-50 recommended)") | |
| parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, | |
| help="TeaCache L1 threshold (0.05-0.15 recommended)") | |
| parser.add_argument("--sp_size", type=int, default=1, | |
| help="Sequence parallel size (number of GPUs)") | |
| parser.add_argument("--hp", type=str, default="", | |
| help="Additional hyperparameters (comma-separated)") | |
| args = parser.parse_args() | |
| logger.info("π OmniAvatar-14B Inference Starting") | |
| logger.info(f"π Config: {args.config}") | |
| logger.info(f"π Input: {args.input_file}") | |
| logger.info(f"π― Parameters: guidance_scale={args.guidance_scale}, audio_scale={args.audio_scale}, steps={args.num_steps}") | |
| try: | |
| # Load configuration | |
| config = load_config(args.config) | |
| # Validate models | |
| if not validate_models(config): | |
| return 1 | |
| # Parse input samples | |
| samples = parse_input_file(args.input_file) | |
| if not samples: | |
| logger.error("β No valid samples found in input file") | |
| return 1 | |
| # Setup output directory | |
| output_dir = setup_output_directory(config.get('inference', {}).get('output_dir', './outputs')) | |
| # Process each sample | |
| total_samples = len(samples) | |
| successful_outputs = [] | |
| for i, sample in enumerate(samples, 1): | |
| logger.info(f"π Processing sample {i}/{total_samples}") | |
| try: | |
| output_path = mock_inference(sample, config, output_dir, args) | |
| successful_outputs.append(output_path) | |
| except Exception as e: | |
| logger.error(f"β Failed to process sample {sample['line_number']}: {e}") | |
| continue | |
| # Summary | |
| logger.info("π Inference completed!") | |
| logger.info(f"β Successfully processed: {len(successful_outputs)}/{total_samples} samples") | |
| logger.info(f"π Output directory: {output_dir}") | |
| if successful_outputs: | |
| logger.info("πΉ Generated videos:") | |
| for output in successful_outputs: | |
| logger.info(f" - {output}") | |
| # Implementation note | |
| logger.info("π‘ NOTE: This is a mock implementation.") | |
| logger.info("π For full OmniAvatar functionality, integrate with:") | |
| logger.info(" https://github.com/Omni-Avatar/OmniAvatar") | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"β Inference failed: {e}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |