Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Helper script to download the full fine-tuned model files at build time for Hugging Face Spaces | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import logging | |
| from pathlib import Path | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| MAIN_MODEL_ID = "Tonic/petite-elle-L-aime-3-sft" | |
| LOCAL_MODEL_PATH = "./model" | |
| def download_model(): | |
| """Download the full fine-tuned model files to local directory""" | |
| try: | |
| logger.info(f"Downloading full fine-tuned model from {MAIN_MODEL_ID}") | |
| # Create local directory if it doesn't exist | |
| os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) | |
| # Use huggingface_hub to download the model files | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| # List all files in the repository | |
| all_files = list_repo_files(MAIN_MODEL_ID) | |
| # Filter files that are in the main repository (not in subfolders) | |
| main_files = [f for f in all_files if not "/" in f or f.startswith("int4/") == False] | |
| logger.info(f"Found {len(main_files)} files in main repository") | |
| # Download each required file | |
| required_files = [ | |
| "config.json", | |
| "pytorch_model.bin", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "special_tokens_map.json", | |
| "generation_config.json", | |
| "chat_template.jinja" | |
| ] | |
| downloaded_count = 0 | |
| for file_name in required_files: | |
| if file_name in all_files: | |
| logger.info(f"Downloading {file_name}...") | |
| hf_hub_download( | |
| repo_id=MAIN_MODEL_ID, | |
| filename=file_name, | |
| local_dir=LOCAL_MODEL_PATH, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"Downloaded {file_name}") | |
| downloaded_count += 1 | |
| else: | |
| logger.warning(f"File {file_name} not found in main repository") | |
| logger.info(f"Downloaded {downloaded_count} out of {len(required_files)} required files") | |
| logger.info(f"Model downloaded successfully to {LOCAL_MODEL_PATH}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error downloading model: {e}") | |
| return False | |
| def check_model_files(): | |
| """Check if required model files exist""" | |
| required_files = [ | |
| "config.json", | |
| "pytorch_model.bin", | |
| "tokenizer.json", | |
| "tokenizer_config.json" | |
| ] | |
| missing_files = [] | |
| for file in required_files: | |
| file_path = os.path.join(LOCAL_MODEL_PATH, file) | |
| if not os.path.exists(file_path): | |
| missing_files.append(file) | |
| if missing_files: | |
| logger.error(f"Missing model files: {missing_files}") | |
| return False | |
| logger.info("All required model files found") | |
| return True | |
| def verify_model_integrity(): | |
| """Verify that the downloaded model files are valid""" | |
| try: | |
| # Try to load the tokenizer to verify it's working | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH) | |
| logger.info("Tokenizer loaded successfully from local files") | |
| # Try to load the model config | |
| from transformers import AutoConfig | |
| config = AutoConfig.from_pretrained(LOCAL_MODEL_PATH) | |
| logger.info("Model config loaded successfully from local files") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error verifying model integrity: {e}") | |
| return False | |
| def main(): | |
| """Main function to download model at build time""" | |
| logger.info("Starting model download for Hugging Face Space...") | |
| # Check if model files already exist | |
| if check_model_files(): | |
| logger.info("Model files already exist, verifying integrity...") | |
| if verify_model_integrity(): | |
| logger.info("Model files verified successfully") | |
| return True | |
| else: | |
| logger.warning("Model files exist but failed integrity check, re-downloading...") | |
| # Download the model | |
| if download_model(): | |
| logger.info("Model download completed successfully") | |
| # Verify the downloaded files | |
| if check_model_files() and verify_model_integrity(): | |
| logger.info("Model download and verification completed successfully") | |
| return True | |
| else: | |
| logger.error("Model download completed but verification failed") | |
| return False | |
| else: | |
| logger.error("Model download failed") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) |