Spaces:
Running
on
Zero
Running
on
Zero
| print('Starting...') | |
| import random | |
| import os | |
| import time | |
| import requests | |
| from queue import Queue | |
| from threading import Thread | |
| import symusic | |
| import transformers | |
| import spaces | |
| import gradio as gr | |
| os.makedirs('./temp', exist_ok=True) | |
| print('\n\n\n') | |
| print('Loading model...') | |
| cpu_pipe = transformers.pipeline( | |
| 'text-generation', | |
| model='dx2102/llama-midi', | |
| torch_dtype='float32', | |
| device='cpu', | |
| ) | |
| gpu_pipe = transformers.pipeline( | |
| 'text-generation', | |
| model='dx2102/llama-midi', | |
| torch_dtype='bfloat16', | |
| device='cuda:0', | |
| ) | |
| # print devices | |
| print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }") | |
| print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }") | |
| print('Done') | |
| example_prefix = '''pitch duration wait velocity instrument | |
| 71 1310 0 20 0 | |
| 48 330 350 20 0 | |
| 55 330 350 20 0 | |
| 64 1310 690 20 0 | |
| 74 660 690 20 0 | |
| 69 1310 0 20 0 | |
| 48 330 350 20 0 | |
| 57 330 350 20 0 | |
| 66 1310 690 20 0 | |
| 67 330 350 20 0 | |
| 69 330 350 20 0 | |
| 71 1310 0 20 0 | |
| 48 330 350 20 0 | |
| 55 330 350 20 0 | |
| 64 1310 690 20 0 | |
| 74 660 690 20 0 | |
| 69 1970 0 20 0 | |
| 48 330 350 20 0 | |
| ''' | |
| print('cpu:', cpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text']) | |
| print('gpu:', gpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text']) | |
| def postprocess(txt, path): | |
| # remove prefix | |
| txt = txt.split('\n\n')[-1] | |
| # track = symusic.core.TrackSecond() | |
| tracks = {} | |
| now = 0 | |
| for line in txt.split('\n'): | |
| # we need to ignore the invalid output by the model | |
| try: | |
| pitch, duration, wait, velocity, instrument = line.split(' ') | |
| pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]] | |
| if instrument not in tracks: | |
| tracks[instrument] = symusic.core.TrackSecond() | |
| if instrument != 'drum': | |
| tracks[instrument].program = int(instrument) | |
| else: | |
| tracks[instrument].is_drum = True | |
| # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter') | |
| tracks[instrument].notes.append(symusic.core.NoteSecond( | |
| time=now/1000, | |
| duration=duration/1000, | |
| pitch=int(pitch), | |
| velocity=int(velocity * 4), | |
| )) | |
| now += wait | |
| except Exception as e: | |
| print(f'Postprocess: Ignored line: "{line}" because of error:', e) | |
| print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes') | |
| try: | |
| # track = symusic.core.TrackSecond() | |
| # track.notes = symusic.core.NoteSecondList(notes) | |
| score = symusic.Score(ttype='Second') | |
| # score.tracks.append(track) | |
| score.tracks.extend(tracks.values()) | |
| score.dump_midi(path) | |
| except Exception as e: | |
| print('Postprocess: Ignored postprocessing error:', e) | |
| with gr.Blocks() as demo: | |
| chatbot_box = gr.Chatbot(type='messages', render_markdown=False, sanitize_html=False) | |
| prefix_box = gr.TextArea(value='Twinkle Twinkle Little Star', label='Score title / text prefix') | |
| with gr.Row(): | |
| submit_btn = gr.Button('Generate') | |
| continue_btn = gr.Button('Continue') | |
| clear_btn = gr.Button('Clear history') | |
| with gr.Row(): | |
| get_audio_btn = gr.Button('Convert to audio') | |
| get_midi_btn = gr.Button('Convert to MIDI') | |
| audio_box = gr.Audio() | |
| midi_box = gr.File() | |
| piano_roll_box = gr.Image() | |
| server_box = gr.Dropdown( | |
| choices=['CPU', 'Huggingface ZeroGPU'], | |
| label='GPU Server', | |
| ) | |
| gr.Markdown(''' | |
| ZeroGPU comes with a time limit currently: | |
| - 3 minutes (not logged in) | |
| - 5 minutes (logged in) | |
| - 25 minutes (Pro user) | |
| CPUs will be slower but there is no time limit. | |
| '''.strip()) | |
| example_box = gr.Examples( | |
| [ | |
| # [example_prefix], | |
| ['Twinkle Twinkle Little Star'], ['Twinkle Twinkle Little Star (Minor Key Version)'], | |
| ['The Entertainer - Scott Joplin (Piano Solo)'], ['Clair de Lune – Debussy'], ['Nocturne | Frederic Chopin'], | |
| ['Fugue I in C major, BWV 846'], ['Beethoven Symphony No. 7 (2nd movement) Piano solo'], | |
| ['Guitar'], | |
| # ['Composer: Chopin'], ['Composer: Bach'], ['Composer: Beethoven'], ['Composer: Debussy'], | |
| ], | |
| inputs=prefix_box, | |
| examples_per_page=9999, | |
| ) | |
| def user_fn(user_message, history: list): | |
| return '', history + [{'role': 'user', 'content': user_message}] | |
| def get_last(history: list): | |
| if len(history) == 0: | |
| raise gr.Error('''No messages to read yet. Try the 'Generate' button first!''') | |
| return history[-1]['content'] | |
| def generate_fn(history, server): | |
| # continue from user input | |
| prefix = get_last(history) | |
| # prevent the model from continuing user's score title | |
| if prefix != '' and '\n' not in prefix: | |
| # prefix is a single line => prefix is the score title | |
| # add '\n' to prevent model from continuing the title | |
| prefix += '\n' | |
| history.append({'role': 'assistant', 'content': ''}) | |
| # history[-1]['content'] += 'Generating with the given prefix...\n' | |
| for history in model_fn(prefix, history, server): | |
| yield history | |
| def continue_fn(history, server): | |
| # continue from the last model output | |
| prefix = history[-1]['content'] | |
| for history in model_fn(prefix, history, server): | |
| yield history | |
| def model_fn(prefix, history, server): | |
| if server == 'Huggingface ZeroGPU': | |
| generator = zerogpu_model_fn(prefix, history, server) | |
| elif server == 'CPU': | |
| generator = cpu_model_fn(prefix, history, server) | |
| # elif server == 'RunPod': | |
| # generator = runpod_model_fn(prefix, history) | |
| else: | |
| raise gr.Error(f'Unknown server: {server}') | |
| for history in generator: | |
| yield history | |
| def cpu_model_fn(prefix, history, server): | |
| queue = Queue(maxsize=10) | |
| if server == 'CPU': | |
| pipe = cpu_pipe | |
| else: | |
| pipe = gpu_pipe | |
| class MyStreamer: | |
| def put(self, tokens): | |
| for token in tokens.flatten(): | |
| text = pipe.tokenizer.decode(token.item()) | |
| if text == '<|begin_of_text|>': | |
| continue | |
| queue.put(text, block=True, timeout=5) | |
| def end(self): | |
| queue.put(None) | |
| def background_fn(): | |
| print(f"{pipe.device = }, {pipe.model.device = }") | |
| print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }") | |
| print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }") | |
| result = pipe( | |
| prefix, | |
| streamer=MyStreamer(), | |
| max_new_tokens=500, | |
| top_p=0.9, temperature=0.6, | |
| ) | |
| print('Generated text:') | |
| print(result[0]['generated_text']) | |
| print() | |
| Thread(target=background_fn).start() | |
| while True: | |
| text = queue.get() | |
| if text is None: | |
| break | |
| history[-1]['content'] += text | |
| yield history | |
| # duration: GPU time out (seconds) | |
| zerogpu_model_fn = spaces.GPU(duration=15)(cpu_model_fn) | |
| def runpod_model_fn(prefix, history): | |
| # NOTE | |
| runpod_api_key = os.getenv('RUNPOD_API_KEY') | |
| runpod_endpoint = os.getenv('RUNPOD_ENDPOINT') | |
| # synchronized request | |
| response = requests.post( | |
| f'https://api.runpod.ai/v2/{runpod_endpoint}/runsync', | |
| headers={'Authorization': f'Bearer {runpod_api_key}'}, | |
| json={'input': {'prompt': prefix}} | |
| ).json()['output'][0]['choices'][0]['tokens'][0] | |
| # yield just once | |
| history[-1]['content'] += response | |
| yield history | |
| submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( | |
| generate_fn, [chatbot_box, server_box], chatbot_box | |
| ) | |
| continue_event = continue_btn.click( | |
| continue_fn, [chatbot_box, server_box], chatbot_box | |
| ) | |
| clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event, continue_event], queue=False) | |
| def get_audio_fn(history): | |
| i = random.randint(0, 1000_000_000) | |
| path = f'./temp/{i}.mid' | |
| text = get_last(history) | |
| try: | |
| postprocess(text, path) | |
| except Exception as e: | |
| raise gr.Error(f'Error: {type(e)}, {e}') | |
| # turn midi into audio with timidity | |
| os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav') | |
| # wav to mp3 | |
| os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3') | |
| return f'./temp/{i}.mp3' | |
| get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False) | |
| def get_midi_fn(history): | |
| i = random.randint(0, 1000_000_000) | |
| # turn the text into midi | |
| text = get_last(history) | |
| try: | |
| postprocess(text, f'./temp/{i}.mid') | |
| except Exception as e: | |
| raise gr.Error(f'Error: {type(e)}, {e}') | |
| # also render the piano roll | |
| import matplotlib.pyplot as plt | |
| plt.figure(figsize=(12, 4)) | |
| now = 0 | |
| for line in history[-1]['content'].split('\n\n')[-1].split('\n'): | |
| try: | |
| pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()] | |
| except Exception as e: | |
| continue | |
| plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1) | |
| plt.scatter(now, pitch, s=6, color='black', alpha=0.3) | |
| now += wait | |
| plt.savefig(f'./temp/{i}.svg') | |
| return f'./temp/{i}.mid', f'./temp/{i}.svg' | |
| get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False) | |
| demo.launch() | |