badaoui HF Staff commited on
Commit
ea87f8a
·
verified ·
1 Parent(s): 2d82228

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +763 -482
app.py CHANGED
@@ -1,526 +1,807 @@
1
- import csv
2
  import os
3
- from datetime import datetime
4
- from typing import Optional, Union, List
 
 
 
5
  import gradio as gr
6
- from huggingface_hub import HfApi, Repository
7
- from optimum_neuron_export import convert, DIFFUSION_PIPELINE_MAPPING
8
- from gradio_huggingfacehub_search import HuggingfaceHubSearch
9
- from apscheduler.schedulers.background import BackgroundScheduler
10
-
11
- # Define transformer tasks and their categories for coloring
12
- TRANSFORMER_TASKS = {
13
- "auto": {"color": "#6b7280", "category": "Auto"},
14
- "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
15
- "fill-mask": {"color": "#8b5cf6", "category": "NLP"},
16
- "multiple-choice": {"color": "#8b5cf6", "category": "NLP"},
17
- "question-answering": {"color": "#8b5cf6", "category": "NLP"},
18
- "text-classification": {"color": "#8b5cf6", "category": "NLP"},
19
- "token-classification": {"color": "#8b5cf6", "category": "NLP"},
20
- "text-generation": {"color": "#10b981", "category": "Text Generation"},
21
- "text2text-generation": {"color": "#10b981", "category": "Text Generation"},
22
- "audio-classification": {"color": "#f59e0b", "category": "Audio"},
23
- "automatic-speech-recognition": {"color": "#f59e0b", "category": "Audio"},
24
- "audio-frame-classification": {"color": "#f59e0b", "category": "Audio"},
25
- "audio-xvector": {"color": "#f59e0b", "category": "Audio"},
26
- "image-classification": {"color": "#ef4444", "category": "Vision"},
27
- "object-detection": {"color": "#ef4444", "category": "Vision"},
28
- "semantic-segmentation": {"color": "#ef4444", "category": "Vision"},
29
- "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"},
30
- "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"},
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Define diffusion pipeline types - updated structure
34
- DIFFUSION_PIPELINES = {
35
- "stable-diffusion": {"color": "#ec4899", "category": "Stable Diffusion", "tasks": ["text-to-image", "image-to-image", "inpaint"]},
36
- "stable-diffusion-xl": {"color": "#10b981", "category": "Stable Diffusion XL", "tasks": ["text-to-image", "image-to-image", "inpaint"]},
37
- "sdxl-turbo": {"color": "#f59e0b", "category": "SDXL Turbo", "tasks": ["text-to-image", "image-to-image", "inpaint"]},
38
- "lcm": {"color": "#8b5cf6", "category": "LCM", "tasks": ["text-to-image"]},
39
- "pixart-alpha": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]},
40
- "pixart-sigma": {"color": "#ef4444", "category": "PixArt", "tasks": ["text-to-image"]},
41
- "flux": {"color": "#06b6d4", "category": "Flux", "tasks": ["text-to-image", "inpaint"]},
42
- "flux-kont": {"color": "#06b6d4", "category": "Flux Kont", "tasks": ["text-to-image", "image-to-image"]},
 
 
 
 
 
 
 
 
 
 
 
 
43
  }
44
 
45
- TAGS = {
46
- "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
47
- "NLP": {"color": "#8b5cf6", "category": "NLP"},
48
- "Text Generation": {"color": "#10b981", "category": "Text Generation"},
49
- "Audio": {"color": "#f59e0b", "category": "Audio"},
50
- "Vision": {"color": "#ef4444", "category": "Vision"},
51
- "Multimodal": {"color": "#ec4899", "category": "Multimodal"},
52
- "Similarity": {"color": "#06b6d4", "category": "Similarity"},
53
- "Stable Diffusion": {"color": "#ec4899", "category": "Stable Diffusion"},
54
- "Stable Diffusion XL": {"color": "#10b981", "category": "Stable Diffusion XL"},
55
- "ControlNet": {"color": "#f59e0b", "category": "ControlNet"},
56
- "ControlNet XL": {"color": "#f59e0b", "category": "ControlNet XL"},
57
- "PixArt": {"color": "#ef4444", "category": "PixArt"},
58
- "Latent Consistency": {"color": "#8b5cf6", "category": "Latent Consistency"},
59
- "Flux": {"color": "#06b6d4", "category": "Flux"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
 
62
- # UPDATED: New choices for the Pull Request destination UI component
63
- DEST_NEW_NEURON_REPO = "Create new Neuron-optimized repository"
64
- DEST_CACHE_REPO = "Create a PR in the cache repository"
65
- DEST_CUSTOM_REPO = "Create a PR in a custom repository"
66
-
67
- PR_DESTINATION_CHOICES = [
68
- DEST_NEW_NEURON_REPO,
69
- DEST_CACHE_REPO,
70
- DEST_CUSTOM_REPO
71
- ]
72
-
73
- DEFAULT_CACHE_REPO = "aws-neuron/optimum-neuron-cache"
74
-
75
- # Get all tasks and pipelines for dropdowns
76
- ALL_TRANSFORMER_TASKS = list(TRANSFORMER_TASKS.keys())
77
- ALL_DIFFUSION_PIPELINES = list(DIFFUSION_PIPELINES.keys())
78
-
79
- def create_task_tag(task: str) -> str:
80
- """Create a colored HTML tag for a task"""
81
- if task in TRANSFORMER_TASKS:
82
- color = TRANSFORMER_TASKS[task]["color"]
83
- return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
84
- elif task in DIFFUSION_PIPELINES:
85
- color = DIFFUSION_PIPELINES[task]["color"]
86
- return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
87
- elif task in TAGS:
88
- color = TAGS[task]["color"]
89
- return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
90
  else:
91
- return f'<span style="background-color: #6b7280; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
92
-
93
- def format_tasks_for_table(tasks_str: str) -> str:
94
- """Convert comma-separated tasks into colored tags"""
95
- tasks = [task.strip() for task in tasks_str.split(',')]
96
- return ' '.join([create_task_tag(task) for task in tasks])
97
-
98
- def update_pipeline_and_task_dropdowns(model_type: str):
99
- """Update the pipeline and task dropdowns based on selected model type"""
100
- if model_type == "transformers":
101
- return (
102
- gr.Dropdown(visible=False), # pipeline dropdown hidden
103
- gr.Dropdown(
104
- choices=ALL_TRANSFORMER_TASKS,
105
- value="auto",
106
- label="Task (auto can infer task from model)",
107
- visible=True
108
- )
109
- )
110
- else: # diffusers
111
- # Show pipeline dropdown, hide task dropdown initially
112
- return (
113
- gr.Dropdown(
114
- choices=ALL_DIFFUSION_PIPELINES,
115
- value="stable-diffusion",
116
- label="Pipeline Type",
117
- visible=True
118
- ),
119
- gr.Dropdown(
120
- choices=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"],
121
- value=DIFFUSION_PIPELINES["stable-diffusion"]["tasks"][0],
122
- label="Task",
123
- visible=True
124
- )
125
- )
126
 
127
- def update_task_dropdown_for_pipeline(pipeline_name: str):
128
- """Update task dropdown based on selected pipeline"""
129
- if pipeline_name in DIFFUSION_PIPELINES:
130
- tasks = DIFFUSION_PIPELINES[pipeline_name]["tasks"]
131
- return gr.Dropdown(
132
- choices=tasks,
133
- value=tasks[0] if tasks else None,
134
- label="Task",
135
- visible=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
- return gr.Dropdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- def toggle_custom_repo_box(pr_destinations: List[str]):
140
- """Show or hide the custom repo ID textbox based on checkbox selection."""
141
- if DEST_CUSTOM_REPO in pr_destinations:
142
- return gr.Textbox(visible=True)
143
- else:
144
- return gr.Textbox(visible=False, value="")
 
 
 
145
 
146
- def neuron_export(model_id: str, model_type: str, pipeline_name: str, task_or_pipeline: str,
147
- pr_destinations: List[str], custom_repo_id: str, custom_cache_repo: str, oauth_token: gr.OAuthToken):
 
148
 
149
- log_buffer = ""
150
- def log(msg):
151
- nonlocal log_buffer
152
- # Handle cases where the message from the backend is not a string
153
- if not isinstance(msg, str):
154
- msg = str(msg)
155
- log_buffer += msg + "\n"
156
- return log_buffer
157
-
158
- if oauth_token.token is None:
159
- yield log("You must be logged in to use this space")
160
- return
 
161
 
162
- if not model_id:
163
- yield log("🚫 Invalid input. Please specify a model name from the hub.")
164
- return
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  try:
167
- api = HfApi(token=oauth_token.token)
168
- # Set custom cache repo as environment variable
169
- if custom_cache_repo:
170
- os.environ['CUSTOM_CACHE_REPO'] = custom_cache_repo.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- yield log(f"🔑 Logging in ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  try:
174
- api.model_info(model_id, token=oauth_token.token)
175
  except Exception as e:
176
- yield log(f"❌ Could not access model `{model_id}`: {e}")
177
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- yield log(f"✅ Model `{model_id}` is accessible. Starting Neuron export...")
 
180
 
181
- # UPDATED: Build pr_options with new structure
182
- pr_options = {
183
- "create_neuron_repo": DEST_NEW_NEURON_REPO in pr_destinations,
184
- "create_cache_pr": DEST_CACHE_REPO in pr_destinations,
185
- "create_custom_pr": DEST_CUSTOM_REPO in pr_destinations,
186
- "custom_repo_id": custom_repo_id.strip() if custom_repo_id else ""
187
- }
 
 
188
 
189
- # The convert function is a generator, so we iterate through its messages
190
- for status_code, message in convert(
191
- api, model_id, task_or_pipeline, model_type,
192
- token=oauth_token.token, pr_options=pr_options,
193
- pipeline_name=pipeline_name if model_type == "diffusers" else None
194
- ):
195
- if isinstance(message, str):
196
- yield log(message)
197
- else: # It's the final result dictionary
198
- final_message = "🎉 Process finished.\n"
199
- if message.get("neuron_repo"):
200
- final_message += f"🏗️ New Neuron Repository: {message['neuron_repo']}\n"
201
- if message.get("readme_pr"):
202
- final_message += f"📝 README PR (Original Model): {message['readme_pr']}\n"
203
- if message.get("cache_pr"):
204
- final_message += f"🔗 Cache PR: {message['cache_pr']}\n"
205
- if message.get("custom_pr"):
206
- final_message += f"🔗 Custom PR: {message['custom_pr']}\n"
207
- yield log(final_message)
208
 
209
- except Exception as e:
210
- yield log(f"❗ An unexpected error occurred in the Gradio interface: {e}")
 
211
 
212
- TITLE_IMAGE = """
213
- <div style="display: block; margin-left: auto; margin-right: auto; width: 50%;">
214
- <img src="https://huggingface.co/spaces/optimum/neuron-export/resolve/main/huggingfaceXneuron.png"/>
215
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- TITLE = """
219
- <div style="text-align: center; max-width: 1400px; margin: 0 auto;">
220
- <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px; font-size: 2.2rem;">
221
- 🤗 Optimum Neuron Model Exporter 🏎️
222
- </h1>
223
- </div>
224
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # UPDATED: Description to reflect new workflow
227
- DESCRIPTION = """
228
- This Space allows you to automatically export 🤗 transformers and 🧨 diffusion models to AWS Neuron-optimized format for Inferentia/Trainium acceleration.
229
 
230
- Simply provide a model ID from the Hugging Face Hub, and choose your desired output.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- ### Key Features
 
 
 
233
 
234
- * **🚀 Create a New Optimized Repo**: Automatically converts your model and uploads it to a new repository under your username (e.g., `your-username/model-name-neuron`).
235
- * **🔗 Link Back to Original**: Creates a Pull Request on the original model's repository to add a link to your optimized version, making it easier for the community to discover.
236
- * **🛠️ PR to a Custom Repo**: For custom workflows, you can create a Pull Request to add the optimized files directly into an existing repository you own.
237
- * **📦 Contribute to Cache**: Contribute the generated compilation artifacts to a centralized cache repository (or your own private cache), helping avoid recompilation of already exported models.
238
 
239
- ### ⚙️ How to Use
240
- 1. **Model ID**: Enter the ID of the model you want to export (e.g., `bert-base-uncased` or `stabilityai/stable-diffusion-xl-base-1.0`) and choose the corresponding task.
241
- 2. **Export Options**: Select at least one option for where to save the exported model. You can provide your own cache repo ID or use the default (`aws-neuron/optimum-neuron-cache`).
242
- 3. **Convert & Upload**: Click the button and follow the logs to track progress!
243
 
 
 
 
 
244
  """
245
 
246
- CUSTOM_CSS = """
247
- /* Primary button styling with warm colors */
248
- button.gradio-button.lg.primary {
249
- /* Changed the blue/green gradient to an orange/yellow one */
250
- background: linear-gradient(135deg, #F97316, #FBBF24) !important;
251
- color: white !important;
252
- padding: 16px 32px !important;
253
- font-size: 1.1rem !important;
254
- font-weight: 700 !important;
255
- border: none !important;
256
- border-radius: 12px !important;
257
- /* Updated the shadow to match the new orange color */
258
- box-shadow: 0 0 15px rgba(249, 115, 22, 0.5) !important;
259
- transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
260
- position: relative;
261
- overflow: hidden;
262
- }
263
- /* Login button styling with glow effect using dark blue and violet colors */
264
- #login-button {
265
- background: linear-gradient(135deg, #1a237e, #6a1b9a) !important; /* Dark Blue to Violet */
266
- color: white !important;
267
- font-weight: 700 !important;
268
- border: none !important;
269
- border-radius: 12px !important;
270
- box-shadow: 0 0 15px rgba(106, 27, 154, 0.6) !important; /* Cool violet glow */
271
- transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
272
- position: relative;
273
- overflow: hidden;
274
- animation: glow 1.5s ease-in-out infinite alternate;
275
- max-width: 350px !important;
276
- margin: 0 auto !important;
277
- }
278
- #login-button::before {
279
- content: "🔑 ";
280
- display: inline-block !important;
281
- vertical-align: middle !important;
282
- margin-right: 5px !important;
283
- line-height: normal !important;
284
- }
285
- #login-button:hover {
286
- transform: translateY(-3px) scale(1.03) !important;
287
- box-shadow: 0 10px 25px rgba(26, 35, 126, 0.7) !important; /* Deeper blue glow */
288
- }
289
- #login-button::after {
290
- content: "";
291
- position: absolute;
292
- top: 0;
293
- left: -100%;
294
- width: 100%;
295
- height: 100%;
296
- background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.25), transparent);
297
- transition: 0.5s;
298
- }
299
- #login-button:hover::after {
300
- left: 100%;
301
- }
302
 
 
 
 
 
 
 
303
  """
304
 
305
- with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
306
- gr.Markdown("**You must be logged in to use this space**")
307
- gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
308
- gr.HTML(TITLE_IMAGE)
309
- gr.HTML(TITLE)
310
- gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- with gr.Tabs():
313
- with gr.Tab("Export Model"):
314
- with gr.Group():
315
- with gr.Row():
316
- pr_destinations_checkbox = gr.CheckboxGroup(
317
- choices=PR_DESTINATION_CHOICES,
318
- label="Export Destination",
319
- value=[DEST_NEW_NEURON_REPO],
320
- info="Select one or more destinations for the compiled model."
321
- )
322
- custom_repo_id_textbox = gr.Textbox(
323
- label="Custom Repository ID",
324
- placeholder="e.g., your-username/your-repo-name",
325
- visible=False,
326
- interactive=True
327
- )
328
- custom_cache_repo_textbox = gr.Textbox(
329
- label="Custom Cache Repository",
330
- placeholder="e.g., your-org/your-cache-repo",
331
- value=DEFAULT_CACHE_REPO,
332
- info=f"Repository to store and fetch from compilation cache artifacts (default: {DEFAULT_CACHE_REPO}) ",
333
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  )
335
- with gr.Row():
336
- model_type = gr.Radio(
337
- choices=["transformers", "diffusers"],
338
- value="transformers",
339
- label="Model Type",
340
- info="Choose the type of model you want to export"
341
- )
342
- with gr.Row():
343
- input_model = HuggingfaceHubSearch(
344
- label="Hub model ID",
345
- placeholder="Search for a model on the Hub...",
346
- search_type="model",
347
- )
348
- pipeline_dropdown = gr.Dropdown(
349
- choices=ALL_DIFFUSION_PIPELINES,
350
- value="stable-diffusion",
351
- label="Pipeline Type",
352
- visible=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  )
354
- task_dropdown = gr.Dropdown(
355
- choices=ALL_TRANSFORMER_TASKS,
356
- value="auto",
357
- label="Task (auto can infer from model)",
 
 
 
 
 
 
 
 
 
358
  )
 
 
 
 
 
 
 
 
359
 
360
- btn = gr.Button("Export to Neuron", size="lg", variant="primary")
361
-
362
- log_box = gr.Textbox(label="Logs", lines=20, interactive=False, show_copy_button=True)
363
-
364
- # Event Handlers
365
- model_type.change(
366
- fn=update_pipeline_and_task_dropdowns,
367
- inputs=[model_type],
368
- outputs=[pipeline_dropdown, task_dropdown]
369
- )
370
-
371
- pipeline_dropdown.change(
372
- fn=update_task_dropdown_for_pipeline,
373
- inputs=[pipeline_dropdown],
374
- outputs=[task_dropdown]
375
- )
376
-
377
- pr_destinations_checkbox.change(
378
- fn=toggle_custom_repo_box,
379
- inputs=pr_destinations_checkbox,
380
- outputs=custom_repo_id_textbox
381
- )
382
-
383
- btn.click(
384
- fn=neuron_export,
385
- inputs=[
386
- input_model,
387
- model_type,
388
- pipeline_dropdown,
389
- task_dropdown,
390
- pr_destinations_checkbox,
391
- custom_repo_id_textbox,
392
- custom_cache_repo_textbox
393
- ],
394
- outputs=log_box,
395
- )
396
-
397
- with gr.Tab("Supported Architectures"):
398
- gr.HTML(f"""
399
- <div style="margin-bottom: 20px;">
400
- <h3>🎨 Task Categories Legend</h3>
401
- <div class="task-tags">
402
- {create_task_tag("Feature Extraction")}
403
- {create_task_tag("NLP")}
404
- {create_task_tag("Text Generation")}
405
- {create_task_tag("Audio")}
406
- {create_task_tag("Vision")}
407
- {create_task_tag("Multimodal")}
408
- {create_task_tag("Similarity")}
409
- </div>
410
- </div>
411
- """)
412
-
413
- gr.HTML(f"""
414
- <h2>🤗 Transformers</h2>
415
- <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
416
- <colgroup>
417
- <col style="width: 30%;">
418
- <col style="width: 70%;">
419
- </colgroup>
420
- <thead>
421
- <tr style="background-color: var(--background-fill-secondary);">
422
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Architecture</th>
423
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Supported Tasks</th>
424
- </tr>
425
- </thead>
426
- <tbody>
427
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ALBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
428
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">AST</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, audio-classification")}</td></tr>
429
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">BERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
430
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">BLOOM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
431
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Beit</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
432
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CamemBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
433
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CLIP</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
434
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ConvBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
435
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ConvNext</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
436
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ConvNextV2</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
437
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CvT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
438
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">DeBERTa (INF2 only)</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
439
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">DeBERTa-v2 (INF2 only)</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
440
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Deit</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
441
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">DistilBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
442
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">DonutSwin</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction")}</td></tr>
443
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Dpt</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction")}</td></tr>
444
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ELECTRA</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
445
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ESM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}</td></tr>
446
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">FlauBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
447
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">GPT2</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
448
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Hubert</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}</td></tr>
449
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Levit</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
450
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Llama, Llama 2, Llama 3</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
451
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Mistral</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
452
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Mixtral</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
453
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">MobileBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
454
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">MobileNetV2</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}</td></tr>
455
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">MobileViT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification, semantic-segmentation")}</td></tr>
456
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ModernBERT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, text-classification, token-classification")}</td></tr>
457
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">MPNet</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
458
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">OPT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-generation")}</td></tr>
459
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Phi</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, text-classification, token-classification")}</td></tr>
460
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">RoBERTa</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
461
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">RoFormer</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
462
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Swin</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
463
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">T5</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text2text-generation")}</td></tr>
464
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">UniSpeech</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification")}</td></tr>
465
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">UniSpeech-SAT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}</td></tr>
466
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">ViT</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, image-classification")}</td></tr>
467
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Wav2Vec2</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}</td></tr>
468
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">WavLM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector")}</td></tr>
469
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Whisper</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("automatic-speech-recognition")}</td></tr>
470
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">XLM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
471
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">XLM-RoBERTa</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification")}</td></tr>
472
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Yolos</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, object-detection")}</td></tr>
473
- </tbody>
474
- </table>
475
- <h2>🧨 Diffusers</h2>
476
- <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
477
- <colgroup>
478
- <col style="width: 30%;">
479
- <col style="width: 70%;">
480
- </colgroup>
481
- <thead>
482
- <tr style="background-color: var(--background-fill-secondary);">
483
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Architecture</th>
484
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Supported Tasks</th>
485
- </tr>
486
- </thead>
487
- <tbody>
488
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Stable Diffusion</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image, image-to-image, inpaint")}</td></tr>
489
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Stable Diffusion XL Base</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image, image-to-image, inpaint")}</td></tr>
490
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Stable Diffusion XL Refiner</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("image-to-image, inpaint")}</td></tr>
491
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">SDXL Turbo</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image, image-to-image, inpaint")}</td></tr>
492
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">LCM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
493
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-α</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
494
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-Σ</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
495
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Flux</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
496
-
497
- </tbody>
498
- </table>
499
- <h2>🤖 Sentence Transformers</h2>
500
- <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
501
- <colgroup>
502
- <col style="width: 30%;">
503
- <col style="width: 70%;">
504
- </colgroup>
505
- <thead>
506
- <tr style="background-color: var(--background-fill-secondary);">
507
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Architecture</th>
508
- <th style="border: 1px solid var(--border-color-primary); padding: 12px; text-align: left;">Supported Tasks</th>
509
- </tr>
510
- </thead>
511
- <tbody>
512
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Transformer</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, sentence-similarity")}</td></tr>
513
- <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CLIP</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, zero-shot-image-classification")}</td></tr>
514
- </tbody>
515
- </table>
516
- <div style="margin-top: 20px;">
517
- <p>💡 <strong>Note</strong>: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.</p>
518
- <p>For more details, check the <a href="https://huggingface.co/docs/optimum-neuron" target="_blank">Optimum Neuron documentation</a>.</p>
519
- </div>
520
- """)
521
-
522
- # Add spacing between tabs and content
523
- gr.Markdown("<br><br><br><br>")
524
 
525
- if __name__ == "__main__":
526
- demo.launch(debug=True)
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
+ from tempfile import TemporaryDirectory, NamedTemporaryFile
4
+ from typing import List, Union, Optional, Tuple, Dict, Any, Generator
5
+ from pathlib import Path
6
+ import torch
7
  import gradio as gr
8
+ from huggingface_hub import (
9
+ CommitOperationAdd,
10
+ HfApi,
11
+ ModelCard,
12
+ Discussion,
13
+ CommitInfo,
14
+ create_repo,
15
+ RepoUrl,
16
+ )
17
+ from huggingface_hub.file_download import repo_folder_name
18
+ from optimum.exporters.tasks import TasksManager
19
+ from optimum.exporters.neuron.model_configs import *
20
+ from optimum.exporters.neuron import build_stable_diffusion_components_mandatory_shapes
21
+ from optimum.exporters.neuron.model_configs import *
22
+ from optimum.exporters.neuron import get_submodels_and_neuron_configs, export_models
23
+ from optimum.neuron import (
24
+ NeuronModelForFeatureExtraction,
25
+ NeuronModelForSentenceTransformers,
26
+ NeuronModelForMaskedLM,
27
+ NeuronModelForQuestionAnswering,
28
+ NeuronModelForSequenceClassification,
29
+ NeuronModelForTokenClassification,
30
+ NeuronModelForMultipleChoice,
31
+ NeuronModelForImageClassification,
32
+ NeuronModelForSemanticSegmentation,
33
+ NeuronModelForObjectDetection,
34
+ NeuronModelForAudioClassification,
35
+ NeuronModelForAudioFrameClassification,
36
+ NeuronModelForCTC,
37
+ NeuronModelForXVector,
38
+ NeuronModelForCausalLM,
39
+ NeuronModelForSeq2SeqLM,
40
+ )
41
+
42
+ # Import diffusers pipelines
43
+ from diffusers import (
44
+ StableDiffusionPipeline,
45
+ StableDiffusionImg2ImgPipeline,
46
+ StableDiffusionInpaintPipeline,
47
+ StableDiffusionXLPipeline,
48
+ StableDiffusionXLImg2ImgPipeline,
49
+ StableDiffusionXLInpaintPipeline,
50
+ LatentConsistencyModelPipeline,
51
+ PixArtAlphaPipeline,
52
+ PixArtSigmaPipeline,
53
+ FluxPipeline,
54
+ FluxInpaintPipeline,
55
+ FluxImg2ImgPipeline,
56
+ )
57
+
58
+ from optimum.neuron.cache import synchronize_hub_cache
59
+ from synchronizer import synchronize_hub_cache_with_pr
60
 
61
+ SPACES_URL = "https://huggingface.co/spaces/optimum/neuron-export"
62
+ CUSTOM_CACHE_REPO = os.getenv("CUSTOM_CACHE_REPO")
63
+ HF_TOKEN = os.environ.get("HF_TOKEN")
64
+
65
+ # Task to NeuronModel mapping for transformers
66
+ TASK_TO_MODEL_CLASS = {
67
+ "feature-extraction": NeuronModelForFeatureExtraction,
68
+ "sentence-transformers": NeuronModelForSentenceTransformers,
69
+ "fill-mask": NeuronModelForMaskedLM,
70
+ "question-answering": NeuronModelForQuestionAnswering,
71
+ "text-classification": NeuronModelForSequenceClassification,
72
+ "token-classification": NeuronModelForTokenClassification,
73
+ "multiple-choice": NeuronModelForMultipleChoice,
74
+ "image-classification": NeuronModelForImageClassification,
75
+ "semantic-segmentation": NeuronModelForSemanticSegmentation,
76
+ "object-detection": NeuronModelForObjectDetection,
77
+ "audio-classification": NeuronModelForAudioClassification,
78
+ "audio-frame-classification": NeuronModelForAudioFrameClassification,
79
+ "automatic-speech-recognition": NeuronModelForCTC,
80
+ "audio-xvector": NeuronModelForXVector,
81
+ "text-generation": NeuronModelForCausalLM,
82
+ "text2text-generation": NeuronModelForSeq2SeqLM,
83
  }
84
 
85
+ # Diffusion pipeline mapping with their corresponding diffusers classes and supported tasks
86
+ DIFFUSION_PIPELINE_MAPPING = {
87
+ "stable-diffusion": {
88
+ "class": StableDiffusionPipeline,
89
+ "tasks": ["text-to-image"],
90
+ "default_task": "text-to-image"
91
+ },
92
+ "stable-diffusion-img2img": {
93
+ "class": StableDiffusionImg2ImgPipeline,
94
+ "tasks": ["image-to-image"],
95
+ "default_task": "image-to-image"
96
+ },
97
+ "stable-diffusion-inpaint": {
98
+ "class": StableDiffusionInpaintPipeline,
99
+ "tasks": ["inpaint"],
100
+ "default_task": "inpaint"
101
+ },
102
+ "stable-diffusion-xl": {
103
+ "class": StableDiffusionXLPipeline,
104
+ "tasks": ["text-to-image"],
105
+ "default_task": "text-to-image"
106
+ },
107
+ "stable-diffusion-xl-img2img": {
108
+ "class": StableDiffusionXLImg2ImgPipeline,
109
+ "tasks": ["image-to-image"],
110
+ "default_task": "image-to-image"
111
+ },
112
+ "stable-diffusion-xl-inpaint": {
113
+ "class": StableDiffusionXLInpaintPipeline,
114
+ "tasks": ["inpaint"],
115
+ "default_task": "inpaint"
116
+ },
117
+ "lcm": {
118
+ "class": LatentConsistencyModelPipeline,
119
+ "tasks": ["text-to-image"],
120
+ "default_task": "text-to-image"
121
+ },
122
+ "pixart-alpha": {
123
+ "class": PixArtAlphaPipeline,
124
+ "tasks": ["text-to-image"],
125
+ "default_task": "text-to-image"
126
+ },
127
+ "pixart-sigma": {
128
+ "class": PixArtSigmaPipeline,
129
+ "tasks": ["text-to-image"],
130
+ "default_task": "text-to-image"
131
+ },
132
+ "flux": {
133
+ "class": FluxPipeline,
134
+ "tasks": ["text-to-image"],
135
+ "default_task": "text-to-image"
136
+ },
137
+ "flux-inpaint": {
138
+ "class": FluxInpaintPipeline,
139
+ "tasks": ["inpaint"],
140
+ "default_task": "inpaint"
141
+ },
142
+ "flux-img2img": {
143
+ "class": FluxImg2ImgPipeline,
144
+ "tasks": ["image-to-image"],
145
+ "default_task": "image-to-image"
146
+ },
147
  }
148
 
149
+ def get_default_inputs(task_or_pipeline: str, pipeline_name: str = None) -> Dict[str, int]:
150
+ """Get default input shapes based on task type or diffusion pipeline type."""
151
+ if task_or_pipeline in ["feature-extraction", "sentence-transformers", "fill-mask", "question-answering", "text-classification", "token-classification","text-generation"]:
152
+ return {"batch_size": 1, "sequence_length": 128}
153
+ elif task_or_pipeline == "multiple-choice":
154
+ return {"batch_size": 1, "num_choices": 4, "sequence_length": 128}
155
+ elif task_or_pipeline == "text2text-generation":
156
+ return {"batch_size": 1, "sequence_length": 128, "num_beams":4}
157
+ elif task_or_pipeline in ["image-classification", "semantic-segmentation", "object-detection"]:
158
+ return {"batch_size": 1, "num_channels": 3, "height": 224, "width": 224}
159
+ elif task_or_pipeline in ["audio-classification", "audio-frame-classification", "audio-xvector"]:
160
+ return {"batch_size": 1, "audio_sequence_length": 16000}
161
+ elif pipeline_name and pipeline_name in DIFFUSION_PIPELINE_MAPPING:
162
+ # For diffusion models, use appropriate sizes based on pipeline
163
+ if "xl" in pipeline_name.lower():
164
+ return {"batch_size": 1, "height": 1024, "width": 1024, "num_images_per_prompt": 1}
165
+ else:
166
+ return {"batch_size": 1, "height": 512, "width": 512, "num_images_per_prompt": 1}
 
 
 
 
 
 
 
 
 
 
167
  else:
168
+ # Default to text-based shapes
169
+ return {"batch_size": 1, "sequence_length": 128}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ def find_neuron_cache_artifacts(cache_base_dir: str = "/var/tmp/neuron-compile-cache") -> Optional[str]:
172
+ """
173
+ Find the most recently created Neuron cache artifacts directory.
174
+ Returns the path to the MODULE directory containing the compiled artifacts.
175
+ """
176
+ if not os.path.exists(cache_base_dir):
177
+ return None
178
+
179
+ # Find all MODULE directories
180
+ module_dirs = []
181
+ for root, dirs, files in os.walk(cache_base_dir):
182
+ for d in dirs:
183
+ if d.startswith("MODULE_"):
184
+ full_path = os.path.join(root, d)
185
+ # Check if it contains the expected files (for transformers)
186
+ if os.path.exists(os.path.join(full_path, "model.neuron")):
187
+ module_dirs.append(full_path)
188
+
189
+ if not module_dirs:
190
+ return None
191
+
192
+ # Return the most recently modified directory
193
+ return max(module_dirs, key=os.path.getmtime)
194
+
195
+ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
196
+ try:
197
+ discussions = api.get_repo_discussions(repo_id=model_id)
198
+ except Exception:
199
+ return None
200
+ for discussion in discussions:
201
+ if (
202
+ discussion.status == "open"
203
+ and discussion.is_pull_request
204
+ and discussion.title == pr_title
205
+ ):
206
+ return discussion
207
+ return None
208
+
209
+ def export_diffusion_model(model_id: str, pipeline_name: str, task: str, folder: str, token: str) -> Generator:
210
+ """Export diffusion model using optimum.exporters.neuron"""
211
+
212
+ yield f"📦 Exporting diffusion model `{model_id}` with pipeline `{pipeline_name}` for task `{task}`..."
213
+
214
+ if pipeline_name not in DIFFUSION_PIPELINE_MAPPING:
215
+ supported = list(DIFFUSION_PIPELINE_MAPPING.keys())
216
+ raise Exception(f"❌ Unsupported pipeline: {pipeline_name}. Supported: {supported}")
217
+
218
+ pipeline_config = DIFFUSION_PIPELINE_MAPPING[pipeline_name]
219
+ pipeline_class = pipeline_config["class"]
220
+
221
+ # Get default inputs
222
+ inputs = get_default_inputs(task, pipeline_name)
223
+ yield f"🔧 Using default inputs: {inputs}"
224
+
225
+ try:
226
+ # Load the pipeline
227
+ yield "📥 Loading diffusion pipeline..."
228
+ model = pipeline_class.from_pretrained(model_id, token=token)
229
+
230
+ # Build input shapes for compilation
231
+ input_shapes = build_stable_diffusion_components_mandatory_shapes(**inputs)
232
+
233
+ # Compiler arguments
234
+ compiler_kwargs = {
235
+ "auto_cast": "matmul",
236
+ "auto_cast_type": "bf16",
237
+ }
238
+
239
+ yield "🔨 Starting compilation process..."
240
+
241
+ # Get submodels and neuron configs
242
+ models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
243
+ model=model,
244
+ input_shapes=input_shapes,
245
+ task=task,
246
+ library_name="diffusers",
247
+ output=Path(folder),
248
+ model_name_or_path=model_id,
249
  )
250
+
251
+ # Export models
252
+ _, neuron_outputs = export_models(
253
+ models_and_neuron_configs=models_and_neuron_configs,
254
+ task=task,
255
+ output_dir=Path(folder),
256
+ output_file_names=output_model_names,
257
+ compiler_kwargs=compiler_kwargs,
258
+ )
259
+
260
+ yield f"✅ Diffusion model export completed. Files saved to {folder}"
261
+
262
+ except Exception as e:
263
+ yield f"❌ Export failed with error: {e}"
264
+ raise
265
 
266
+ def export_transformer_model(model_id: str, task: str, folder: str, token: str) -> Generator:
267
+ """Export transformer model using optimum.neuron"""
268
+
269
+ yield f"📦 Exporting transformer model `{model_id}` for task `{task}`..."
270
+
271
+ model_class = TASK_TO_MODEL_CLASS.get(task)
272
+ if model_class is None:
273
+ supported = list(TASK_TO_MODEL_CLASS.keys())
274
+ raise Exception(f"❌ Unsupported task: {task}. Supported: {supported}")
275
 
276
+ inputs = get_default_inputs(task)
277
+ compiler_configs = {"auto_cast": "matmul", "auto_cast_type": "bf16", "instance_type": "inf2"}
278
+ yield f"🔧 Using default inputs: {inputs}"
279
 
280
+ # Clear any old cache artifacts before export
281
+ cache_base_dir = "/var/tmp/neuron-compile-cache"
282
+
283
+ try:
284
+ # Trigger the export/compilation
285
+ model = model_class.from_pretrained(
286
+ model_id,
287
+ export=True,
288
+ tensor_parallel_size=4,
289
+ token=token,
290
+ **compiler_configs,
291
+ **inputs,
292
+ )
293
 
294
+ yield "✅ Export/compilation completed successfully."
295
+
296
+ # Find the newly created cache artifacts
297
+ yield "🔍 Locating compiled artifacts in Neuron cache..."
298
+ cache_artifact_dir = find_neuron_cache_artifacts(cache_base_dir)
299
+
300
+ if not cache_artifact_dir:
301
+ raise Exception("❌ Could not find compiled artifacts in Neuron cache")
302
 
303
+ yield f"📂 Found artifacts at: {cache_artifact_dir}"
304
+
305
+ # Copy artifacts from cache to our target folder
306
+ yield f"📋 Copying artifacts to export folder..."
307
+ if os.path.exists(folder):
308
+ shutil.rmtree(folder)
309
+ shutil.copytree(cache_artifact_dir, folder)
310
+
311
+ yield f"✅ Artifacts successfully copied to {folder}"
312
+
313
+ except Exception as e:
314
+ yield f"❌ Export failed with error: {e}"
315
+ raise
316
+
317
+ def export_and_git_add(model_id: str, task_or_pipeline: str, model_type: str, folder: str, token: str, pipeline_name: str = None) -> Any:
318
+
319
+ operations = []
320
+
321
  try:
322
+ if model_type == "diffusers":
323
+ # For diffusion models, use the new export function
324
+ export_gen = export_diffusion_model(model_id, pipeline_name, task_or_pipeline, folder, token)
325
+ for message in export_gen:
326
+ yield message
327
+ else:
328
+ # For transformer models, use the existing function
329
+ export_gen = export_transformer_model(model_id, task_or_pipeline, folder, token)
330
+ for message in export_gen:
331
+ yield message
332
+
333
+ # Create operations from exported files
334
+ for root, _, files in os.walk(folder):
335
+ for filename in files:
336
+ file_path = os.path.join(root, filename)
337
+ repo_path = os.path.relpath(file_path, folder)
338
+ operations.append(CommitOperationAdd(path_in_repo=repo_path, path_or_fileobj=file_path))
339
 
340
+ # Update model card
341
+ try:
342
+ card = ModelCard.load(model_id, token=token)
343
+ if not hasattr(card.data, "tags") or card.data.tags is None:
344
+ card.data.tags = []
345
+ if "neuron" not in card.data.tags:
346
+ card.data.tags.append("neuron")
347
+
348
+ readme_path = os.path.join(folder, "README.md")
349
+ card.save(readme_path)
350
+
351
+ # Check if README.md is already in operations, if so update, else add
352
+ readme_op = next((op for op in operations if op.path_in_repo == "README.md"), None)
353
+ if readme_op:
354
+ readme_op.path_or_fileobj = readme_path
355
+ else:
356
+ operations.append(CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=readme_path))
357
+
358
+ except Exception as e:
359
+ yield f"⚠️ Warning: Could not update model card: {e}"
360
+
361
+ except Exception as e:
362
+ yield f"❌ Export failed with error: {e}"
363
+ raise
364
+
365
+ yield ("__RETURN__", operations)
366
+
367
+ def generate_neuron_repo_name(api, original_model_id: str, task_or_pipeline: str, token:str) -> str:
368
+ """Generate a name for the Neuron-optimized repository."""
369
+ requesting_user = api.whoami(token=token)["name"]
370
+ base_name = original_model_id.replace('/', '-')
371
+ return f"{requesting_user}/{base_name}-neuron"
372
+
373
+ def create_neuron_repo_and_upload(
374
+ operations: List[CommitOperationAdd],
375
+ original_model_id: str,
376
+ model_type: str,
377
+ task_or_pipeline: str,
378
+ requesting_user: str,
379
+ token: str,
380
+ pipeline_name: str = None,
381
+ ) -> Generator[Union[str, RepoUrl], None, None]:
382
+ """
383
+ Creates a new repository with Neuron files and uploads them.
384
+ """
385
+ api = HfApi(token=token)
386
+
387
+ if task_or_pipeline == "auto" and model_type == "transformers":
388
  try:
389
+ task_or_pipeline = TasksManager.infer_task_from_model(original_model_id, token=token)
390
  except Exception as e:
391
+ raise Exception(f"❌ Could not infer task for model {original_model_id}: {e}")
392
+
393
+ # Generate repository name
394
+ neuron_repo_name = generate_neuron_repo_name(api, original_model_id, task_or_pipeline, token)
395
+
396
+ try:
397
+ # Create the repository
398
+ repo_url = create_repo(
399
+ repo_id=neuron_repo_name,
400
+ token=token,
401
+ repo_type="model",
402
+ private=False,
403
+ exist_ok=True,
404
+ )
405
+
406
+ # Get the appropriate class name for the Python example
407
+ if model_type == "transformers":
408
+ model_class = TASK_TO_MODEL_CLASS.get(task_or_pipeline)
409
+ model_class_name = model_class.__name__ if model_class else "NeuronModel"
410
+ usage_example = f"""```python
411
+ from optimum.neuron import {model_class_name}
412
+
413
+ model = {model_class_name}.from_pretrained("{neuron_repo_name}")
414
+ ```"""
415
+ else:
416
+ # For diffusion models
417
+ pipeline_config = DIFFUSION_PIPELINE_MAPPING.get(pipeline_name, {})
418
+ pipeline_class = pipeline_config.get("class")
419
+ if pipeline_class:
420
+ class_name = pipeline_class.__name__.replace("Pipeline", "")
421
+ model_class_name = f"Neuron{class_name}Pipeline"
422
+ else:
423
+ model_class_name = "NeuronStableDiffusionPipeline"
424
+
425
+ usage_example = f"""```python
426
+ from optimum.neuron import {model_class_name}
427
 
428
+ pipeline = {model_class_name}.from_pretrained("{neuron_repo_name}")
429
+ ```"""
430
 
431
+ # Create enhanced model card for the Neuron repo
432
+ neuron_readme_content = f"""---
433
+ tags:
434
+ - neuron
435
+ - optimized
436
+ - aws-neuron
437
+ - {task_or_pipeline}
438
+ base_model: {original_model_id}
439
+ ---
440
 
441
+ # Neuron-Optimized {original_model_id}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
+ This repository contains AWS Neuron-optimized files for [{original_model_id}](https://huggingface.co/{original_model_id}).
444
+
445
+ ## Model Details
446
 
447
+ - **Base Model**: [{original_model_id}](https://huggingface.co/{original_model_id})
448
+ - **Task**: {task_or_pipeline}
449
+ - **Optimization**: AWS Neuron compilation
450
+ - **Generated by**: [{requesting_user}](https://huggingface.co/{requesting_user})
451
+ - **Generated using**: [Optimum Neuron Compiler Space]({SPACES_URL})
452
+
453
+ ## Usage
454
+
455
+ This model has been optimized for AWS Neuron devices (Inferentia/Trainium). To use it:
456
+
457
+ {usage_example}
458
+
459
+ ## Performance
460
+
461
+ These files are pre-compiled for AWS Neuron devices and should provide improved inference performance compared to the original model when deployed on Inferentia or Trainium instances.
462
+
463
+ ## Original Model
464
+
465
+ For the original model, training details, and more information, please visit: [{original_model_id}](https://huggingface.co/{original_model_id})
466
  """
467
+
468
+ # Update the README in operations
469
+ readme_op = next((op for op in operations if op.path_in_repo == "README.md"), None)
470
+ if readme_op:
471
+ # Create a temporary file with the new content
472
+ with NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
473
+ f.write(neuron_readme_content)
474
+ readme_op.path_or_fileobj = f.name
475
+ else:
476
+ # Add new README operation
477
+ with NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
478
+ f.write(neuron_readme_content)
479
+ operations.append(CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=f.name))
480
+
481
+ # Upload files to the new repository
482
+ commit_message = f"Add Neuron-optimized files for {original_model_id}"
483
+ commit_description = f"""
484
+ 🤖 Neuron Export Bot: Adding AWS Neuron-optimized model files.
485
+
486
+ Original model: [{original_model_id}](https://huggingface.co/{original_model_id})
487
+ Task: {task_or_pipeline}
488
+ Generated by: [{requesting_user}](https://huggingface.co/{requesting_user})
489
+ Generated using: [Optimum Neuron Compiler Space]({SPACES_URL})
490
 
491
+ These files have been pre-compiled for AWS Neuron devices (Inferentia/Trainium) and should provide improved inference performance.
 
 
 
 
 
492
  """
493
+
494
+ commit_info = api.create_commit(
495
+ repo_id=neuron_repo_name,
496
+ operations=operations,
497
+ commit_message=commit_message,
498
+ commit_description=commit_description,
499
+ token=token,
500
+ )
501
+ yield f"✅ Repository created: {repo_url}"
502
+
503
+ except Exception as e:
504
+ yield f"❌ Failed to create/upload to Neuron repository: {e}"
505
+ raise
506
+
507
+ def create_readme_pr_for_original_model(
508
+ original_model_id: str,
509
+ neuron_repo_name: str,
510
+ task_or_pipeline: str,
511
+ requesting_user: str,
512
+ token: str,
513
+ ) -> Generator[Union[str, CommitInfo], None, None]:
514
+ """
515
+ Creates a PR on the original model repository to add a link to the Neuron-optimized version.
516
+ """
517
+ api = HfApi(token=token)
518
 
519
+ yield f"📝 Creating PR to add Neuron repo link in {original_model_id}..."
 
 
520
 
521
+ try:
522
+ # Check if there's already an open PR
523
+ pr_title = "Add link to Neuron-optimized version"
524
+ existing_pr = previous_pr(api, original_model_id, pr_title)
525
+
526
+ if existing_pr:
527
+ yield f"⚠️ PR already exists: https://huggingface.co/{original_model_id}/discussions/{existing_pr.num}"
528
+ return
529
+
530
+ # Get the current README
531
+ try:
532
+ current_readme_path = api.hf_hub_download(
533
+ repo_id=original_model_id,
534
+ filename="README.md",
535
+ token=token,
536
+ )
537
+ with open(current_readme_path, 'r', encoding='utf-8') as f:
538
+ readme_content = f.read()
539
+ except Exception:
540
+ # If README doesn't exist, create a basic one
541
+ readme_content = f"# {original_model_id}\n\n"
542
 
543
+ # Add Neuron optimization section, separated by a horizontal rule
544
+ neuron_section = f"""
545
+ ---
546
+ ## 🚀 AWS Neuron Optimized Version Available
547
 
548
+ A Neuron-optimized version of this model is available for improved performance on AWS Inferentia/Trainium instances:
 
 
 
549
 
550
+ **[{neuron_repo_name}](https://huggingface.co/{neuron_repo_name})**
 
 
 
551
 
552
+ The Neuron-optimized version provides:
553
+ - Pre-compiled artifacts for faster loading
554
+ - Optimized performance on AWS Neuron devices
555
+ - Same model capabilities with improved inference speed
556
  """
557
 
558
+ # Append the Neuron section to the end of the README
559
+ updated_readme = readme_content.rstrip() + "\n" + neuron_section
560
+
561
+ # Create temporary file with updated README
562
+ with NamedTemporaryFile(mode='w', suffix='.md', delete=False, encoding="utf-8") as f:
563
+ f.write(updated_readme)
564
+ temp_readme_path = f.name
565
+
566
+ # Create the PR
567
+ operations = [CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=temp_readme_path)]
568
+
569
+ commit_description = f"""
570
+ 🤖 Neuron Export Bot: Adding link to Neuron-optimized version.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
+ A Neuron-optimized version of this model has been created at [{neuron_repo_name}](https://huggingface.co/{neuron_repo_name}).
573
+
574
+ The optimized version provides improved performance on AWS Inferentia/Trainium instances with pre-compiled artifacts.
575
+
576
+ Generated by: [{requesting_user}](https://huggingface.co/{requesting_user})
577
+ Generated using: [Optimum Neuron Compiler Space]({SPACES_URL})
578
  """
579
 
580
+ pr = api.create_commit(
581
+ repo_id=original_model_id,
582
+ operations=operations,
583
+ commit_message=pr_title,
584
+ commit_description=commit_description,
585
+ create_pr=True,
586
+ token=token,
587
+ )
588
+
589
+ yield f"✅ README PR created: https://huggingface.co/{original_model_id}/discussions/{pr.pr_num}"
590
+
591
+ # Clean up temporary file
592
+ os.unlink(temp_readme_path)
593
+
594
+ except Exception as e:
595
+ yield f"❌ Failed to create README PR: {e}"
596
+ raise
597
+
598
+ def upload_to_custom_repo(
599
+ operations: List[CommitOperationAdd],
600
+ custom_repo_id: str,
601
+ original_model_id: str,
602
+ requesting_user: str,
603
+ token: str,
604
+ ) -> Generator[Union[str, CommitInfo], None, None]:
605
+ """
606
+ Uploads neuron files to a custom repository and creates a PR.
607
+ """
608
+ api = HfApi(token=token)
609
 
610
+ try:
611
+ # Ensure the custom repo exists
612
+ api.repo_info(repo_id=custom_repo_id, repo_type="model")
613
+ except Exception as e:
614
+ yield f"❌ Could not access custom repository `{custom_repo_id}`. Please ensure it exists and you have write access. Error: {e}"
615
+ raise
616
+
617
+ pr_title = f"Add Neuron-optimized files for {original_model_id}"
618
+ commit_description = f"""
619
+ 🤖 Neuron Export Bot: On behalf of [{requesting_user}](https://huggingface.co/{requesting_user}), adding AWS Neuron-optimized model files for `{original_model_id}`.
620
+
621
+ These files were generated using the [Optimum Neuron Compiler Space](https://huggingface.co/spaces/optimum/neuron-export).
622
+ """
623
+
624
+ try:
625
+ custom_pr = api.create_commit(
626
+ repo_id=custom_repo_id,
627
+ operations=operations,
628
+ commit_message=pr_title,
629
+ commit_description=commit_description,
630
+ create_pr=True,
631
+ token=token,
632
+ )
633
+ yield f"✅ Custom PR created successfully: https://huggingface.co/{custom_repo_id}/discussions/{custom_pr.pr_num}"
634
+ yield custom_pr
635
+
636
+ except Exception as e:
637
+ yield f"❌ Failed to create PR in custom repository: {e}"
638
+ raise
639
+
640
+ def convert(
641
+ api: "HfApi",
642
+ model_id: str,
643
+ task_or_pipeline: str,
644
+ model_type: str = "transformers",
645
+ token: str = None,
646
+ pr_options: Dict = None,
647
+ pipeline_name: str = None,
648
+ ) -> Generator[Tuple[str, Any], None, None]:
649
+ if pr_options is None:
650
+ pr_options = {}
651
+
652
+ info = api.model_info(model_id, token=token)
653
+ filenames = {s.rfilename for s in info.siblings}
654
+ requesting_user = api.whoami(token=token)["name"]
655
+
656
+ if not any(pr_options.values()):
657
+ yield "1", "⚠️ No option selected. Please choose at least one option."
658
+ return
659
+
660
+ if pr_options.get("create_custom_pr") and not pr_options.get("custom_repo_id"):
661
+ yield "1", "⚠️ Custom PR selected but no repository ID was provided."
662
+ return
663
+
664
+ yield "0", f"🚀 Starting export process with options: {pr_options}..."
665
+
666
+ if task_or_pipeline == "auto" and model_type == "transformers":
667
+ try:
668
+ task_or_pipeline = TasksManager.infer_task_from_model(model_id, token=token)
669
+ except Exception as e:
670
+ raise Exception(f"❌ Could not infer task for model {model_id}: {e}")
671
+
672
+ with TemporaryDirectory() as temp_dir:
673
+ export_folder = os.path.join(temp_dir, "export")
674
+ cache_mirror_dir = os.path.join(temp_dir, "cache_mirror")
675
+ os.makedirs(export_folder, exist_ok=True)
676
+ os.makedirs(cache_mirror_dir, exist_ok=True)
677
+
678
+ result_info = {}
679
+
680
+ try:
681
+ # --- Export Logic ---
682
+ export_gen = export_and_git_add(model_id, task_or_pipeline, model_type, export_folder, token=token, pipeline_name=pipeline_name)
683
+ operations = None
684
+ for message in export_gen:
685
+ if isinstance(message, tuple) and message[0] == "__RETURN__":
686
+ operations = message[1]
687
+ break
688
+ else:
689
+ yield "0", message
690
+
691
+ if not operations:
692
+ raise Exception("Export process did not produce any files to commit.")
693
+
694
+ # --- Cache Handling ---
695
+ if pr_options.get("create_cache_pr"):
696
+ yield "0", f"📤 Creating a Pull Request for the cache repository ..."
697
+
698
+ try:
699
+ pr_title = f"Add Neuron cache artifacts for {model_id}"
700
+ custom_pr_description = f"""
701
+ 🤖 **Neuron Cache Sync Bot**
702
+
703
+ This PR adds newly compiled cache artifacts for the model:
704
+ - **Original Model ID:** `{model_id}`
705
+ - **Task:** `{task_or_pipeline}`
706
+
707
+ These files were generated to accelerate model loading on AWS Neuron devices.
708
+ """
709
+
710
+ # 1. Create an instance of your generator
711
+ commit_message = f"Synchronizing local compiler cache of {model_id}"
712
+ inputs = get_default_inputs(task_or_pipeline, pipeline_name)
713
+ commit_description = f"""
714
+ 🤖 **Neuron Cache Sync Bot**
715
+
716
+ This commit adds newly compiled cache artifacts for the model:
717
+ - **Original Model ID:** `{model_id}`
718
+ - **Task:** `{task_or_pipeline}`
719
+ - **Compilation inputs:** {inputs}
720
+ - **Generated by:** [{requesting_user}](https://huggingface.co/{requesting_user})
721
+ - **Generated using:** [Optimum Neuron Model Exporter]({SPACES_URL})
722
+
723
+ These files were generated to accelerate model loading on AWS Neuron devices.
724
+ """
725
+
726
+ pr_generator = synchronize_hub_cache_with_pr(
727
+ cache_repo_id=CUSTOM_CACHE_REPO,
728
+ commit_message=commit_message,
729
+ commit_description=commit_description,
730
+ token=token,
731
  )
732
+
733
+ pr_url = None
734
+ # 2. Loop to process yielded status messages and capture the final return value
735
+ while True:
736
+ try:
737
+ # Get the next status message from your generator
738
+ status_message = next(pr_generator)
739
+ yield "0", status_message
740
+ except StopIteration as e:
741
+ # The generator is finished. Its `return` value is in e.value.
742
+ pr_url = e.value
743
+ break # Exit the loop
744
+
745
+ # 3. Process the final result
746
+ if pr_url:
747
+ yield "0", f"✅ Successfully captured PR URL."
748
+ result_info["cache_pr"] = pr_url
749
+ else:
750
+ yield "0", "⚠️ PR process finished, but no URL was returned. This may be expected in non-blocking mode."
751
+
752
+ except Exception as e:
753
+ yield "0", f"❌ Failed to create cache PR: {e}"
754
+
755
+ # --- New Repository Creation (Replaces Model PR) ---
756
+ if pr_options.get("create_neuron_repo"):
757
+ yield "0", "🏗️ Creating new Neuron-optimized repository..."
758
+ neuron_repo_url = None
759
+ # Generate the repo name first so we can use it consistently
760
+ neuron_repo_name = generate_neuron_repo_name(api, model_id, task_or_pipeline, token)
761
+
762
+ repo_creation_gen = create_neuron_repo_and_upload(
763
+ operations, model_id, model_type, task_or_pipeline, requesting_user, token, pipeline_name
764
  )
765
+
766
+ for msg in repo_creation_gen:
767
+ if isinstance(msg, str):
768
+ yield "0", msg
769
+ else:
770
+ neuron_repo_url = msg
771
+
772
+ result_info["neuron_repo"] = f"https://huggingface.co/{neuron_repo_name}"
773
+
774
+ # Automatically create a PR on the original model to add a link
775
+ readme_pr = None
776
+ readme_pr_gen = create_readme_pr_for_original_model(
777
+ model_id, neuron_repo_name, task_or_pipeline, requesting_user, token
778
  )
779
+ for msg in readme_pr_gen:
780
+ if isinstance(msg, str):
781
+ yield "0", msg
782
+ else:
783
+ readme_pr = msg
784
+
785
+ if readme_pr:
786
+ result_info["readme_pr"] = f"https://huggingface.co/{model_id}/discussions/{readme_pr.pr_num}"
787
 
788
+ # --- Custom Repository PR ---
789
+ if pr_options.get("create_custom_pr"):
790
+ custom_repo_id = pr_options["custom_repo_id"]
791
+ yield "0", f"📤 Creating PR in custom repository: {custom_repo_id}..."
792
+ custom_pr = None
793
+ custom_upload_gen = upload_to_custom_repo(operations, custom_repo_id, model_id, requesting_user, token)
794
+ for msg in custom_upload_gen:
795
+ if isinstance(msg, str):
796
+ yield "0", msg
797
+ else:
798
+ custom_pr = msg
799
+ if custom_pr:
800
+ result_info["custom_pr"] = f"https://huggingface.co/{custom_repo_id}/discussions/{custom_pr.pr_num}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
 
802
+ yield "0", result_info
803
+
804
+ except Exception as e:
805
+ yield "1", f"❌ Conversion failed with a critical error: {e}"
806
+ # Re-raise the exception to be caught by the outer try-except in the Gradio app if needed
807
+ raise