print('Starting...') import random import os import time import requests from queue import Queue from threading import Thread import symusic import transformers import spaces import gradio as gr os.makedirs('./temp', exist_ok=True) print('\n\n\n') print('Loading model...') cpu_pipe = transformers.pipeline( 'text-generation', model='dx2102/llama-midi', torch_dtype='float32', device='cpu', ) gpu_pipe = transformers.pipeline( 'text-generation', model='dx2102/llama-midi', torch_dtype='bfloat16', device='cuda:0', ) # print devices print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }") print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }") print('Done') example_prefix = '''pitch duration wait velocity instrument 71 1310 0 20 0 48 330 350 20 0 55 330 350 20 0 64 1310 690 20 0 74 660 690 20 0 69 1310 0 20 0 48 330 350 20 0 57 330 350 20 0 66 1310 690 20 0 67 330 350 20 0 69 330 350 20 0 71 1310 0 20 0 48 330 350 20 0 55 330 350 20 0 64 1310 690 20 0 74 660 690 20 0 69 1970 0 20 0 48 330 350 20 0 ''' print('cpu:', cpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text']) print('gpu:', gpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text']) def postprocess(txt, path): # remove prefix txt = txt.split('\n\n')[-1] # track = symusic.core.TrackSecond() tracks = {} now = 0 for line in txt.split('\n'): # we need to ignore the invalid output by the model try: pitch, duration, wait, velocity, instrument = line.split(' ') pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]] if instrument not in tracks: tracks[instrument] = symusic.core.TrackSecond() if instrument != 'drum': tracks[instrument].program = int(instrument) else: tracks[instrument].is_drum = True # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter') tracks[instrument].notes.append(symusic.core.NoteSecond( time=now/1000, duration=duration/1000, pitch=int(pitch), velocity=int(velocity * 4), )) now += wait except Exception as e: print(f'Postprocess: Ignored line: "{line}" because of error:', e) print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes') try: # track = symusic.core.TrackSecond() # track.notes = symusic.core.NoteSecondList(notes) score = symusic.Score(ttype='Second') # score.tracks.append(track) score.tracks.extend(tracks.values()) score.dump_midi(path) except Exception as e: print('Postprocess: Ignored postprocessing error:', e) with gr.Blocks() as demo: chatbot_box = gr.Chatbot(type='messages', render_markdown=False, sanitize_html=False) prefix_box = gr.TextArea(value='Twinkle Twinkle Little Star', label='Score title / text prefix') with gr.Row(): submit_btn = gr.Button('Generate') continue_btn = gr.Button('Continue') clear_btn = gr.Button('Clear history') with gr.Row(): get_audio_btn = gr.Button('Convert to audio') get_midi_btn = gr.Button('Convert to MIDI') audio_box = gr.Audio() midi_box = gr.File() piano_roll_box = gr.Image() server_box = gr.Dropdown( choices=['CPU', 'Huggingface ZeroGPU'], label='GPU Server', ) gr.Markdown(''' ZeroGPU comes with a time limit currently: - 3 minutes (not logged in) - 5 minutes (logged in) - 25 minutes (Pro user) CPUs will be slower but there is no time limit. '''.strip()) example_box = gr.Examples( [ # [example_prefix], ['Twinkle Twinkle Little Star'], ['Twinkle Twinkle Little Star (Minor Key Version)'], ['The Entertainer - Scott Joplin (Piano Solo)'], ['Clair de Lune – Debussy'], ['Nocturne | Frederic Chopin'], ['Fugue I in C major, BWV 846'], ['Beethoven Symphony No. 7 (2nd movement) Piano solo'], ['Guitar'], # ['Composer: Chopin'], ['Composer: Bach'], ['Composer: Beethoven'], ['Composer: Debussy'], ], inputs=prefix_box, examples_per_page=9999, ) def user_fn(user_message, history: list): return '', history + [{'role': 'user', 'content': user_message}] def get_last(history: list): if len(history) == 0: raise gr.Error('''No messages to read yet. Try the 'Generate' button first!''') return history[-1]['content'] def generate_fn(history, server): # continue from user input prefix = get_last(history) # prevent the model from continuing user's score title if prefix != '' and '\n' not in prefix: # prefix is a single line => prefix is the score title # add '\n' to prevent model from continuing the title prefix += '\n' history.append({'role': 'assistant', 'content': ''}) # history[-1]['content'] += 'Generating with the given prefix...\n' for history in model_fn(prefix, history, server): yield history def continue_fn(history, server): # continue from the last model output prefix = history[-1]['content'] for history in model_fn(prefix, history, server): yield history def model_fn(prefix, history, server): if server == 'Huggingface ZeroGPU': generator = zerogpu_model_fn(prefix, history, server) elif server == 'CPU': generator = cpu_model_fn(prefix, history, server) # elif server == 'RunPod': # generator = runpod_model_fn(prefix, history) else: raise gr.Error(f'Unknown server: {server}') for history in generator: yield history def cpu_model_fn(prefix, history, server): queue = Queue(maxsize=10) if server == 'CPU': pipe = cpu_pipe else: pipe = gpu_pipe class MyStreamer: def put(self, tokens): for token in tokens.flatten(): text = pipe.tokenizer.decode(token.item()) if text == '<|begin_of_text|>': continue queue.put(text, block=True, timeout=5) def end(self): queue.put(None) def background_fn(): print(f"{pipe.device = }, {pipe.model.device = }") print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }") print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }") result = pipe( prefix, streamer=MyStreamer(), max_new_tokens=500, top_p=0.9, temperature=0.6, ) print('Generated text:') print(result[0]['generated_text']) print() Thread(target=background_fn).start() while True: text = queue.get() if text is None: break history[-1]['content'] += text yield history # duration: GPU time out (seconds) zerogpu_model_fn = spaces.GPU(duration=15)(cpu_model_fn) def runpod_model_fn(prefix, history): # NOTE runpod_api_key = os.getenv('RUNPOD_API_KEY') runpod_endpoint = os.getenv('RUNPOD_ENDPOINT') # synchronized request response = requests.post( f'https://api.runpod.ai/v2/{runpod_endpoint}/runsync', headers={'Authorization': f'Bearer {runpod_api_key}'}, json={'input': {'prompt': prefix}} ).json()['output'][0]['choices'][0]['tokens'][0] # yield just once history[-1]['content'] += response yield history submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( generate_fn, [chatbot_box, server_box], chatbot_box ) continue_event = continue_btn.click( continue_fn, [chatbot_box, server_box], chatbot_box ) clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event, continue_event], queue=False) def get_audio_fn(history): i = random.randint(0, 1000_000_000) path = f'./temp/{i}.mid' text = get_last(history) try: postprocess(text, path) except Exception as e: raise gr.Error(f'Error: {type(e)}, {e}') # turn midi into audio with timidity os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav') # wav to mp3 os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3') return f'./temp/{i}.mp3' get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False) def get_midi_fn(history): i = random.randint(0, 1000_000_000) # turn the text into midi text = get_last(history) try: postprocess(text, f'./temp/{i}.mid') except Exception as e: raise gr.Error(f'Error: {type(e)}, {e}') # also render the piano roll import matplotlib.pyplot as plt plt.figure(figsize=(12, 4)) now = 0 for line in history[-1]['content'].split('\n\n')[-1].split('\n'): try: pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()] except Exception as e: continue plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1) plt.scatter(now, pitch, s=6, color='black', alpha=0.3) now += wait plt.savefig(f'./temp/{i}.svg') return f'./temp/{i}.mid', f'./temp/{i}.svg' get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False) demo.launch()