Spaces:
Sleeping
Sleeping
| import argparse | |
| import base64 | |
| import io | |
| import os | |
| import random | |
| import numpy as np | |
| import gradio as gr | |
| import rtmidi | |
| import onnxruntime as rt | |
| from huggingface_hub import hf_hub_download | |
| import MIDI | |
| from midi_synthesizer import MidiSynthesizer | |
| from midi_tokenizer import MIDITokenizer | |
| # Constants | |
| MAX_SEED = np.iinfo(np.int32).max | |
| IN_SPACE = os.getenv("SYSTEM") == "spaces" | |
| MAX_LENGTH = 1024 # Maximum tokens for generation | |
| # MIDI Device Manager | |
| class MIDIDeviceManager: | |
| def __init__(self): | |
| self.midiout = rtmidi.MidiOut() | |
| self.midiin = rtmidi.MidiIn() | |
| def get_device_info(self): | |
| out_ports = self.midiout.get_ports() or ["No MIDI output devices"] | |
| in_ports = self.midiin.get_ports() or ["No MIDI input devices"] | |
| return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" | |
| def close(self): | |
| if self.midiout.is_port_open(): | |
| self.midiout.close_port() | |
| if self.midiin.is_port_open(): | |
| self.midiin.close_port() | |
| del self.midiout, self.midiin | |
| # MIDI Processor with ONNX Generation | |
| class MIDIManager: | |
| def __init__(self): | |
| self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| self.synthesizer = MidiSynthesizer(self.soundfont) | |
| self.tokenizer = self._load_tokenizer("skytnt/midi-model") | |
| self.model_base = rt.InferenceSession( | |
| hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| self.model_token = rt.InferenceSession( | |
| hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), | |
| providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| ) | |
| self.generated_files = [] # Store base64-encoded MIDI data | |
| self.is_playing = False | |
| def _load_tokenizer(self, repo_id): | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
| tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
| return tokenizer | |
| def load_midi(self, file_path): | |
| try: | |
| return MIDI.load(file_path) | |
| except Exception as e: | |
| raise ValueError(f"Failed to load MIDI file: {e}") | |
| def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20): | |
| # Tokenize input MIDI | |
| mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) | |
| input_tensor = np.array([mid_seq], dtype=np.int64) | |
| cur_len = input_tensor.shape[1] | |
| generator = np.random.RandomState(random.randint(0, MAX_SEED)) | |
| # Generate up to MAX_LENGTH | |
| while cur_len < MAX_LENGTH: | |
| inputs = {"x": input_tensor[:, -1:]} # Last token | |
| hidden = self.model_base.run(None, inputs)[0] # Base model output | |
| logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output | |
| probs = softmax(logits / temp, axis=-1) | |
| next_token = sample_top_p_k(probs, top_p, top_k, generator) | |
| input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
| cur_len += 1 | |
| # Detokenize and save as MIDI | |
| new_seq = input_tensor[0].tolist() | |
| new_midi = self.tokenizer.detokenize(new_seq) | |
| midi_output = io.BytesIO() | |
| MIDI.score2midi(new_midi, midi_output) | |
| midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
| self.generated_files.append(midi_data) | |
| return midi_data | |
| def play_midi(self, midi_data): | |
| self.is_playing = True | |
| midi_bytes = base64.b64decode(midi_data) | |
| midi_file = MIDI.load(io.BytesIO(midi_bytes)) | |
| audio = io.BytesIO() | |
| self.synthesizer.render_midi(midi_file, audio) | |
| audio.seek(0) | |
| return audio | |
| def stop(self): | |
| self.is_playing = False | |
| # Helper Functions | |
| def softmax(x, axis): | |
| exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | |
| return exp_x / np.sum(exp_x, axis=axis, keepdims=True) | |
| def sample_top_p_k(probs, p, k, generator): | |
| probs_idx = np.argsort(-probs, axis=-1) | |
| probs_sort = np.take_along_axis(probs, probs_idx, axis=-1) | |
| probs_sum = np.cumsum(probs_sort, axis=-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort[mask] = 0.0 | |
| probs_sort[:, k:] = 0.0 # Top-k filtering | |
| probs_sort /= probs_sort.sum(axis=-1, keepdims=True) | |
| next_token = generator.choice(probs.shape[-1], p=probs_sort[0]) | |
| return np.array([[next_token]]) | |
| # UI Functions | |
| def process_midi_upload(files): | |
| if not files: | |
| return None, "No file uploaded", "" | |
| file = files[0] # Process first file | |
| try: | |
| midi_data = midi_processor.load_midi(file.name) | |
| generated_midi = midi_processor.generate_variation(midi_data) | |
| audio = midi_processor.play_midi(generated_midi) | |
| download_html = create_download_list() | |
| return audio, "Generated and playing", download_html | |
| except Exception as e: | |
| return None, f"Error: {e}", "" | |
| def create_download_list(): | |
| if not midi_processor.generated_files: | |
| return "<p>No generated files yet.</p>" | |
| html = "<h3>Generated MIDI Files</h3><ul>" | |
| for i, midi_data in enumerate(midi_processor.generated_files): | |
| html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>' | |
| html += "</ul>" | |
| return html | |
| def refresh_devices(): | |
| return device_manager.get_device_info() | |
| def stop_playback(): | |
| midi_processor.stop() | |
| return "Playback stopped" | |
| # Main Application | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| args = parser.parse_args() | |
| device_manager = MIDIDeviceManager() | |
| midi_processor = MIDIManager() | |
| with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# 🎵 MIDI Composer 🎵") | |
| with gr.Tabs(): | |
| # MIDI Prompt Tab | |
| with gr.Tab("MIDI Prompt"): | |
| midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"]) | |
| audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True) | |
| status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
| midi_upload.change( | |
| process_midi_upload, | |
| inputs=[midi_upload], | |
| outputs=[audio_output, status, gr.HTML(elem_id="downloads")] | |
| ) | |
| # Downloads Tab | |
| with gr.Tab("Downloads", elem_id="downloads"): | |
| gr.HTML(value=create_download_list()) | |
| # Devices Tab | |
| with gr.Tab("Devices"): | |
| device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False) | |
| refresh_btn = gr.Button("Refresh Devices") | |
| stop_btn = gr.Button("Stop Playback") | |
| refresh_btn.click(refresh_devices, outputs=[device_info]) | |
| stop_btn.click(stop_playback, outputs=[status]) | |
| app.launch(server_port=args.port, share=args.share, inbrowser=True) | |
| device_manager.close() |