Spaces:
Runtime error
Runtime error
zerogpu
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import argparse
|
| 2 |
import glob
|
| 3 |
import json
|
|
@@ -94,11 +95,12 @@ def create_msg(name, data):
|
|
| 94 |
def send_msgs(msgs):
|
| 95 |
return json.dumps(msgs)
|
| 96 |
|
| 97 |
-
|
| 98 |
def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
| 99 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
| 100 |
gen_events, temp, top_p, top_k, allow_cc):
|
| 101 |
model = models[model_name]
|
|
|
|
| 102 |
tokenizer = model.tokenizer
|
| 103 |
bpm = int(bpm)
|
| 104 |
if time_sig == "auto":
|
|
@@ -300,10 +302,9 @@ if __name__ == "__main__":
|
|
| 300 |
for name, (repo_id, path, config) in models_info.items():
|
| 301 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 302 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
| 303 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
| 304 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 305 |
model.load_state_dict(state_dict, strict=False)
|
| 306 |
-
model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
|
| 307 |
models[name] = model
|
| 308 |
|
| 309 |
load_javascript()
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import argparse
|
| 3 |
import glob
|
| 4 |
import json
|
|
|
|
| 95 |
def send_msgs(msgs):
|
| 96 |
return json.dumps(msgs)
|
| 97 |
|
| 98 |
+
@spaces.GPU()
|
| 99 |
def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm, time_sig, key_sig, mid, midi_events,
|
| 100 |
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
| 101 |
gen_events, temp, top_p, top_k, allow_cc):
|
| 102 |
model = models[model_name]
|
| 103 |
+
model.to(device=opt.device, dtype=torch.bfloat16 if opt.device == "cuda" else torch.float32).eval()
|
| 104 |
tokenizer = model.tokenizer
|
| 105 |
bpm = int(bpm)
|
| 106 |
if time_sig == "auto":
|
|
|
|
| 302 |
for name, (repo_id, path, config) in models_info.items():
|
| 303 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 304 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
| 305 |
+
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 306 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 307 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
| 308 |
models[name] = model
|
| 309 |
|
| 310 |
load_javascript()
|