Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import rtmidi | |
| import os | |
| import argparse | |
| import base64 | |
| import io | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| import onnxruntime as rt | |
| import MIDI | |
| from midi_synthesizer import MidiSynthesizer | |
| from midi_tokenizer import MIDITokenizer | |
| # Match the JavaScript constant | |
| MIDI_OUTPUT_BATCH_SIZE = 4 | |
| class MIDIDeviceManager: | |
| """Manages MIDI input/output devices.""" | |
| def __init__(self): | |
| self.midiout = rtmidi.MidiOut() | |
| self.midiin = rtmidi.MidiIn() | |
| def get_device_info(self): | |
| """Returns a string listing available MIDI devices.""" | |
| out_ports = self.midiout.get_ports() or ["No MIDI output devices"] | |
| in_ports = self.midiin.get_ports() or ["No MIDI input devices"] | |
| return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" | |
| def close(self): | |
| """Closes open MIDI ports.""" | |
| if self.midiout.is_port_open(): | |
| self.midiout.close_port() | |
| if self.midiin.is_port_open(): | |
| self.midiin.close_port() | |
| del self.midiout, self.midiin | |
| class MIDIManager: | |
| """Handles MIDI processing, generation, and playback.""" | |
| def __init__(self): | |
| # Load soundfont and models from Hugging Face | |
| self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
| self.tokenizer = self._load_tokenizer("skytnt/midi-model") | |
| self.model_base = rt.InferenceSession( | |
| hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| self.model_token = rt.InferenceSession( | |
| hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| self.generated_files = [] | |
| def _load_tokenizer(self, repo_id): | |
| """Loads the MIDI tokenizer configuration.""" | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
| tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
| return tokenizer | |
| def load_midi(self, file_path): | |
| """Loads a MIDI file from the given path.""" | |
| return MIDI.load(file_path) | |
| def generate_onnx(self, midi_data): | |
| """Generates a MIDI variation using ONNX models.""" | |
| mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) | |
| input_tensor = np.array([mid_seq], dtype=np.int64) | |
| cur_len = input_tensor.shape[1] | |
| max_len = 1024 | |
| while cur_len < max_len: | |
| inputs = {"x": input_tensor[:, -1:]} | |
| hidden = self.model_base.run(None, inputs)[0] | |
| logits = self.model_token.run(None, {"hidden": hidden})[0] | |
| probs = self._softmax(logits, axis=-1) | |
| next_token = self._sample_top_p_k(probs, 0.98, 20) | |
| input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
| cur_len += 1 | |
| new_seq = input_tensor[0].tolist() | |
| generated_midi = self.tokenizer.detokenize(new_seq) | |
| # Store base64-encoded MIDI data for downloads | |
| midi_bytes = MIDI.save(generated_midi) | |
| self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8')) | |
| return generated_midi | |
| def play_midi(self, midi_data): | |
| """Renders MIDI data to audio bytes.""" | |
| midi_bytes = base64.b64decode(midi_data) | |
| midi_file = MIDI.load(io.BytesIO(midi_bytes)) | |
| audio = io.BytesIO() | |
| self.synthesizer.render_midi(midi_file, audio) | |
| audio.seek(0) | |
| return audio | |
| def _softmax(x, axis): | |
| """Computes softmax probabilities.""" | |
| exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | |
| return exp_x / np.sum(exp_x, axis=axis, keepdims=True) | |
| def _sample_top_p_k(probs, p, k): | |
| """Samples a token using top-p and top-k sampling (simplified).""" | |
| # Placeholder: replace with actual sampling logic if needed | |
| return np.array([[np.random.choice(len(probs[0]))]]) | |
| def process_midi(files): | |
| """Processes uploaded MIDI files and yields updates for Gradio components.""" | |
| if not files: | |
| yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
| return | |
| for idx, file in enumerate(files): | |
| output_idx = idx % MIDI_OUTPUT_BATCH_SIZE | |
| midi_data = midi_processor.load_midi(file.name) | |
| generated_midi = midi_processor.generate_onnx(midi_data) | |
| # Placeholder for MIDI events; in practice, extract from generated_midi | |
| # Expected format: ["note", delta_time, track, channel, pitch, velocity, duration] | |
| events = [ | |
| ["note", 0, 0, 0, 60, 100, 1000], # Example event | |
| # Add logic to convert generated_midi to events using tokenizer | |
| ] | |
| # Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...] | |
| updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
| # Clear visualizer | |
| updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}])) | |
| yield updates | |
| # Send MIDI events | |
| updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}])) | |
| yield updates | |
| # Finalize visualizer and update audio/MIDI outputs | |
| audio_update = midi_processor.play_midi(generated_midi) | |
| midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}") | |
| updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}])) | |
| updates[1 + 2 * output_idx] = audio_update # Audio component | |
| updates[2 + 2 * output_idx] = midi_update # MIDI file component | |
| yield updates | |
| # Final yield to ensure all components are in a stable state | |
| yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="MIDI Composer App") | |
| parser.add_argument("--port", type=int, default=7860, help="Server port") | |
| parser.add_argument("--share", action="store_true", help="Share the app publicly") | |
| opt = parser.parse_args() | |
| device_manager = MIDIDeviceManager() | |
| midi_processor = MIDIManager() | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| # Hidden textbox for sending messages to JS | |
| js_msg = gr.Textbox(visible=False, elem_id="msg_receiver") | |
| with gr.Tabs(): | |
| # MIDI Prompt Tab | |
| with gr.Tab("MIDI Prompt"): | |
| midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple") | |
| generate_btn = gr.Button("Generate") | |
| status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
| # Outputs Tab | |
| with gr.Tab("Outputs"): | |
| output_audios = [] | |
| output_midis = [] | |
| for i in range(MIDI_OUTPUT_BATCH_SIZE): | |
| with gr.Column(): | |
| gr.Markdown(f"## Output {i+1}") | |
| gr.HTML(elem_id=f"midi_visualizer_container_{i}") | |
| output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}") | |
| output_midi = gr.File(label="Generated MIDI", file_types=[".mid"]) | |
| output_audios.append(output_audio) | |
| output_midis.append(output_midi) | |
| # Devices Tab | |
| with gr.Tab("Devices"): | |
| device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False) | |
| refresh_btn = gr.Button("Refresh Devices") | |
| refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info]) | |
| # Define output components for event handling | |
| outputs = [js_msg] + output_audios + output_midis | |
| # Bind the generate button to the processing function | |
| generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs) | |
| # Launch the app | |
| app.launch(server_port=opt.port, share=opt.share, inbrowser=True) | |
| device_manager.close() |