import csv import os import base64 from datetime import datetime from typing import Optional, Union, List import gradio as gr from huggingface_hub import HfApi, Repository from optimum_neuron_export import convert, DIFFUSION_PIPELINE_MAPPING from gradio_huggingfacehub_search import HuggingfaceHubSearch from apscheduler.schedulers.background import BackgroundScheduler # Define transformer tasks and their categories for coloring TRANSFORMER_TASKS = { "auto": {"color": "#6b7280", "category": "Auto"}, "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "fill-mask": {"color": "#8b5cf6", "category": "NLP"}, "multiple-choice": {"color": "#8b5cf6", "category": "NLP"}, "question-answering": {"color": "#8b5cf6", "category": "NLP"}, "text-classification": {"color": "#8b5cf6", "category": "NLP"}, "token-classification": {"color": "#8b5cf6", "category": "NLP"}, "text-generation": {"color": "#10b981", "category": "Text Generation"}, "text2text-generation": {"color": "#10b981", "category": "Text Generation"}, "audio-classification": {"color": "#f59e0b", "category": "Audio"}, "automatic-speech-recognition": {"color": "#f59e0b", "category": "Audio"}, "audio-frame-classification": {"color": "#f59e0b", "category": "Audio"}, "audio-xvector": {"color": "#f59e0b", "category": "Audio"}, "image-classification": {"color": "#ef4444", "category": "Vision"}, "object-detection": {"color": "#ef4444", "category": "Vision"}, "semantic-segmentation": {"color": "#ef4444", "category": "Vision"}, "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"}, "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"}, } # Define diffusion pipeline types - updated structure DIFFUSION_PIPELINES = { "stable-diffusion": {"color": "#ec4899", "category": "Stable Diffusion", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "stable-diffusion-xl": {"color": "#10b981", "category": "Stable Diffusion XL", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "sdxl-turbo": {"color": "#f59e0b", "category": "SDXL Turbo", "tasks": ["text-to-image", "image-to-image", "inpaint"]}, "lcm": {"color": "#8b5cf6", "category": "LCM", "tasks": ["text-to-image"]}, "pixart-alpha": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]}, "pixart-sigma": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]}, "flux": {"color": "#06b6d4", "category": "Flux", "tasks": ["text-to-image", "inpaint"]}, "flux-kontext": {"color": "#06b6d4", "category": "Flux Kontext", "tasks": ["text-to-image", "image-to-image"]}, } TAGS = { "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"}, "NLP": {"color": "#8b5cf6", "category": "NLP"}, "Text Generation": {"color": "#10b981", "category": "Text Generation"}, "Audio": {"color": "#f59e0b", "category": "Audio"}, "Vision": {"color": "#ef4444", "category": "Vision"}, "Multimodal": {"color": "#ec4899", "category": "Multimodal"}, "Similarity": {"color": "#06b6d4", "category": "Similarity"}, "Stable Diffusion": {"color": "#ec4899", "category": "Stable Diffusion"}, "Stable Diffusion XL": {"color": "#10b981", "category": "Stable Diffusion XL"}, "ControlNet": {"color": "#f59e0b", "category": "ControlNet"}, "ControlNet XL": {"color": "#f59e0b", "category": "ControlNet XL"}, "PixArt": {"color": "#ef4444", "category": "PixArt"}, "Latent Consistency": {"color": "#8b5cf6", "category": "Latent Consistency"}, "Flux": {"color": "#06b6d4", "category": "Flux"}, } # UPDATED: New choices for the Pull Request destination UI component DEST_NEW_NEURON_REPO = "Create new Neuron-optimized repository" DEST_CACHE_REPO = "Create a PR in the cache repository" DEST_CUSTOM_REPO = "Create a PR in a custom repository" PR_DESTINATION_CHOICES = [ DEST_CACHE_REPO, DEST_NEW_NEURON_REPO, DEST_CUSTOM_REPO ] DEFAULT_CACHE_REPO = "aws-neuron/optimum-neuron-cache" # Get all tasks and pipelines for dropdowns ALL_TRANSFORMER_TASKS = list(TRANSFORMER_TASKS.keys()) ALL_DIFFUSION_PIPELINES = list(DIFFUSION_PIPELINES.keys()) def create_task_tag(task: str) -> str: """Create a colored HTML tag for a task""" if task in TRANSFORMER_TASKS: color = TRANSFORMER_TASKS[task]["color"] return f'{task}' elif task in DIFFUSION_PIPELINES: color = DIFFUSION_PIPELINES[task]["color"] return f'{task}' elif task in TAGS: color = TAGS[task]["color"] return f'{task}' else: return f'{task}' def format_tasks_for_table(tasks_str: str) -> str: """Convert comma-separated tasks into colored tags""" tasks = [task.strip() for task in tasks_str.split(',')] return ' '.join([create_task_tag(task) for task in tasks]) def update_pipeline_and_task_dropdowns(model_type: str): """Update the pipeline and task dropdowns based on selected model type""" if model_type == "transformers": return ( gr.Dropdown(visible=False), # pipeline dropdown hidden gr.Dropdown( choices=ALL_TRANSFORMER_TASKS, value="auto", label="Task (auto can infer task from model)", visible=True ) ) else: # diffusers # Show pipeline dropdown, hide task dropdown initially return ( gr.Dropdown( choices=ALL_DIFFUSION_PIPELINES, value="stable-diffusion", label="Pipeline Type", visible=True ), gr.Dropdown( choices=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"], value=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"][0], label="Task", visible=True ) ) def update_task_dropdown_for_pipeline(pipeline_name: str): """Update task dropdown based on selected pipeline""" if pipeline_name in DIFFUSION_PIPELINES: tasks = DIFFUSION_PIPELINES[pipeline_name]["tasks"] return gr.Dropdown( choices=tasks, value=tasks[0] if tasks else None, label="Task", visible=True ) return gr.Dropdown(visible=False) def toggle_custom_repo_box(pr_destinations: List[str]): """Show or hide the custom repo ID textbox based on checkbox selection.""" if DEST_CUSTOM_REPO in pr_destinations: return gr.Textbox(visible=True) else: return gr.Textbox(visible=False, value="") def neuron_export(model_id: str, model_type: str, pipeline_name: str, task_or_pipeline: str, pr_destinations: List[str], custom_repo_id: str, custom_cache_repo: str, oauth_token: gr.OAuthToken): log_buffer = "" def log(msg, in_progress: bool = False): nonlocal log_buffer # Handle cases where the message from the backend is not a string if not isinstance(msg, str): msg = str(msg) log_buffer += msg + "\n" return log_buffer, gr.update(visible=in_progress) if oauth_token.token is None: yield log("You must be logged in to use this space") return if not model_id: yield log("๐ซ Invalid input. Please specify a model name from the hub.") return try: api = HfApi(token=oauth_token.token) # Set custom cache repo as environment variable if custom_cache_repo: os.environ['CUSTOM_CACHE_REPO'] = custom_cache_repo.strip() yield log("๐ Logging in ...", in_progress=True) try: api.model_info(model_id, token=oauth_token.token) except Exception as e: yield log(f"โ Could not access model `{model_id}`: {e}") return yield log(f"โ Model `{model_id}` is accessible. Starting Neuron export...", in_progress=True) # UPDATED: Build pr_options with new structure pr_options = { "create_cache_pr": DEST_CACHE_REPO in pr_destinations, "create_neuron_repo": DEST_NEW_NEURON_REPO in pr_destinations, "create_custom_pr": DEST_CUSTOM_REPO in pr_destinations, "custom_repo_id": custom_repo_id.strip() if custom_repo_id else "" } # The convert function is a generator, so we iterate through its messages for status_code, message in convert( api, model_id, task_or_pipeline, model_type, token=oauth_token.token, pr_options=pr_options, pipeline_name=pipeline_name if model_type == "diffusers (soon)" else None ): if isinstance(message, str): yield log(message, in_progress=True) else: # It's the final result dictionary final_message = "๐ Process finished.\n" if message.get("neuron_repo"): final_message += f"๐๏ธ New Neuron Repository: {message['neuron_repo']}\n" if message.get("readme_pr"): final_message += f"๐ README PR (Original Model): {message['readme_pr']}\n" if message.get("cache_pr"): final_message += f"๐ Cache PR: {message['cache_pr']}\n" if message.get("custom_pr"): final_message += f"๐ Custom PR: {message['custom_pr']}\n" yield log(final_message) except Exception as e: yield log(f"โ An unexpected error occurred in the Gradio interface: {e}") TITLE = """
| Architecture | Supported Tasks |
|---|---|
| ALBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| AST | {format_tasks_for_table("feature-extraction, audio-classification")} |
| BERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| BLOOM | {format_tasks_for_table("text-generation")} |
| Beit | {format_tasks_for_table("feature-extraction, image-classification")} |
| CamemBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| CLIP | {format_tasks_for_table("feature-extraction, image-classification")} |
| ConvBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| ConvNext | {format_tasks_for_table("feature-extraction, image-classification")} |
| ConvNextV2 | {format_tasks_for_table("feature-extraction, image-classification")} |
| CvT | {format_tasks_for_table("feature-extraction, image-classification")} |
| DeBERTa (INF2 only) | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| DeBERTa-v2 (INF2 only) | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| Deit | {format_tasks_for_table("feature-extraction, image-classification")} |
| DistilBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| DonutSwin | {format_tasks_for_table("feature-extraction")} |
| Dpt | {format_tasks_for_table("feature-extraction")} |
| ELECTRA | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| ESM | {format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")} |
| FlauBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| GPT2 | {format_tasks_for_table("text-generation")} |
| Hubert | {format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")} |
| Levit | {format_tasks_for_table("feature-extraction, image-classification")} |
| Llama, Llama 2, Llama 3 | {format_tasks_for_table("text-generation")} |
| Mistral | {format_tasks_for_table("text-generation")} |
| Mixtral | {format_tasks_for_table("text-generation")} |
| MobileBERT | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| MobileNetV2 | {format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")} |
| MobileViT | {format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")} |
| ModernBERT | {format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")} |
| MPNet | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| OPT | {format_tasks_for_table("text-generation")} |
| Phi | {format_tasks_for_table("feature-extraction, text-classification, token-classification")} |
| RoBERTa | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| RoFormer | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| Swin | {format_tasks_for_table("feature-extraction, image-classification")} |
| T5 | {format_tasks_for_table("text2text-generation")} |
| UniSpeech | {format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")} |
| UniSpeech-SAT | {format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")} |
| ViT | {format_tasks_for_table("feature-extraction, image-classification")} |
| Wav2Vec2 | {format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")} |
| WavLM | {format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")} |
| Whisper | {format_tasks_for_table("automatic-speech-recognition")} |
| XLM | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| XLM-RoBERTa | {format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")} |
| Yolos | {format_tasks_for_table("feature-extraction, object-detection")} |
| Architecture | Supported Tasks |
|---|---|
| Stable Diffusion | {format_tasks_for_table("text-to-image, image-to-image, inpaint")} |
| Stable Diffusion XL Base | {format_tasks_for_table("text-to-image, image-to-image, inpaint")} |
| Stable Diffusion XL Refiner | {format_tasks_for_table("image-to-image, inpaint")} |
| SDXL Turbo | {format_tasks_for_table("text-to-image, image-to-image, inpaint")} |
| LCM | {format_tasks_for_table("text-to-image")} |
| PixArt-ฮฑ | {format_tasks_for_table("text-to-image")} |
| PixArt-ฮฃ | {format_tasks_for_table("text-to-image")} |
| Flux | {format_tasks_for_table("text-to-image")} |
| Architecture | Supported Tasks |
|---|---|
| Transformer | {format_tasks_for_table("feature-extraction, sentence-similarity")} |
| CLIP | {format_tasks_for_table("feature-extraction, zero-shot-image-classification")} |
๐ก Note: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.
For more details, check the Optimum Neuron documentation.