Spaces:
Runtime error
Runtime error
fix midi visualizer
Browse files- app.py +20 -29
- javascript/app.js +4 -0
app.py
CHANGED
|
@@ -111,7 +111,15 @@ def create_msg(name, data):
|
|
| 111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
|
|
|
| 115 |
mid_seq = []
|
| 116 |
gen_events = int(gen_events)
|
| 117 |
max_len = gen_events
|
|
@@ -146,7 +154,7 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 146 |
init_msgs = [create_msg("visualizer_clear", None)]
|
| 147 |
for tokens in mid_seq:
|
| 148 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 149 |
-
yield mid_seq, None, None, init_msgs
|
| 150 |
model = models[model_name]
|
| 151 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 152 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
@@ -155,22 +163,22 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te
|
|
| 155 |
token_seq = token_seq.tolist()
|
| 156 |
mid_seq.append(token_seq)
|
| 157 |
event = tokenizer.tokens2event(token_seq)
|
| 158 |
-
yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
|
| 159 |
mid = tokenizer.detokenize(mid_seq)
|
| 160 |
with open(f"output.mid", 'wb') as f:
|
| 161 |
f.write(MIDI.score2midi(mid))
|
| 162 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 163 |
-
yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
| 164 |
|
| 165 |
|
| 166 |
-
def cancel_run(mid_seq):
|
| 167 |
if mid_seq is None:
|
| 168 |
return None, None, []
|
| 169 |
mid = tokenizer.detokenize(mid_seq)
|
| 170 |
with open(f"output.mid", 'wb') as f:
|
| 171 |
f.write(MIDI.score2midi(mid))
|
| 172 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 173 |
-
return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
|
| 174 |
|
| 175 |
|
| 176 |
def load_javascript(dir="javascript"):
|
|
@@ -191,25 +199,6 @@ def load_javascript(dir="javascript"):
|
|
| 191 |
gr.routes.templates.TemplateResponse = template_response
|
| 192 |
|
| 193 |
|
| 194 |
-
# JSMsgReceiver
|
| 195 |
-
Textbox_postprocess_ori = gr.Textbox.postprocess
|
| 196 |
-
|
| 197 |
-
msg_history = []
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# the change event may not trigger every time, so send msg history to avoid msg missing.
|
| 201 |
-
def JSMsgReceiver_postprocess(self, y):
|
| 202 |
-
global msg_history
|
| 203 |
-
if self.elem_id == "msg_receiver" and y:
|
| 204 |
-
msg_history.append(y)
|
| 205 |
-
if len(msg_history) > 50:
|
| 206 |
-
msg_history = msg_history[1:]
|
| 207 |
-
y = json.dumps(msg_history)
|
| 208 |
-
return Textbox_postprocess_ori(self, y)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
gr.Textbox.postprocess = JSMsgReceiver_postprocess
|
| 212 |
-
|
| 213 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 214 |
40: "Blush", 48: "Orchestra"}
|
| 215 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
|
@@ -223,8 +212,8 @@ if __name__ == "__main__":
|
|
| 223 |
opt = parser.parse_args()
|
| 224 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 225 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 226 |
-
"j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 227 |
-
"touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 228 |
}
|
| 229 |
models = {}
|
| 230 |
tokenizer = MIDITokenizer()
|
|
@@ -247,6 +236,7 @@ if __name__ == "__main__":
|
|
| 247 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 248 |
" for faster running and longer generation"
|
| 249 |
)
|
|
|
|
| 250 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
| 251 |
js_msg.change(None, [js_msg], [], js="""
|
| 252 |
(msg_json) =>{
|
|
@@ -302,6 +292,7 @@ if __name__ == "__main__":
|
|
| 302 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
| 303 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
| 304 |
input_allow_cc],
|
| 305 |
-
[output_midi_seq, output_midi, output_audio, js_msg]
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
| 111 |
return {"name": name, "data": data, "uuid": uuid.uuid4().hex}
|
| 112 |
|
| 113 |
|
| 114 |
+
def send_msgs(msgs, msgs_history):
|
| 115 |
+
msgs_history.append(msgs)
|
| 116 |
+
if len(msgs_history) > 50:
|
| 117 |
+
msgs_history.pop(0)
|
| 118 |
+
return json.dumps(msgs_history)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, temp, top_p, top_k, allow_cc):
|
| 122 |
+
msgs_history = []
|
| 123 |
mid_seq = []
|
| 124 |
gen_events = int(gen_events)
|
| 125 |
max_len = gen_events
|
|
|
|
| 154 |
init_msgs = [create_msg("visualizer_clear", None)]
|
| 155 |
for tokens in mid_seq:
|
| 156 |
init_msgs.append(create_msg("visualizer_append", tokenizer.tokens2event(tokens)))
|
| 157 |
+
yield mid_seq, None, None, send_msgs(init_msgs, msgs_history), msgs_history
|
| 158 |
model = models[model_name]
|
| 159 |
generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 160 |
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
|
|
|
| 163 |
token_seq = token_seq.tolist()
|
| 164 |
mid_seq.append(token_seq)
|
| 165 |
event = tokenizer.tokens2event(token_seq)
|
| 166 |
+
yield mid_seq, None, None, send_msgs([create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])], msgs_history), msgs_history
|
| 167 |
mid = tokenizer.detokenize(mid_seq)
|
| 168 |
with open(f"output.mid", 'wb') as f:
|
| 169 |
f.write(MIDI.score2midi(mid))
|
| 170 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 171 |
+
yield mid_seq, "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history), msgs_history
|
| 172 |
|
| 173 |
|
| 174 |
+
def cancel_run(mid_seq, msgs_history):
|
| 175 |
if mid_seq is None:
|
| 176 |
return None, None, []
|
| 177 |
mid = tokenizer.detokenize(mid_seq)
|
| 178 |
with open(f"output.mid", 'wb') as f:
|
| 179 |
f.write(MIDI.score2midi(mid))
|
| 180 |
audio = synthesis(MIDI.score2opus(mid), soundfont_path)
|
| 181 |
+
return "output.mid", (44100, audio), send_msgs([create_msg("visualizer_end", None)], msgs_history)
|
| 182 |
|
| 183 |
|
| 184 |
def load_javascript(dir="javascript"):
|
|
|
|
| 199 |
gr.routes.templates.TemplateResponse = template_response
|
| 200 |
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
|
| 203 |
40: "Blush", 48: "Orchestra"}
|
| 204 |
patch2number = {v: k for k, v in MIDI.Number2patch.items()}
|
|
|
|
| 212 |
opt = parser.parse_args()
|
| 213 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 214 |
models_info = {"generic pretrain model": ["skytnt/midi-model", ""],
|
| 215 |
+
# "j-pop finetune model": ["skytnt/midi-model-ft", "jpop/"],
|
| 216 |
+
# "touhou finetune model": ["skytnt/midi-model-ft", "touhou/"],
|
| 217 |
}
|
| 218 |
models = {}
|
| 219 |
tokenizer = MIDITokenizer()
|
|
|
|
| 236 |
"(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
|
| 237 |
" for faster running and longer generation"
|
| 238 |
)
|
| 239 |
+
js_msg_history_state = gr.State(value=[])
|
| 240 |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
|
| 241 |
js_msg.change(None, [js_msg], [], js="""
|
| 242 |
(msg_json) =>{
|
|
|
|
| 292 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_midi,
|
| 293 |
input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
|
| 294 |
input_allow_cc],
|
| 295 |
+
[output_midi_seq, output_midi, output_audio, js_msg, js_msg_history_state],
|
| 296 |
+
concurrency_limit=3)
|
| 297 |
+
stop_btn.click(cancel_run, [output_midi_seq, js_msg_history_state], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
| 298 |
+
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
|
javascript/app.js
CHANGED
|
@@ -316,6 +316,10 @@ class MidiVisualizer extends HTMLElement{
|
|
| 316 |
audio.addEventListener("pause", (event)=>{
|
| 317 |
this.pause()
|
| 318 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
}
|
| 320 |
|
| 321 |
bindWaveformCursor(cursor){
|
|
|
|
| 316 |
audio.addEventListener("pause", (event)=>{
|
| 317 |
this.pause()
|
| 318 |
})
|
| 319 |
+
audio.addEventListener("loadedmetadata", (event)=>{
|
| 320 |
+
//I don't know why the calculated totalTimeMs is different from audio.duration*10**3
|
| 321 |
+
this.totalTimeMs = audio.duration*10**3;
|
| 322 |
+
})
|
| 323 |
}
|
| 324 |
|
| 325 |
bindWaveformCursor(cursor){
|