File size: 4,857 Bytes
19b19f0
 
109031b
19b19f0
 
 
 
 
 
 
 
 
 
 
 
 
 
109031b
19b19f0
 
109031b
19b19f0
109031b
19b19f0
 
 
 
5a6251e
 
19b19f0
384c439
 
 
109031b
 
 
5a6251e
 
 
 
 
 
 
 
109031b
 
5a6251e
 
384c439
5a6251e
109031b
5a6251e
 
 
109031b
5a6251e
 
 
 
384c439
5a6251e
109031b
19b19f0
384c439
19b19f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
"""
Helper script to download the full fine-tuned model files at build time for Hugging Face Spaces
"""

import os
import sys
import subprocess
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model configuration
MAIN_MODEL_ID = "Tonic/petite-elle-L-aime-3-sft"
LOCAL_MODEL_PATH = "./model"

def download_model():
    """Download the full fine-tuned model files to local directory"""
    try:
        logger.info(f"Downloading full fine-tuned model from {MAIN_MODEL_ID}")
        
        # Create local directory if it doesn't exist
        os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
        
        # Use huggingface_hub to download the model files
        from huggingface_hub import hf_hub_download, list_repo_files
        
        # List all files in the repository
        all_files = list_repo_files(MAIN_MODEL_ID)
        
        # Filter files that are in the main repository (not in subfolders)
        main_files = [f for f in all_files if not "/" in f or f.startswith("int4/") == False]
        logger.info(f"Found {len(main_files)} files in main repository")
        
        # Download each required file
        required_files = [
            "config.json",
            "pytorch_model.bin", 
            "tokenizer.json",
            "tokenizer_config.json",
            "special_tokens_map.json",
            "generation_config.json",
            "chat_template.jinja"
        ]
        
        downloaded_count = 0
        for file_name in required_files:
            if file_name in all_files:
                logger.info(f"Downloading {file_name}...")
                hf_hub_download(
                    repo_id=MAIN_MODEL_ID,
                    filename=file_name,
                    local_dir=LOCAL_MODEL_PATH,
                    local_dir_use_symlinks=False
                )
                logger.info(f"Downloaded {file_name}")
                downloaded_count += 1
            else:
                logger.warning(f"File {file_name} not found in main repository")
        
        logger.info(f"Downloaded {downloaded_count} out of {len(required_files)} required files")
        logger.info(f"Model downloaded successfully to {LOCAL_MODEL_PATH}")
        return True
        
    except Exception as e:
        logger.error(f"Error downloading model: {e}")
        return False

def check_model_files():
    """Check if required model files exist"""
    required_files = [
        "config.json",
        "pytorch_model.bin",
        "tokenizer.json",
        "tokenizer_config.json"
    ]
    
    missing_files = []
    for file in required_files:
        file_path = os.path.join(LOCAL_MODEL_PATH, file)
        if not os.path.exists(file_path):
            missing_files.append(file)
    
    if missing_files:
        logger.error(f"Missing model files: {missing_files}")
        return False
    
    logger.info("All required model files found")
    return True

def verify_model_integrity():
    """Verify that the downloaded model files are valid"""
    try:
        # Try to load the tokenizer to verify it's working
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH)
        logger.info("Tokenizer loaded successfully from local files")
        
        # Try to load the model config
        from transformers import AutoConfig
        config = AutoConfig.from_pretrained(LOCAL_MODEL_PATH)
        logger.info("Model config loaded successfully from local files")
        
        return True
        
    except Exception as e:
        logger.error(f"Error verifying model integrity: {e}")
        return False

def main():
    """Main function to download model at build time"""
    logger.info("Starting model download for Hugging Face Space...")
    
    # Check if model files already exist
    if check_model_files():
        logger.info("Model files already exist, verifying integrity...")
        if verify_model_integrity():
            logger.info("Model files verified successfully")
            return True
        else:
            logger.warning("Model files exist but failed integrity check, re-downloading...")
    
    # Download the model
    if download_model():
        logger.info("Model download completed successfully")
        
        # Verify the downloaded files
        if check_model_files() and verify_model_integrity():
            logger.info("Model download and verification completed successfully")
            return True
        else:
            logger.error("Model download completed but verification failed")
            return False
    else:
        logger.error("Model download failed")
        return False

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)