import csv import os import base64 from datetime import datetime from typing import Optional, Union, List import gradio as gr from huggingface_hub import HfApi, Repository from optimum_neuron_export import convert, DIFFUSION_PIPELINE_MAPPING from gradio_huggingfacehub_search import HuggingfaceHubSearch from apscheduler.schedulers.background import BackgroundScheduler # Define transformer tasks and their categories for coloring TRANSFORMER_TASKS = { "auto": {"color": "#6b7280", "category": "Auto"}, "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "fill-mask": {"color": "#8b5cf6", "category": "NLP"}, "multiple-choice": {"color": "#8b5cf6", "category": "NLP"}, "question-answering": {"color": "#8b5cf6", "category": "NLP"}, "text-classification": {"color": "#8b5cf6", "category": "NLP"}, "token-classification": {"color": "#8b5cf6", "category": "NLP"}, "text-generation": {"color": "#10b981", "category": "Text Generation"}, "text2text-generation": {"color": "#10b981", "category": "Text Generation"}, "audio-classification": {"color": "#f59e0b", "category": "Audio"}, "automatic-speech-recognition": {"color": "#f59e0b", "category": "Audio"}, "audio-frame-classification": {"color": "#f59e0b", "category": "Audio"}, "audio-xvector": {"color": "#f59e0b", "category": "Audio"}, "image-classification": {"color": "#ef4444", "category": "Vision"}, "object-detection": {"color": "#ef4444", "category": "Vision"}, "semantic-segmentation": {"color": "#ef4444", "category": "Vision"}, "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"}, "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"}, } # Define diffusion pipeline types - updated structure DIFFUSION_PIPELINES = { "stable-diffusion": {"color": "#ec4899", "category": "Stable Diffusion", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "stable-diffusion-xl": {"color": "#10b981", "category": "Stable Diffusion XL", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "sdxl-turbo": {"color": "#f59e0b", "category": "SDXL Turbo", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "lcm": {"color": "#8b5cf6", "category": "LCM", "tasks": ["text-to-image"]}, "pixart-alpha": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]}, "pixart-sigma": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]}, "flux": {"color": "#06b6d4", "category": "Flux", "tasks": ["text-to-image", "inpaint"]}, "flux-kontext": {"color": "#06b6d4", "category": "Flux Kontext", "tasks": ["text-to-image", "image-to-image"]}, } TAGS = { "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "NLP": {"color": "#8b5cf6", "category": "NLP"}, "Text Generation": {"color": "#10b981", "category": "Text Generation"}, "Audio": {"color": "#f59e0b", "category": "Audio"}, "Vision": {"color": "#ef4444", "category": "Vision"}, "Multimodal": {"color": "#ec4899", "category": "Multimodal"}, "Similarity": {"color": "#06b6d4", "category": "Similarity"}, "Stable Diffusion": {"color": "#ec4899", "category": "Stable Diffusion"}, "Stable Diffusion XL": {"color": "#10b981", "category": "Stable Diffusion XL"}, "ControlNet": {"color": "#f59e0b", "category": "ControlNet"}, "ControlNet XL": {"color": "#f59e0b", "category": "ControlNet XL"}, "PixArt": {"color": "#ef4444", "category": "PixArt"}, "Latent Consistency": {"color": "#8b5cf6", "category": "Latent Consistency"}, "Flux": {"color": "#06b6d4", "category": "Flux"}, } # UPDATED: New choices for the Pull Request destination UI component DEST_NEW_NEURON_REPO = "Create new Neuron-optimized repository" DEST_CACHE_REPO = "Create a PR in the cache repository" DEST_CUSTOM_REPO = "Create a PR in a custom repository" PR_DESTINATION_CHOICES = [ DEST_CACHE_REPO, DEST_NEW_NEURON_REPO, DEST_CUSTOM_REPO ] DEFAULT_CACHE_REPO = "aws-neuron/optimum-neuron-cache" # Get all tasks and pipelines for dropdowns ALL_TRANSFORMER_TASKS = list(TRANSFORMER_TASKS.keys()) ALL_DIFFUSION_PIPELINES = list(DIFFUSION_PIPELINES.keys()) def create_task_tag(task: str) -> str: """Create a colored HTML tag for a task""" if task in TRANSFORMER_TASKS: color = TRANSFORMER_TASKS[task]["color"] return f'{task}' elif task in DIFFUSION_PIPELINES: color = DIFFUSION_PIPELINES[task]["color"] return f'{task}' elif task in TAGS: color = TAGS[task]["color"] return f'{task}' else: return f'{task}' def format_tasks_for_table(tasks_str: str) -> str: """Convert comma-separated tasks into colored tags""" tasks = [task.strip() for task in tasks_str.split(',')] return ' '.join([create_task_tag(task) for task in tasks]) def update_pipeline_and_task_dropdowns(model_type: str): """Update the pipeline and task dropdowns based on selected model type""" if model_type == "transformers": return ( gr.Dropdown(visible=False), # pipeline dropdown hidden gr.Dropdown( choices=ALL_TRANSFORMER_TASKS, value="auto", label="Task (auto can infer task from model)", visible=True ) ) else: # diffusers # Show pipeline dropdown, hide task dropdown initially return ( gr.Dropdown( choices=ALL_DIFFUSION_PIPELINES, value="stable-diffusion", label="Pipeline Type", visible=True ), gr.Dropdown( choices=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"], value=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"][0], label="Task", visible=True ) ) def update_task_dropdown_for_pipeline(pipeline_name: str): """Update task dropdown based on selected pipeline""" if pipeline_name in DIFFUSION_PIPELINES: tasks = DIFFUSION_PIPELINES[pipeline_name]["tasks"] return gr.Dropdown( choices=tasks, value=tasks[0] if tasks else None, label="Task", visible=True ) return gr.Dropdown(visible=False) def toggle_custom_repo_box(pr_destinations: List[str]): """Show or hide the custom repo ID textbox based on checkbox selection.""" if DEST_CUSTOM_REPO in pr_destinations: return gr.Textbox(visible=True) else: return gr.Textbox(visible=False, value="") def neuron_export(model_id: str, model_type: str, pipeline_name: str, task_or_pipeline: str, pr_destinations: List[str], custom_repo_id: str, custom_cache_repo: str, oauth_token: gr.OAuthToken): log_buffer = "" def log(msg, in_progress: bool = False): nonlocal log_buffer # Handle cases where the message from the backend is not a string if not isinstance(msg, str): msg = str(msg) log_buffer += msg + "\n" return log_buffer, gr.update(visible=in_progress) if oauth_token.token is None: yield log("You must be logged in to use this space") return if not model_id: yield log("๐Ÿšซ Invalid input. Please specify a model name from the hub.") return try: api = HfApi(token=oauth_token.token) # Set custom cache repo as environment variable if custom_cache_repo: os.environ['CUSTOM_CACHE_REPO'] = custom_cache_repo.strip() yield log("๐Ÿ”‘ Logging in ...", in_progress=True) try: api.model_info(model_id, token=oauth_token.token) except Exception as e: yield log(f"โŒ Could not access model `{model_id}`: {e}") return yield log(f"โœ… Model `{model_id}` is accessible. Starting Neuron export...", in_progress=True) # UPDATED: Build pr_options with new structure pr_options = { "create_cache_pr": DEST_CACHE_REPO in pr_destinations, "create_neuron_repo": DEST_NEW_NEURON_REPO in pr_destinations, "create_custom_pr": DEST_CUSTOM_REPO in pr_destinations, "custom_repo_id": custom_repo_id.strip() if custom_repo_id else "" } # The convert function is a generator, so we iterate through its messages for status_code, message in convert( api, model_id, task_or_pipeline, model_type, token=oauth_token.token, pr_options=pr_options, pipeline_name=pipeline_name if model_type == "diffusers (soon)" else None ): if isinstance(message, str): yield log(message, in_progress=True) else: # It's the final result dictionary final_message = "๐ŸŽ‰ Process finished.\n" if message.get("neuron_repo"): final_message += f"๐Ÿ—๏ธ New Neuron Repository: {message['neuron_repo']}\n" if message.get("readme_pr"): final_message += f"๐Ÿ“ README PR (Original Model): {message['readme_pr']}\n" if message.get("cache_pr"): final_message += f"๐Ÿ”— Cache PR: {message['cache_pr']}\n" if message.get("custom_pr"): final_message += f"๐Ÿ”— Custom PR: {message['custom_pr']}\n" yield log(final_message) except Exception as e: yield log(f"โ— An unexpected error occurred in the Gradio interface: {e}") TITLE = """

๐Ÿค— Optimum Neuron Model Exporter ๐ŸŽ๏ธ

""" # UPDATED: Description to reflect new workflow DESCRIPTION = """ This Space allows you to automatically export ๐Ÿค— transformers to AWS Neuron-optimized format for Inferentia/Trainium acceleration. """ CUSTOM_CSS = """ /* Primary button styling with warm colors */ button.gradio-button.lg.primary { /* Changed the blue/green gradient to an orange/yellow one */ background: linear-gradient(135deg, #F97316, #FBBF24) !important; color: white !important; padding: 16px 32px !important; font-size: 1.1rem !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; /* Updated the shadow to match the new orange color */ box-shadow: 0 0 15px rgba(249, 115, 22, 0.5) !important; transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; } /* Login button styling with glow effect using dark blue and violet colors */ #login-button { background: linear-gradient(135deg, #1a237e, #6a1b9a) !important; /* Dark Blue to Violet */ color: white !important; font-weight: 700 !important; border: none !important; border-radius: 12px !important; box-shadow: 0 0 15px rgba(106, 27, 154, 0.6) !important; /* Cool violet glow */ transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important; position: relative; overflow: hidden; animation: glow 1.5s ease-in-out infinite alternate; max-width: 350px !important; margin: 0 auto !important; } #login-button::before { content: "๐Ÿ”‘ "; display: inline-block !important; vertical-align: middle !important; margin-right: 5px !important; line-height: normal !important; } #login-button:hover { transform: translateY(-3px) scale(1.03) !important; box-shadow: 0 10px 25px rgba(26, 35, 126, 0.7) !important; /* Deeper blue glow */ } #login-button::after { content: ""; position: absolute; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.25), transparent); transition: 0.5s; } #login-button:hover::after { left: 100%; } .loader { width: 20px; height: 20px; border: 5px solid #d7d7d7; border-bottom-color: transparent; border-radius: 50%; display: inline-block; box-sizing: border-box; animation: rotation 1s linear infinite; } @keyframes rotation { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } """ LOADING_MESSAGE = """
     Model export in progress...
""" # Download the title image at startup and encode as base64 to avoid CORS errors def get_title_image_html(): """Download image and return HTML with base64 encoded image""" image_data = None # Download image try: # Use huggingface_hub to download the file (handles auth automatically) from huggingface_hub import hf_hub_download downloaded_path = hf_hub_download( repo_id="optimum/neuron-exporter", filename="huggingfaceXneuron.png", repo_type="space" ) # Read directly from downloaded path with open(downloaded_path, 'rb') as f: image_data = f.read() except Exception as e: print(f"Warning: Could not download title image: {e}") return "" # Return empty if download fails # Encode as base64 if image_data: encoded = base64.b64encode(image_data).decode('utf-8') return f"""
""" return "" with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: login_message = gr.Markdown("**You must be logged in to use this space**", visible=True) login_button = gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) title_image = get_title_image_html() gr.HTML(title_image) gr.HTML(TITLE) gr.Markdown(DESCRIPTION) with gr.Tabs(): with gr.Tab("Export Model"): with gr.Group(): with gr.Row(): pr_destinations_checkbox = gr.CheckboxGroup( choices=PR_DESTINATION_CHOICES, label="Export Destination", value=[DEST_CACHE_REPO], info="Select one or more destinations for the compiled model." ) custom_repo_id_textbox = gr.Textbox( label="Custom Repository ID", placeholder="e.g., your-username/your-repo-name", visible=False, interactive=True ) custom_cache_repo_textbox = gr.Textbox( label="Custom Cache Repository", placeholder="e.g., your-org/your-cache-repo", value=DEFAULT_CACHE_REPO, info=f"Repository to store and fetch from compilation cache artifacts (default: {DEFAULT_CACHE_REPO}) ", interactive=True ) with gr.Row(): model_type = gr.Radio( choices=["transformers", "diffusers (soon)"], value="transformers", label="Model Type", info="Choose the type of model you want to export" ) with gr.Row(): input_model = HuggingfaceHubSearch( label="Hub model ID", placeholder="Search for a model on the Hub...", search_type="model", ) pipeline_dropdown = gr.Dropdown( choices=ALL_DIFFUSION_PIPELINES, value="stable-diffusion", label="Pipeline Type", visible=False ) task_dropdown = gr.Dropdown( choices=ALL_TRANSFORMER_TASKS, value="auto", label="Task (auto can infer from model)", ) btn = gr.Button("Export to Neuron", size="lg", variant="primary") loading_message = gr.HTML(LOADING_MESSAGE, visible=False, elem_id="loaging_message") log_box = gr.Textbox(label="Logs", lines=20, interactive=False, show_copy_button=True) # Event Handlers model_type.change( fn=update_pipeline_and_task_dropdowns, inputs=[model_type], outputs=[pipeline_dropdown, task_dropdown] ) pipeline_dropdown.change( fn=update_task_dropdown_for_pipeline, inputs=[pipeline_dropdown], outputs=[task_dropdown] ) pr_destinations_checkbox.change( fn=toggle_custom_repo_box, inputs=pr_destinations_checkbox, outputs=custom_repo_id_textbox ) btn.click( fn=neuron_export, inputs=[ input_model, model_type, pipeline_dropdown, task_dropdown, pr_destinations_checkbox, custom_repo_id_textbox, custom_cache_repo_textbox ], outputs=[log_box, loading_message], ) with gr.Tab("Get Started"): gr.Markdown( """ **optimum-neuron version:** 0.4.1 This Space allows you to automatically export ๐Ÿค— transformers to AWS Neuron-optimized format for Inferentia/Trainium acceleration. Simply provide a model ID from the Hugging Face Hub, and choose your desired output. ### โœจ Key Features * **๐Ÿš€ Create a New Optimized Repo**: Automatically converts your model and uploads it to a new repository under your username (e.g., `your-username/model-name-neuron`). * **๐Ÿ”— Link Back to Original**: Creates a Pull Request on the original model's repository to add a link to your optimized version, making it easier for the community to discover. * **๐Ÿ› ๏ธ PR to a Custom Repo**: For custom workflows, you can create a Pull Request to add the optimized files directly into an existing repository you own. * **๐Ÿ“ฆ Contribute to Cache**: Contribute the generated compilation artifacts to a centralized cache repository (or your own private cache), helping avoid recompilation of already exported models. ### โš™๏ธ How to Use 1. **Model ID**: Enter the ID of the model you want to export (e.g., `bert-base-uncased` or `stabilityai/stable-diffusion-xl-base-1.0`) and choose the corresponding task. 2. **Export Options**: Select at least one option for where to save the exported model. You can provide your own cache repo ID or use the default (`aws-neuron/optimum-neuron-cache`). 3. **Convert & Upload**: Click the button and follow the logs to track progress! """ ) with gr.Tab("Supported Architectures"): gr.HTML(f"""

๐ŸŽจ Task Categories Legend

{create_task_tag("Feature Extraction")} {create_task_tag("NLP")} {create_task_tag("Text Generation")} {create_task_tag("Audio")} {create_task_tag("Vision")} {create_task_tag("Multimodal")} {create_task_tag("Similarity")}
""") gr.HTML(f"""

๐Ÿค— Transformers

Architecture Supported Tasks
ALBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
AST{format_tasks_for_table("feature-extraction, audio-classification")}
BERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
BLOOM{format_tasks_for_table("text-generation")}
Beit{format_tasks_for_table("feature-extraction, image-classification")}
CamemBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
CLIP{format_tasks_for_table("feature-extraction, image-classification")}
ConvBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
ConvNext{format_tasks_for_table("feature-extraction, image-classification")}
ConvNextV2{format_tasks_for_table("feature-extraction, image-classification")}
CvT{format_tasks_for_table("feature-extraction, image-classification")}
DeBERTa (INF2 only){format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
DeBERTa-v2 (INF2 only){format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Deit{format_tasks_for_table("feature-extraction, image-classification")}
DistilBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
DonutSwin{format_tasks_for_table("feature-extraction")}
Dpt{format_tasks_for_table("feature-extraction")}
ELECTRA{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
ESM{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}
FlauBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
GPT2{format_tasks_for_table("text-generation")}
Hubert{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}
Levit{format_tasks_for_table("feature-extraction, image-classification")}
Llama, Llama 2, Llama 3{format_tasks_for_table("text-generation")}
Mistral{format_tasks_for_table("text-generation")}
Mixtral{format_tasks_for_table("text-generation")}
MobileBERT{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
MobileNetV2{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}
MobileViT{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}
ModernBERT{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}
MPNet{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
OPT{format_tasks_for_table("text-generation")}
Phi{format_tasks_for_table("feature-extraction, text-classification, token-classification")}
RoBERTa{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
RoFormer{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Swin{format_tasks_for_table("feature-extraction, image-classification")}
T5{format_tasks_for_table("text2text-generation")}
UniSpeech{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}
UniSpeech-SAT{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
ViT{format_tasks_for_table("feature-extraction, image-classification")}
Wav2Vec2{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
WavLM{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}
Whisper{format_tasks_for_table("automatic-speech-recognition")}
XLM{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
XLM-RoBERTa{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}
Yolos{format_tasks_for_table("feature-extraction, object-detection")}

๐Ÿงจ Diffusers

Architecture Supported Tasks
Stable Diffusion{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
Stable Diffusion XL Base{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
Stable Diffusion XL Refiner{format_tasks_for_table("image-to-image, inpaint")}
SDXL Turbo{format_tasks_for_table("text-to-image, image-to-image, inpaint")}
LCM{format_tasks_for_table("text-to-image")}
PixArt-ฮฑ{format_tasks_for_table("text-to-image")}
PixArt-ฮฃ{format_tasks_for_table("text-to-image")}
Flux{format_tasks_for_table("text-to-image")}

๐Ÿค– Sentence Transformers

Architecture Supported Tasks
Transformer{format_tasks_for_table("feature-extraction, sentence-similarity")}
CLIP{format_tasks_for_table("feature-extraction, zero-shot-image-classification")}

๐Ÿ’ก Note: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.

For more details, check the Optimum Neuron documentation.

""") # Add spacing between tabs and content gr.Markdown("



") def update_login_visibility(oauth_token: gr.OAuthToken): if oauth_token.token is None: return gr.Markdown(visible=True) else: return gr.Markdown(visible=False) demo.load( fn=update_login_visibility, inputs=None, outputs=[login_message] ) if __name__ == "__main__": demo.launch(debug=True)