Spaces:
Running
on
A10G
Running
on
A10G
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, BitsAndBytesConfig | |
| import tempfile | |
| from huggingface_hub import HfApi | |
| from huggingface_hub import list_models | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from packaging import version | |
| import os | |
| def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: | |
| # ^ expect a gr.OAuthProfile object as input to get the user's profile | |
| # if the user is not logged in, profile will be None | |
| if profile is None: | |
| return "Hello !" | |
| return f"Hello {profile.name} ! Welcome to BitsAndBytes Space" | |
| def check_model_exists(oauth_token: gr.OAuthToken | None, username, quantization_type, model_name, quantized_model_name): | |
| """Check if a model exists in the user's Hugging Face repository.""" | |
| try: | |
| models = list_models(author=username, token=oauth_token.token) | |
| model_names = [model.id for model in models] | |
| if quantized_model_name : | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else : | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-{quantization_type}" | |
| if repo_name in model_names: | |
| return f"Model '{repo_name}' already exists in your repository." | |
| else: | |
| return None # Model does not exist | |
| except Exception as e: | |
| return f"Error checking model existence: {str(e)}" | |
| def create_model_card(model_name, quantization_type, threshold, quant_type_4, double_quant_4,): | |
| model_card = f"""--- | |
| base_model: | |
| - {model_name} | |
| --- | |
| # {model_name} (Quantized) | |
| ## Description | |
| This model is a quantized version of the original model `{model_name}`. It has been quantized using {quantization_type} quantization with bitsandbytes. | |
| ## Quantization Details | |
| - **Quantization Type**: {quantization_type} | |
| - **Threshold**: {threshold if quantization_type == "int8" else None} | |
| - **bnb_4bit_quant_type**: {quant_type_4 if quantization_type == "int4" else None} | |
| - **bnb_4bit_use_double_quant**: {double_quant_4 if quantization_type=="int4" else None} | |
| ## Usage | |
| You can use this model in your applications by loading it directly from the Hugging Face Hub: | |
| ```python | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("{model_name}")""" | |
| return model_card | |
| def load_model(model_name, quantization_config, auth_token) : | |
| return AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token) | |
| def quantize_model(model_name, quantization_type, threshold, quant_type_4, double_quant_4, auth_token=None, username=None): | |
| print(f"Quantizing model: {quantization_type}") | |
| if quantization_type=="int4": | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=quant_type_4, | |
| bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, | |
| ) | |
| else : | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=threshold, | |
| ) | |
| model = load_model(model_name, quantization_config=quantization_config, auth_token=auth_token) | |
| return model | |
| def save_model(model, model_name, quantization_type, threshold, quant_type_4, double_quant_4, username=None, auth_token=None, quantized_model_name=None): | |
| print("Saving quantized model") | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) | |
| if quantized_model_name : | |
| repo_name = f"{username}/{quantized_model_name}" | |
| else : | |
| repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-{quantization_type}" | |
| model_card = create_model_card(repo_name, quantization_type, threshold, quant_type_4, double_quant_4) | |
| with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
| f.write(model_card) | |
| # Push to Hub | |
| api = HfApi(token=auth_token.token) | |
| api.create_repo(repo_name, exist_ok=True) | |
| api.upload_folder( | |
| folder_path=tmpdirname, | |
| repo_id=repo_name, | |
| repo_type="model", | |
| ) | |
| return f'<h1> 🤗 DONE</h1><br/>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a>' | |
| def is_float(value): | |
| try: | |
| float(value) | |
| return True | |
| except ValueError: | |
| return False | |
| def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quantization_type, threshold, quant_type_4, double_quant_4, quantized_model_name): | |
| if oauth_token is None : | |
| return "Error : Please Sign In to your HuggingFace account to use the quantizer" | |
| if not profile: | |
| return "Error: Please Sign In to your HuggingFace account to use the quantizer" | |
| exists_message = check_model_exists(oauth_token, profile.username, quantization_type, model_name, quantized_model_name) | |
| if exists_message : | |
| return exists_message | |
| if not is_float(threshold) : | |
| return "Threshold must be a float" | |
| threshold = float(threshold) | |
| # try: | |
| quantized_model = quantize_model(model_name, quantization_type, threshold, quant_type_4, double_quant_4, oauth_token, profile.username) | |
| return save_model(quantized_model, model_name, quantization_type, threshold, quant_type_4, double_quant_4, profile.username, oauth_token, quantized_model_name) | |
| # except Exception as e : | |
| # print(e) | |
| # return f"An error occurred: {str(e)}" | |
| css="""/* Custom CSS to allow scrolling */ | |
| .gradio-container {overflow-y: auto;} | |
| .custom-radio { | |
| margin-left: 20px; /* Adjust the value as needed */ | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤗 LLM Model BitsAndBytes Quantization App | |
| Quantize your favorite Hugging Face models using BitsAndBytes and save them to your profile! | |
| """ | |
| ) | |
| gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) | |
| m1 = gr.Markdown() | |
| demo.load(hello, inputs=None, outputs=m1) | |
| # radio = gr.Radio(["show", "hide"], label="Show Instructions") | |
| instructions = gr.Markdown( | |
| """ | |
| ## Instructions | |
| 1. Login to your HuggingFace account | |
| 2. Enter the name of the Hugging Face LLM model you want to quantize (Make sure you have access to it) | |
| 3. Choose the quantization type. | |
| 4. Optionally, specify the group size. | |
| 5. Optionally, choose a custom name for the quantized model | |
| 6. Click "Quantize and Save Model" to start the process. | |
| 7. Once complete, you'll receive a link to the quantized model on Hugging Face. | |
| Note: This process may take some time depending on the model size and your hardware you can check the container logs to see where are you at in the process! | |
| """, | |
| visible=False | |
| ) | |
| instructions_visible = gr.State(False) | |
| toggle_button = gr.Button("▼ Show Instructions", elem_id="toggle-button", elem_classes="toggle-button") | |
| def toggle_instructions(instructions_visible): | |
| new_visibility = not instructions_visible # Toggle the state | |
| new_label = "▲ Hide Instructions" if new_visibility else "▼ Show Instructions" # Change label based on visibility | |
| return gr.update(visible=new_visibility), new_visibility, gr.update(value=new_label) # Toggle visibility and return new state | |
| toggle_button.click(toggle_instructions, instructions_visible, [instructions, instructions_visible, toggle_button]) | |
| # def update_visibility(radio): # Accept the event argument, even if not used | |
| # value = radio # Get the selected value from the radio button | |
| # if value == "show": | |
| # return gr.Textbox(visible=True) #make it visible | |
| # else: | |
| # return gr.Textbox(visible=False) | |
| # radio.change(update_visibility, radio, instructions) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| model_name = HuggingfaceHubSearch( | |
| label="Hub Model ID", | |
| placeholder="Search for model id on Huggingface", | |
| search_type="model", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| quantization_type = gr.Dropdown( | |
| info="Quantization Type", | |
| choices=["int4", "int8"], | |
| value="int8", | |
| filterable=False, | |
| show_label=False, | |
| ) | |
| threshold_8 = gr.Textbox( | |
| info="Outlier threshold", | |
| value=6, | |
| interactive=True, | |
| show_label=False, | |
| visible=True | |
| ) | |
| quant_type_4 = gr.Dropdown( | |
| info="The quantization data type in the bnb.nn.Linear4Bit layers", | |
| choices=["fp4", "nf4"], | |
| value="fp4", | |
| visible=False, | |
| show_label=False | |
| ) | |
| radio_4 = gr.Radio(["False", "True"], info="Use Double Quant", visible=False, value="False", elem_classes="custom_radio") | |
| def update_visibility(quantization_type): | |
| return gr.update(visible=(quantization_type=="int8")), gr.update(visible=(quantization_type=="int4")), gr.update(visible=(quantization_type=="int4")) | |
| quantization_type.change(fn=update_visibility, inputs=quantization_type, outputs=[threshold_8, quant_type_4, radio_4]) | |
| quantized_model_name = gr.Textbox( | |
| info="Model Name (optional : to override default)", | |
| value="", | |
| interactive=True, | |
| show_label=False | |
| ) | |
| with gr.Column(): | |
| quantize_button = gr.Button("Quantize and Save Model", variant="primary") | |
| output_link = gr.Markdown(label="Quantized Model Link", container=True, min_height=80) | |
| # Adding CSS styles for the username box | |
| demo.css = """ | |
| #username-box { | |
| background-color: #f0f8ff; /* Light color */ | |
| border-radius: 8px; | |
| padding: 10px; | |
| } | |
| """ | |
| demo.css = """ | |
| .center-button { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| margin: 0 auto; /* Center horizontally */ | |
| } | |
| """ | |
| quantize_button.click( | |
| fn=quantize_and_save, | |
| inputs=[model_name, quantization_type, threshold_8, quant_type_4, radio_4, quantized_model_name], | |
| outputs=[output_link] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |
| # Launch the app | |
| # demo.launch(share=True, debug=True) |