Felladrin commited on
Commit
5a05fd1
·
1 Parent(s): e1979fa

Temporarily patch transformers.js convert script to fix the model conversion

Browse files
Files changed (2) hide show
  1. app.py +58 -2
  2. requirements.txt +4 -1
app.py CHANGED
@@ -86,6 +86,8 @@ class ModelConverter:
86
  try:
87
  urlretrieve(archive_url, archive_path)
88
  self._extract_archive(archive_path)
 
 
89
  logger.info("Repository downloaded and extracted successfully")
90
  except Exception as e:
91
  raise RuntimeError(f"Failed to setup repository: {e}")
@@ -101,6 +103,60 @@ class ModelConverter:
101
  extracted_folder = next(Path(tmp_dir).iterdir())
102
  extracted_folder.rename(self.config.repo_path)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def _run_conversion_subprocess(
105
  self, input_model_id: str, extra_args: List[str] = None
106
  ) -> subprocess.CompletedProcess:
@@ -110,8 +166,6 @@ class ModelConverter:
110
  "-m",
111
  "scripts.convert",
112
  "--quantize",
113
- "--opset",
114
- "18",
115
  "--model_id",
116
  input_model_id,
117
  ]
@@ -126,6 +180,8 @@ class ModelConverter:
126
  text=True,
127
  env={
128
  "HF_TOKEN": self.config.hf_token,
 
 
129
  },
130
  )
131
 
 
86
  try:
87
  urlretrieve(archive_url, archive_path)
88
  self._extract_archive(archive_path)
89
+ self._patch_convert_script()
90
+ self._install_scripts_requirements()
91
  logger.info("Repository downloaded and extracted successfully")
92
  except Exception as e:
93
  raise RuntimeError(f"Failed to setup repository: {e}")
 
103
  extracted_folder = next(Path(tmp_dir).iterdir())
104
  extracted_folder.rename(self.config.repo_path)
105
 
106
+ def _install_scripts_requirements(self) -> None:
107
+ req_path = self.config.repo_path / "scripts" / "requirements.txt"
108
+ if req_path.exists():
109
+ subprocess.run(
110
+ [
111
+ sys.executable,
112
+ "-m",
113
+ "pip",
114
+ "install",
115
+ "--no-cache-dir",
116
+ "-r",
117
+ str(req_path),
118
+ ],
119
+ check=True,
120
+ )
121
+
122
+ def _patch_convert_script(self) -> None:
123
+ """Patch transformers.js convert script to force eager attention for better ONNX compatibility."""
124
+ path = self.config.repo_path / "scripts" / "convert.py"
125
+ if not path.exists():
126
+ return
127
+ try:
128
+ text = path.read_text(encoding="utf-8")
129
+ marker = "export_kwargs = dict("
130
+ if marker in text:
131
+ lines = text.splitlines()
132
+ start_idx = None
133
+ paren = 0
134
+ for i, line in enumerate(lines):
135
+ if marker in line:
136
+ start_idx = i
137
+ break
138
+ if start_idx is not None:
139
+ for k in range(start_idx, len(lines)):
140
+ paren += lines[k].count("(")
141
+ paren -= lines[k].count(")")
142
+ if paren <= 0:
143
+ insert_at = k + 1
144
+ lines.insert(
145
+ insert_at,
146
+ " export_kwargs.setdefault('model_kwargs', {})",
147
+ )
148
+ lines.insert(
149
+ insert_at + 1,
150
+ " export_kwargs['model_kwargs']['attn_implementation'] = 'eager'",
151
+ )
152
+ patched = "\n".join(lines) + (
153
+ "\n" if text.endswith("\n") else ""
154
+ )
155
+ path.write_text(patched, encoding="utf-8")
156
+ break
157
+ except Exception:
158
+ pass
159
+
160
  def _run_conversion_subprocess(
161
  self, input_model_id: str, extra_args: List[str] = None
162
  ) -> subprocess.CompletedProcess:
 
166
  "-m",
167
  "scripts.convert",
168
  "--quantize",
 
 
169
  "--model_id",
170
  input_model_id,
171
  ]
 
180
  text=True,
181
  env={
182
  "HF_TOKEN": self.config.hf_token,
183
+ "TRANSFORMERS_ATTENTION_IMPLEMENTATION": "eager",
184
+ "PYTORCH_SDP_KERNEL": "math",
185
  },
186
  )
187
 
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
  huggingface_hub==0.35.3
2
  streamlit==1.50.0
3
  onnxscript==0.5.4
4
- transformers[torch]==4.49.0
 
 
 
5
  onnxruntime==1.20.1
6
  optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
7
  onnx==1.17.0
 
1
  huggingface_hub==0.35.3
2
  streamlit==1.50.0
3
  onnxscript==0.5.4
4
+ onnxconverter_common==1.16.0
5
+ onnx_graphsurgeon==0.5.8
6
+ torch==2.5.1
7
+ transformers==4.49.0
8
  onnxruntime==1.20.1
9
  optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
10
  onnx==1.17.0