Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
import os
|
| 3 |
import time
|
|
@@ -18,13 +20,19 @@ os.makedirs('./temp', exist_ok=True)
|
|
| 18 |
print('\n\n\n')
|
| 19 |
print('Loading model...')
|
| 20 |
pipe = transformers.pipeline(
|
| 21 |
-
|
| 22 |
-
model=
|
| 23 |
-
# revision=
|
| 24 |
-
torch_dtype=
|
| 25 |
-
device=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
-
cpu_pipe = pipe.to("cpu")
|
| 28 |
print('Done')
|
| 29 |
|
| 30 |
example_prefix = '''pitch duration wait velocity instrument
|
|
@@ -78,7 +86,7 @@ def postprocess(txt, path):
|
|
| 78 |
))
|
| 79 |
now += wait
|
| 80 |
except Exception as e:
|
| 81 |
-
print(f'Postprocess: Ignored line:
|
| 82 |
|
| 83 |
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')
|
| 84 |
|
|
@@ -97,21 +105,21 @@ def postprocess(txt, path):
|
|
| 97 |
|
| 98 |
|
| 99 |
with gr.Blocks() as demo:
|
| 100 |
-
chatbot_box = gr.Chatbot(type=
|
| 101 |
-
prefix_box = gr.TextArea(value=
|
| 102 |
with gr.Row():
|
| 103 |
-
submit_btn = gr.Button(
|
| 104 |
-
continue_btn = gr.Button(
|
| 105 |
-
clear_btn = gr.Button(
|
| 106 |
with gr.Row():
|
| 107 |
-
get_audio_btn = gr.Button(
|
| 108 |
-
get_midi_btn = gr.Button(
|
| 109 |
audio_box = gr.Audio()
|
| 110 |
midi_box = gr.File()
|
| 111 |
piano_roll_box = gr.Image()
|
| 112 |
server_box = gr.Dropdown(
|
| 113 |
-
choices=[
|
| 114 |
-
label=
|
| 115 |
)
|
| 116 |
gr.Markdown('''
|
| 117 |
ZeroGPU comes with a time limit currently:
|
|
@@ -124,23 +132,23 @@ CPUs will be slower but there is no time limit.
|
|
| 124 |
example_box = gr.Examples(
|
| 125 |
[
|
| 126 |
# [example_prefix],
|
| 127 |
-
[
|
| 128 |
-
[
|
| 129 |
-
[
|
| 130 |
-
[
|
| 131 |
-
# [
|
| 132 |
],
|
| 133 |
inputs=prefix_box,
|
| 134 |
examples_per_page=9999,
|
| 135 |
)
|
| 136 |
|
| 137 |
def user_fn(user_message, history: list):
|
| 138 |
-
return
|
| 139 |
|
| 140 |
def get_last(history: list):
|
| 141 |
if len(history) == 0:
|
| 142 |
-
raise gr.Error('''No messages to read yet. Try the
|
| 143 |
-
return history[-1][
|
| 144 |
|
| 145 |
def generate_fn(history, server):
|
| 146 |
# continue from user input
|
|
@@ -151,14 +159,14 @@ CPUs will be slower but there is no time limit.
|
|
| 151 |
# add '\n' to prevent model from continuing the title
|
| 152 |
prefix += '\n'
|
| 153 |
|
| 154 |
-
history.append({
|
| 155 |
-
# history[-1][
|
| 156 |
for history in model_fn(prefix, history, server):
|
| 157 |
yield history
|
| 158 |
|
| 159 |
def continue_fn(history, server):
|
| 160 |
# continue from the last model output
|
| 161 |
-
prefix = history[-1][
|
| 162 |
for history in model_fn(prefix, history, server):
|
| 163 |
yield history
|
| 164 |
|
|
@@ -166,14 +174,14 @@ CPUs will be slower but there is no time limit.
|
|
| 166 |
|
| 167 |
|
| 168 |
def model_fn(prefix, history, server):
|
| 169 |
-
if server ==
|
| 170 |
generator = zerogpu_model_fn(prefix, history, pipe)
|
| 171 |
-
elif server ==
|
| 172 |
generator = cpu_model_fn(prefix, history, cpu_pipe)
|
| 173 |
-
# elif server ==
|
| 174 |
# generator = runpod_model_fn(prefix, history)
|
| 175 |
else:
|
| 176 |
-
raise gr.Error(f
|
| 177 |
for history in generator:
|
| 178 |
yield history
|
| 179 |
|
|
@@ -203,7 +211,7 @@ CPUs will be slower but there is no time limit.
|
|
| 203 |
text = queue.get()
|
| 204 |
if text is None:
|
| 205 |
break
|
| 206 |
-
history[-1][
|
| 207 |
yield history
|
| 208 |
|
| 209 |
zerogpu_model_fn = spaces.GPU(cpu_model_fn)
|
|
@@ -215,12 +223,12 @@ CPUs will be slower but there is no time limit.
|
|
| 215 |
|
| 216 |
# synchronized request
|
| 217 |
response = requests.post(
|
| 218 |
-
f
|
| 219 |
-
headers={
|
| 220 |
-
json={
|
| 221 |
).json()['output'][0]['choices'][0]['tokens'][0]
|
| 222 |
# yield just once
|
| 223 |
-
history[-1][
|
| 224 |
yield history
|
| 225 |
|
| 226 |
|
|
@@ -266,7 +274,7 @@ CPUs will be slower but there is no time limit.
|
|
| 266 |
import matplotlib.pyplot as plt
|
| 267 |
plt.figure(figsize=(12, 4))
|
| 268 |
now = 0
|
| 269 |
-
for line in history[-1][
|
| 270 |
try:
|
| 271 |
pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
|
| 272 |
except Exception as e:
|
|
|
|
| 1 |
+
print('Starting...')
|
| 2 |
+
|
| 3 |
import random
|
| 4 |
import os
|
| 5 |
import time
|
|
|
|
| 20 |
print('\n\n\n')
|
| 21 |
print('Loading model...')
|
| 22 |
pipe = transformers.pipeline(
|
| 23 |
+
'text-generation',
|
| 24 |
+
model='dx2102/llama-midi',
|
| 25 |
+
# revision='c303c108399aba837146e893375849b918f413b3',
|
| 26 |
+
torch_dtype='bfloat16',
|
| 27 |
+
device='cuda',
|
| 28 |
+
)
|
| 29 |
+
cpu_pipe = transformers.pipeline(
|
| 30 |
+
'text-generation',
|
| 31 |
+
model='dx2102/llama-midi',
|
| 32 |
+
# revision='c303c108399aba837146e893375849b918f413b3',
|
| 33 |
+
torch_dtype='float32',
|
| 34 |
+
device='cpu',
|
| 35 |
)
|
|
|
|
| 36 |
print('Done')
|
| 37 |
|
| 38 |
example_prefix = '''pitch duration wait velocity instrument
|
|
|
|
| 86 |
))
|
| 87 |
now += wait
|
| 88 |
except Exception as e:
|
| 89 |
+
print(f'Postprocess: Ignored line: '{line}' because of error:', e)
|
| 90 |
|
| 91 |
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')
|
| 92 |
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
with gr.Blocks() as demo:
|
| 108 |
+
chatbot_box = gr.Chatbot(type='messages', render_markdown=False, sanitize_html=False)
|
| 109 |
+
prefix_box = gr.TextArea(value='Twinkle Twinkle Little Star', label='Score title / text prefix')
|
| 110 |
with gr.Row():
|
| 111 |
+
submit_btn = gr.Button('Generate')
|
| 112 |
+
continue_btn = gr.Button('Continue')
|
| 113 |
+
clear_btn = gr.Button('Clear history')
|
| 114 |
with gr.Row():
|
| 115 |
+
get_audio_btn = gr.Button('Convert to audio')
|
| 116 |
+
get_midi_btn = gr.Button('Convert to MIDI')
|
| 117 |
audio_box = gr.Audio()
|
| 118 |
midi_box = gr.File()
|
| 119 |
piano_roll_box = gr.Image()
|
| 120 |
server_box = gr.Dropdown(
|
| 121 |
+
choices=['Huggingface ZeroGPU', 'CPU'],
|
| 122 |
+
label='GPU Server',
|
| 123 |
)
|
| 124 |
gr.Markdown('''
|
| 125 |
ZeroGPU comes with a time limit currently:
|
|
|
|
| 132 |
example_box = gr.Examples(
|
| 133 |
[
|
| 134 |
# [example_prefix],
|
| 135 |
+
['Twinkle Twinkle Little Star'], ['Twinkle Twinkle Little Star (Minor Key Version)'],
|
| 136 |
+
['The Entertainer - Scott Joplin (Piano Solo)'], ['Clair de Lune – Debussy'], ['Nocturne | Frederic Chopin'],
|
| 137 |
+
['Fugue I in C major, BWV 846'], ['Beethoven Symphony No. 7 (2nd movement) Piano solo'],
|
| 138 |
+
['Guitar'],
|
| 139 |
+
# ['Composer: Chopin'], ['Composer: Bach'], ['Composer: Beethoven'], ['Composer: Debussy'],
|
| 140 |
],
|
| 141 |
inputs=prefix_box,
|
| 142 |
examples_per_page=9999,
|
| 143 |
)
|
| 144 |
|
| 145 |
def user_fn(user_message, history: list):
|
| 146 |
+
return '', history + [{'role': 'user', 'content': user_message}]
|
| 147 |
|
| 148 |
def get_last(history: list):
|
| 149 |
if len(history) == 0:
|
| 150 |
+
raise gr.Error('''No messages to read yet. Try the 'Generate' button first!''')
|
| 151 |
+
return history[-1]['content']
|
| 152 |
|
| 153 |
def generate_fn(history, server):
|
| 154 |
# continue from user input
|
|
|
|
| 159 |
# add '\n' to prevent model from continuing the title
|
| 160 |
prefix += '\n'
|
| 161 |
|
| 162 |
+
history.append({'role': 'assistant', 'content': ''})
|
| 163 |
+
# history[-1]['content'] += 'Generating with the given prefix...\n'
|
| 164 |
for history in model_fn(prefix, history, server):
|
| 165 |
yield history
|
| 166 |
|
| 167 |
def continue_fn(history, server):
|
| 168 |
# continue from the last model output
|
| 169 |
+
prefix = history[-1]['content']
|
| 170 |
for history in model_fn(prefix, history, server):
|
| 171 |
yield history
|
| 172 |
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
def model_fn(prefix, history, server):
|
| 177 |
+
if server == 'Huggingface ZeroGPU':
|
| 178 |
generator = zerogpu_model_fn(prefix, history, pipe)
|
| 179 |
+
elif server == 'CPU':
|
| 180 |
generator = cpu_model_fn(prefix, history, cpu_pipe)
|
| 181 |
+
# elif server == 'RunPod':
|
| 182 |
# generator = runpod_model_fn(prefix, history)
|
| 183 |
else:
|
| 184 |
+
raise gr.Error(f'Unknown server: {server}')
|
| 185 |
for history in generator:
|
| 186 |
yield history
|
| 187 |
|
|
|
|
| 211 |
text = queue.get()
|
| 212 |
if text is None:
|
| 213 |
break
|
| 214 |
+
history[-1]['content'] += text
|
| 215 |
yield history
|
| 216 |
|
| 217 |
zerogpu_model_fn = spaces.GPU(cpu_model_fn)
|
|
|
|
| 223 |
|
| 224 |
# synchronized request
|
| 225 |
response = requests.post(
|
| 226 |
+
f'https://api.runpod.ai/v2/{runpod_endpoint}/runsync',
|
| 227 |
+
headers={'Authorization': f'Bearer {runpod_api_key}'},
|
| 228 |
+
json={'input': {'prompt': prefix}}
|
| 229 |
).json()['output'][0]['choices'][0]['tokens'][0]
|
| 230 |
# yield just once
|
| 231 |
+
history[-1]['content'] += response
|
| 232 |
yield history
|
| 233 |
|
| 234 |
|
|
|
|
| 274 |
import matplotlib.pyplot as plt
|
| 275 |
plt.figure(figsize=(12, 4))
|
| 276 |
now = 0
|
| 277 |
+
for line in history[-1]['content'].split('\n\n')[-1].split('\n'):
|
| 278 |
try:
|
| 279 |
pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
|
| 280 |
except Exception as e:
|