Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import shutil | |
| import zipfile | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import traceback | |
| # --- Configuration for output paths --- | |
| # This directory will store the quantized models temporarily on the Space | |
| OUTPUT_DIR = Path("quantized_models_output") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # --- The core quantization function --- | |
| def quantize_model(model_id_or_path: str, quantization_level: str) -> gr.File: | |
| """ | |
| Loads an AI model (from Hugging Face Hub or local path), quantizes it | |
| based on the specified level, and saves the quantized model. | |
| The quantized model directory is then zipped for easier download. | |
| Args: | |
| model_id_or_path: The Hugging Face model ID (e.g., "stabilityai/stablelm-zephyr-3b") | |
| or a local path to a model directory (less common for HF Spaces, | |
| but useful if you pre-upload models to the Space itself). | |
| quantization_level: String indicating the desired quantization (e.g., '8-bit (INT8)', '4-bit (INT4)'). | |
| Returns: | |
| A Gradio File object pointing to the path of the saved quantized model directory (as a zip). | |
| """ | |
| if not model_id_or_path: | |
| raise gr.Error("Please provide a Hugging Face Model ID or a path to a local model directory.") | |
| print(f"[{model_id_or_path}] Attempting to quantize model.") | |
| print(f"[{model_id_or_path}] Desired quantization level: {quantization_level}") | |
| # Create a unique name for the saved quantized model directory | |
| safe_model_name = model_id_or_path.replace('/', '__').replace('\\', '__').replace('.', '_') | |
| quantized_model_base_name = f"quantized_{safe_model_name}_{quantization_level.replace(' ', '_').replace('(', '').replace(')', '')}" | |
| quantized_model_save_path = OUTPUT_DIR / quantized_model_base_name | |
| try: | |
| # Determine quantization configuration based on selection | |
| bnb_config = None | |
| if "8-bit" in quantization_level: | |
| print(f"[{model_id_or_path}] Configuring for 8-bit quantization (NF8).") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_quant_type="nf8", # Default for 8-bit | |
| bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
| ) | |
| elif "4-bit" in quantization_level: | |
| print(f"[{model_id_or_path}] Configuring for 4-bit quantization (NF4).") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, # More memory savings | |
| bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
| ) | |
| elif "FP16" in quantization_level: | |
| print(f"[{model_id_or_path}] Configuring for FP16 (Half-Precision).") | |
| # For FP16, we mainly rely on `torch_dtype` during `from_pretrained` | |
| # and no BitsAndBytesConfig is directly needed for loading | |
| pass # No bnb_config needed for direct FP16 load | |
| else: | |
| raise gr.Error(f"Unsupported quantization level: {quantization_level}") | |
| # --- Load Model and Tokenizer --- | |
| print(f"[{model_id_or_path}] Loading model and tokenizer from: {model_id_or_path}...") | |
| # Determine the torch_dtype based on GPU availability and quantization level | |
| load_torch_dtype = torch.float32 # Default | |
| if torch.cuda.is_available(): | |
| if "FP16" in quantization_level: | |
| load_torch_dtype = torch.float16 | |
| elif bnb_config and bnb_config.bnb_4bit_compute_dtype: | |
| load_torch_dtype = bnb_config.bnb_4bit_compute_dtype # Use bfloat16 for 4/8-bit if set | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id_or_path, | |
| quantization_config=bnb_config, # Will be None for FP16, used for 4/8-bit | |
| device_map="auto", # Automatically assigns layers to available devices (CPU/GPU) | |
| torch_dtype=load_torch_dtype, | |
| # trust_remote_code=True # Uncomment ONLY if you trust the model and it has custom code | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id_or_path) | |
| print(f"[{model_id_or_path}] Model and Tokenizer loaded successfully.") | |
| # --- Save the Quantized Model --- | |
| # First, clean up any previous runs of this specific model's quantized output | |
| if quantized_model_save_path.exists(): | |
| print(f"[{model_id_or_path}] Cleaning up previous output directory: {quantized_model_save_path}") | |
| shutil.rmtree(quantized_model_save_path) | |
| model.save_pretrained(quantized_model_save_path) | |
| tokenizer.save_pretrained(quantized_model_save_path) | |
| print(f"[{model_id_or_path}] Quantized model and tokenizer saved to: {quantized_model_save_path}") | |
| # Zip the directory for easy download | |
| # shutil.make_archive automatically adds a .zip extension | |
| zip_file_path = shutil.make_archive( | |
| base_name=str(quantized_model_save_path), | |
| format='zip', | |
| root_dir=str(quantized_model_save_path) | |
| ) | |
| print(f"[{model_id_or_path}] Quantized model zipped to: {zip_file_path}") | |
| # Return the path to the zipped file for Gradio to make downloadable | |
| return gr.File(value=zip_file_path, filename=Path(zip_file_path).name, label="Download Quantized Model (ZIP)") | |
| except Exception as e: | |
| print(f"[{model_id_or_path}] An error occurred during quantization: {e}") | |
| traceback.print_exc() # Print full traceback for debugging in the Space logs | |
| raise gr.Error(f"Quantization failed! Error: {e}. Check the Hugging Face Space logs for details. " | |
| "Ensure you have a CUDA-enabled GPU for 8/4-bit quantization, " | |
| "and that the model is compatible.") | |
| # --- Gradio Interface Definition --- | |
| iface = gr.Interface( | |
| fn=quantize_model, | |
| inputs=[ | |
| gr.Textbox(label="Hugging Face Model ID (e.g., stabilityai/stablelm-zephyr-3b)", | |
| placeholder="Enter a model ID from Hugging Face Hub (e.g., meta-llama/Llama-2-7b-hf)"), | |
| gr.Dropdown( | |
| choices=["8-bit (INT8)", "4-bit (INT4)", "FP16 (Half-Precision)"], | |
| label="Select Quantization Level", | |
| value="8-bit (INT8)" # Default selection | |
| ) | |
| ], | |
| outputs=gr.File(label="Quantized Model Download"), | |
| title="🌌 AI Model Shrinker: Quantize Your Models!", | |
| description=( | |
| "Enter a Hugging Face Model ID to effortlessly quantize it and reduce its size and memory footprint. " | |
| "This can significantly improve inference speed and allow larger models to run on more modest hardware. " | |
| "<br><b>Important Notes:</b>" | |
| "<ul>" | |
| "<li><b>GPU Required:</b> 8-bit and 4-bit quantization (using `bitsandbytes`) require a **CUDA-enabled GPU** to work properly. Choose a GPU hardware tier for your Space.</li>" | |
| "<li><b>Compatibility:</b> Not all models are guaranteed to work perfectly after quantization, especially 4-bit. Performance might vary.</li>" | |
| "<li><b>Downloading:</b> The output will be a `.zip` file containing the quantized model's directory.</li>" | |
| "<li><b>Experimental:</b> Embrace the experimental spirit! This tool pushes boundaries in AI accessibility.</li>" | |
| "</ul>" | |
| ), | |
| live=False, # Set to True if you want live updates, but not ideal for long processes | |
| allow_flagging="manual", # Allows users to flag inputs/outputs, useful for debugging | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| # When running locally, share=True creates a public URL for easy sharing | |
| # On Hugging Face Spaces, this is handled automatically. | |
| iface.launch(share=True) | |