Spaces:
Sleeping
Sleeping
| import random | |
| import argparse | |
| import os | |
| import glob | |
| import json | |
| import rtmidi | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as rt | |
| from huggingface_hub import hf_hub_download | |
| import MIDI | |
| from midi_synthesizer import MidiSynthesizer | |
| from midi_tokenizer import MIDITokenizer | |
| MAX_SEED = np.iinfo(np.int32).max | |
| in_space = os.getenv("SYSTEM") == "spaces" | |
| class MIDIDeviceManager: | |
| def __init__(self): | |
| self.midiout = rtmidi.MidiOut() | |
| self.midiin = rtmidi.MidiIn() | |
| def get_output_devices(self): | |
| return self.midiout.get_ports() or ["No MIDI output devices"] | |
| def get_input_devices(self): | |
| return self.midiin.get_ports() or ["No MIDI input devices"] | |
| def get_device_info(self): | |
| out_devices = self.get_output_devices() | |
| in_devices = self.get_input_devices() | |
| out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected" | |
| in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected" | |
| return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}" | |
| def close(self): | |
| if self.midiout.is_port_open(): | |
| self.midiout.close_port() | |
| if self.midiin.is_port_open(): | |
| self.midiin.close_port() | |
| del self.midiout | |
| del self.midiin | |
| class MIDIManager: | |
| def __init__(self): | |
| self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
| self.loaded_midi = {} | |
| self.modified_files = [] | |
| self.is_playing = False | |
| self.tokenizer = self.load_tokenizer("skytnt/midi-model") | |
| self.model_base = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| self.model_token = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
| def load_tokenizer(self, repo_id): | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
| tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
| return tokenizer | |
| def load_midi(self, file_path): | |
| midi = MIDI.load(file_path) | |
| midi_id = f"midi_{len(self.loaded_midi)}" | |
| self.loaded_midi[midi_id] = (file_path, midi) | |
| return midi_id | |
| def extract_notes_and_instruments(self, midi): | |
| notes = [] | |
| instruments = set() | |
| for track in midi.tracks: | |
| for event in track.events: | |
| if event.type == 'note_on' and event.velocity > 0: | |
| notes.append((event.note, event.velocity, event.time)) | |
| if hasattr(event, 'program'): | |
| instruments.add(event.program) | |
| return notes, list(instruments) | |
| def generate_variation(self, midi_id, length_factor=10, variation=0.3): | |
| if midi_id not in self.loaded_midi: | |
| return None | |
| _, midi = self.loaded_midi[midi_id] | |
| notes, instruments = self.extract_notes_and_instruments(midi) | |
| new_notes = [] | |
| for _ in range(int(length_factor)): # Max length: 10x repetition | |
| for note, vel, time in notes: | |
| if random.random() < variation: | |
| new_note = min(127, max(0, note + random.randint(-2, 2))) | |
| new_vel = min(127, max(0, vel + random.randint(-10, 10))) | |
| new_notes.append((new_note, new_vel, time)) | |
| else: | |
| new_notes.append((note, vel, time)) | |
| new_midi = MIDI.MIDIFile(len(instruments) or 1) | |
| for i, inst in enumerate(instruments or [0]): | |
| new_midi.addTrack() | |
| new_midi.addProgramChange(i, 0, 0, inst) | |
| for note, vel, time in new_notes: | |
| new_midi.addNote(i, 0, note, time, 100, vel) | |
| midi_output = io.BytesIO() | |
| new_midi.writeFile(midi_output) | |
| midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
| self.modified_files.append(midi_data) | |
| return midi_data | |
| def generate_onnx(self, midi_id, max_len=1024, temp=1.0, top_p=0.98, top_k=20): | |
| if midi_id not in self.loaded_midi: | |
| return None | |
| _, mid = self.loaded_midi[midi_id] | |
| mid_seq = self.tokenizer.tokenize(MIDI.midi2score(mid)) | |
| mid = np.asarray([mid_seq], dtype=np.int64) | |
| generator = np.random.RandomState(random.randint(0, MAX_SEED)) | |
| # Simplified ONNX generation from app_onnx.py | |
| input_tensor = mid | |
| cur_len = input_tensor.shape[1] | |
| model = [self.model_base, self.model_token, self.tokenizer] | |
| while cur_len < max_len: | |
| inputs = {"x": rt.OrtValue.ortvalue_from_numpy(input_tensor[:, -1:], device_type="cuda")} | |
| outputs = {"hidden": rt.OrtValue.ortvalue_from_shape_and_type((1, 1, 1024), np.float32, device_type="cuda")} | |
| io_binding = model[0].io_binding() | |
| for name, val in inputs.items(): | |
| io_binding.bind_ortvalue_input(name, val) | |
| for name in outputs: | |
| io_binding.bind_ortvalue_output(name, outputs[name]) | |
| model[0].run_with_iobinding(io_binding) | |
| hidden = outputs["hidden"].numpy()[:, -1:] | |
| logits = model[1].run(None, {"hidden": hidden})[0] | |
| scores = softmax(logits / temp, -1) | |
| next_token = sample_top_p_k(scores, top_p, top_k, generator) | |
| input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
| cur_len += 1 | |
| mid_seq = input_tensor.tolist()[0] | |
| new_midi = self.tokenizer.detokenize(mid_seq) | |
| midi_output = io.BytesIO() | |
| MIDI.score2midi(new_midi, midi_output) | |
| midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
| self.modified_files.append(midi_data) | |
| return midi_data | |
| def play_with_loop(self, midi_data): | |
| self.is_playing = True | |
| midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
| while self.is_playing: | |
| self.synthesizer.play_midi(midi_file) | |
| def stop_playback(self): | |
| self.is_playing = False | |
| return "Playback stopped" | |
| def softmax(x, axis): | |
| x_max = np.max(x, axis=axis, keepdims=True) | |
| exp_x_shifted = np.exp(x - x_max) | |
| return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) | |
| def sample_top_p_k(probs, p, k, generator=None): | |
| if generator is None: | |
| generator = np.random | |
| probs_idx = np.argsort(-probs, axis=-1) | |
| probs_sort = np.take_along_axis(probs, probs_idx, -1) | |
| probs_sum = np.cumsum(probs_sort, axis=-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort[mask] = 0.0 | |
| mask = np.zeros(probs_sort.shape[-1]) | |
| mask[:k] = 1 | |
| probs_sort *= mask | |
| probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True) | |
| shape = probs_sort.shape | |
| probs_sort_flat = probs_sort.reshape(-1, shape[-1]) | |
| probs_idx_flat = probs_idx.reshape(-1, shape[-1]) | |
| next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)]) | |
| return next_token.reshape(*shape[:-1]) | |
| def create_download_list(): | |
| html = "<h3>Downloads</h3><ul>" | |
| for i, midi_data in enumerate(midi_processor.modified_files): | |
| html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a></li>' | |
| html += "</ul>" | |
| return html | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| opt = parser.parse_args() | |
| midi_manager = MIDIDeviceManager() | |
| midi_processor = MIDIManager() | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>") | |
| with gr.Tabs(): | |
| # Tab 1: MIDI Prompt (Main Tab) | |
| with gr.Tab("MIDI Prompt"): | |
| midi_upload = gr.File(label="Upload MIDI File", file_count="multiple") | |
| output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True) | |
| status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
| def process_midi(files): | |
| if not files: | |
| return None, "No file uploaded" | |
| midi_data = None | |
| for file in files: | |
| midi_id = midi_processor.load_midi(file.name) | |
| # Use ONNX generation for advanced synthesis | |
| midi_data = midi_processor.generate_onnx(midi_id, max_len=1024) | |
| midi_processor.play_with_loop(midi_data) | |
| return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list() | |
| midi_upload.change(process_midi, inputs=[midi_upload], | |
| outputs=[output, status, "downloads"]) | |
| # Tab 2: Downloads | |
| with gr.Tab("Downloads", elem_id="downloads"): | |
| downloads = gr.HTML(value="No generated files yet") | |
| # Tab 3: Devices | |
| with gr.Tab("Devices"): | |
| device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), interactive=False) | |
| refresh_btn = gr.Button("Refresh Devices") | |
| stop_btn = gr.Button("Stop Playback") | |
| def refresh_devices(): | |
| return midi_manager.get_device_info() | |
| refresh_btn.click(refresh_devices, inputs=None, outputs=[device_info]) | |
| stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status]) | |
| gr.Markdown(""" | |
| <div style='text-align: center; margin-top: 20px;'> | |
| <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br> | |
| <strong>Hugging Face</strong><br> | |
| <a href='https://huggingface.co/models'>Models</a> | | |
| <a href='https://huggingface.co/datasets'>Datasets</a> | | |
| <a href='https://huggingface.co/spaces'>Spaces</a> | | |
| <a href='https://huggingface.co/posts'>Posts</a> | | |
| <a href='https://huggingface.co/docs'>Docs</a> | | |
| <a href='https://huggingface.co/enterprise'>Enterprise</a> | | |
| <a href='https://huggingface.co/pricing'>Pricing</a> | |
| </div> | |
| """) | |
| app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True) | |
| midi_manager.close() |