Spaces:
Runtime error
Runtime error
update tokenizer
Browse files- app.py +15 -6
- midi_tokenizer.py +111 -9
app.py
CHANGED
|
@@ -121,7 +121,8 @@ def send_msgs(msgs):
|
|
| 121 |
return json.dumps(msgs)
|
| 122 |
|
| 123 |
|
| 124 |
-
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
|
|
|
| 125 |
gen_events, temp, top_p, top_k, allow_cc):
|
| 126 |
mid_seq = []
|
| 127 |
bpm = int(bpm)
|
|
@@ -153,8 +154,11 @@ def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events, midi_opt,
|
|
| 153 |
disable_patch_change = True
|
| 154 |
disable_channels = [i for i in range(16) if i not in patches]
|
| 155 |
elif mid is not None:
|
| 156 |
-
eps = 4 if
|
| 157 |
-
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps
|
|
|
|
|
|
|
|
|
|
| 158 |
mid = np.asarray(mid, dtype=np.int64)
|
| 159 |
mid = mid[:int(midi_events)]
|
| 160 |
for token_seq in mid:
|
|
@@ -306,7 +310,10 @@ if __name__ == "__main__":
|
|
| 306 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
| 307 |
step=1,
|
| 308 |
value=128)
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
| 310 |
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
| 311 |
[input_midi, input_midi_events])
|
| 312 |
|
|
@@ -330,8 +337,10 @@ if __name__ == "__main__":
|
|
| 330 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 331 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 332 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
| 333 |
-
input_midi, input_midi_events,
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
| 336 |
concurrency_limit=3)
|
| 337 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
|
|
|
| 121 |
return json.dumps(msgs)
|
| 122 |
|
| 123 |
|
| 124 |
+
def run(model_name, tab, instruments, drum_kit, bpm, mid, midi_events,
|
| 125 |
+
reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels, seed, seed_rand,
|
| 126 |
gen_events, temp, top_p, top_k, allow_cc):
|
| 127 |
mid_seq = []
|
| 128 |
bpm = int(bpm)
|
|
|
|
| 154 |
disable_patch_change = True
|
| 155 |
disable_channels = [i for i in range(16) if i not in patches]
|
| 156 |
elif mid is not None:
|
| 157 |
+
eps = 4 if reduce_cc_st else 0
|
| 158 |
+
mid = tokenizer.tokenize(MIDI.midi2score(mid), cc_eps=eps, tempo_eps=eps,
|
| 159 |
+
remap_track_channel=remap_track_channel,
|
| 160 |
+
add_default_instr=add_default_instr,
|
| 161 |
+
remove_empty_channels=remove_empty_channels)
|
| 162 |
mid = np.asarray(mid, dtype=np.int64)
|
| 163 |
mid = mid[:int(midi_events)]
|
| 164 |
for token_seq in mid:
|
|
|
|
| 310 |
input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
|
| 311 |
step=1,
|
| 312 |
value=128)
|
| 313 |
+
input_reduce_cc_st = gr.Checkbox(label="reduce control_change and set_tempo events", value=True)
|
| 314 |
+
input_remap_track_channel = gr.Checkbox(label="remap tracks and channels to have only one channel per track", value=True)
|
| 315 |
+
input_add_default_instr = gr.Checkbox(label="add a default instrument to channels that don't have an instrument", value=True)
|
| 316 |
+
input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
|
| 317 |
example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
|
| 318 |
[input_midi, input_midi_events])
|
| 319 |
|
|
|
|
| 337 |
output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
|
| 338 |
output_midi = gr.File(label="output midi", file_types=[".mid"])
|
| 339 |
run_event = run_btn.click(run, [input_model, tab_select, input_instruments, input_drum_kit, input_bpm,
|
| 340 |
+
input_midi, input_midi_events, input_reduce_cc_st, input_remap_track_channel,
|
| 341 |
+
input_add_default_instr, input_remove_empty_channels, input_seed,
|
| 342 |
+
input_seed_rand, input_gen_events, input_temp, input_top_p, input_top_k,
|
| 343 |
+
input_allow_cc],
|
| 344 |
[output_midi_seq, output_midi, output_audio, input_seed, js_msg],
|
| 345 |
concurrency_limit=3)
|
| 346 |
stop_btn.click(cancel_run, [output_midi_seq], [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
|
midi_tokenizer.py
CHANGED
|
@@ -42,9 +42,16 @@ class MIDITokenizer:
|
|
| 42 |
tempo = int((60 / bpm) * 10 ** 6)
|
| 43 |
return tempo
|
| 44 |
|
| 45 |
-
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4
|
|
|
|
| 46 |
ticks_per_beat = midi_score[0]
|
| 47 |
event_list = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 49 |
last_notes = {}
|
| 50 |
patch_dict = {}
|
|
@@ -53,9 +60,18 @@ class MIDITokenizer:
|
|
| 53 |
for event in track:
|
| 54 |
if event[0] not in self.events:
|
| 55 |
continue
|
|
|
|
| 56 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
| 57 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
| 58 |
if event[0] == "note":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
| 60 |
elif event[0] == "set_tempo":
|
| 61 |
if new_event[4] == 0: # invalid tempo
|
|
@@ -68,12 +84,18 @@ class MIDITokenizer:
|
|
| 68 |
key = tuple(new_event[:-1])
|
| 69 |
if event[0] == "patch_change":
|
| 70 |
c, p = event[2:]
|
|
|
|
|
|
|
| 71 |
last_p = patch_dict.setdefault(c, None)
|
| 72 |
if last_p == p:
|
| 73 |
continue
|
| 74 |
patch_dict[c] = p
|
|
|
|
|
|
|
| 75 |
elif event[0] == "control_change":
|
| 76 |
c, cc, v = event[2:]
|
|
|
|
|
|
|
| 77 |
last_v = control_dict.setdefault((c, cc), 0)
|
| 78 |
if abs(last_v - v) < cc_eps:
|
| 79 |
continue
|
|
@@ -84,6 +106,13 @@ class MIDITokenizer:
|
|
| 84 |
continue
|
| 85 |
last_tempo = tempo
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
| 88 |
cp = tuple(new_event[5:7])
|
| 89 |
if cp in last_notes:
|
|
@@ -95,8 +124,79 @@ class MIDITokenizer:
|
|
| 95 |
last_notes[cp] = (key, new_event)
|
| 96 |
event_list[key] = new_event
|
| 97 |
event_list = list(event_list.values())
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
setup_events = {}
|
| 101 |
notes_in_setup = False
|
| 102 |
for i, event in enumerate(event_list): # optimise setup
|
|
@@ -113,7 +213,7 @@ class MIDITokenizer:
|
|
| 113 |
pre_event = event_list[i - 1]
|
| 114 |
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
| 115 |
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
| 116 |
-
event_list = sorted(setup_events.values(), key=
|
| 117 |
break
|
| 118 |
else:
|
| 119 |
if event[0] == "note":
|
|
@@ -122,7 +222,10 @@ class MIDITokenizer:
|
|
| 122 |
setup_events[key] = new_event
|
| 123 |
|
| 124 |
last_t1 = 0
|
|
|
|
| 125 |
for event in event_list:
|
|
|
|
|
|
|
| 126 |
cur_t1 = event[1]
|
| 127 |
event[1] = event[1] - last_t1
|
| 128 |
tokens = self.event2tokens(event)
|
|
@@ -181,7 +284,7 @@ class MIDITokenizer:
|
|
| 181 |
if track_idx not in tracks_dict:
|
| 182 |
tracks_dict[track_idx] = []
|
| 183 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
| 184 |
-
tracks = list(tracks_dict.
|
| 185 |
|
| 186 |
for i in range(len(tracks)): # to eliminate note overlap
|
| 187 |
track = tracks[i]
|
|
@@ -292,7 +395,6 @@ class MIDITokenizer:
|
|
| 292 |
notes_bandwidth_list = []
|
| 293 |
instruments = {}
|
| 294 |
piano_channels = []
|
| 295 |
-
undef_instrument = False
|
| 296 |
abs_t1 = 0
|
| 297 |
last_t = 0
|
| 298 |
for tsi, tokens in enumerate(midi_seq):
|
|
@@ -309,7 +411,9 @@ class MIDITokenizer:
|
|
| 309 |
time_hist[t2] += 1
|
| 310 |
if c != 9: # ignore drum channel
|
| 311 |
if c not in instruments:
|
| 312 |
-
|
|
|
|
|
|
|
| 313 |
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
| 314 |
if last_t != t:
|
| 315 |
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
|
@@ -330,8 +434,6 @@ class MIDITokenizer:
|
|
| 330 |
reasons.append("total_min")
|
| 331 |
if total_notes > total_notes_max:
|
| 332 |
reasons.append("total_max")
|
| 333 |
-
if undef_instrument:
|
| 334 |
-
reasons.append("undef_instr")
|
| 335 |
if len(note_windows) == 0 and total_notes > 0:
|
| 336 |
reasons.append("drum_only")
|
| 337 |
if reasons:
|
|
|
|
| 42 |
tempo = int((60 / bpm) * 10 ** 6)
|
| 43 |
return tempo
|
| 44 |
|
| 45 |
+
def tokenize(self, midi_score, add_bos_eos=True, cc_eps=4, tempo_eps=4,
|
| 46 |
+
remap_track_channel=False, add_default_instr=False, remove_empty_channels=False):
|
| 47 |
ticks_per_beat = midi_score[0]
|
| 48 |
event_list = {}
|
| 49 |
+
track_idx_map = {i: dict() for i in range(16)}
|
| 50 |
+
track_idx_dict = {}
|
| 51 |
+
channels = []
|
| 52 |
+
patch_channels = []
|
| 53 |
+
empty_channels = [True]*16
|
| 54 |
+
channel_note_tracks = {i: list() for i in range(16)}
|
| 55 |
for track_idx, track in enumerate(midi_score[1:129]):
|
| 56 |
last_notes = {}
|
| 57 |
patch_dict = {}
|
|
|
|
| 60 |
for event in track:
|
| 61 |
if event[0] not in self.events:
|
| 62 |
continue
|
| 63 |
+
c = -1
|
| 64 |
t = round(16 * event[1] / ticks_per_beat) # quantization
|
| 65 |
new_event = [event[0], t // 16, t % 16, track_idx] + event[2:]
|
| 66 |
if event[0] == "note":
|
| 67 |
+
c = event[3]
|
| 68 |
+
if c > 15 or c < 0:
|
| 69 |
+
continue
|
| 70 |
+
empty_channels[c] = False
|
| 71 |
+
track_idx_dict.setdefault(c, track_idx)
|
| 72 |
+
note_tracks = channel_note_tracks[c]
|
| 73 |
+
if track_idx not in note_tracks:
|
| 74 |
+
note_tracks.append(track_idx)
|
| 75 |
new_event[4] = max(1, round(16 * new_event[4] / ticks_per_beat))
|
| 76 |
elif event[0] == "set_tempo":
|
| 77 |
if new_event[4] == 0: # invalid tempo
|
|
|
|
| 84 |
key = tuple(new_event[:-1])
|
| 85 |
if event[0] == "patch_change":
|
| 86 |
c, p = event[2:]
|
| 87 |
+
if c > 15 or c < 0:
|
| 88 |
+
continue
|
| 89 |
last_p = patch_dict.setdefault(c, None)
|
| 90 |
if last_p == p:
|
| 91 |
continue
|
| 92 |
patch_dict[c] = p
|
| 93 |
+
if c not in patch_channels:
|
| 94 |
+
patch_channels.append(c)
|
| 95 |
elif event[0] == "control_change":
|
| 96 |
c, cc, v = event[2:]
|
| 97 |
+
if c > 15 or c < 0:
|
| 98 |
+
continue
|
| 99 |
last_v = control_dict.setdefault((c, cc), 0)
|
| 100 |
if abs(last_v - v) < cc_eps:
|
| 101 |
continue
|
|
|
|
| 106 |
continue
|
| 107 |
last_tempo = tempo
|
| 108 |
|
| 109 |
+
if c != -1:
|
| 110 |
+
if c not in channels:
|
| 111 |
+
channels.append(c)
|
| 112 |
+
tr_map = track_idx_map[c]
|
| 113 |
+
if track_idx not in tr_map:
|
| 114 |
+
tr_map[track_idx] = 0
|
| 115 |
+
|
| 116 |
if event[0] == "note": # to eliminate note overlap due to quantization
|
| 117 |
cp = tuple(new_event[5:7])
|
| 118 |
if cp in last_notes:
|
|
|
|
| 124 |
last_notes[cp] = (key, new_event)
|
| 125 |
event_list[key] = new_event
|
| 126 |
event_list = list(event_list.values())
|
| 127 |
+
|
| 128 |
+
empty_channels = [c for c in channels if empty_channels[c]]
|
| 129 |
+
|
| 130 |
+
if remap_track_channel:
|
| 131 |
+
patch_channels = []
|
| 132 |
+
channels_count = 0
|
| 133 |
+
channels_map = {9: 9} if 9 in channels else {}
|
| 134 |
+
for c in channels:
|
| 135 |
+
if c == 9:
|
| 136 |
+
continue
|
| 137 |
+
channels_map[c] = channels_count
|
| 138 |
+
channels_count += 1
|
| 139 |
+
if channels_count == 9:
|
| 140 |
+
channels_count = 10
|
| 141 |
+
channels = list(channels_map.values())
|
| 142 |
+
|
| 143 |
+
track_count = 0
|
| 144 |
+
track_idx_map_order = [k for k,v in sorted(list(channels_map.items()), key=lambda x: x[1])]
|
| 145 |
+
for c in track_idx_map_order: # tracks not to remove
|
| 146 |
+
if remove_empty_channels and c in empty_channels:
|
| 147 |
+
continue
|
| 148 |
+
tr_map = track_idx_map[c]
|
| 149 |
+
for track_idx in tr_map:
|
| 150 |
+
note_tracks = channel_note_tracks[c]
|
| 151 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 152 |
+
continue
|
| 153 |
+
track_count += 1
|
| 154 |
+
tr_map[track_idx] = track_count
|
| 155 |
+
for c in track_idx_map_order: # tracks to remove
|
| 156 |
+
if not (remove_empty_channels and c in empty_channels):
|
| 157 |
+
continue
|
| 158 |
+
tr_map = track_idx_map[c]
|
| 159 |
+
for track_idx in tr_map:
|
| 160 |
+
note_tracks = channel_note_tracks[c]
|
| 161 |
+
if not (len(note_tracks) != 0 and track_idx not in note_tracks):
|
| 162 |
+
continue
|
| 163 |
+
track_count += 1
|
| 164 |
+
tr_map[track_idx] = track_count
|
| 165 |
+
|
| 166 |
+
empty_channels = [channels_map[c] for c in empty_channels]
|
| 167 |
+
|
| 168 |
+
for event in event_list:
|
| 169 |
+
name = event[0]
|
| 170 |
+
track_idx = event[3]
|
| 171 |
+
if name == "note":
|
| 172 |
+
c = event[5]
|
| 173 |
+
event[5] = channels_map[c]
|
| 174 |
+
event[3] = track_idx_map[c][track_idx]
|
| 175 |
+
track_idx_dict[event[5]] = event[3]
|
| 176 |
+
elif name == "set_tempo":
|
| 177 |
+
event[3] = 0
|
| 178 |
+
elif name == "control_change" or name == "patch_change":
|
| 179 |
+
c = event[4]
|
| 180 |
+
event[4] = channels_map[c]
|
| 181 |
+
tr_map = track_idx_map[c]
|
| 182 |
+
# move the event to first track of the channel if it's original track is empty
|
| 183 |
+
note_tracks = channel_note_tracks[c]
|
| 184 |
+
if len(note_tracks) != 0 and track_idx not in note_tracks:
|
| 185 |
+
track_idx = channel_note_tracks[c][0]
|
| 186 |
+
new_track_idx = tr_map.setdefault(track_idx, next(iter(tr_map.values())))
|
| 187 |
+
event[3] = new_track_idx
|
| 188 |
+
if name == "patch_change" and event[4] not in patch_channels:
|
| 189 |
+
patch_channels.append(event[4])
|
| 190 |
+
|
| 191 |
+
if add_default_instr:
|
| 192 |
+
for c in channels:
|
| 193 |
+
if c not in patch_channels:
|
| 194 |
+
event_list.append(["patch_change", 0,0, track_idx_dict[c], c, 0])
|
| 195 |
+
|
| 196 |
+
events_name_order = {"set_tempo":0, "patch_change":1, "control_change":2, "note":3}
|
| 197 |
+
events_order = lambda e: e[1:4] + [events_name_order[e[0]]]
|
| 198 |
+
event_list = sorted(event_list, key=events_order)
|
| 199 |
+
|
| 200 |
setup_events = {}
|
| 201 |
notes_in_setup = False
|
| 202 |
for i, event in enumerate(event_list): # optimise setup
|
|
|
|
| 213 |
pre_event = event_list[i - 1]
|
| 214 |
has_pre = event[1] + event[2] == pre_event[1] + pre_event[2]
|
| 215 |
if (event[0] == "note" and not has_next) or (notes_in_setup and not has_pre) :
|
| 216 |
+
event_list = sorted(setup_events.values(), key=events_order) + event_list[i:]
|
| 217 |
break
|
| 218 |
else:
|
| 219 |
if event[0] == "note":
|
|
|
|
| 222 |
setup_events[key] = new_event
|
| 223 |
|
| 224 |
last_t1 = 0
|
| 225 |
+
midi_seq = []
|
| 226 |
for event in event_list:
|
| 227 |
+
if remove_empty_channels and event[0] in ["control_change", "patch_change"] and event[4] in empty_channels:
|
| 228 |
+
continue
|
| 229 |
cur_t1 = event[1]
|
| 230 |
event[1] = event[1] - last_t1
|
| 231 |
tokens = self.event2tokens(event)
|
|
|
|
| 284 |
if track_idx not in tracks_dict:
|
| 285 |
tracks_dict[track_idx] = []
|
| 286 |
tracks_dict[track_idx].append([event[0], t] + event[4:])
|
| 287 |
+
tracks = [tr for idx, tr in sorted(list(tracks_dict.items()), key=lambda it: it[0])]
|
| 288 |
|
| 289 |
for i in range(len(tracks)): # to eliminate note overlap
|
| 290 |
track = tracks[i]
|
|
|
|
| 395 |
notes_bandwidth_list = []
|
| 396 |
instruments = {}
|
| 397 |
piano_channels = []
|
|
|
|
| 398 |
abs_t1 = 0
|
| 399 |
last_t = 0
|
| 400 |
for tsi, tokens in enumerate(midi_seq):
|
|
|
|
| 411 |
time_hist[t2] += 1
|
| 412 |
if c != 9: # ignore drum channel
|
| 413 |
if c not in instruments:
|
| 414 |
+
instruments[c] = 0
|
| 415 |
+
if c not in piano_channels:
|
| 416 |
+
piano_channels.append(c)
|
| 417 |
note_windows.setdefault(abs_t1 // note_window_size, []).append(p)
|
| 418 |
if last_t != t:
|
| 419 |
notes_sametime = [(et, p_) for et, p_ in notes_sametime if et > last_t]
|
|
|
|
| 434 |
reasons.append("total_min")
|
| 435 |
if total_notes > total_notes_max:
|
| 436 |
reasons.append("total_max")
|
|
|
|
|
|
|
| 437 |
if len(note_windows) == 0 and total_notes > 0:
|
| 438 |
reasons.append("drum_only")
|
| 439 |
if reasons:
|