Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,7 @@ def sample_top_p_k(probs, p, k):
|
|
| 41 |
return next_token
|
| 42 |
|
| 43 |
|
| 44 |
-
def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 45 |
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
| 46 |
if disable_channels is not None:
|
| 47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
@@ -63,7 +63,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 63 |
with bar:
|
| 64 |
while cur_len < max_len:
|
| 65 |
end = False
|
| 66 |
-
hidden =
|
| 67 |
next_token_seq = np.empty((1, 0), dtype=np.int64)
|
| 68 |
event_name = ""
|
| 69 |
for i in range(max_token_seq):
|
|
@@ -81,7 +81,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 81 |
if param_name == "channel":
|
| 82 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
| 83 |
mask[mask_ids] = 1
|
| 84 |
-
logits =
|
| 85 |
scores = softmax(logits / temp, -1) * mask
|
| 86 |
sample = sample_top_p_k(scores, top_p, top_k)
|
| 87 |
if i == 0:
|
|
@@ -107,7 +107,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 107 |
break
|
| 108 |
|
| 109 |
|
| 110 |
-
def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
| 111 |
mid_seq = []
|
| 112 |
max_len = int(gen_events)
|
| 113 |
img_len = 1024
|
|
@@ -172,7 +172,8 @@ def run(tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, t
|
|
| 172 |
for token_seq in mid:
|
| 173 |
mid_seq.append(token_seq)
|
| 174 |
draw_event(token_seq)
|
| 175 |
-
|
|
|
|
| 176 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 177 |
disable_channels=disable_channels)
|
| 178 |
for token_seq in generator:
|
|
@@ -208,13 +209,18 @@ if __name__ == "__main__":
|
|
| 208 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 209 |
opt = parser.parse_args()
|
| 210 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
tokenizer = MIDITokenizer()
|
| 215 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
app = gr.Blocks()
|
| 220 |
with app:
|
|
@@ -229,6 +235,8 @@ if __name__ == "__main__":
|
|
| 229 |
|
| 230 |
tab_select = gr.Variable(value=0)
|
| 231 |
with gr.Tabs():
|
|
|
|
|
|
|
| 232 |
with gr.TabItem("instrument prompt") as tab1:
|
| 233 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
| 234 |
multiselect=True, max_choices=15, type="value")
|
|
@@ -260,7 +268,7 @@ if __name__ == "__main__":
|
|
| 260 |
with gr.Accordion("options", open=False):
|
| 261 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 262 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
| 263 |
-
input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=
|
| 264 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 265 |
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
|
| 266 |
run_btn = gr.Button("generate", variant="primary")
|
|
@@ -269,8 +277,8 @@ if __name__ == "__main__":
|
|
| 269 |
output_midi_img = gr.Image(label="output image")
|
| 270 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 271 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
| 272 |
-
run_event = run_btn.click(run, [tab_select, input_instruments, input_drum_kit, input_midi,
|
| 273 |
-
input_gen_events, input_temp, input_top_p, input_top_k,
|
| 274 |
input_allow_cc],
|
| 275 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
| 276 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|
|
|
|
| 41 |
return next_token
|
| 42 |
|
| 43 |
|
| 44 |
+
def generate(model, prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
| 45 |
disable_patch_change=False, disable_control_change=False, disable_channels=None):
|
| 46 |
if disable_channels is not None:
|
| 47 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
|
|
| 63 |
with bar:
|
| 64 |
while cur_len < max_len:
|
| 65 |
end = False
|
| 66 |
+
hidden = model[0].run(None, {'x': input_tensor})[0][:, -1]
|
| 67 |
next_token_seq = np.empty((1, 0), dtype=np.int64)
|
| 68 |
event_name = ""
|
| 69 |
for i in range(max_token_seq):
|
|
|
|
| 81 |
if param_name == "channel":
|
| 82 |
mask_ids = [i for i in mask_ids if i not in disable_channels]
|
| 83 |
mask[mask_ids] = 1
|
| 84 |
+
logits = model[1].run(None, {'x': next_token_seq, "hidden": hidden})[0][:, -1:]
|
| 85 |
scores = softmax(logits / temp, -1) * mask
|
| 86 |
sample = sample_top_p_k(scores, top_p, top_k)
|
| 87 |
if i == 0:
|
|
|
|
| 107 |
break
|
| 108 |
|
| 109 |
|
| 110 |
+
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
| 111 |
mid_seq = []
|
| 112 |
max_len = int(gen_events)
|
| 113 |
img_len = 1024
|
|
|
|
| 172 |
for token_seq in mid:
|
| 173 |
mid_seq.append(token_seq)
|
| 174 |
draw_event(token_seq)
|
| 175 |
+
model = models[model_name]
|
| 176 |
+
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 177 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 178 |
disable_channels=disable_channels)
|
| 179 |
for token_seq in generator:
|
|
|
|
| 209 |
parser.add_argument("--max-gen", type=int, default=1024, help="max")
|
| 210 |
opt = parser.parse_args()
|
| 211 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 212 |
+
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 213 |
+
"symphony finetune model": ["skytnt/midi-model-ft", "symphony/"],
|
| 214 |
+
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"]}
|
| 215 |
+
models = {}
|
| 216 |
tokenizer = MIDITokenizer()
|
| 217 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 218 |
+
for name, (repo_id, path) in models_info.items():
|
| 219 |
+
model_base_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_base.onnx")
|
| 220 |
+
model_token_path = hf_hub_download(repo_id=repo_id, filename=f"{path}onnx/model_token.onnx")
|
| 221 |
+
model_base = rt.InferenceSession(model_base_path, providers=providers)
|
| 222 |
+
model_token = rt.InferenceSession(model_token_path, providers=providers)
|
| 223 |
+
models[name] = [model_base, model_token]
|
| 224 |
|
| 225 |
app = gr.Blocks()
|
| 226 |
with app:
|
|
|
|
| 235 |
|
| 236 |
tab_select = gr.Variable(value=0)
|
| 237 |
with gr.Tabs():
|
| 238 |
+
input_model = gr.Dropdown(label="select model", choices=list(models.keys()),
|
| 239 |
+
type="value", value=list(models.keys())[0])
|
| 240 |
with gr.TabItem("instrument prompt") as tab1:
|
| 241 |
input_instruments = gr.Dropdown(label="instruments (auto if empty)", choices=list(patch2number.keys()),
|
| 242 |
multiselect=True, max_choices=15, type="value")
|
|
|
|
| 268 |
with gr.Accordion("options", open=False):
|
| 269 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
| 270 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.98)
|
| 271 |
+
input_top_k = gr.Slider(label="top k", minimum=1, maximum=20, step=1, value=20)
|
| 272 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
| 273 |
example3 = gr.Examples([[1, 0.98, 12], [1.2, 0.95, 8]], [input_temp, input_top_p, input_top_k])
|
| 274 |
run_btn = gr.Button("generate", variant="primary")
|
|
|
|
| 277 |
output_midi_img = gr.Image(label="output image")
|
| 278 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 279 |
output_audio = gr.Audio(label="output audio", format="mp3")
|
| 280 |
+
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
| 281 |
+
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
| 282 |
input_allow_cc],
|
| 283 |
[output_midi_seq, output_midi_img, output_midi, output_audio])
|
| 284 |
stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio], cancels=run_event, queue=False)
|