Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel | |
| import tempfile | |
| from huggingface_hub import HfApi, snapshot_download | |
| from huggingface_hub import list_models | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from packaging import version | |
| import os | |
| from torchao.quantization import ( | |
| Int4WeightOnlyConfig, | |
| Int8WeightOnlyConfig, | |
| Int8DynamicActivationInt8WeightConfig, | |
| Float8WeightOnlyConfig, | |
| Float8DynamicActivationFloat8WeightConfig, | |
| GemliteUIntXWeightOnlyConfig, | |
| ) | |
| MAP_QUANT_TYPE_TO_NAME = { | |
| "Int4WeightOnly": "int4wo", | |
| "GemliteUIntXWeightOnly": "intxwo-gemlite", | |
| "Int8WeightOnly": "int8wo", | |
| "Int8DynamicActivationInt8Weight": "int8da8w8", | |
| "Float8WeightOnly": "float8wo", | |
| "Float8DynamicActivationFloat8Weight": "float8da8w8", | |
| "autoquant": "autoquant", | |
| } | |
| MAP_QUANT_TYPE_TO_CONFIG = { | |
| "Int4WeightOnly": Int4WeightOnlyConfig, | |
| "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig, | |
| "Int8WeightOnly": Int8WeightOnlyConfig, | |
| "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig, | |
| "Float8WeightOnly": Float8WeightOnlyConfig, | |
| "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig, | |
| } | |
| def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: | |
| # ^ expect a gr.OAuthProfile object as input to get the user's profile | |
| # if the user is not logged in, profile will be None | |
| if profile is None: | |
| return "Hello !" | |
| return f"Hello {profile.name} !" | |
| def check_model_exists( | |
| oauth_token: gr.OAuthToken | None, | |
| username, | |
| quantization_type, | |
| group_size, | |
| model_name, | |
| quantized_model_name, | |
| ): | |
| """Check if a model exists in the user's Hugging Face repository.""" | |
| try: | |
| models = list_models(author=username, token=oauth_token.token) | |
| model_names = [model.id for model in models] | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if ( | |
| quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] | |
| ) and (group_size is not None): | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}" | |
| if repo_name in model_names: | |
| return f"Model '{repo_name}' already exists in your repository." | |
| else: | |
| return None # Model does not exist | |
| except Exception as e: | |
| return f"Error checking model existence: {str(e)}" | |
| def create_model_card(model_name, quantization_type, group_size): | |
| # Try to download the original README | |
| original_readme = "" | |
| original_yaml_header = "" | |
| try: | |
| # Download the README.md file from the original model | |
| model_path = snapshot_download( | |
| repo_id=model_name, allow_patterns=["README.md"], repo_type="model" | |
| ) | |
| readme_path = os.path.join(model_path, "README.md") | |
| if os.path.exists(readme_path): | |
| with open(readme_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| if content.startswith("---"): | |
| parts = content.split("---", 2) | |
| if len(parts) >= 3: | |
| original_yaml_header = parts[1] | |
| original_readme = "---".join(parts[2:]) | |
| else: | |
| original_readme = content | |
| else: | |
| original_readme = content | |
| except Exception as e: | |
| print(f"Error reading original README: {str(e)}") | |
| original_readme = "" | |
| # Create new YAML header with base_model field | |
| yaml_header = f"""--- | |
| base_model: | |
| - {model_name}""" | |
| # Add any original YAML fields except base_model | |
| if original_yaml_header: | |
| in_base_model_section = False | |
| found_tags = False | |
| for line in original_yaml_header.strip().split("\n"): | |
| # Skip if we're in a base_model section that continues to the next line | |
| if in_base_model_section: | |
| if ( | |
| line.strip().startswith("-") | |
| or not line.strip() | |
| or line.startswith(" ") | |
| ): | |
| continue | |
| else: | |
| in_base_model_section = False | |
| # Check for base_model field | |
| if line.strip().startswith("base_model:"): | |
| in_base_model_section = True | |
| # If base_model has inline value (like "base_model: model_name") | |
| if ":" in line and len(line.split(":", 1)[1].strip()) > 0: | |
| in_base_model_section = False | |
| continue | |
| # Check for tags field and add bnb-my-repo | |
| if line.strip().startswith("tags:"): | |
| found_tags = True | |
| yaml_header += f"\n{line}" | |
| yaml_header += "\n- torchao-my-repo" | |
| continue | |
| yaml_header += f"\n{line}" | |
| # If tags field wasn't found, add it | |
| if not found_tags: | |
| yaml_header += "\ntags:" | |
| yaml_header += "\n- torchao-my-repo" | |
| # Complete the YAML header | |
| yaml_header += "\n---" | |
| # Create the quantization info section | |
| quant_info = f""" | |
| # {model_name} (Quantized) | |
| ## Description | |
| This model is a quantized version of the original model [`{model_name}`](https://huggingface.co/{model_name}). | |
| It's quantized using the TorchAO library using the [torchao-my-repo](https://huggingface.co/spaces/pytorch/torchao-my-repo) space. | |
| ## Quantization Details | |
| - **Quantization Type**: {quantization_type} | |
| - **Group Size**: {group_size} | |
| """ | |
| # Combine everything | |
| model_card = yaml_header + quant_info | |
| # Append original README content if available | |
| if original_readme and not original_readme.isspace(): | |
| model_card += "\n\n# π Original Model Information\n\n" + original_readme | |
| return model_card | |
| def quantize_model( | |
| model_name, quantization_type, group_size=128, auth_token=None, username=None, progress=gr.Progress() | |
| ): | |
| print(f"Quantizing model: {quantization_type}") | |
| progress(0, desc="Preparing Quantization") | |
| if ( | |
| quantization_type == "GemliteUIntXWeightOnly" | |
| ): | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]( | |
| group_size=group_size | |
| ) | |
| quantization_config = TorchAoConfig(quant_config) | |
| elif quantization_type == "Int4WeightOnly": | |
| from torchao.dtypes import Int4CPULayout | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]( | |
| group_size=group_size, layout=Int4CPULayout() | |
| ) | |
| quantization_config = TorchAoConfig(quant_config) | |
| elif quantization_type == "autoquant": | |
| quantization_config = TorchAoConfig(quantization_type) | |
| else: | |
| quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]() | |
| quantization_config = TorchAoConfig(quant_config) | |
| progress(0.10, desc="Quantizing model") | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| quantization_config=quantization_config, | |
| device_map="cpu", | |
| use_auth_token=auth_token.token, | |
| ) | |
| progress(0.45, desc="Quantization completed") | |
| return model | |
| def save_model( | |
| model, | |
| model_name, | |
| quantization_type, | |
| group_size=128, | |
| username=None, | |
| auth_token=None, | |
| quantized_model_name=None, | |
| public=True, | |
| progress=gr.Progress(), | |
| ): | |
| progress(0.50, desc="Preparing to push") | |
| print("Saving quantized model") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| # Load and save the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, use_auth_token=auth_token.token | |
| ) | |
| tokenizer.save_pretrained(tmpdirname, use_auth_token=auth_token.token) | |
| # Save the model | |
| progress(0.60, desc="Saving model") | |
| model.save_pretrained( | |
| tmpdirname, safe_serialization=False, use_auth_token=auth_token.token | |
| ) | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| if ( | |
| quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] | |
| ) and (group_size is not None): | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}-gs{group_size}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type.lower()]}" | |
| progress(0.70, desc="Creating model card") | |
| model_card = create_model_card(model_name, quantization_type, group_size) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| # Push to Hub | |
| api = HfApi(token=auth_token.token) | |
| api.create_repo(repo_name, exist_ok=True, private=not public) | |
| progress(0.80, desc="Pushing to Hub") | |
| api.upload_folder( | |
| folder_path=tmpdirname, | |
| repo_id=repo_name, | |
| repo_type="model", | |
| ) | |
| progress(1.00, desc="Pushing to Hub completed") | |
| import io | |
| from contextlib import redirect_stdout | |
| import html | |
| # Capture the model architecture string | |
| f = io.StringIO() | |
| with redirect_stdout(f): | |
| print(model) | |
| model_architecture_str = f.getvalue() | |
| # Escape HTML characters and format with line breaks | |
| model_architecture_str_html = html.escape(model_architecture_str).replace( | |
| "\n", "<br/>" | |
| ) | |
| # Format it for display in markdown with proper styling | |
| model_architecture_info = f""" | |
| <div class="model-architecture-container" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;"> | |
| <h3 style="margin-top: 0; color: #2E7D32;">π Model Architecture</h3> | |
| <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;"> | |
| <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div> | |
| </div> | |
| </div> | |
| """ | |
| repo_link = f""" | |
| <div class="repo-link" style="margin-top: 20px; margin-bottom: 20px; background-color: #f8f9fa; padding: 15px; border-radius: 8px; border-left: 4px solid #4CAF50;"> | |
| <h3 style="margin-top: 0; color: #2E7D32;">π Repository Link</h3> | |
| <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a></p> | |
| </div> | |
| """ | |
| return ( | |
| f"<h1>π Quantization Completed</h1><br/>{repo_link}{model_architecture_info}" | |
| ) | |
| def quantize_and_save( | |
| profile: gr.OAuthProfile | None, | |
| oauth_token: gr.OAuthToken | None, | |
| model_name, | |
| quantization_type, | |
| group_size, | |
| quantized_model_name, | |
| public, | |
| ): | |
| if oauth_token is None: | |
| return """ | |
| <div class="error-box"> | |
| <h3>β Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| if not profile: | |
| return """ | |
| <div class="error-box"> | |
| <h3>β Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| if not group_size.isdigit(): | |
| if group_size != "": | |
| return """ | |
| <div class="error-box"> | |
| <h3>β Group Size Error</h3> | |
| <p>Group Size is a parameter for Int4WeightOnly or GemliteUIntXWeightOnly</p> | |
| </div> | |
| """ | |
| if group_size and group_size.strip(): | |
| group_size = int(group_size) | |
| else: | |
| group_size = None | |
| exists_message = check_model_exists( | |
| oauth_token, | |
| profile.username, | |
| quantization_type, | |
| group_size, | |
| model_name, | |
| quantized_model_name, | |
| ) | |
| if exists_message: | |
| return f""" | |
| <div class="warning-box"> | |
| <h3>β οΈ Model Already Exists</h3> | |
| <p>{exists_message}</p> | |
| </div> | |
| """ | |
| # if quantization_type == "int4_weight_only" : | |
| # return "int4_weight_only not supported on cpu" | |
| try: | |
| quantized_model = quantize_model( | |
| model_name, quantization_type, group_size, oauth_token, profile.username | |
| ) | |
| return save_model( | |
| quantized_model, | |
| model_name, | |
| quantization_type, | |
| group_size, | |
| profile.username, | |
| oauth_token, | |
| quantized_model_name, | |
| public, | |
| ) | |
| except Exception as e: | |
| # raise e | |
| return str(e) | |
| def get_model_size(model): | |
| """ | |
| Calculate the size of a PyTorch model in gigabytes. | |
| Args: | |
| model: PyTorch model | |
| Returns: | |
| float: Size of the model in GB | |
| """ | |
| # Get model state dict | |
| state_dict = model.state_dict() | |
| # Calculate total size in bytes | |
| total_size = 0 | |
| for param in state_dict.values(): | |
| # Calculate bytes for each parameter | |
| total_size += param.nelement() * param.element_size() | |
| # Convert bytes to gigabytes (1 GB = 1,073,741,824 bytes) | |
| size_gb = total_size / (1024**3) | |
| size_gb = round(size_gb, 2) | |
| return size_gb | |
| # Add enhanced CSS styling | |
| css = """ | |
| /* Custom CSS for enhanced UI */ | |
| .gradio-container {overflow-y: auto;} | |
| /* Fix alignment for radio buttons and dropdowns */ | |
| .gradio-radio, .gradio-dropdown { | |
| display: flex !important; | |
| align-items: center !important; | |
| margin: 10px 0 !important; | |
| } | |
| /* Consistent spacing and alignment */ | |
| .gradio-dropdown, .gradio-textbox, .gradio-radio { | |
| margin-bottom: 12px !important; | |
| width: 100% !important; | |
| } | |
| button[variant="primary"]::before { | |
| content: "π₯ "; /* PyTorch flame icon */ | |
| } | |
| button[variant="primary"]:hover { | |
| transform: translateY(-5px) scale(1.05) !important; | |
| box-shadow: 0 10px 25px rgba(238, 76, 44, 0.7) !important; | |
| } | |
| @keyframes pytorch-glow { | |
| from { | |
| box-shadow: 0 0 10px rgba(238, 76, 44, 0.5); | |
| } | |
| to { | |
| box-shadow: 0 0 20px rgba(238, 76, 44, 0.8), 0 0 30px rgba(255, 156, 0, 0.5); | |
| } | |
| } | |
| /* Login button styling */ | |
| #login-button { | |
| background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| border: none !important; | |
| border-radius: 15px !important; | |
| box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important; | |
| transition: all 0.3s ease !important; | |
| max-width: 300px !important; | |
| margin: 0 auto !important; | |
| } | |
| .quantize-button { | |
| background: linear-gradient(135deg, #EE4C2C, #FF9C00) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| border: none !important; | |
| border-radius: 15px !important; | |
| box-shadow: 0 0 15px rgba(238, 76, 44, 0.5) !important; | |
| transition: all 0.3s ease !important; | |
| animation: pytorch-glow 1.5s infinite alternate !important; | |
| transform-origin: center !important; | |
| letter-spacing: 0.5px !important; | |
| text-shadow: 0 1px 2px rgba(0, 0, 0, 0.2) !important; | |
| } | |
| .quantize-button:hover { | |
| transform: translateY(-3px) scale(1.03) !important; | |
| box-shadow: 0 8px 20px rgba(238, 76, 44, 0.7) !important; | |
| } | |
| """ | |
| # Update the main app layout | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ TorchAO Model Quantizer β¨ | |
| Quantize your favorite Hugging Face models using TorchAO and save them to your profile! | |
| <br/> | |
| """ | |
| ) | |
| gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) | |
| m1 = gr.Markdown() | |
| demo.load(hello, inputs=None, outputs=m1) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name = HuggingfaceHubSearch( | |
| label="π Hub Model ID", | |
| placeholder="Search for model id on Huggingface", | |
| search_type="model", | |
| ) | |
| gr.Markdown("""### βοΈ Quantization Settings""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| quantization_type = gr.Dropdown( | |
| info="Select the Quantization method", | |
| choices=[ | |
| "Int4WeightOnly", | |
| "GemliteUIntXWeightOnly" | |
| "Int8WeightOnly", | |
| "Int8DynamicActivationInt8Weight", | |
| "Float8WeightOnly", | |
| "Float8DynamicActivationFloat8Weight", | |
| "autoquant", | |
| ], | |
| value="int8_weight_only", | |
| filterable=False, | |
| show_label=False, | |
| ) | |
| group_size = gr.Textbox( | |
| info="Group Size (only for int4_weight_only and int8_weight_only)", | |
| value="128", | |
| interactive=(quantization_type.value == "int4_weight_only" or quantization_type.value == "int8_weight_only"), | |
| show_label=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### πΎ Saving Settings | |
| """ | |
| ) | |
| with gr.Row(): | |
| quantized_model_name = gr.Textbox( | |
| label="βοΈ Model Name", | |
| info="Model Name (optional : to override default)", | |
| value="", | |
| interactive=True, | |
| elem_classes="model-name-textbox", | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| public = gr.Checkbox( | |
| label="π Make model public", | |
| info="If checked, the model will be publicly accessible", | |
| value=True, | |
| interactive=True, | |
| show_label=True, | |
| ) | |
| with gr.Column(): | |
| quantize_button = gr.Button( | |
| "π Quantize and Push to Hub", elem_classes="quantize-button", elem_id="quantize-button" | |
| ) | |
| output_link = gr.Markdown( | |
| label="π Quantized Model Info", container=True, min_height=200 | |
| ) | |
| # Add information section | |
| with gr.Accordion("π About TorchAO Quantization", open=True): | |
| gr.Markdown( | |
| """ | |
| ## π Quantization Options | |
| ### Quantization Types | |
| - **Int4WeightOnly**: 4-bit weight-only quantization | |
| - **GemliteUIntXWeightOnly**: uintx gemlite quantization (default to 4 bit only for now) | |
| - **Int8WeightOnly**: 8-bit weight-only quantization | |
| - **Int8DynamicActivationInt8Weight**: 8-bit quantization for both weights and activations | |
| - **Float8WeightOnly**: float8 weight-only quantization | |
| - **Float8DynamicActivationFloat8Weight**: float8 quantization for both weights and activations | |
| - **autoquant**: automatic quantization (uses the best quantization method for the model) | |
| ### Group Size | |
| - Only applicable for Int4WeightOnly and GemliteUIntXWeightOnly quantization | |
| - Default value is 128 | |
| - Affects the granularity of quantization | |
| ## π How It Works | |
| 1. Downloads the original model | |
| 2. Applies TorchAO quantization with your selected settings | |
| 3. Uploads the quantized model to your HuggingFace account | |
| ## π Memory Benefits | |
| - int4 quantization can reduce model size by up to 75% | |
| - int8 quantization typically reduces size by about 50% | |
| """ | |
| ) | |
| # Keep existing click handler | |
| quantize_button.click( | |
| fn=quantize_and_save, | |
| inputs=[model_name, quantization_type, group_size, quantized_model_name, public], | |
| outputs=[output_link], | |
| ) | |
| # Launch the app | |
| demo.launch(share=True) | |