Spaces:
Running
on
A10G
Running
on
A10G
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, BitsAndBytesConfig | |
| import tempfile | |
| from huggingface_hub import HfApi | |
| from huggingface_hub import list_models | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from bitsandbytes.nn import Linear4bit | |
| from packaging import version | |
| import os | |
| from tqdm import tqdm | |
| def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: | |
| # ^ expect a gr.OAuthProfile object as input to get the user's profile | |
| # if the user is not logged in, profile will be None | |
| if profile is None: | |
| return "Hello Please Login to HuggingFace to use the BitsAndBytes Quantizer!" | |
| return f"Hello {profile.name} ! Welcome to BitsAndBytes Quantizer" | |
| def check_model_exists(oauth_token: gr.OAuthToken | None, username, model_name, quantized_model_name): | |
| """Check if a model exists in the user's Hugging Face repository.""" | |
| try: | |
| models = list_models(author=username, token=oauth_token.token) | |
| model_names = [model.id for model in models] | |
| if quantized_model_name : | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else : | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" | |
| if repo_name in model_names: | |
| return f"Model '{repo_name}' already exists in your repository." | |
| else: | |
| return None # Model does not exist | |
| except Exception as e: | |
| return f"Error checking model existence: {str(e)}" | |
| def create_model_card(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4): | |
| model_card = f"""--- | |
| base_model: | |
| - {model_name} | |
| --- | |
| # {model_name} (Quantized) | |
| ## Description | |
| This model is a quantized version of the original model `{model_name}`. It has been quantized using int4 quantization with bitsandbytes. | |
| ## Quantization Details | |
| - **Quantization Type**: int4 | |
| - **bnb_4bit_quant_type**: {quant_type_4} | |
| - **bnb_4bit_use_double_quant**: {double_quant_4} | |
| - **bnb_4bit_compute_dtype**: {compute_type_4} | |
| - **bnb_4bit_quant_storage**: {quant_storage_4} | |
| ## Usage | |
| You can use this model in your applications by loading it directly from the Hugging Face Hub: | |
| ```python | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("{model_name}")""" | |
| return model_card | |
| DTYPE_MAPPING = { | |
| "int8": torch.int8, | |
| "uint8": torch.uint8, | |
| "float16": torch.float16, | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| def quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token=None, progress=gr.Progress()): | |
| progress(0, desc="Starting") | |
| print(f"Quantizing model: {quant_type_4}") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=quant_type_4, | |
| bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, | |
| bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4], | |
| bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4], | |
| ) | |
| model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16) | |
| for _ , module in progress.tqdm(model.named_modules(), desc="Quantizing model", total=len(list(model.named_modules())), unit="layers"): | |
| if isinstance(module, Linear4bit): | |
| module.to("cuda") | |
| module.to("cpu") | |
| return model | |
| def save_model(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False): | |
| print("Saving quantized model") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) | |
| if quantized_model_name : | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else : | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" | |
| model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| # Push to Hub | |
| api = HfApi(token=auth_token.token) | |
| api.create_repo(repo_name, exist_ok=True, private=not public) | |
| api.upload_folder( | |
| folder_path=tmpdirname, | |
| repo_id=repo_name, | |
| repo_type="model", | |
| ) | |
| # Get model architecture as string | |
| import io | |
| from contextlib import redirect_stdout | |
| import html | |
| # Capture the model architecture string | |
| f = io.StringIO() | |
| with redirect_stdout(f): | |
| print(model) | |
| model_architecture_str = f.getvalue() | |
| # Escape HTML characters and format with line breaks | |
| model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '<br/>') | |
| # Format it for display in markdown with proper styling | |
| model_architecture_info = f""" | |
| <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;"> | |
| <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div> | |
| </div> | |
| """ | |
| return f'🔗 Quantized Model <br/><h1> 🤗 DONE</h1><br/>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a><br/><br/>📊 Model Architecture<br/>{model_architecture_info}' | |
| def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public): | |
| if oauth_token is None : | |
| return """ | |
| <div class="error-box"> | |
| <h3>❌ Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| if not profile: | |
| return """ | |
| <div class="error-box"> | |
| <h3>❌ Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name) | |
| if exists_message : | |
| return f""" | |
| <div class="warning-box"> | |
| <h3>⚠️ Model Already Exists</h3> | |
| <p>{exists_message}</p> | |
| </div> | |
| """ | |
| try: | |
| # Download phase | |
| quantized_model = quantize_model(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token) | |
| final_message = save_model(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public) | |
| return final_message | |
| except Exception as e : | |
| error_message = str(e).replace('\n', '<br/>') | |
| return f""" | |
| <div class="error-box"> | |
| <h3>❌ Error Occurred</h3> | |
| <p>{error_message}</p> | |
| </div> | |
| """ | |
| css="""/* Custom CSS to allow scrolling */ | |
| .gradio-container {overflow-y: auto;} | |
| /* Fix alignment for radio buttons and checkboxes */ | |
| .gradio-radio { | |
| display: flex !important; | |
| align-items: center !important; | |
| margin: 10px 0 !important; | |
| } | |
| .gradio-checkbox { | |
| display: flex !important; | |
| align-items: center !important; | |
| margin: 10px 0 !important; | |
| } | |
| /* Ensure consistent spacing and alignment */ | |
| .gradio-dropdown, .gradio-textbox, .gradio-radio, .gradio-checkbox { | |
| margin-bottom: 12px !important; | |
| width: 100% !important; | |
| } | |
| /* Align radio buttons and checkboxes horizontally */ | |
| .option-row { | |
| display: flex !important; | |
| justify-content: space-between !important; | |
| align-items: center !important; | |
| gap: 20px !important; | |
| margin-bottom: 12px !important; | |
| } | |
| .option-row .gradio-radio, .option-row .gradio-checkbox { | |
| margin: 0 !important; | |
| flex: 1 !important; | |
| } | |
| /* Horizontally align radio button options with text */ | |
| .gradio-radio label { | |
| display: flex !important; | |
| align-items: center !important; | |
| } | |
| .gradio-radio input[type="radio"] { | |
| margin-right: 5px !important; | |
| } | |
| /* Remove padding and margin from model name textbox for better alignment */ | |
| .model-name-textbox { | |
| padding-left: 0 !important; | |
| padding-right: 0 !important; | |
| margin-left: 0 !important; | |
| margin-right: 0 !important; | |
| } | |
| /* Quantize button styling with glow effect */ | |
| button[variant="primary"] { | |
| background: linear-gradient(135deg, #3B82F6, #10B981) !important; | |
| color: white !important; | |
| padding: 16px 32px !important; | |
| font-size: 1.1rem !important; | |
| font-weight: 700 !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; | |
| transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; | |
| position: relative; | |
| overflow: hidden; | |
| animation: glow 1.5s ease-in-out infinite alternate; | |
| } | |
| button[variant="primary"]::before { | |
| content: "✨ "; | |
| } | |
| button[variant="primary"]:hover { | |
| transform: translateY(-5px) scale(1.05) !important; | |
| box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; | |
| } | |
| @keyframes glow { | |
| from { | |
| box-shadow: 0 0 10px rgba(59, 130, 246, 0.5); | |
| } | |
| to { | |
| box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5); | |
| } | |
| } | |
| /* Login button styling with glow effect */ | |
| #login-button { | |
| background: linear-gradient(135deg, #3B82F6, #10B981) !important; | |
| color: white !important; | |
| font-weight: 700 !important; | |
| border: none !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important; | |
| transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; | |
| position: relative; | |
| overflow: hidden; | |
| animation: glow 1.5s ease-in-out infinite alternate; | |
| max-width: 300px !important; | |
| margin: 0 auto !important; | |
| } | |
| #login-button::before { | |
| content: "🔑 "; | |
| display: inline-block !important; | |
| vertical-align: middle !important; | |
| margin-right: 5px !important; | |
| line-height: normal !important; | |
| } | |
| #login-button:hover { | |
| transform: translateY(-3px) scale(1.03) !important; | |
| box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important; | |
| } | |
| #login-button::after { | |
| content: ""; | |
| position: absolute; | |
| top: 0; | |
| left: -100%; | |
| width: 100%; | |
| height: 100%; | |
| background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); | |
| transition: 0.5s; | |
| } | |
| #login-button:hover::after { | |
| left: 100%; | |
| } | |
| /* Toggle instructions button styling */ | |
| #toggle-button { | |
| background: linear-gradient(135deg, #3B82F6, #10B981) !important; | |
| color: white !important; | |
| font-size: 0.85rem !important; | |
| font-weight: 600 !important; | |
| padding: 8px 16px !important; | |
| border: none !important; | |
| border-radius: 8px !important; | |
| box-shadow: 0 2px 10px rgba(59, 130, 246, 0.3) !important; | |
| transition: all 0.3s ease !important; | |
| margin: 0.5rem auto 1.5rem auto !important; | |
| display: block !important; | |
| max-width: 200px !important; | |
| text-align: center !important; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| #toggle-button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 4px 12px rgba(59, 130, 246, 0.5) !important; | |
| } | |
| #toggle-button::after { | |
| content: ""; | |
| position: absolute; | |
| top: 0; | |
| left: -100%; | |
| width: 100%; | |
| height: 100%; | |
| background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent); | |
| transition: 0.5s; | |
| } | |
| #toggle-button:hover::after { | |
| left: 100%; | |
| } | |
| /* Progress Bar Styles */ | |
| .progress-container { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| padding: 20px; | |
| background: white; | |
| border-radius: 12px; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .progress-stage { | |
| font-size: 0.9rem; | |
| font-weight: 600; | |
| color: #64748b; | |
| } | |
| .progress-stage .stage { | |
| position: relative; | |
| padding: 8px 12px; | |
| border-radius: 6px; | |
| background: #f1f5f9; | |
| transition: all 0.3s ease; | |
| } | |
| .progress-stage .stage.completed { | |
| background: #ecfdf5; | |
| } | |
| .progress-bar { | |
| box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.1); | |
| } | |
| .progress { | |
| transition: width 0.8s cubic-bezier(0.4, 0, 0.2, 1); | |
| box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3); | |
| } | |
| """ | |
| def quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, auth_token, progress=gr.Progress()): | |
| """Quantize model with progress updates.""" | |
| progress(0, desc="Loading model") | |
| # Configure quantization | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=quant_type_4, | |
| bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, | |
| bnb_4bit_quant_storage=DTYPE_MAPPING[quant_storage_4], | |
| bnb_4bit_compute_dtype=DTYPE_MAPPING[compute_type_4], | |
| ) | |
| # Load model | |
| model = AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token, torch_dtype=torch.bfloat16) | |
| progress(0.33, desc="Quantizing") | |
| # Quantize model | |
| modules = list(model.named_modules()) | |
| for idx, (_, module) in enumerate(modules): | |
| if isinstance(module, Linear4bit): | |
| module.to("cuda") | |
| module.to("cpu") | |
| progress(0.33 + (0.33 * idx / len(modules)), desc="Quantizing") | |
| progress(0.66, desc="Quantized successfully") | |
| return model | |
| def save_model_with_progress(model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, username=None, auth_token=None, quantized_model_name=None, public=False, progress=gr.Progress()): | |
| """Save model with progress updates.""" | |
| progress(0.67, desc="Preparing to push") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| # Save model | |
| model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) | |
| progress(0.75, desc="Preparing to push") | |
| # Prepare repo name and model card | |
| if quantized_model_name: | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else: | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-bnb-4bit" | |
| model_card = create_model_card(repo_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| progress(0.80, desc="Model card created") | |
| # Push to Hub | |
| api = HfApi(token=auth_token.token) | |
| api.create_repo(repo_name, exist_ok=True, private=not public) | |
| progress(0.85, desc="Pushing to Hub") | |
| # Upload files | |
| api.upload_folder( | |
| folder_path=tmpdirname, | |
| repo_id=repo_name, | |
| repo_type="model", | |
| ) | |
| progress(1.00, desc="Model pushed to Hub") | |
| # Get model architecture as string | |
| import io | |
| from contextlib import redirect_stdout | |
| import html | |
| # Capture the model architecture string | |
| f = io.StringIO() | |
| with redirect_stdout(f): | |
| print(model) | |
| model_architecture_str = f.getvalue() | |
| # Escape HTML characters and format with line breaks | |
| model_architecture_str_html = html.escape(model_architecture_str).replace('\n', '<br/>') | |
| # Format it for display in markdown with proper styling | |
| model_architecture_info = f""" | |
| <div class="model-architecture" style="max-height: 500px; overflow-y: auto; overflow-x: auto; background-color: #f5f5f5; padding: 5px; border-radius: 8px; font-family: monospace; white-space: pre-wrap;"> | |
| <div style="line-height: 1.2; font-size: 0.75em;">{model_architecture_str_html}</div> | |
| </div> | |
| """ | |
| return f'🔗 Quantized Model <br/><h1> 🤗 DONE</h1><br/>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a><br/><br/>📊 Model Architecture<br/>{model_architecture_info}' | |
| def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public, progress=gr.Progress()): | |
| if oauth_token is None: | |
| return """ | |
| <div class="error-box"> | |
| <h3>❌ Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| if not profile: | |
| return """ | |
| <div class="error-box"> | |
| <h3>❌ Authentication Error</h3> | |
| <p>Please sign in to your HuggingFace account to use the quantizer.</p> | |
| </div> | |
| """ | |
| exists_message = check_model_exists(oauth_token, profile.username, model_name, quantized_model_name) | |
| if exists_message: | |
| return f""" | |
| <div class="warning-box"> | |
| <h3>⚠️ Model Already Exists</h3> | |
| <p>{exists_message}</p> | |
| </div> | |
| """ | |
| try: | |
| # Download and quantize phase | |
| progress(0, desc="Starting quantization process") | |
| quantized_model = quantize_model_with_progress(model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, oauth_token, progress) | |
| # Save and push phase | |
| final_message = save_model_with_progress(quantized_model, model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, profile.username, oauth_token, quantized_model_name, public, progress) | |
| return final_message | |
| except Exception as e: | |
| error_message = str(e).replace('\n', '<br/>') | |
| return f""" | |
| <div class="error-box"> | |
| <h3>❌ Error Occurred</h3> | |
| <p>{error_message}</p> | |
| </div> | |
| """ | |
| with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤗 LLM Model BitsAndBytes Quantizer ✨ | |
| """ | |
| ) | |
| gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) | |
| m1 = gr.Markdown() | |
| demo.load(hello, inputs=None, outputs=m1) | |
| instructions_visible = gr.State(False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name = HuggingfaceHubSearch( | |
| label="🔍 Hub Model ID", | |
| placeholder="Search for model id on Huggingface", | |
| search_type="model", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| ### ⚙️ Model Quantization Type Settings | |
| """ | |
| ) | |
| quant_type_4 = gr.Dropdown( | |
| info="The quantization data type in the bnb.nn.Linear4Bit layers", | |
| choices=["fp4", "nf4"], | |
| value="nf4", | |
| visible=True, | |
| show_label=False | |
| ) | |
| compute_type_4 = gr.Dropdown( | |
| info="The compute type for the model", | |
| choices=["float16", "bfloat16", "float32"], | |
| value="bfloat16", | |
| visible=True, | |
| show_label=False | |
| ) | |
| quant_storage_4 = gr.Dropdown( | |
| info="The storage type for the model", | |
| choices=["float16", "float32", "int8", "uint8", "bfloat16"], | |
| value="uint8", | |
| visible=True, | |
| show_label=False | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### 🔄 Double Quantization Settings | |
| """ | |
| ) | |
| with gr.Row(elem_classes="option-row"): | |
| double_quant_4 = gr.Radio( | |
| ["True", "False"], | |
| info="Use Double Quant", | |
| visible=True, | |
| value="True", | |
| show_label=False | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### 💾 Saving Settings | |
| """ | |
| ) | |
| with gr.Row(): | |
| quantized_model_name = gr.Textbox( | |
| label="✏️ Model Name", | |
| info="Model Name (optional : to override default)", | |
| value="", | |
| interactive=True, | |
| elem_classes="model-name-textbox", | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| public = gr.Checkbox( | |
| label="🌐 Make model public", | |
| info="If checked, the model will be publicly accessible", | |
| value=True, | |
| interactive=True, | |
| show_label=True | |
| ) | |
| with gr.Column(): | |
| quantize_button = gr.Button("🚀 Quantize and Push to the Hub", variant="primary") | |
| output_link = gr.Markdown("🔗 Quantized Model", container=True, min_height=100) | |
| quantize_button.click( | |
| fn=quantize_and_save, | |
| inputs=[model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4, quantized_model_name, public], | |
| outputs=[output_link], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |