File size: 4,465 Bytes
dcf0937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05d082e
dcf0937
 
05d082e
dcf0937
 
 
05d082e
dcf0937
 
05d082e
dcf0937
 
 
05d082e
dcf0937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05d082e
dcf0937
 
 
 
 
 
 
 
 
 
 
05d082e
dcf0937
 
 
 
 
 
 
 
 
 
 
 
 
 
05d082e
 
dcf0937
 
 
 
 
 
 
 
 
05d082e
 
 
dcf0937
 
05d082e
 
dcf0937
 
 
 
 
 
05d082e
dcf0937
 
 
 
 
05d082e
dcf0937
 
 
 
05d082e
dcf0937
 
05d082e
dcf0937
 
 
 
 
 
 
05d082e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Safe Installation Script for OmniAvatar Dependencies
Handles problematic packages like flash-attn and xformers carefully
"""

import subprocess
import sys
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def run_pip_command(cmd, description="", optional=False):
    """Run a pip command with proper error handling"""
    logger.info(f"[PROCESS] {description}")
    try:
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        logger.info(f"SUCCESS: {description} - Success")
        return True
    except subprocess.CalledProcessError as e:
        if optional:
            logger.warning(f"WARNING: {description} - Failed (optional): {e.stderr}")
            return False
        else:
            logger.error(f"ERROR: {description} - Failed: {e.stderr}")
            raise

def main():
    logger.info("[LAUNCH] Starting safe dependency installation for OmniAvatar")
    
    # Step 1: Upgrade pip and essential tools
    run_pip_command([
        sys.executable, "-m", "pip", "install", "--upgrade", 
        "pip", "setuptools", "wheel", "packaging"
    ], "Upgrading pip and build tools")
    
    # Step 2: Install PyTorch with CUDA support (if available)
    logger.info("📦 Installing PyTorch...")
    try:
        # Try CUDA version first
        run_pip_command([
            sys.executable, "-m", "pip", "install", 
            "torch", "torchvision", "torchaudio", 
            "--index-url", "https://download.pytorch.org/whl/cu124"
        ], "Installing PyTorch with CUDA support")
    except:
        logger.warning("WARNING: CUDA PyTorch failed, installing CPU version")
        run_pip_command([
            sys.executable, "-m", "pip", "install", 
            "torch", "torchvision", "torchaudio"
        ], "Installing PyTorch CPU version")
    
    # Step 3: Install main requirements
    run_pip_command([
        sys.executable, "-m", "pip", "install", "-r", "requirements.txt"
    ], "Installing main requirements")
    
    # Step 4: Try to install optional performance packages
    logger.info("[TARGET] Installing optional performance packages...")
    
    # Try xformers (memory efficient attention)
    run_pip_command([
        sys.executable, "-m", "pip", "install", "xformers"
    ], "Installing xformers (memory efficient attention)", optional=True)
    
    # Try flash-attn (advanced attention mechanism)
    logger.info("🔥 Attempting flash-attn installation (this may take a while or fail)...")
    try:
        # First try pre-built wheel
        run_pip_command([
            sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"
        ], "Installing flash-attn from wheel", optional=True)
    except:
        logger.warning("WARNING: flash-attn installation failed - this is common and not critical")
        logger.info("TIP: flash-attn can be installed later manually if needed")
    
    # Step 5: Verify installation
    logger.info("🔍 Verifying installation...")
    try:
        import torch
        import transformers
        import gradio
        import fastapi
        
        logger.info(f"SUCCESS: PyTorch: {torch.__version__}")
        logger.info(f"SUCCESS: Transformers: {transformers.__version__}")
        logger.info(f"SUCCESS: Gradio: {gradio.__version__}")
        
        if torch.cuda.is_available():
            logger.info(f"SUCCESS: CUDA: {torch.version.cuda}")
            logger.info(f"SUCCESS: GPU Count: {torch.cuda.device_count()}")
        else:
            logger.info("ℹ️ CUDA not available - will use CPU")
        
        # Check optional packages
        try:
            import xformers
            logger.info(f"SUCCESS: xformers: {xformers.__version__}")
        except ImportError:
            logger.info("ℹ️ xformers not available (optional)")
        
        try:
            import flash_attn
            logger.info("SUCCESS: flash_attn: Available")
        except ImportError:
            logger.info("ℹ️ flash_attn not available (optional)")
        
        logger.info("🎉 Installation completed successfully!")
        logger.info("TIP: You can now run: python app.py")
        
    except ImportError as e:
        logger.error(f"ERROR: Installation verification failed: {e}")
        return False
    
    return True

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)