asigalov61 commited on
Commit
3775575
·
verified ·
1 Parent(s): c37947a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -139,7 +139,7 @@ print('=' * 70)
139
 
140
  #==================================================================================
141
 
142
- def load_midi(input_midi):
143
 
144
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
145
 
@@ -229,9 +229,13 @@ def cosine_similarity_numpy(src_array, trg_array):
229
 
230
  #==================================================================================
231
 
232
- def select_best_output(outputs, embeddings, src_embeddings, top_k=10):
233
 
234
- emb_sims = cosine_similarity_numpy(src_embeddings, embeddings)
 
 
 
 
235
 
236
  sorted_emb_sims = sorted(emb_sims, reverse=True)
237
 
@@ -265,6 +269,7 @@ def Classify_MIDI_Genre(input_midi):
265
  fn1 = fn.split('.')[0]
266
 
267
  print('Input MIDI file name:', fn)
 
268
  print('=' * 70)
269
 
270
  #===============================================================================
@@ -295,7 +300,7 @@ def Classify_MIDI_Genre(input_midi):
295
 
296
  #===============================================================================
297
 
298
- result = select_best_output(midi_gas_ps, midi_gas_pse, src_emb)
299
 
300
  results_str = ''
301
 
@@ -455,6 +460,8 @@ with gr.Blocks() as demo:
455
  input_midi = gr.File(label="Input MIDI",
456
  file_types=[".midi", ".mid", ".kar"]
457
  )
 
 
458
 
459
  generate_btn = gr.Button("Classify", variant="primary")
460
 
@@ -467,7 +474,8 @@ with gr.Blocks() as demo:
467
  output_cls_results = gr.Textbox(label="MIDI classification results")
468
 
469
  generate_btn.click(Classify_MIDI_Genre,
470
- [input_midi
 
471
  ],
472
  [output_title,
473
  output_audio,
@@ -478,8 +486,8 @@ with gr.Blocks() as demo:
478
  )
479
 
480
  gr.Examples(
481
- [["Hotel California.mid"],
482
- ["Come To My Window.mid"]
483
  ],
484
  [input_midi
485
  ],
 
139
 
140
  #==================================================================================
141
 
142
+ def load_midi(input_midi, input_mixed_pooling):
143
 
144
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
145
 
 
229
 
230
  #==================================================================================
231
 
232
+ def select_best_output(outputs, embeddings, src_embeddings, input_mixed_pooling, top_k=10):
233
 
234
+ if input_mixed_pooling:
235
+ emb_sims = cosine_similarity_numpy(src_embeddings, embeddings)
236
+
237
+ else:
238
+ emb_sims = cosine_similarity_numpy(src_embeddings[:, :2048], embeddings[:, :2048])
239
 
240
  sorted_emb_sims = sorted(emb_sims, reverse=True)
241
 
 
269
  fn1 = fn.split('.')[0]
270
 
271
  print('Input MIDI file name:', fn)
272
+ print('Use mixed embeddings pooling:', input_mixed_pooling)
273
  print('=' * 70)
274
 
275
  #===============================================================================
 
300
 
301
  #===============================================================================
302
 
303
+ result = select_best_output(midi_gas_ps, midi_gas_pse, src_emb, input_mixed_pooling)
304
 
305
  results_str = ''
306
 
 
460
  input_midi = gr.File(label="Input MIDI",
461
  file_types=[".midi", ".mid", ".kar"]
462
  )
463
+
464
+ input_mixed_pooling = gr.Checkbox(value=False, label="Use mixed embeddings pooling")
465
 
466
  generate_btn = gr.Button("Classify", variant="primary")
467
 
 
474
  output_cls_results = gr.Textbox(label="MIDI classification results")
475
 
476
  generate_btn.click(Classify_MIDI_Genre,
477
+ [input_midi,
478
+ input_mixed_pooling
479
  ],
480
  [output_title,
481
  output_audio,
 
486
  )
487
 
488
  gr.Examples(
489
+ [["Hotel California.mid", False],
490
+ ["Come To My Window.mid", False]
491
  ],
492
  [input_midi
493
  ],