Spaces:
Runtime error
Runtime error
add retry to avoid runtime error
Browse files
app.py
CHANGED
|
@@ -199,6 +199,18 @@ def load_javascript(dir="javascript"):
|
|
| 199 |
gr.routes.templates.TemplateResponse = template_response
|
| 200 |
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 203 |
40: "Blush", 48: "Orchestra"}
|
| 204 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
|
@@ -210,7 +222,7 @@ if __name__ == "__main__":
|
|
| 210 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 211 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 212 |
opt = parser.parse_args()
|
| 213 |
-
soundfont_path =
|
| 214 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 215 |
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 216 |
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
|
@@ -219,8 +231,8 @@ if __name__ == "__main__":
|
|
| 219 |
tokenizer = MIDITokenizer()
|
| 220 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 221 |
for name, (repo_id, path) in models_info.items():
|
| 222 |
-
model_base_path =
|
| 223 |
-
model_token_path =
|
| 224 |
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
| 225 |
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
| 226 |
models[name] = [model_base, model_token]
|
|
|
|
| 199 |
gr.routes.templates.TemplateResponse = template_response
|
| 200 |
|
| 201 |
|
| 202 |
+
def hf_hub_download_retry(repo_id, filename):
|
| 203 |
+
retry = 0
|
| 204 |
+
err = None
|
| 205 |
+
while retry < 30:
|
| 206 |
+
try:
|
| 207 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
err = e
|
| 210 |
+
retry += 1
|
| 211 |
+
if err:
|
| 212 |
+
raise err
|
| 213 |
+
|
| 214 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 215 |
40: "Blush", 48: "Orchestra"}
|
| 216 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
|
|
|
| 222 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 223 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 224 |
opt = parser.parse_args()
|
| 225 |
+
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 226 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 227 |
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 228 |
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
|
|
|
| 231 |
tokenizer = MIDITokenizer()
|
| 232 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 233 |
for name, (repo_id, path) in models_info.items():
|
| 234 |
+
model_base_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 235 |
+
model_token_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
| 236 |
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
| 237 |
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
| 238 |
models[name] = [model_base, model_token]
|