Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -71,8 +71,9 @@ print('=' * 70)
|
|
| 71 |
print('Loading MIDI GAS processed scores dataset...')
|
| 72 |
|
| 73 |
midi_gas_ps_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
|
| 77 |
midi_gas_ps = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_ps_pickle)
|
| 78 |
|
|
@@ -86,10 +87,11 @@ print('=' * 70)
|
|
| 86 |
print('Loading MIDI GAS processed scores embeddings dataset...')
|
| 87 |
|
| 88 |
midi_gas_pse_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
-
midi_gas_pse = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_pse_pickle)
|
| 93 |
|
| 94 |
print('=' * 70)
|
| 95 |
print('Done!')
|
|
@@ -139,7 +141,7 @@ print('=' * 70)
|
|
| 139 |
|
| 140 |
def load_midi(input_midi):
|
| 141 |
|
| 142 |
-
raw_score = TMIDIX.midi2single_track_ms_score(
|
| 143 |
|
| 144 |
escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
| 145 |
|
|
@@ -213,6 +215,38 @@ def get_embeddings(inputs):
|
|
| 213 |
|
| 214 |
#==================================================================================
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# @spaces.GPU
|
| 217 |
def Classify_MIDI_Genre(input_midi,
|
| 218 |
input_melody,
|
|
@@ -224,10 +258,6 @@ def Classify_MIDI_Genre(input_midi,
|
|
| 224 |
|
| 225 |
#===============================================================================
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
#===============================================================================
|
| 230 |
-
|
| 231 |
print('=' * 70)
|
| 232 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 233 |
start_time = reqtime.time()
|
|
@@ -250,52 +280,33 @@ def Classify_MIDI_Genre(input_midi,
|
|
| 250 |
|
| 251 |
print('=' * 70)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
print('Prepping melody...')
|
| 256 |
|
| 257 |
-
|
| 258 |
-
inp_mel = 'Custom MIDI'
|
| 259 |
-
score, score_list = load_midi(input_midi.name, melody_patch, use_nth_note)
|
| 260 |
|
| 261 |
-
|
| 262 |
-
mel_list = [m[0].lower() for m in popular_hook_melodies]
|
| 263 |
|
| 264 |
-
|
| 265 |
|
| 266 |
-
|
| 267 |
-
if input_melody.lower().strip() in m:
|
| 268 |
-
inp_mel = m.title()
|
| 269 |
-
break
|
| 270 |
|
| 271 |
-
|
| 272 |
-
score_list = [[[score[i]], score[i+1:i+3]] for i in range(0, len(score)-3, 3)]
|
| 273 |
|
| 274 |
-
|
| 275 |
|
| 276 |
-
print('Sample
|
| 277 |
|
| 278 |
-
|
| 279 |
|
| 280 |
print('=' * 70)
|
| 281 |
-
print('
|
| 282 |
|
| 283 |
model.to(device_type)
|
| 284 |
model.eval()
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
start_score_seq = [1792] + score + [1793]
|
| 289 |
-
|
| 290 |
-
#==================================================================
|
| 291 |
|
| 292 |
-
input_seq = generate_full_seq(start_score_seq,
|
| 293 |
-
max_toks=MAX_GEN_TOKS,
|
| 294 |
-
temperature=model_temperature,
|
| 295 |
-
top_k_value=model_sampling_top_k,
|
| 296 |
-
)
|
| 297 |
|
| 298 |
-
final_song = input_seq[len(start_score_seq):]
|
| 299 |
|
| 300 |
print('=' * 70)
|
| 301 |
print('Done!')
|
|
@@ -386,7 +397,7 @@ def Classify_MIDI_Genre(input_midi,
|
|
| 386 |
print('Done!')
|
| 387 |
print('=' * 70)
|
| 388 |
|
| 389 |
-
|
| 390 |
|
| 391 |
output_title = str(inp_mel)
|
| 392 |
output_midi = str(new_fn)
|
|
@@ -398,7 +409,7 @@ def Classify_MIDI_Genre(input_midi,
|
|
| 398 |
print('Output MIDI melody title:', output_title)
|
| 399 |
print('=' * 70)
|
| 400 |
|
| 401 |
-
|
| 402 |
|
| 403 |
print('-' * 70)
|
| 404 |
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
|
|
|
| 71 |
print('Loading MIDI GAS processed scores dataset...')
|
| 72 |
|
| 73 |
midi_gas_ps_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
|
| 74 |
+
filename='MIDI_GAS_Processed_Scores_CC_BY_NC_SA.pickle',
|
| 75 |
+
repo_type='dataset'
|
| 76 |
+
)
|
| 77 |
|
| 78 |
midi_gas_ps = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_ps_pickle)
|
| 79 |
|
|
|
|
| 87 |
print('Loading MIDI GAS processed scores embeddings dataset...')
|
| 88 |
|
| 89 |
midi_gas_pse_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
|
| 90 |
+
filename='MIDI_GAS_Processed_Scores_Embeddings_CC_BY_NC_SA.pickle',
|
| 91 |
+
repo_type='dataset'
|
| 92 |
+
)
|
| 93 |
|
| 94 |
+
midi_gas_pse = [a[3] for a in TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_pse_pickle)]
|
| 95 |
|
| 96 |
print('=' * 70)
|
| 97 |
print('Done!')
|
|
|
|
| 141 |
|
| 142 |
def load_midi(input_midi):
|
| 143 |
|
| 144 |
+
raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
|
| 145 |
|
| 146 |
escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
|
| 147 |
|
|
|
|
| 215 |
|
| 216 |
#==================================================================================
|
| 217 |
|
| 218 |
+
def cosine_similarity_numpy(src_array, trg_array):
|
| 219 |
+
|
| 220 |
+
src_norm = np.linalg.norm(src_array)
|
| 221 |
+
|
| 222 |
+
trg_norms = np.linalg.norm(trg_array, axis=1)
|
| 223 |
+
|
| 224 |
+
dot_products = np.dot(trg_array, src_array)
|
| 225 |
+
|
| 226 |
+
cosine_sims = dot_products / (src_norm * trg_norms + 1e-10)
|
| 227 |
+
|
| 228 |
+
return cosine_sims.tolist()
|
| 229 |
+
|
| 230 |
+
#==================================================================================
|
| 231 |
+
|
| 232 |
+
def select_best_output(outputs, embeddings, src_embeddings, top_k=10):
|
| 233 |
+
|
| 234 |
+
emb_sims = cosine_similarity_numpy(np.array(src_embeddings), np.array(embeddings))
|
| 235 |
+
|
| 236 |
+
sorted_emb_sims = sorted(emb_sims, reverse=True)
|
| 237 |
+
|
| 238 |
+
hits = []
|
| 239 |
+
hits_idxs = []
|
| 240 |
+
|
| 241 |
+
for s in sorted_emb_sims[:top_k]:
|
| 242 |
+
idx = emb_sims.index(s)
|
| 243 |
+
hits_idxs.append(idx)
|
| 244 |
+
hits.extend([[str(s)] + outputs[idx][:3]])
|
| 245 |
+
|
| 246 |
+
return hits, hits_idxs
|
| 247 |
+
|
| 248 |
+
#==================================================================================
|
| 249 |
+
|
| 250 |
# @spaces.GPU
|
| 251 |
def Classify_MIDI_Genre(input_midi,
|
| 252 |
input_melody,
|
|
|
|
| 258 |
|
| 259 |
#===============================================================================
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
print('=' * 70)
|
| 262 |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|
| 263 |
start_time = reqtime.time()
|
|
|
|
| 280 |
|
| 281 |
print('=' * 70)
|
| 282 |
|
| 283 |
+
#===============================================================================
|
|
|
|
|
|
|
| 284 |
|
| 285 |
+
print('Loadin and prepping source MIDI...')
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
src_score = load_midi(input_midi)
|
|
|
|
| 288 |
|
| 289 |
+
inp = torch.LongTensor([src_score]).to(device_type)
|
| 290 |
|
| 291 |
+
src_emb = get_embeddings(inp).tolist()
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
print('Done!')
|
|
|
|
| 294 |
|
| 295 |
+
#===============================================================================
|
| 296 |
|
| 297 |
+
print('Sample embeddings values', src_emb[:3])
|
| 298 |
|
| 299 |
+
#===============================================================================
|
| 300 |
|
| 301 |
print('=' * 70)
|
| 302 |
+
print('Classifying...')
|
| 303 |
|
| 304 |
model.to(device_type)
|
| 305 |
model.eval()
|
| 306 |
|
| 307 |
+
#===============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
|
|
|
| 310 |
|
| 311 |
print('=' * 70)
|
| 312 |
print('Done!')
|
|
|
|
| 397 |
print('Done!')
|
| 398 |
print('=' * 70)
|
| 399 |
|
| 400 |
+
#===============================================================================
|
| 401 |
|
| 402 |
output_title = str(inp_mel)
|
| 403 |
output_midi = str(new_fn)
|
|
|
|
| 409 |
print('Output MIDI melody title:', output_title)
|
| 410 |
print('=' * 70)
|
| 411 |
|
| 412 |
+
#===============================================================================
|
| 413 |
|
| 414 |
print('-' * 70)
|
| 415 |
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
|