Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Safe Installation Script for OmniAvatar Dependencies | |
| Handles problematic packages like flash-attn and xformers carefully | |
| """ | |
| import subprocess | |
| import sys | |
| import os | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def run_pip_command(cmd, description="", optional=False): | |
| """Run a pip command with proper error handling""" | |
| logger.info(f"[PROCESS] {description}") | |
| try: | |
| result = subprocess.run(cmd, check=True, capture_output=True, text=True) | |
| logger.info(f"SUCCESS: {description} - Success") | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| if optional: | |
| logger.warning(f"WARNING: {description} - Failed (optional): {e.stderr}") | |
| return False | |
| else: | |
| logger.error(f"ERROR: {description} - Failed: {e.stderr}") | |
| raise | |
| def main(): | |
| logger.info("[LAUNCH] Starting safe dependency installation for OmniAvatar") | |
| # Step 1: Upgrade pip and essential tools | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", "--upgrade", | |
| "pip", "setuptools", "wheel", "packaging" | |
| ], "Upgrading pip and build tools") | |
| # Step 2: Install PyTorch with CUDA support (if available) | |
| logger.info("📦 Installing PyTorch...") | |
| try: | |
| # Try CUDA version first | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", | |
| "torch", "torchvision", "torchaudio", | |
| "--index-url", "https://download.pytorch.org/whl/cu124" | |
| ], "Installing PyTorch with CUDA support") | |
| except: | |
| logger.warning("WARNING: CUDA PyTorch failed, installing CPU version") | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", | |
| "torch", "torchvision", "torchaudio" | |
| ], "Installing PyTorch CPU version") | |
| # Step 3: Install main requirements | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", "-r", "requirements.txt" | |
| ], "Installing main requirements") | |
| # Step 4: Try to install optional performance packages | |
| logger.info("[TARGET] Installing optional performance packages...") | |
| # Try xformers (memory efficient attention) | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", "xformers" | |
| ], "Installing xformers (memory efficient attention)", optional=True) | |
| # Try flash-attn (advanced attention mechanism) | |
| logger.info("🔥 Attempting flash-attn installation (this may take a while or fail)...") | |
| try: | |
| # First try pre-built wheel | |
| run_pip_command([ | |
| sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation" | |
| ], "Installing flash-attn from wheel", optional=True) | |
| except: | |
| logger.warning("WARNING: flash-attn installation failed - this is common and not critical") | |
| logger.info("TIP: flash-attn can be installed later manually if needed") | |
| # Step 5: Verify installation | |
| logger.info("🔍 Verifying installation...") | |
| try: | |
| import torch | |
| import transformers | |
| import gradio | |
| import fastapi | |
| logger.info(f"SUCCESS: PyTorch: {torch.__version__}") | |
| logger.info(f"SUCCESS: Transformers: {transformers.__version__}") | |
| logger.info(f"SUCCESS: Gradio: {gradio.__version__}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"SUCCESS: CUDA: {torch.version.cuda}") | |
| logger.info(f"SUCCESS: GPU Count: {torch.cuda.device_count()}") | |
| else: | |
| logger.info("ℹ️ CUDA not available - will use CPU") | |
| # Check optional packages | |
| try: | |
| import xformers | |
| logger.info(f"SUCCESS: xformers: {xformers.__version__}") | |
| except ImportError: | |
| logger.info("ℹ️ xformers not available (optional)") | |
| try: | |
| import flash_attn | |
| logger.info("SUCCESS: flash_attn: Available") | |
| except ImportError: | |
| logger.info("ℹ️ flash_attn not available (optional)") | |
| logger.info("🎉 Installation completed successfully!") | |
| logger.info("TIP: You can now run: python app.py") | |
| except ImportError as e: | |
| logger.error(f"ERROR: Installation verification failed: {e}") | |
| return False | |
| return True | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) | |