Spaces:
Running
Running
| ο»Ώimport argparse | |
| import yaml | |
| import torch | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def get_device(config_device): | |
| """Auto-detect available device""" | |
| if config_device == "auto": | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| logger.info("CUDA available, using GPU") | |
| else: | |
| device = "cpu" | |
| logger.info("CUDA not available, using CPU") | |
| else: | |
| device = config_device | |
| logger.info(f"Using configured device: {device}") | |
| return device | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="OmniAvatar-14B Inference") | |
| parser.add_argument("--config", type=str, required=True, help="Path to config file") | |
| parser.add_argument("--input_file", type=str, required=True, help="Path to input samples file") | |
| parser.add_argument("--guidance_scale", type=float, default=5.0, help="Guidance scale") | |
| parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio guidance scale") | |
| parser.add_argument("--num_steps", type=int, default=30, help="Number of inference steps") | |
| parser.add_argument("--sp_size", type=int, default=1, help="Multi-GPU size") | |
| parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache threshold") | |
| return parser.parse_args() | |
| def load_config(config_path): | |
| with open(config_path, 'r') as f: | |
| return yaml.safe_load(f) | |
| def process_input_file(input_file): | |
| """Parse input file with format: prompt@@image_path@@audio_path""" | |
| samples = [] | |
| with open(input_file, 'r') as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| parts = line.split('@@') | |
| if len(parts) >= 3: | |
| prompt = parts[0] | |
| image_path = parts[1] if parts[1] else None | |
| audio_path = parts[2] | |
| samples.append({ | |
| 'prompt': prompt, | |
| 'image_path': image_path, | |
| 'audio_path': audio_path | |
| }) | |
| return samples | |
| def create_placeholder_video(output_path, duration=5.0, fps=24): | |
| """Create a simple placeholder video""" | |
| import numpy as np | |
| import cv2 | |
| logger.info(f"Creating placeholder video: {output_path}") | |
| # Video properties | |
| width, height = 480, 480 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # Create video writer | |
| out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height)) | |
| # Generate frames | |
| total_frames = int(duration * fps) | |
| for frame_idx in range(total_frames): | |
| # Create a simple animated frame | |
| frame = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Add some animation - moving circle | |
| center_x = int(width/2 + 100 * np.sin(2 * np.pi * frame_idx / 60)) | |
| center_y = int(height/2 + 50 * np.cos(2 * np.pi * frame_idx / 60)) | |
| # Draw circle | |
| cv2.circle(frame, (center_x, center_y), 30, (0, 255, 0), -1) | |
| # Add text | |
| text = f"Avatar Placeholder Frame {frame_idx + 1}" | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| cv2.putText(frame, text, (10, 30), font, 0.5, (255, 255, 255), 1) | |
| out.write(frame) | |
| out.release() | |
| logger.info(f"β Placeholder video created: {output_path}") | |
| def main(): | |
| args = parse_args() | |
| logger.info("π Starting OmniAvatar-14B Inference") | |
| logger.info(f"Arguments: {args}") | |
| # Load configuration | |
| config = load_config(args.config) | |
| # Auto-detect device | |
| device = get_device(config["hardware"]["device"]) | |
| config["hardware"]["device"] = device | |
| # Process input samples | |
| samples = process_input_file(args.input_file) | |
| logger.info(f"Processing {len(samples)} samples") | |
| if not samples: | |
| logger.error("No valid samples found in input file") | |
| return | |
| # Create output directory | |
| output_dir = Path(config['output']['output_dir']) | |
| output_dir.mkdir(exist_ok=True) | |
| # Process each sample | |
| for i, sample in enumerate(samples): | |
| logger.info(f"Processing sample {i+1}/{len(samples)}: {sample['prompt'][:50]}...") | |
| # For now, create a placeholder video | |
| output_filename = f"avatar_output_{i:03d}.mp4" | |
| output_path = output_dir / output_filename | |
| try: | |
| # Create placeholder video (in the future, this would be actual avatar generation) | |
| create_placeholder_video(output_path, duration=5.0, fps=24) | |
| logger.info(f"β Sample {i+1} completed: {output_path}") | |
| except Exception as e: | |
| logger.error(f"β Error processing sample {i+1}: {e}") | |
| logger.info("π Inference completed!") | |
| logger.info("π Note: Currently generating placeholder videos.") | |
| logger.info("π Future updates will include actual OmniAvatar model inference.") | |
| if __name__ == "__main__": | |
| main() | |