Add docstrings and reorganize the code
Browse files
app.py
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import subprocess
|
| 4 |
import sys
|
| 5 |
-
import
|
| 6 |
-
import re
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import List, Optional, Tuple
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
|
| 11 |
import streamlit as st
|
| 12 |
-
from huggingface_hub import HfApi, whoami, model_info, hf_hub_download
|
| 13 |
import yaml
|
|
|
|
| 14 |
|
| 15 |
logging.basicConfig(level=logging.INFO)
|
| 16 |
logger = logging.getLogger(__name__)
|
|
@@ -18,7 +28,15 @@ logger = logging.getLogger(__name__)
|
|
| 18 |
|
| 19 |
@dataclass
|
| 20 |
class Config:
|
| 21 |
-
"""Application configuration.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
hf_token: str
|
| 24 |
hf_username: str
|
|
@@ -28,10 +46,22 @@ class Config:
|
|
| 28 |
|
| 29 |
@classmethod
|
| 30 |
def from_env(cls) -> "Config":
|
| 31 |
-
"""Create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
system_token = os.getenv("HF_TOKEN")
|
| 33 |
user_token = st.session_state.get("user_hf_token")
|
| 34 |
|
|
|
|
| 35 |
if user_token:
|
| 36 |
hf_username = whoami(token=user_token)["name"]
|
| 37 |
else:
|
|
@@ -39,6 +69,7 @@ class Config:
|
|
| 39 |
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
|
| 40 |
)
|
| 41 |
|
|
|
|
| 42 |
hf_token = user_token or system_token
|
| 43 |
|
| 44 |
if not hf_token:
|
|
@@ -54,53 +85,118 @@ class Config:
|
|
| 54 |
|
| 55 |
|
| 56 |
class ModelConverter:
|
| 57 |
-
"""Handles model conversion and upload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def __init__(self, config: Config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
self.config = config
|
| 61 |
self.api = HfApi(token=config.hf_token)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def _fetch_original_readme(self, repo_id: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
repo_id=repo_id, filename="README.md", token=self.config.hf_token
|
| 67 |
)
|
| 68 |
-
with open(
|
| 69 |
return f.read()
|
| 70 |
except Exception:
|
|
|
|
| 71 |
return ""
|
| 72 |
|
| 73 |
def _strip_yaml_frontmatter(self, text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
if not text:
|
| 75 |
return ""
|
| 76 |
if text.startswith("---"):
|
| 77 |
-
|
| 78 |
-
if
|
| 79 |
-
return text[
|
| 80 |
return text
|
| 81 |
|
| 82 |
def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
|
| 83 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if not text or not text.startswith("---"):
|
| 85 |
return {}, text or ""
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
return {}, text
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
| 91 |
try:
|
| 92 |
-
|
| 93 |
-
if not isinstance(
|
| 94 |
-
|
| 95 |
except Exception:
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
def _pipeline_docs_url(self, pipeline_tag: Optional[str]) -> Optional[str]:
|
| 100 |
-
base = "https://huggingface.co/docs/transformers.js/api/pipelines"
|
| 101 |
if not pipeline_tag:
|
| 102 |
-
return
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
"text-classification": "TextClassificationPipeline",
|
| 105 |
"token-classification": "TokenClassificationPipeline",
|
| 106 |
"question-answering": "QuestionAnsweringPipeline",
|
|
@@ -127,49 +223,85 @@ class ModelConverter:
|
|
| 127 |
"image-to-image": "ImageToImagePipeline",
|
| 128 |
"depth-estimation": "DepthEstimationPipeline",
|
| 129 |
}
|
| 130 |
-
cls = mapping.get(pipeline_tag)
|
| 131 |
-
if not cls:
|
| 132 |
-
return base
|
| 133 |
-
return f"{base}#module_pipelines.{cls}"
|
| 134 |
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
if not pipeline_tag:
|
| 137 |
return None
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
"vqa": "visual-question-answering",
|
| 140 |
}
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
def setup_repository(self) -> None:
|
| 144 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if not self.config.repo_path.exists():
|
| 146 |
raise RuntimeError(
|
| 147 |
-
f"Expected transformers.js repository at {self.config.repo_path}
|
|
|
|
| 148 |
)
|
| 149 |
|
| 150 |
def _run_conversion_subprocess(
|
| 151 |
-
self, input_model_id: str, extra_args: List[str] = None
|
| 152 |
) -> subprocess.CompletedProcess:
|
| 153 |
-
"""
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
sys.executable,
|
| 156 |
"-m",
|
| 157 |
"scripts.convert",
|
| 158 |
-
"--quantize",
|
| 159 |
"--model_id",
|
| 160 |
input_model_id,
|
| 161 |
]
|
| 162 |
|
| 163 |
if extra_args:
|
| 164 |
-
|
| 165 |
|
|
|
|
| 166 |
return subprocess.run(
|
| 167 |
-
|
| 168 |
cwd=self.config.repo_path,
|
| 169 |
capture_output=True,
|
| 170 |
text=True,
|
| 171 |
env={
|
| 172 |
"HF_TOKEN": self.config.hf_token,
|
|
|
|
| 173 |
"TRANSFORMERS_ATTENTION_IMPLEMENTATION": "eager",
|
| 174 |
"PYTORCH_SDP_KERNEL": "math",
|
| 175 |
},
|
|
@@ -178,34 +310,55 @@ class ModelConverter:
|
|
| 178 |
def convert_model(
|
| 179 |
self,
|
| 180 |
input_model_id: str,
|
| 181 |
-
trust_remote_code=False,
|
| 182 |
-
output_attentions=False,
|
| 183 |
) -> Tuple[bool, Optional[str]]:
|
| 184 |
-
"""Convert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
try:
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
if trust_remote_code:
|
| 188 |
if not self.config.is_using_user_token:
|
| 189 |
raise Exception(
|
| 190 |
"Trust Remote Code requires your own HuggingFace token."
|
| 191 |
)
|
| 192 |
-
|
| 193 |
|
|
|
|
| 194 |
if output_attentions:
|
| 195 |
-
|
| 196 |
|
|
|
|
|
|
|
| 197 |
try:
|
| 198 |
info = model_info(repo_id=input_model_id, token=self.config.hf_token)
|
| 199 |
-
|
|
|
|
| 200 |
if task:
|
| 201 |
-
|
| 202 |
except Exception:
|
|
|
|
|
|
|
| 203 |
pass
|
| 204 |
|
|
|
|
| 205 |
result = self._run_conversion_subprocess(
|
| 206 |
-
input_model_id, extra_args=
|
| 207 |
)
|
| 208 |
|
|
|
|
| 209 |
if result.returncode != 0:
|
| 210 |
return False, result.stderr
|
| 211 |
|
|
@@ -214,86 +367,148 @@ class ModelConverter:
|
|
| 214 |
except Exception as e:
|
| 215 |
return False, str(e)
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
|
| 218 |
-
"""Upload the converted model to Hugging Face.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
model_folder_path = self.config.repo_path / "models" / input_model_id
|
| 220 |
|
| 221 |
try:
|
|
|
|
| 222 |
self.api.create_repo(output_model_id, exist_ok=True, private=False)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
|
|
|
| 229 |
self.api.upload_folder(
|
| 230 |
folder_path=str(model_folder_path), repo_id=output_model_id
|
| 231 |
)
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
except Exception as e:
|
| 234 |
return str(e)
|
| 235 |
finally:
|
|
|
|
| 236 |
shutil.rmtree(model_folder_path, ignore_errors=True)
|
| 237 |
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
try:
|
| 240 |
-
info = model_info(repo_id=
|
| 241 |
pipeline_tag = getattr(info, "pipeline_tag", None)
|
| 242 |
except Exception:
|
| 243 |
pipeline_tag = None
|
| 244 |
|
| 245 |
-
|
|
|
|
| 246 |
original_meta, original_body = self._extract_yaml_frontmatter(original_text)
|
| 247 |
original_body = (
|
| 248 |
original_body or self._strip_yaml_frontmatter(original_text)
|
| 249 |
).strip()
|
| 250 |
|
|
|
|
| 251 |
merged_meta = {}
|
| 252 |
if isinstance(original_meta, dict):
|
| 253 |
merged_meta.update(original_meta)
|
| 254 |
merged_meta["library_name"] = "transformers.js"
|
| 255 |
-
merged_meta["base_model"] = [
|
| 256 |
if pipeline_tag is not None:
|
| 257 |
merged_meta["pipeline_tag"] = pipeline_tag
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
"It was automatically converted and uploaded using "
|
| 268 |
"[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
|
| 269 |
)
|
| 270 |
|
| 271 |
-
|
|
|
|
| 272 |
if docs_url:
|
| 273 |
-
|
| 274 |
if pipeline_tag:
|
| 275 |
-
|
| 276 |
f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
|
| 277 |
)
|
| 278 |
else:
|
| 279 |
-
|
| 280 |
|
|
|
|
| 281 |
if original_body:
|
| 282 |
-
|
| 283 |
-
|
| 284 |
|
| 285 |
-
return "\n\n".join(
|
| 286 |
|
| 287 |
|
| 288 |
def main():
|
| 289 |
-
"""Main application entry point.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
st.write("## Convert a Hugging Face model to ONNX")
|
| 291 |
|
| 292 |
try:
|
|
|
|
| 293 |
config = Config.from_env()
|
| 294 |
converter = ModelConverter(config)
|
| 295 |
converter.setup_repository()
|
| 296 |
|
|
|
|
| 297 |
input_model_id = st.text_input(
|
| 298 |
"Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
|
| 299 |
)
|
|
@@ -301,23 +516,29 @@ def main():
|
|
| 301 |
if not input_model_id:
|
| 302 |
return
|
| 303 |
|
|
|
|
| 304 |
st.text_input(
|
| 305 |
f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
|
| 306 |
type="password",
|
| 307 |
key="user_hf_token",
|
| 308 |
)
|
|
|
|
|
|
|
| 309 |
trust_remote_code = st.toggle("Optional: Trust Remote Code.")
|
| 310 |
if trust_remote_code:
|
| 311 |
st.warning(
|
| 312 |
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
|
| 313 |
)
|
| 314 |
|
|
|
|
| 315 |
output_attentions = False
|
| 316 |
if "whisper" in input_model_id.lower():
|
| 317 |
output_attentions = st.toggle(
|
| 318 |
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
|
| 319 |
)
|
| 320 |
|
|
|
|
|
|
|
| 321 |
if config.hf_username == input_model_id.split("/")[0]:
|
| 322 |
same_repo = st.checkbox(
|
| 323 |
"Upload the ONNX weights to the existing repository"
|
|
@@ -326,25 +547,29 @@ def main():
|
|
| 326 |
same_repo = False
|
| 327 |
|
| 328 |
model_name = input_model_id.split("/")[-1]
|
| 329 |
-
|
| 330 |
output_model_id = f"{config.hf_username}/{model_name}"
|
| 331 |
|
|
|
|
| 332 |
if not same_repo:
|
| 333 |
output_model_id += "-ONNX"
|
| 334 |
|
| 335 |
output_model_url = f"{config.hf_base_url}/{output_model_id}"
|
| 336 |
|
|
|
|
| 337 |
if not same_repo and converter.api.repo_exists(output_model_id):
|
| 338 |
st.write("This model has already been converted! 🎉")
|
| 339 |
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
|
| 340 |
return
|
| 341 |
|
|
|
|
| 342 |
st.write(f"URL where the model will be converted and uploaded to:")
|
| 343 |
st.code(output_model_url, language="plaintext")
|
| 344 |
|
|
|
|
| 345 |
if not st.button(label="Proceed", type="primary"):
|
| 346 |
return
|
| 347 |
|
|
|
|
| 348 |
with st.spinner("Converting model..."):
|
| 349 |
success, stderr = converter.convert_model(
|
| 350 |
input_model_id,
|
|
@@ -358,6 +583,7 @@ def main():
|
|
| 358 |
st.success("Conversion successful!")
|
| 359 |
st.code(stderr)
|
| 360 |
|
|
|
|
| 361 |
with st.spinner("Uploading model..."):
|
| 362 |
error = converter.upload_model(input_model_id, output_model_id)
|
| 363 |
if error:
|
|
|
|
| 1 |
+
"""Convert Hugging Face models to ONNX format.
|
| 2 |
+
|
| 3 |
+
This application provides a Streamlit interface for converting Hugging Face models
|
| 4 |
+
to ONNX format using the Transformers.js conversion scripts. It handles:
|
| 5 |
+
- Model conversion with optional trust_remote_code and output_attentions
|
| 6 |
+
- Automatic task inference with fallback support
|
| 7 |
+
- README generation with merged metadata from the original model
|
| 8 |
+
- Upload to Hugging Face Hub
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
import logging
|
| 12 |
import os
|
| 13 |
+
import re
|
| 14 |
+
import shutil
|
| 15 |
import subprocess
|
| 16 |
import sys
|
| 17 |
+
from dataclasses import dataclass
|
|
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import List, Optional, Tuple
|
|
|
|
| 20 |
|
| 21 |
import streamlit as st
|
|
|
|
| 22 |
import yaml
|
| 23 |
+
from huggingface_hub import HfApi, hf_hub_download, model_info, whoami
|
| 24 |
|
| 25 |
logging.basicConfig(level=logging.INFO)
|
| 26 |
logger = logging.getLogger(__name__)
|
|
|
|
| 28 |
|
| 29 |
@dataclass
|
| 30 |
class Config:
|
| 31 |
+
"""Application configuration containing authentication and path settings.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
hf_token: Hugging Face API token (user token takes precedence over system token)
|
| 35 |
+
hf_username: Hugging Face username associated with the token
|
| 36 |
+
is_using_user_token: True if using a user-provided token, False if using system token
|
| 37 |
+
hf_base_url: Base URL for Hugging Face Hub
|
| 38 |
+
repo_path: Path to the bundled transformers.js repository
|
| 39 |
+
"""
|
| 40 |
|
| 41 |
hf_token: str
|
| 42 |
hf_username: str
|
|
|
|
| 46 |
|
| 47 |
@classmethod
|
| 48 |
def from_env(cls) -> "Config":
|
| 49 |
+
"""Create configuration from environment variables and Streamlit session state.
|
| 50 |
+
|
| 51 |
+
Priority order for tokens:
|
| 52 |
+
1. User-provided token from Streamlit session (st.session_state.user_hf_token)
|
| 53 |
+
2. System token from environment variable (HF_TOKEN)
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Config: Initialized configuration object
|
| 57 |
+
|
| 58 |
+
Raises:
|
| 59 |
+
ValueError: If no valid token is available
|
| 60 |
+
"""
|
| 61 |
system_token = os.getenv("HF_TOKEN")
|
| 62 |
user_token = st.session_state.get("user_hf_token")
|
| 63 |
|
| 64 |
+
# Determine username based on which token is being used
|
| 65 |
if user_token:
|
| 66 |
hf_username = whoami(token=user_token)["name"]
|
| 67 |
else:
|
|
|
|
| 69 |
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
|
| 70 |
)
|
| 71 |
|
| 72 |
+
# User token takes precedence over system token
|
| 73 |
hf_token = user_token or system_token
|
| 74 |
|
| 75 |
if not hf_token:
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
class ModelConverter:
|
| 88 |
+
"""Handles model conversion to ONNX format and upload to Hugging Face Hub.
|
| 89 |
+
|
| 90 |
+
This class manages the entire conversion workflow:
|
| 91 |
+
1. Fetching original model metadata and README
|
| 92 |
+
2. Running the ONNX conversion subprocess
|
| 93 |
+
3. Generating an enhanced README with merged metadata
|
| 94 |
+
4. Uploading the converted model to Hugging Face Hub
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
config: Application configuration containing tokens and paths
|
| 98 |
+
api: Hugging Face API client for repository operations
|
| 99 |
+
"""
|
| 100 |
|
| 101 |
def __init__(self, config: Config):
|
| 102 |
+
"""Initialize the converter with configuration.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
config: Application configuration object
|
| 106 |
+
"""
|
| 107 |
self.config = config
|
| 108 |
self.api = HfApi(token=config.hf_token)
|
| 109 |
|
| 110 |
+
# ============================================================================
|
| 111 |
+
# README Processing Methods
|
| 112 |
+
# ============================================================================
|
| 113 |
+
|
| 114 |
def _fetch_original_readme(self, repo_id: str) -> str:
|
| 115 |
+
"""Download the README from the original model repository.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
repo_id: Hugging Face model repository ID (e.g., 'username/model-name')
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
str: Content of the README file, or empty string if not found
|
| 122 |
+
"""
|
| 123 |
try:
|
| 124 |
+
readme_path = hf_hub_download(
|
| 125 |
repo_id=repo_id, filename="README.md", token=self.config.hf_token
|
| 126 |
)
|
| 127 |
+
with open(readme_path, "r", encoding="utf-8", errors="ignore") as f:
|
| 128 |
return f.read()
|
| 129 |
except Exception:
|
| 130 |
+
# Silently fail if README doesn't exist or can't be downloaded
|
| 131 |
return ""
|
| 132 |
|
| 133 |
def _strip_yaml_frontmatter(self, text: str) -> str:
|
| 134 |
+
"""Remove YAML frontmatter from text, returning only the body.
|
| 135 |
+
|
| 136 |
+
YAML frontmatter is delimited by '---' at the start and end.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
text: Text that may contain YAML frontmatter
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
str: Text with frontmatter removed, or original text if no frontmatter found
|
| 143 |
+
"""
|
| 144 |
if not text:
|
| 145 |
return ""
|
| 146 |
if text.startswith("---"):
|
| 147 |
+
match = re.match(r"^---[\s\S]*?\n---\s*\n", text)
|
| 148 |
+
if match:
|
| 149 |
+
return text[match.end() :]
|
| 150 |
return text
|
| 151 |
|
| 152 |
def _extract_yaml_frontmatter(self, text: str) -> Tuple[dict, str]:
|
| 153 |
+
"""Parse and extract YAML frontmatter from text.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
text: Text that may contain YAML frontmatter
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Tuple containing:
|
| 160 |
+
- dict: Parsed YAML frontmatter as a dictionary (empty dict if none found)
|
| 161 |
+
- str: Remaining body text after the frontmatter
|
| 162 |
+
"""
|
| 163 |
if not text or not text.startswith("---"):
|
| 164 |
return {}, text or ""
|
| 165 |
+
|
| 166 |
+
# Match YAML frontmatter pattern: ---\n...content...\n---\n
|
| 167 |
+
match = re.match(r"^---\s*\n([\s\S]*?)\n---\s*\n", text)
|
| 168 |
+
if not match:
|
| 169 |
return {}, text
|
| 170 |
+
|
| 171 |
+
frontmatter_text = match.group(1)
|
| 172 |
+
body = text[match.end() :]
|
| 173 |
+
|
| 174 |
+
# Parse YAML safely, returning empty dict on any error
|
| 175 |
try:
|
| 176 |
+
parsed_data = yaml.safe_load(frontmatter_text)
|
| 177 |
+
if not isinstance(parsed_data, dict):
|
| 178 |
+
parsed_data = {}
|
| 179 |
except Exception:
|
| 180 |
+
parsed_data = {}
|
| 181 |
+
|
| 182 |
+
return parsed_data, body
|
| 183 |
+
|
| 184 |
+
def _get_pipeline_docs_url(self, pipeline_tag: Optional[str]) -> str:
|
| 185 |
+
"""Generate Transformers.js documentation URL for a given pipeline tag.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
pipeline_tag: Hugging Face pipeline tag (e.g., 'text-generation')
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
str: URL to the relevant Transformers.js pipeline documentation
|
| 192 |
+
"""
|
| 193 |
+
base_url = "https://huggingface.co/docs/transformers.js/api/pipelines"
|
| 194 |
|
|
|
|
|
|
|
| 195 |
if not pipeline_tag:
|
| 196 |
+
return base_url
|
| 197 |
+
|
| 198 |
+
# Map Hugging Face pipeline tags to Transformers.js pipeline class names
|
| 199 |
+
pipeline_class_mapping = {
|
| 200 |
"text-classification": "TextClassificationPipeline",
|
| 201 |
"token-classification": "TokenClassificationPipeline",
|
| 202 |
"question-answering": "QuestionAnsweringPipeline",
|
|
|
|
| 223 |
"image-to-image": "ImageToImagePipeline",
|
| 224 |
"depth-estimation": "DepthEstimationPipeline",
|
| 225 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
+
pipeline_class = pipeline_class_mapping.get(pipeline_tag)
|
| 228 |
+
if not pipeline_class:
|
| 229 |
+
return base_url
|
| 230 |
+
|
| 231 |
+
return f"{base_url}#module_pipelines.{pipeline_class}"
|
| 232 |
+
|
| 233 |
+
def _normalize_pipeline_tag(self, pipeline_tag: Optional[str]) -> Optional[str]:
|
| 234 |
+
"""Normalize pipeline tag to match expected task names.
|
| 235 |
+
|
| 236 |
+
Some pipeline tags use abbreviations that need to be expanded
|
| 237 |
+
for the conversion script to recognize them.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
pipeline_tag: Original pipeline tag from model metadata
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Optional[str]: Normalized task name, or None if input is None
|
| 244 |
+
"""
|
| 245 |
if not pipeline_tag:
|
| 246 |
return None
|
| 247 |
+
|
| 248 |
+
# Map abbreviated tags to their full names
|
| 249 |
+
tag_synonyms = {
|
| 250 |
"vqa": "visual-question-answering",
|
| 251 |
}
|
| 252 |
+
|
| 253 |
+
return tag_synonyms.get(pipeline_tag, pipeline_tag)
|
| 254 |
+
|
| 255 |
+
# ============================================================================
|
| 256 |
+
# Model Conversion Methods
|
| 257 |
+
# ============================================================================
|
| 258 |
|
| 259 |
def setup_repository(self) -> None:
|
| 260 |
+
"""Verify that the transformers.js repository exists.
|
| 261 |
+
|
| 262 |
+
Raises:
|
| 263 |
+
RuntimeError: If the repository is not found at the expected path
|
| 264 |
+
"""
|
| 265 |
if not self.config.repo_path.exists():
|
| 266 |
raise RuntimeError(
|
| 267 |
+
f"Expected transformers.js repository at {self.config.repo_path} "
|
| 268 |
+
f"but it was not found."
|
| 269 |
)
|
| 270 |
|
| 271 |
def _run_conversion_subprocess(
|
| 272 |
+
self, input_model_id: str, extra_args: Optional[List[str]] = None
|
| 273 |
) -> subprocess.CompletedProcess:
|
| 274 |
+
"""Execute the ONNX conversion script as a subprocess.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
input_model_id: Hugging Face model ID to convert
|
| 278 |
+
extra_args: Additional command-line arguments for the conversion script
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
subprocess.CompletedProcess: Result of the subprocess execution
|
| 282 |
+
"""
|
| 283 |
+
# Build the conversion command
|
| 284 |
+
command = [
|
| 285 |
sys.executable,
|
| 286 |
"-m",
|
| 287 |
"scripts.convert",
|
| 288 |
+
"--quantize", # Enable quantization for smaller model size
|
| 289 |
"--model_id",
|
| 290 |
input_model_id,
|
| 291 |
]
|
| 292 |
|
| 293 |
if extra_args:
|
| 294 |
+
command.extend(extra_args)
|
| 295 |
|
| 296 |
+
# Run conversion in the transformers.js repository directory
|
| 297 |
return subprocess.run(
|
| 298 |
+
command,
|
| 299 |
cwd=self.config.repo_path,
|
| 300 |
capture_output=True,
|
| 301 |
text=True,
|
| 302 |
env={
|
| 303 |
"HF_TOKEN": self.config.hf_token,
|
| 304 |
+
# Force eager attention implementation for compatibility
|
| 305 |
"TRANSFORMERS_ATTENTION_IMPLEMENTATION": "eager",
|
| 306 |
"PYTORCH_SDP_KERNEL": "math",
|
| 307 |
},
|
|
|
|
| 310 |
def convert_model(
|
| 311 |
self,
|
| 312 |
input_model_id: str,
|
| 313 |
+
trust_remote_code: bool = False,
|
| 314 |
+
output_attentions: bool = False,
|
| 315 |
) -> Tuple[bool, Optional[str]]:
|
| 316 |
+
"""Convert a Hugging Face model to ONNX format.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
input_model_id: Hugging Face model repository ID
|
| 320 |
+
trust_remote_code: Whether to trust and execute remote code from the model
|
| 321 |
+
output_attentions: Whether to output attention weights (required for some tasks)
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Tuple containing:
|
| 325 |
+
- bool: True if conversion succeeded, False otherwise
|
| 326 |
+
- Optional[str]: Error message if failed, or conversion log if succeeded
|
| 327 |
+
"""
|
| 328 |
try:
|
| 329 |
+
conversion_args: List[str] = []
|
| 330 |
+
|
| 331 |
+
# Handle trust_remote_code option (requires user token for security)
|
| 332 |
if trust_remote_code:
|
| 333 |
if not self.config.is_using_user_token:
|
| 334 |
raise Exception(
|
| 335 |
"Trust Remote Code requires your own HuggingFace token."
|
| 336 |
)
|
| 337 |
+
conversion_args.append("--trust_remote_code")
|
| 338 |
|
| 339 |
+
# Handle output_attentions option (needed for word-level timestamps in Whisper)
|
| 340 |
if output_attentions:
|
| 341 |
+
conversion_args.append("--output_attentions")
|
| 342 |
|
| 343 |
+
# Try to infer the task from model metadata and pass it to the conversion script
|
| 344 |
+
# This helps the script choose the right export configuration
|
| 345 |
try:
|
| 346 |
info = model_info(repo_id=input_model_id, token=self.config.hf_token)
|
| 347 |
+
pipeline_tag = getattr(info, "pipeline_tag", None)
|
| 348 |
+
task = self._normalize_pipeline_tag(pipeline_tag)
|
| 349 |
if task:
|
| 350 |
+
conversion_args.extend(["--task", task])
|
| 351 |
except Exception:
|
| 352 |
+
# If we can't fetch the task, continue without it
|
| 353 |
+
# The conversion script will try to infer it automatically
|
| 354 |
pass
|
| 355 |
|
| 356 |
+
# Run the conversion
|
| 357 |
result = self._run_conversion_subprocess(
|
| 358 |
+
input_model_id, extra_args=conversion_args or None
|
| 359 |
)
|
| 360 |
|
| 361 |
+
# Check if conversion succeeded
|
| 362 |
if result.returncode != 0:
|
| 363 |
return False, result.stderr
|
| 364 |
|
|
|
|
| 367 |
except Exception as e:
|
| 368 |
return False, str(e)
|
| 369 |
|
| 370 |
+
# ============================================================================
|
| 371 |
+
# Upload Methods
|
| 372 |
+
# ============================================================================
|
| 373 |
+
|
| 374 |
def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
|
| 375 |
+
"""Upload the converted ONNX model to Hugging Face Hub.
|
| 376 |
+
|
| 377 |
+
This method:
|
| 378 |
+
1. Creates the target repository (if it doesn't exist)
|
| 379 |
+
2. Generates an enhanced README with merged metadata
|
| 380 |
+
3. Uploads all model files to the repository
|
| 381 |
+
4. Cleans up local files after upload
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
input_model_id: Original model repository ID
|
| 385 |
+
output_model_id: Target repository ID for the ONNX model
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Optional[str]: Error message if upload failed, None if successful
|
| 389 |
+
"""
|
| 390 |
model_folder_path = self.config.repo_path / "models" / input_model_id
|
| 391 |
|
| 392 |
try:
|
| 393 |
+
# Create the target repository (public by default)
|
| 394 |
self.api.create_repo(output_model_id, exist_ok=True, private=False)
|
| 395 |
|
| 396 |
+
# Generate and write the enhanced README
|
| 397 |
+
readme_path = model_folder_path / "README.md"
|
| 398 |
+
readme_content = self.generate_readme(input_model_id)
|
| 399 |
+
readme_path.write_text(readme_content, encoding="utf-8")
|
| 400 |
|
| 401 |
+
# Upload all files from the model folder
|
| 402 |
self.api.upload_folder(
|
| 403 |
folder_path=str(model_folder_path), repo_id=output_model_id
|
| 404 |
)
|
| 405 |
+
|
| 406 |
+
return None # Success
|
| 407 |
+
|
| 408 |
except Exception as e:
|
| 409 |
return str(e)
|
| 410 |
finally:
|
| 411 |
+
# Always clean up local files, even if upload failed
|
| 412 |
shutil.rmtree(model_folder_path, ignore_errors=True)
|
| 413 |
|
| 414 |
+
# ============================================================================
|
| 415 |
+
# README Generation Methods
|
| 416 |
+
# ============================================================================
|
| 417 |
+
|
| 418 |
+
def generate_readme(self, input_model_id: str) -> str:
|
| 419 |
+
"""Generate an enhanced README for the ONNX model.
|
| 420 |
+
|
| 421 |
+
This method creates a README that:
|
| 422 |
+
1. Merges metadata from the original model with ONNX-specific metadata
|
| 423 |
+
2. Adds a description and link to the conversion space
|
| 424 |
+
3. Includes usage instructions with links to Transformers.js docs
|
| 425 |
+
4. Appends the original model's README content
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
input_model_id: Original model repository ID
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
str: Complete README content in Markdown format with YAML frontmatter
|
| 432 |
+
"""
|
| 433 |
+
# Fetch pipeline tag from model metadata (if available)
|
| 434 |
try:
|
| 435 |
+
info = model_info(repo_id=input_model_id, token=self.config.hf_token)
|
| 436 |
pipeline_tag = getattr(info, "pipeline_tag", None)
|
| 437 |
except Exception:
|
| 438 |
pipeline_tag = None
|
| 439 |
|
| 440 |
+
# Fetch and parse the original README
|
| 441 |
+
original_text = self._fetch_original_readme(input_model_id)
|
| 442 |
original_meta, original_body = self._extract_yaml_frontmatter(original_text)
|
| 443 |
original_body = (
|
| 444 |
original_body or self._strip_yaml_frontmatter(original_text)
|
| 445 |
).strip()
|
| 446 |
|
| 447 |
+
# Merge original metadata with our ONNX-specific metadata (ours take precedence)
|
| 448 |
merged_meta = {}
|
| 449 |
if isinstance(original_meta, dict):
|
| 450 |
merged_meta.update(original_meta)
|
| 451 |
merged_meta["library_name"] = "transformers.js"
|
| 452 |
+
merged_meta["base_model"] = [input_model_id]
|
| 453 |
if pipeline_tag is not None:
|
| 454 |
merged_meta["pipeline_tag"] = pipeline_tag
|
| 455 |
|
| 456 |
+
# Generate YAML frontmatter
|
| 457 |
+
frontmatter_yaml = yaml.safe_dump(merged_meta, sort_keys=False).strip()
|
| 458 |
+
header = f"---\n{frontmatter_yaml}\n---\n\n"
|
| 459 |
|
| 460 |
+
# Build README sections
|
| 461 |
+
readme_sections: List[str] = []
|
| 462 |
+
readme_sections.append(header)
|
| 463 |
+
|
| 464 |
+
# Add title
|
| 465 |
+
model_name = input_model_id.split("/")[-1]
|
| 466 |
+
readme_sections.append(f"# {model_name} (ONNX)\n")
|
| 467 |
+
|
| 468 |
+
# Add description
|
| 469 |
+
readme_sections.append(
|
| 470 |
+
f"This is an ONNX version of [{input_model_id}](https://huggingface.co/{input_model_id}). "
|
| 471 |
"It was automatically converted and uploaded using "
|
| 472 |
"[this Hugging Face Space](https://huggingface.co/spaces/onnx-community/convert-to-onnx)."
|
| 473 |
)
|
| 474 |
|
| 475 |
+
# Add usage section with Transformers.js docs link
|
| 476 |
+
docs_url = self._get_pipeline_docs_url(pipeline_tag)
|
| 477 |
if docs_url:
|
| 478 |
+
readme_sections.append("\n## Usage with Transformers.js\n")
|
| 479 |
if pipeline_tag:
|
| 480 |
+
readme_sections.append(
|
| 481 |
f"See the pipeline documentation for `{pipeline_tag}`: {docs_url}"
|
| 482 |
)
|
| 483 |
else:
|
| 484 |
+
readme_sections.append(f"See the pipelines documentation: {docs_url}")
|
| 485 |
|
| 486 |
+
# Append original README content (if available)
|
| 487 |
if original_body:
|
| 488 |
+
readme_sections.append("\n---\n")
|
| 489 |
+
readme_sections.append(original_body)
|
| 490 |
|
| 491 |
+
return "\n\n".join(readme_sections) + "\n"
|
| 492 |
|
| 493 |
|
| 494 |
def main():
|
| 495 |
+
"""Main application entry point for the Streamlit interface.
|
| 496 |
+
|
| 497 |
+
This function:
|
| 498 |
+
1. Initializes configuration and converter
|
| 499 |
+
2. Displays the UI for model input and options
|
| 500 |
+
3. Handles the conversion workflow
|
| 501 |
+
4. Shows progress and results to the user
|
| 502 |
+
"""
|
| 503 |
st.write("## Convert a Hugging Face model to ONNX")
|
| 504 |
|
| 505 |
try:
|
| 506 |
+
# Initialize configuration and converter
|
| 507 |
config = Config.from_env()
|
| 508 |
converter = ModelConverter(config)
|
| 509 |
converter.setup_repository()
|
| 510 |
|
| 511 |
+
# Get model ID from user
|
| 512 |
input_model_id = st.text_input(
|
| 513 |
"Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
|
| 514 |
)
|
|
|
|
| 516 |
if not input_model_id:
|
| 517 |
return
|
| 518 |
|
| 519 |
+
# Optional: User token input
|
| 520 |
st.text_input(
|
| 521 |
f"Optional: Your Hugging Face write token. Fill it if you want to upload the model under your account.",
|
| 522 |
type="password",
|
| 523 |
key="user_hf_token",
|
| 524 |
)
|
| 525 |
+
|
| 526 |
+
# Optional: Trust remote code toggle (requires user token)
|
| 527 |
trust_remote_code = st.toggle("Optional: Trust Remote Code.")
|
| 528 |
if trust_remote_code:
|
| 529 |
st.warning(
|
| 530 |
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
|
| 531 |
)
|
| 532 |
|
| 533 |
+
# Optional: Output attentions (for Whisper models)
|
| 534 |
output_attentions = False
|
| 535 |
if "whisper" in input_model_id.lower():
|
| 536 |
output_attentions = st.toggle(
|
| 537 |
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
|
| 538 |
)
|
| 539 |
|
| 540 |
+
# Determine output repository
|
| 541 |
+
# If user owns the model, allow uploading to the same repo
|
| 542 |
if config.hf_username == input_model_id.split("/")[0]:
|
| 543 |
same_repo = st.checkbox(
|
| 544 |
"Upload the ONNX weights to the existing repository"
|
|
|
|
| 547 |
same_repo = False
|
| 548 |
|
| 549 |
model_name = input_model_id.split("/")[-1]
|
|
|
|
| 550 |
output_model_id = f"{config.hf_username}/{model_name}"
|
| 551 |
|
| 552 |
+
# Add -ONNX suffix if creating a new repository
|
| 553 |
if not same_repo:
|
| 554 |
output_model_id += "-ONNX"
|
| 555 |
|
| 556 |
output_model_url = f"{config.hf_base_url}/{output_model_id}"
|
| 557 |
|
| 558 |
+
# Check if model already exists
|
| 559 |
if not same_repo and converter.api.repo_exists(output_model_id):
|
| 560 |
st.write("This model has already been converted! 🎉")
|
| 561 |
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
|
| 562 |
return
|
| 563 |
|
| 564 |
+
# Show where the model will be uploaded
|
| 565 |
st.write(f"URL where the model will be converted and uploaded to:")
|
| 566 |
st.code(output_model_url, language="plaintext")
|
| 567 |
|
| 568 |
+
# Wait for user confirmation before proceeding
|
| 569 |
if not st.button(label="Proceed", type="primary"):
|
| 570 |
return
|
| 571 |
|
| 572 |
+
# Step 1: Convert the model to ONNX
|
| 573 |
with st.spinner("Converting model..."):
|
| 574 |
success, stderr = converter.convert_model(
|
| 575 |
input_model_id,
|
|
|
|
| 583 |
st.success("Conversion successful!")
|
| 584 |
st.code(stderr)
|
| 585 |
|
| 586 |
+
# Step 2: Upload the converted model to Hugging Face
|
| 587 |
with st.spinner("Uploading model..."):
|
| 588 |
error = converter.upload_model(input_model_id, output_model_id)
|
| 589 |
if error:
|