Felladrin commited on
Commit
5fff74f
·
1 Parent(s): b6efc2b

Add docstrings and reorganize the code

Browse files
Files changed (1) hide show
  1. app.py +302 -76
app.py CHANGED
@@ -1,16 +1,26 @@
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
 
 
3
  import subprocess
4
  import sys
5
- import shutil
6
- import re
7
  from pathlib import Path
8
  from typing import List, Optional, Tuple
9
- from dataclasses import dataclass
10
 
11
  import streamlit as st
12
- from huggingface_hub import HfApi, whoami, model_info, hf_hub_download
13
  import yaml
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
@@ -18,7 +28,15 @@ logger = logging.getLogger(__name__)
18
 
19
  @dataclass
20
  class Config:
21
- """Application configuration."""
 
 
 
 
 
 
 
 
22
 
23
  hf_token: str
24
  hf_username: str
@@ -28,10 +46,22 @@ class Config:
28
 
29
  @classmethod
30
  def from_env(cls) -> "Config":
31
- """Create config from environment variables and secrets."""
 
 
 
 
 
 
 
 
 
 
 
32
  system_token = os.getenv("HF_TOKEN")
33
  user_token = st.session_state.get("user_hf_token")
34
 
 
35
  if user_token:
36
  hf_username = whoami(token=user_token)["name"]
37
  else:
@@ -39,6 +69,7 @@ class Config:
39
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
40
  )
41
 
 
42
  hf_token = user_token or system_token
43
 
44
  if not hf_token:
@@ -54,53 +85,118 @@ class Config:
54
 
55
 
56
  class ModelConverter:
57
- """Handles model conversion and upload operations."""
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def __init__(self, config: Config):
 
 
 
 
 
60
  self.config = config
61
  self.api = HfApi(token=config.hf_token)
62
 
 
 
 
 
63
  def _fetch_original_readme(self, repo_id: str) -> str:
 
 
 
 
 
 
 
 
64
  try:
65
- path = hf_hub_download(
66
  repo_id=repo_id, filename="README.md", token=self.config.hf_token
67
  )
68
- with open(path, "r", encoding="utf-8", errors="ignore") as f:
69
  return f.read()
70
  except Exception:
 
71
  return ""
72
 
73
  def _strip_yaml_frontmatter(self, text: str) -> str:
 
 
 
 
 
 
 
 
 
 
74
  if not text:
75
  return ""
76
  if text.startswith("---"):
77
- m = re.match(r"^---[\s\S]*?\n---\s*\n", text)
78
- if m:
79
- return text[m.end() :]
80
  return text
81
 
82
  def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
83
- """Return (frontmatter_dict, body). If no frontmatter, returns ({}, text)."""
 
 
 
 
 
 
 
 
 
84
  if not text or not text.startswith("---"):
85
  return {}, text or ""
86
- m = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text)
87
- if not m:
 
 
88
  return {}, text
89
- fm_text = m.group(1)
90
- body = text[m.end() :]
 
 
 
91
  try:
92
- data = yaml.safe_load(fm_text)
93
- if not isinstance(data, dict):
94
- data = {}
95
  except Exception:
96
- data = {}
97
- return data, body
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- def _pipeline_docs_url(self, pipeline_tag: Optional[str]) -> Optional[str]:
100
- base = "https://huggingface.co/docs/transformers.js/api/pipelines"
101
  if not pipeline_tag:
102
- return base
103
- mapping = {
 
 
104
  "text-classification": "TextClassificationPipeline",
105
  "token-classification": "TokenClassificationPipeline",
106
  "question-answering": "QuestionAnsweringPipeline",
@@ -127,49 +223,85 @@ class ModelConverter:
127
  "image-to-image": "ImageToImagePipeline",
128
  "depth-estimation": "DepthEstimationPipeline",
129
  }
130
- cls = mapping.get(pipeline_tag)
131
- if not cls:
132
- return base
133
- return f"{base}#module_pipelines.{cls}"
134
 
135
- def _map_pipeline_to_task(self, pipeline_tag: Optional[str]) -> Optional[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  if not pipeline_tag:
137
  return None
138
- synonyms = {
 
 
139
  "vqa": "visual-question-answering",
140
  }
141
- return synonyms.get(pipeline_tag, pipeline_tag)
 
 
 
 
 
142
 
143
  def setup_repository(self) -> None:
144
- """Ensure the bundled transformers.js repository is present."""
 
 
 
 
145
  if not self.config.repo_path.exists():
146
  raise RuntimeError(
147
- f"Expected transformers.js repository at {self.config.repo_path} but it was not found."
 
148
  )
149
 
150
  def _run_conversion_subprocess(
151
- self, input_model_id: str, extra_args: List[str] = None
152
  ) -> subprocess.CompletedProcess:
153
- """Run the conversion subprocess with the given arguments."""
154
- cmd = [
 
 
 
 
 
 
 
 
 
155
  sys.executable,
156
  "-m",
157
  "scripts.convert",
158
- "--quantize",
159
  "--model_id",
160
  input_model_id,
161
  ]
162
 
163
  if extra_args:
164
- cmd.extend(extra_args)
165
 
 
166
  return subprocess.run(
167
- cmd,
168
  cwd=self.config.repo_path,
169
  capture_output=True,
170
  text=True,
171
  env={
172
  "HF_TOKEN": self.config.hf_token,
 
173
  "TRANSFORMERS_ATTENTION_IMPLEMENTATION": "eager",
174
  "PYTORCH_SDP_KERNEL": "math",
175
  },
@@ -178,34 +310,55 @@ class ModelConverter:
178
  def convert_model(
179
  self,
180
  input_model_id: str,
181
- trust_remote_code=False,
182
- output_attentions=False,
183
  ) -> Tuple[bool, Optional[str]]:
184
- """Convert the model to ONNX format."""
 
 
 
 
 
 
 
 
 
 
 
185
  try:
186
- extra_args: List[str] = []
 
 
187
  if trust_remote_code:
188
  if not self.config.is_using_user_token:
189
  raise Exception(
190
  "Trust Remote Code requires your own HuggingFace token."
191
  )
192
- extra_args.append("--trust_remote_code")
193
 
 
194
  if output_attentions:
195
- extra_args.append("--output_attentions")
196
 
 
 
197
  try:
198
  info = model_info(repo_id=input_model_id, token=self.config.hf_token)
199
- task = self._map_pipeline_to_task(getattr(info, "pipeline_tag", None))
 
200
  if task:
201
- extra_args.extend(["--task", task])
202
  except Exception:
 
 
203
  pass
204
 
 
205
  result = self._run_conversion_subprocess(
206
- input_model_id, extra_args=extra_args or None
207
  )
208
 
 
209
  if result.returncode != 0:
210
  return False, result.stderr
211
 
@@ -214,86 +367,148 @@ class ModelConverter:
214
  except Exception as e:
215
  return False, str(e)
216
 
 
 
 
 
217
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
218
- """Upload the converted model to Hugging Face."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  model_folder_path = self.config.repo_path / "models" / input_model_id
220
 
221
  try:
 
222
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
223
 
224
- readme_path = f"{model_folder_path}/README.md"
225
-
226
- with open(readme_path, "w") as file:
227
- file.write(self.generate_readme(input_model_id))
228
 
 
229
  self.api.upload_folder(
230
  folder_path=str(model_folder_path), repo_id=output_model_id
231
  )
232
- return None
 
 
233
  except Exception as e:
234
  return str(e)
235
  finally:
 
236
  shutil.rmtree(model_folder_path, ignore_errors=True)
237
 
238
- def generate_readme(self, imi: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  try:
240
- info = model_info(repo_id=imi, token=self.config.hf_token)
241
  pipeline_tag = getattr(info, "pipeline_tag", None)
242
  except Exception:
243
  pipeline_tag = None
244
 
245
- original_text = self._fetch_original_readme(imi)
 
246
  original_meta, original_body = self._extract_yaml_frontmatter(original_text)
247
  original_body = (
248
  original_body or self._strip_yaml_frontmatter(original_text)
249
  ).strip()
250
 
 
251
  merged_meta = {}
252
  if isinstance(original_meta, dict):
253
  merged_meta.update(original_meta)
254
  merged_meta["library_name"] = "transformers.js"
255
- merged_meta["base_model"] = [imi]
256
  if pipeline_tag is not None:
257
  merged_meta["pipeline_tag"] = pipeline_tag
258
 
259
- fm_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip()
260
- header = f"---\n{fm_yaml}\n---\n\n"
 
261
 
262
- parts: List[str] = []
263
- parts.append(header)
264
- parts.append(f"# {imi.split('/')[-1]} (ONNX)\n")
265
- parts.append(
266
- f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
 
 
 
 
 
 
267
  "It was automatically converted and uploaded using "
268
  "[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
269
  )
270
 
271
- docs_url = self._pipeline_docs_url(pipeline_tag)
 
272
  if docs_url:
273
- parts.append("\n## Usage with Transformers.js\n")
274
  if pipeline_tag:
275
- parts.append(
276
  f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
277
  )
278
  else:
279
- parts.append(f"See the pipelines documentation: {docs_url}")
280
 
 
281
  if original_body:
282
- parts.append("\n---\n")
283
- parts.append(original_body)
284
 
285
- return "\n\n".join(parts) + "\n"
286
 
287
 
288
  def main():
289
- """Main application entry point."""
 
 
 
 
 
 
 
290
  st.write("## Convert a Hugging Face model to ONNX")
291
 
292
  try:
 
293
  config = Config.from_env()
294
  converter = ModelConverter(config)
295
  converter.setup_repository()
296
 
 
297
  input_model_id = st.text_input(
298
  "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
299
  )
@@ -301,23 +516,29 @@ def main():
301
  if not input_model_id:
302
  return
303
 
 
304
  st.text_input(
305
  f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
306
  type="password",
307
  key="user_hf_token",
308
  )
 
 
309
  trust_remote_code = st.toggle("Optional: Trust Remote Code.")
310
  if trust_remote_code:
311
  st.warning(
312
  "This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
313
  )
314
 
 
315
  output_attentions = False
316
  if "whisper" in input_model_id.lower():
317
  output_attentions = st.toggle(
318
  "Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
319
  )
320
 
 
 
321
  if config.hf_username == input_model_id.split("/")[0]:
322
  same_repo = st.checkbox(
323
  "Upload the ONNX weights to the existing repository"
@@ -326,25 +547,29 @@ def main():
326
  same_repo = False
327
 
328
  model_name = input_model_id.split("/")[-1]
329
-
330
  output_model_id = f"{config.hf_username}/{model_name}"
331
 
 
332
  if not same_repo:
333
  output_model_id += "-ONNX"
334
 
335
  output_model_url = f"{config.hf_base_url}/{output_model_id}"
336
 
 
337
  if not same_repo and converter.api.repo_exists(output_model_id):
338
  st.write("This model has already been converted! 🎉")
339
  st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
340
  return
341
 
 
342
  st.write(f"URL where the model will be converted and uploaded to:")
343
  st.code(output_model_url, language="plaintext")
344
 
 
345
  if not st.button(label="Proceed", type="primary"):
346
  return
347
 
 
348
  with st.spinner("Converting model..."):
349
  success, stderr = converter.convert_model(
350
  input_model_id,
@@ -358,6 +583,7 @@ def main():
358
  st.success("Conversion successful!")
359
  st.code(stderr)
360
 
 
361
  with st.spinner("Uploading model..."):
362
  error = converter.upload_model(input_model_id, output_model_id)
363
  if error:
 
1
+ """Convert Hugging Face models to ONNX format.
2
+
3
+ This application provides a Streamlit interface for converting Hugging Face models
4
+ to ONNX format using the Transformers.js conversion scripts. It handles:
5
+ - Model conversion with optional trust_remote_code and output_attentions
6
+ - Automatic task inference with fallback support
7
+ - README generation with merged metadata from the original model
8
+ - Upload to Hugging Face Hub
9
+ """
10
+
11
  import logging
12
  import os
13
+ import re
14
+ import shutil
15
  import subprocess
16
  import sys
17
+ from dataclasses import dataclass
 
18
  from pathlib import Path
19
  from typing import List, Optional, Tuple
 
20
 
21
  import streamlit as st
 
22
  import yaml
23
+ from huggingface_hub import HfApi, hf_hub_download, model_info, whoami
24
 
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
 
28
 
29
  @dataclass
30
  class Config:
31
+ """Application configuration containing authentication and path settings.
32
+
33
+ Attributes:
34
+ hf_token: Hugging Face API token (user token takes precedence over system token)
35
+ hf_username: Hugging Face username associated with the token
36
+ is_using_user_token: True if using a user-provided token, False if using system token
37
+ hf_base_url: Base URL for Hugging Face Hub
38
+ repo_path: Path to the bundled transformers.js repository
39
+ """
40
 
41
  hf_token: str
42
  hf_username: str
 
46
 
47
  @classmethod
48
  def from_env(cls) -> "Config":
49
+ """Create configuration from environment variables and Streamlit session state.
50
+
51
+ Priority order for tokens:
52
+ 1. User-provided token from Streamlit session (st.session_state.user_hf_token)
53
+ 2. System token from environment variable (HF_TOKEN)
54
+
55
+ Returns:
56
+ Config: Initialized configuration object
57
+
58
+ Raises:
59
+ ValueError: If no valid token is available
60
+ """
61
  system_token = os.getenv("HF_TOKEN")
62
  user_token = st.session_state.get("user_hf_token")
63
 
64
+ # Determine username based on which token is being used
65
  if user_token:
66
  hf_username = whoami(token=user_token)["name"]
67
  else:
 
69
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
70
  )
71
 
72
+ # User token takes precedence over system token
73
  hf_token = user_token or system_token
74
 
75
  if not hf_token:
 
85
 
86
 
87
  class ModelConverter:
88
+ """Handles model conversion to ONNX format and upload to Hugging Face Hub.
89
+
90
+ This class manages the entire conversion workflow:
91
+ 1. Fetching original model metadata and README
92
+ 2. Running the ONNX conversion subprocess
93
+ 3. Generating an enhanced README with merged metadata
94
+ 4. Uploading the converted model to Hugging Face Hub
95
+
96
+ Attributes:
97
+ config: Application configuration containing tokens and paths
98
+ api: Hugging Face API client for repository operations
99
+ """
100
 
101
  def __init__(self, config: Config):
102
+ """Initialize the converter with configuration.
103
+
104
+ Args:
105
+ config: Application configuration object
106
+ """
107
  self.config = config
108
  self.api = HfApi(token=config.hf_token)
109
 
110
+ # ============================================================================
111
+ # README Processing Methods
112
+ # ============================================================================
113
+
114
  def _fetch_original_readme(self, repo_id: str) -> str:
115
+ """Download the README from the original model repository.
116
+
117
+ Args:
118
+ repo_id: Hugging Face model repository ID (e.g., 'username/model-name')
119
+
120
+ Returns:
121
+ str: Content of the README file, or empty string if not found
122
+ """
123
  try:
124
+ readme_path = hf_hub_download(
125
  repo_id=repo_id, filename="README.md", token=self.config.hf_token
126
  )
127
+ with open(readme_path, "r", encoding="utf-8", errors="ignore") as f:
128
  return f.read()
129
  except Exception:
130
+ # Silently fail if README doesn't exist or can't be downloaded
131
  return ""
132
 
133
  def _strip_yaml_frontmatter(self, text: str) -> str:
134
+ """Remove YAML frontmatter from text, returning only the body.
135
+
136
+ YAML frontmatter is delimited by '---' at the start and end.
137
+
138
+ Args:
139
+ text: Text that may contain YAML frontmatter
140
+
141
+ Returns:
142
+ str: Text with frontmatter removed, or original text if no frontmatter found
143
+ """
144
  if not text:
145
  return ""
146
  if text.startswith("---"):
147
+ match = re.match(r"^---[\s\S]*?\n---\s*\n", text)
148
+ if match:
149
+ return text[match.end() :]
150
  return text
151
 
152
  def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
153
+ """Parse and extract YAML frontmatter from text.
154
+
155
+ Args:
156
+ text: Text that may contain YAML frontmatter
157
+
158
+ Returns:
159
+ Tuple containing:
160
+ - dict: Parsed YAML frontmatter as a dictionary (empty dict if none found)
161
+ - str: Remaining body text after the frontmatter
162
+ """
163
  if not text or not text.startswith("---"):
164
  return {}, text or ""
165
+
166
+ # Match YAML frontmatter pattern: ---\n...content...\n---\n
167
+ match = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text)
168
+ if not match:
169
  return {}, text
170
+
171
+ frontmatter_text = match.group(1)
172
+ body = text[match.end() :]
173
+
174
+ # Parse YAML safely, returning empty dict on any error
175
  try:
176
+ parsed_data = yaml.safe_load(frontmatter_text)
177
+ if not isinstance(parsed_data, dict):
178
+ parsed_data = {}
179
  except Exception:
180
+ parsed_data = {}
181
+
182
+ return parsed_data, body
183
+
184
+ def _get_pipeline_docs_url(self, pipeline_tag: Optional[str]) -> str:
185
+ """Generate Transformers.js documentation URL for a given pipeline tag.
186
+
187
+ Args:
188
+ pipeline_tag: Hugging Face pipeline tag (e.g., 'text-generation')
189
+
190
+ Returns:
191
+ str: URL to the relevant Transformers.js pipeline documentation
192
+ """
193
+ base_url = "https://huggingface.co/docs/transformers.js/api/pipelines"
194
 
 
 
195
  if not pipeline_tag:
196
+ return base_url
197
+
198
+ # Map Hugging Face pipeline tags to Transformers.js pipeline class names
199
+ pipeline_class_mapping = {
200
  "text-classification": "TextClassificationPipeline",
201
  "token-classification": "TokenClassificationPipeline",
202
  "question-answering": "QuestionAnsweringPipeline",
 
223
  "image-to-image": "ImageToImagePipeline",
224
  "depth-estimation": "DepthEstimationPipeline",
225
  }
 
 
 
 
226
 
227
+ pipeline_class = pipeline_class_mapping.get(pipeline_tag)
228
+ if not pipeline_class:
229
+ return base_url
230
+
231
+ return f"{base_url}#module_pipelines.{pipeline_class}"
232
+
233
+ def _normalize_pipeline_tag(self, pipeline_tag: Optional[str]) -> Optional[str]:
234
+ """Normalize pipeline tag to match expected task names.
235
+
236
+ Some pipeline tags use abbreviations that need to be expanded
237
+ for the conversion script to recognize them.
238
+
239
+ Args:
240
+ pipeline_tag: Original pipeline tag from model metadata
241
+
242
+ Returns:
243
+ Optional[str]: Normalized task name, or None if input is None
244
+ """
245
  if not pipeline_tag:
246
  return None
247
+
248
+ # Map abbreviated tags to their full names
249
+ tag_synonyms = {
250
  "vqa": "visual-question-answering",
251
  }
252
+
253
+ return tag_synonyms.get(pipeline_tag, pipeline_tag)
254
+
255
+ # ============================================================================
256
+ # Model Conversion Methods
257
+ # ============================================================================
258
 
259
  def setup_repository(self) -> None:
260
+ """Verify that the transformers.js repository exists.
261
+
262
+ Raises:
263
+ RuntimeError: If the repository is not found at the expected path
264
+ """
265
  if not self.config.repo_path.exists():
266
  raise RuntimeError(
267
+ f"Expected transformers.js repository at {self.config.repo_path} "
268
+ f"but it was not found."
269
  )
270
 
271
  def _run_conversion_subprocess(
272
+ self, input_model_id: str, extra_args: Optional[List[str]] = None
273
  ) -> subprocess.CompletedProcess:
274
+ """Execute the ONNX conversion script as a subprocess.
275
+
276
+ Args:
277
+ input_model_id: Hugging Face model ID to convert
278
+ extra_args: Additional command-line arguments for the conversion script
279
+
280
+ Returns:
281
+ subprocess.CompletedProcess: Result of the subprocess execution
282
+ """
283
+ # Build the conversion command
284
+ command = [
285
  sys.executable,
286
  "-m",
287
  "scripts.convert",
288
+ "--quantize", # Enable quantization for smaller model size
289
  "--model_id",
290
  input_model_id,
291
  ]
292
 
293
  if extra_args:
294
+ command.extend(extra_args)
295
 
296
+ # Run conversion in the transformers.js repository directory
297
  return subprocess.run(
298
+ command,
299
  cwd=self.config.repo_path,
300
  capture_output=True,
301
  text=True,
302
  env={
303
  "HF_TOKEN": self.config.hf_token,
304
+ # Force eager attention implementation for compatibility
305
  "TRANSFORMERS_ATTENTION_IMPLEMENTATION": "eager",
306
  "PYTORCH_SDP_KERNEL": "math",
307
  },
 
310
  def convert_model(
311
  self,
312
  input_model_id: str,
313
+ trust_remote_code: bool = False,
314
+ output_attentions: bool = False,
315
  ) -> Tuple[bool, Optional[str]]:
316
+ """Convert a Hugging Face model to ONNX format.
317
+
318
+ Args:
319
+ input_model_id: Hugging Face model repository ID
320
+ trust_remote_code: Whether to trust and execute remote code from the model
321
+ output_attentions: Whether to output attention weights (required for some tasks)
322
+
323
+ Returns:
324
+ Tuple containing:
325
+ - bool: True if conversion succeeded, False otherwise
326
+ - Optional[str]: Error message if failed, or conversion log if succeeded
327
+ """
328
  try:
329
+ conversion_args: List[str] = []
330
+
331
+ # Handle trust_remote_code option (requires user token for security)
332
  if trust_remote_code:
333
  if not self.config.is_using_user_token:
334
  raise Exception(
335
  "Trust Remote Code requires your own HuggingFace token."
336
  )
337
+ conversion_args.append("--trust_remote_code")
338
 
339
+ # Handle output_attentions option (needed for word-level timestamps in Whisper)
340
  if output_attentions:
341
+ conversion_args.append("--output_attentions")
342
 
343
+ # Try to infer the task from model metadata and pass it to the conversion script
344
+ # This helps the script choose the right export configuration
345
  try:
346
  info = model_info(repo_id=input_model_id, token=self.config.hf_token)
347
+ pipeline_tag = getattr(info, "pipeline_tag", None)
348
+ task = self._normalize_pipeline_tag(pipeline_tag)
349
  if task:
350
+ conversion_args.extend(["--task", task])
351
  except Exception:
352
+ # If we can't fetch the task, continue without it
353
+ # The conversion script will try to infer it automatically
354
  pass
355
 
356
+ # Run the conversion
357
  result = self._run_conversion_subprocess(
358
+ input_model_id, extra_args=conversion_args or None
359
  )
360
 
361
+ # Check if conversion succeeded
362
  if result.returncode != 0:
363
  return False, result.stderr
364
 
 
367
  except Exception as e:
368
  return False, str(e)
369
 
370
+ # ============================================================================
371
+ # Upload Methods
372
+ # ============================================================================
373
+
374
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
375
+ """Upload the converted ONNX model to Hugging Face Hub.
376
+
377
+ This method:
378
+ 1. Creates the target repository (if it doesn't exist)
379
+ 2. Generates an enhanced README with merged metadata
380
+ 3. Uploads all model files to the repository
381
+ 4. Cleans up local files after upload
382
+
383
+ Args:
384
+ input_model_id: Original model repository ID
385
+ output_model_id: Target repository ID for the ONNX model
386
+
387
+ Returns:
388
+ Optional[str]: Error message if upload failed, None if successful
389
+ """
390
  model_folder_path = self.config.repo_path / "models" / input_model_id
391
 
392
  try:
393
+ # Create the target repository (public by default)
394
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
395
 
396
+ # Generate and write the enhanced README
397
+ readme_path = model_folder_path / "README.md"
398
+ readme_content = self.generate_readme(input_model_id)
399
+ readme_path.write_text(readme_content, encoding="utf-8")
400
 
401
+ # Upload all files from the model folder
402
  self.api.upload_folder(
403
  folder_path=str(model_folder_path), repo_id=output_model_id
404
  )
405
+
406
+ return None # Success
407
+
408
  except Exception as e:
409
  return str(e)
410
  finally:
411
+ # Always clean up local files, even if upload failed
412
  shutil.rmtree(model_folder_path, ignore_errors=True)
413
 
414
+ # ============================================================================
415
+ # README Generation Methods
416
+ # ============================================================================
417
+
418
+ def generate_readme(self, input_model_id: str) -> str:
419
+ """Generate an enhanced README for the ONNX model.
420
+
421
+ This method creates a README that:
422
+ 1. Merges metadata from the original model with ONNX-specific metadata
423
+ 2. Adds a description and link to the conversion space
424
+ 3. Includes usage instructions with links to Transformers.js docs
425
+ 4. Appends the original model's README content
426
+
427
+ Args:
428
+ input_model_id: Original model repository ID
429
+
430
+ Returns:
431
+ str: Complete README content in Markdown format with YAML frontmatter
432
+ """
433
+ # Fetch pipeline tag from model metadata (if available)
434
  try:
435
+ info = model_info(repo_id=input_model_id, token=self.config.hf_token)
436
  pipeline_tag = getattr(info, "pipeline_tag", None)
437
  except Exception:
438
  pipeline_tag = None
439
 
440
+ # Fetch and parse the original README
441
+ original_text = self._fetch_original_readme(input_model_id)
442
  original_meta, original_body = self._extract_yaml_frontmatter(original_text)
443
  original_body = (
444
  original_body or self._strip_yaml_frontmatter(original_text)
445
  ).strip()
446
 
447
+ # Merge original metadata with our ONNX-specific metadata (ours take precedence)
448
  merged_meta = {}
449
  if isinstance(original_meta, dict):
450
  merged_meta.update(original_meta)
451
  merged_meta["library_name"] = "transformers.js"
452
+ merged_meta["base_model"] = [input_model_id]
453
  if pipeline_tag is not None:
454
  merged_meta["pipeline_tag"] = pipeline_tag
455
 
456
+ # Generate YAML frontmatter
457
+ frontmatter_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip()
458
+ header = f"---\n{frontmatter_yaml}\n---\n\n"
459
 
460
+ # Build README sections
461
+ readme_sections: List[str] = []
462
+ readme_sections.append(header)
463
+
464
+ # Add title
465
+ model_name = input_model_id.split("/")[-1]
466
+ readme_sections.append(f"# {model_name} (ONNX)\n")
467
+
468
+ # Add description
469
+ readme_sections.append(
470
+ f"This is an ONNX version of [{input_model_id}](https://huggingface.co/{input_model_id}). "
471
  "It was automatically converted and uploaded using "
472
  "[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
473
  )
474
 
475
+ # Add usage section with Transformers.js docs link
476
+ docs_url = self._get_pipeline_docs_url(pipeline_tag)
477
  if docs_url:
478
+ readme_sections.append("\n## Usage with Transformers.js\n")
479
  if pipeline_tag:
480
+ readme_sections.append(
481
  f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
482
  )
483
  else:
484
+ readme_sections.append(f"See the pipelines documentation: {docs_url}")
485
 
486
+ # Append original README content (if available)
487
  if original_body:
488
+ readme_sections.append("\n---\n")
489
+ readme_sections.append(original_body)
490
 
491
+ return "\n\n".join(readme_sections) + "\n"
492
 
493
 
494
  def main():
495
+ """Main application entry point for the Streamlit interface.
496
+
497
+ This function:
498
+ 1. Initializes configuration and converter
499
+ 2. Displays the UI for model input and options
500
+ 3. Handles the conversion workflow
501
+ 4. Shows progress and results to the user
502
+ """
503
  st.write("## Convert a Hugging Face model to ONNX")
504
 
505
  try:
506
+ # Initialize configuration and converter
507
  config = Config.from_env()
508
  converter = ModelConverter(config)
509
  converter.setup_repository()
510
 
511
+ # Get model ID from user
512
  input_model_id = st.text_input(
513
  "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
514
  )
 
516
  if not input_model_id:
517
  return
518
 
519
+ # Optional: User token input
520
  st.text_input(
521
  f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
522
  type="password",
523
  key="user_hf_token",
524
  )
525
+
526
+ # Optional: Trust remote code toggle (requires user token)
527
  trust_remote_code = st.toggle("Optional: Trust Remote Code.")
528
  if trust_remote_code:
529
  st.warning(
530
  "This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
531
  )
532
 
533
+ # Optional: Output attentions (for Whisper models)
534
  output_attentions = False
535
  if "whisper" in input_model_id.lower():
536
  output_attentions = st.toggle(
537
  "Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
538
  )
539
 
540
+ # Determine output repository
541
+ # If user owns the model, allow uploading to the same repo
542
  if config.hf_username == input_model_id.split("/")[0]:
543
  same_repo = st.checkbox(
544
  "Upload the ONNX weights to the existing repository"
 
547
  same_repo = False
548
 
549
  model_name = input_model_id.split("/")[-1]
 
550
  output_model_id = f"{config.hf_username}/{model_name}"
551
 
552
+ # Add -ONNX suffix if creating a new repository
553
  if not same_repo:
554
  output_model_id += "-ONNX"
555
 
556
  output_model_url = f"{config.hf_base_url}/{output_model_id}"
557
 
558
+ # Check if model already exists
559
  if not same_repo and converter.api.repo_exists(output_model_id):
560
  st.write("This model has already been converted! 🎉")
561
  st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
562
  return
563
 
564
+ # Show where the model will be uploaded
565
  st.write(f"URL where the model will be converted and uploaded to:")
566
  st.code(output_model_url, language="plaintext")
567
 
568
+ # Wait for user confirmation before proceeding
569
  if not st.button(label="Proceed", type="primary"):
570
  return
571
 
572
+ # Step 1: Convert the model to ONNX
573
  with st.spinner("Converting model..."):
574
  success, stderr = converter.convert_model(
575
  input_model_id,
 
583
  st.success("Conversion successful!")
584
  st.code(stderr)
585
 
586
+ # Step 2: Upload the converted model to Hugging Face
587
  with st.spinner("Uploading model..."):
588
  error = converter.upload_model(input_model_id, output_model_id)
589
  if error: