asigalov61 commited on
Commit
d011755
·
verified ·
1 Parent(s): 5019258

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -42
app.py CHANGED
@@ -71,8 +71,9 @@ print('=' * 70)
71
  print('Loading MIDI GAS processed scores dataset...')
72
 
73
  midi_gas_ps_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
74
- filename='MIDI_GAS_Processed_Scores_CC_BY_NC_SA.pickle'
75
- )
 
76
 
77
  midi_gas_ps = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_ps_pickle)
78
 
@@ -86,10 +87,11 @@ print('=' * 70)
86
  print('Loading MIDI GAS processed scores embeddings dataset...')
87
 
88
  midi_gas_pse_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
89
- filename='MIDI_GAS_Processed_Scores_Embeddings_CC_BY_NC_SA.pickle'
90
- )
 
91
 
92
- midi_gas_pse = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_pse_pickle)
93
 
94
  print('=' * 70)
95
  print('Done!')
@@ -139,7 +141,7 @@ print('=' * 70)
139
 
140
  def load_midi(input_midi):
141
 
142
- raw_score = TMIDIX.midi2single_track_ms_score(midi_file)
143
 
144
  escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
145
 
@@ -213,6 +215,38 @@ def get_embeddings(inputs):
213
 
214
  #==================================================================================
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # @spaces.GPU
217
  def Classify_MIDI_Genre(input_midi,
218
  input_melody,
@@ -224,10 +258,6 @@ def Classify_MIDI_Genre(input_midi,
224
 
225
  #===============================================================================
226
 
227
-
228
-
229
- #===============================================================================
230
-
231
  print('=' * 70)
232
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
233
  start_time = reqtime.time()
@@ -250,52 +280,33 @@ def Classify_MIDI_Genre(input_midi,
250
 
251
  print('=' * 70)
252
 
253
- #==================================================================
254
-
255
- print('Prepping melody...')
256
 
257
- if input_midi:
258
- inp_mel = 'Custom MIDI'
259
- score, score_list = load_midi(input_midi.name, melody_patch, use_nth_note)
260
 
261
- else:
262
- mel_list = [m[0].lower() for m in popular_hook_melodies]
263
 
264
- inp_mel = random.choice(mel_list).title()
265
 
266
- for m in mel_list:
267
- if input_melody.lower().strip() in m:
268
- inp_mel = m.title()
269
- break
270
 
271
- score = popular_hook_melodies[[m[0] for m in popular_hook_melodies].index(inp_mel)][1]
272
- score_list = [[[score[i]], score[i+1:i+3]] for i in range(0, len(score)-3, 3)]
273
 
274
- print('Selected melody:', inp_mel)
275
 
276
- print('Sample score events', score[:12])
277
 
278
- #==================================================================
279
 
280
  print('=' * 70)
281
- print('Generating...')
282
 
283
  model.to(device_type)
284
  model.eval()
285
 
286
- #==================================================================
287
-
288
- start_score_seq = [1792] + score + [1793]
289
-
290
- #==================================================================
291
 
292
- input_seq = generate_full_seq(start_score_seq,
293
- max_toks=MAX_GEN_TOKS,
294
- temperature=model_temperature,
295
- top_k_value=model_sampling_top_k,
296
- )
297
 
298
- final_song = input_seq[len(start_score_seq):]
299
 
300
  print('=' * 70)
301
  print('Done!')
@@ -386,7 +397,7 @@ def Classify_MIDI_Genre(input_midi,
386
  print('Done!')
387
  print('=' * 70)
388
 
389
- #========================================================
390
 
391
  output_title = str(inp_mel)
392
  output_midi = str(new_fn)
@@ -398,7 +409,7 @@ def Classify_MIDI_Genre(input_midi,
398
  print('Output MIDI melody title:', output_title)
399
  print('=' * 70)
400
 
401
- #========================================================
402
 
403
  print('-' * 70)
404
  print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
 
71
  print('Loading MIDI GAS processed scores dataset...')
72
 
73
  midi_gas_ps_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
74
+ filename='MIDI_GAS_Processed_Scores_CC_BY_NC_SA.pickle',
75
+ repo_type='dataset'
76
+ )
77
 
78
  midi_gas_ps = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_ps_pickle)
79
 
 
87
  print('Loading MIDI GAS processed scores embeddings dataset...')
88
 
89
  midi_gas_pse_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
90
+ filename='MIDI_GAS_Processed_Scores_Embeddings_CC_BY_NC_SA.pickle',
91
+ repo_type='dataset'
92
+ )
93
 
94
+ midi_gas_pse = [a[3] for a in TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_pse_pickle)]
95
 
96
  print('=' * 70)
97
  print('Done!')
 
141
 
142
  def load_midi(input_midi):
143
 
144
+ raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
145
 
146
  escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
147
 
 
215
 
216
  #==================================================================================
217
 
218
+ def cosine_similarity_numpy(src_array, trg_array):
219
+
220
+ src_norm = np.linalg.norm(src_array)
221
+
222
+ trg_norms = np.linalg.norm(trg_array, axis=1)
223
+
224
+ dot_products = np.dot(trg_array, src_array)
225
+
226
+ cosine_sims = dot_products / (src_norm * trg_norms + 1e-10)
227
+
228
+ return cosine_sims.tolist()
229
+
230
+ #==================================================================================
231
+
232
+ def select_best_output(outputs, embeddings, src_embeddings, top_k=10):
233
+
234
+ emb_sims = cosine_similarity_numpy(np.array(src_embeddings), np.array(embeddings))
235
+
236
+ sorted_emb_sims = sorted(emb_sims, reverse=True)
237
+
238
+ hits = []
239
+ hits_idxs = []
240
+
241
+ for s in sorted_emb_sims[:top_k]:
242
+ idx = emb_sims.index(s)
243
+ hits_idxs.append(idx)
244
+ hits.extend([[str(s)] + outputs[idx][:3]])
245
+
246
+ return hits, hits_idxs
247
+
248
+ #==================================================================================
249
+
250
  # @spaces.GPU
251
  def Classify_MIDI_Genre(input_midi,
252
  input_melody,
 
258
 
259
  #===============================================================================
260
 
 
 
 
 
261
  print('=' * 70)
262
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
263
  start_time = reqtime.time()
 
280
 
281
  print('=' * 70)
282
 
283
+ #===============================================================================
 
 
284
 
285
+ print('Loadin and prepping source MIDI...')
 
 
286
 
287
+ src_score = load_midi(input_midi)
 
288
 
289
+ inp = torch.LongTensor([src_score]).to(device_type)
290
 
291
+ src_emb = get_embeddings(inp).tolist()
 
 
 
292
 
293
+ print('Done!')
 
294
 
295
+ #===============================================================================
296
 
297
+ print('Sample embeddings values', src_emb[:3])
298
 
299
+ #===============================================================================
300
 
301
  print('=' * 70)
302
+ print('Classifying...')
303
 
304
  model.to(device_type)
305
  model.eval()
306
 
307
+ #===============================================================================
 
 
 
 
308
 
 
 
 
 
 
309
 
 
310
 
311
  print('=' * 70)
312
  print('Done!')
 
397
  print('Done!')
398
  print('=' * 70)
399
 
400
+ #===============================================================================
401
 
402
  output_title = str(inp_mel)
403
  output_midi = str(new_fn)
 
409
  print('Output MIDI melody title:', output_title)
410
  print('=' * 70)
411
 
412
+ #===============================================================================
413
 
414
  print('-' * 70)
415
  print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))