llama-midi / app.py
dx2102's picture
Update app.py
f2e4fd7 verified
print('Starting...')
import random
import os
import time
import requests
from queue import Queue
from threading import Thread
import symusic
import transformers
import spaces
import gradio as gr
os.makedirs('./temp', exist_ok=True)
print('\n\n\n')
print('Loading model...')
cpu_pipe = transformers.pipeline(
'text-generation',
model='dx2102/llama-midi',
torch_dtype='float32',
device='cpu',
)
gpu_pipe = transformers.pipeline(
'text-generation',
model='dx2102/llama-midi',
torch_dtype='bfloat16',
device='cuda:0',
)
# print devices
print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }")
print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }")
print('Done')
example_prefix = '''pitch duration wait velocity instrument
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''
print('cpu:', cpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text'])
print('gpu:', gpu_pipe(example_prefix, max_new_tokens=10)[0]['generated_text'])
def postprocess(txt, path):
# remove prefix
txt = txt.split('\n\n')[-1]
# track = symusic.core.TrackSecond()
tracks = {}
now = 0
for line in txt.split('\n'):
# we need to ignore the invalid output by the model
try:
pitch, duration, wait, velocity, instrument = line.split(' ')
pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
if instrument not in tracks:
tracks[instrument] = symusic.core.TrackSecond()
if instrument != 'drum':
tracks[instrument].program = int(instrument)
else:
tracks[instrument].is_drum = True
# Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
tracks[instrument].notes.append(symusic.core.NoteSecond(
time=now/1000,
duration=duration/1000,
pitch=int(pitch),
velocity=int(velocity * 4),
))
now += wait
except Exception as e:
print(f'Postprocess: Ignored line: "{line}" because of error:', e)
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')
try:
# track = symusic.core.TrackSecond()
# track.notes = symusic.core.NoteSecondList(notes)
score = symusic.Score(ttype='Second')
# score.tracks.append(track)
score.tracks.extend(tracks.values())
score.dump_midi(path)
except Exception as e:
print('Postprocess: Ignored postprocessing error:', e)
with gr.Blocks() as demo:
chatbot_box = gr.Chatbot(type='messages', render_markdown=False, sanitize_html=False)
prefix_box = gr.TextArea(value='Twinkle Twinkle Little Star', label='Score title / text prefix')
with gr.Row():
submit_btn = gr.Button('Generate')
continue_btn = gr.Button('Continue')
clear_btn = gr.Button('Clear history')
with gr.Row():
get_audio_btn = gr.Button('Convert to audio')
get_midi_btn = gr.Button('Convert to MIDI')
audio_box = gr.Audio()
midi_box = gr.File()
piano_roll_box = gr.Image()
server_box = gr.Dropdown(
choices=['CPU', 'Huggingface ZeroGPU'],
label='GPU Server',
)
gr.Markdown('''
ZeroGPU comes with a time limit currently:
- 3 minutes (not logged in)
- 5 minutes (logged in)
- 25 minutes (Pro user)
CPUs will be slower but there is no time limit.
'''.strip())
example_box = gr.Examples(
[
# [example_prefix],
['Twinkle Twinkle Little Star'], ['Twinkle Twinkle Little Star (Minor Key Version)'],
['The Entertainer - Scott Joplin (Piano Solo)'], ['Clair de Lune – Debussy'], ['Nocturne | Frederic Chopin'],
['Fugue I in C major, BWV 846'], ['Beethoven Symphony No. 7 (2nd movement) Piano solo'],
['Guitar'],
# ['Composer: Chopin'], ['Composer: Bach'], ['Composer: Beethoven'], ['Composer: Debussy'],
],
inputs=prefix_box,
examples_per_page=9999,
)
def user_fn(user_message, history: list):
return '', history + [{'role': 'user', 'content': user_message}]
def get_last(history: list):
if len(history) == 0:
raise gr.Error('''No messages to read yet. Try the 'Generate' button first!''')
return history[-1]['content']
def generate_fn(history, server):
# continue from user input
prefix = get_last(history)
# prevent the model from continuing user's score title
if prefix != '' and '\n' not in prefix:
# prefix is a single line => prefix is the score title
# add '\n' to prevent model from continuing the title
prefix += '\n'
history.append({'role': 'assistant', 'content': ''})
# history[-1]['content'] += 'Generating with the given prefix...\n'
for history in model_fn(prefix, history, server):
yield history
def continue_fn(history, server):
# continue from the last model output
prefix = history[-1]['content']
for history in model_fn(prefix, history, server):
yield history
def model_fn(prefix, history, server):
if server == 'Huggingface ZeroGPU':
generator = zerogpu_model_fn(prefix, history, server)
elif server == 'CPU':
generator = cpu_model_fn(prefix, history, server)
# elif server == 'RunPod':
# generator = runpod_model_fn(prefix, history)
else:
raise gr.Error(f'Unknown server: {server}')
for history in generator:
yield history
def cpu_model_fn(prefix, history, server):
queue = Queue(maxsize=10)
if server == 'CPU':
pipe = cpu_pipe
else:
pipe = gpu_pipe
class MyStreamer:
def put(self, tokens):
for token in tokens.flatten():
text = pipe.tokenizer.decode(token.item())
if text == '<|begin_of_text|>':
continue
queue.put(text, block=True, timeout=5)
def end(self):
queue.put(None)
def background_fn():
print(f"{pipe.device = }, {pipe.model.device = }")
print(f"{gpu_pipe.device = }, {gpu_pipe.model.device = }")
print(f"{cpu_pipe.device = }, {cpu_pipe.model.device = }")
result = pipe(
prefix,
streamer=MyStreamer(),
max_new_tokens=500,
top_p=0.9, temperature=0.6,
)
print('Generated text:')
print(result[0]['generated_text'])
print()
Thread(target=background_fn).start()
while True:
text = queue.get()
if text is None:
break
history[-1]['content'] += text
yield history
# duration: GPU time out (seconds)
zerogpu_model_fn = spaces.GPU(duration=15)(cpu_model_fn)
def runpod_model_fn(prefix, history):
# NOTE
runpod_api_key = os.getenv('RUNPOD_API_KEY')
runpod_endpoint = os.getenv('RUNPOD_ENDPOINT')
# synchronized request
response = requests.post(
f'https://api.runpod.ai/v2/{runpod_endpoint}/runsync',
headers={'Authorization': f'Bearer {runpod_api_key}'},
json={'input': {'prompt': prefix}}
).json()['output'][0]['choices'][0]['tokens'][0]
# yield just once
history[-1]['content'] += response
yield history
submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
generate_fn, [chatbot_box, server_box], chatbot_box
)
continue_event = continue_btn.click(
continue_fn, [chatbot_box, server_box], chatbot_box
)
clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event, continue_event], queue=False)
def get_audio_fn(history):
i = random.randint(0, 1000_000_000)
path = f'./temp/{i}.mid'
text = get_last(history)
try:
postprocess(text, path)
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# turn midi into audio with timidity
os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
# wav to mp3
os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
return f'./temp/{i}.mp3'
get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)
def get_midi_fn(history):
i = random.randint(0, 1000_000_000)
# turn the text into midi
text = get_last(history)
try:
postprocess(text, f'./temp/{i}.mid')
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# also render the piano roll
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
now = 0
for line in history[-1]['content'].split('\n\n')[-1].split('\n'):
try:
pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
except Exception as e:
continue
plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1)
plt.scatter(now, pitch, s=6, color='black', alpha=0.3)
now += wait
plt.savefig(f'./temp/{i}.svg')
return f'./temp/{i}.mid', f'./temp/{i}.svg'
get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)
demo.launch()