Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -111,6 +111,8 @@ def load_model(model_selector):
|
|
| 111 |
print('Model will use', dtype, 'precision...')
|
| 112 |
print('=' * 70)
|
| 113 |
|
|
|
|
|
|
|
| 114 |
#==================================================================================
|
| 115 |
|
| 116 |
def load_midi(input_midi):
|
|
@@ -195,7 +197,8 @@ def generate_music(prime,
|
|
| 195 |
num_mem_tokens,
|
| 196 |
num_gen_batches,
|
| 197 |
model_temperature,
|
| 198 |
-
# model_sampling_top_p
|
|
|
|
| 199 |
):
|
| 200 |
|
| 201 |
if not prime:
|
|
@@ -203,6 +206,8 @@ def generate_music(prime,
|
|
| 203 |
|
| 204 |
else:
|
| 205 |
inputs = prime[-num_mem_tokens:]
|
|
|
|
|
|
|
| 206 |
|
| 207 |
model.cuda()
|
| 208 |
model.eval()
|
|
@@ -240,7 +245,8 @@ def generate_callback(input_midi,
|
|
| 240 |
# model_sampling_top_p,
|
| 241 |
final_composition,
|
| 242 |
generated_batches,
|
| 243 |
-
block_lines
|
|
|
|
| 244 |
):
|
| 245 |
|
| 246 |
generated_batches = []
|
|
@@ -255,7 +261,8 @@ def generate_callback(input_midi,
|
|
| 255 |
num_mem_tokens,
|
| 256 |
NUM_OUT_BATCHES,
|
| 257 |
model_temperature,
|
| 258 |
-
# model_sampling_top_p
|
|
|
|
| 259 |
)
|
| 260 |
|
| 261 |
outputs = []
|
|
@@ -310,7 +317,8 @@ def generate_callback_wrapper(input_midi,
|
|
| 310 |
final_composition,
|
| 311 |
generated_batches,
|
| 312 |
block_lines,
|
| 313 |
-
model_selector
|
|
|
|
| 314 |
):
|
| 315 |
|
| 316 |
print('=' * 70)
|
|
@@ -325,7 +333,8 @@ def generate_callback_wrapper(input_midi,
|
|
| 325 |
|
| 326 |
print('Selected model type:', model_selector)
|
| 327 |
|
| 328 |
-
|
|
|
|
| 329 |
|
| 330 |
print('Num prime tokens:', num_prime_tokens)
|
| 331 |
print('Num gen tokens:', num_gen_tokens)
|
|
@@ -343,7 +352,8 @@ def generate_callback_wrapper(input_midi,
|
|
| 343 |
# model_sampling_top_p,
|
| 344 |
final_composition,
|
| 345 |
generated_batches,
|
| 346 |
-
block_lines
|
|
|
|
| 347 |
)
|
| 348 |
|
| 349 |
generated_batches = [sublist[-1] for sublist in result[0]]
|
|
@@ -481,6 +491,7 @@ with gr.Blocks() as demo:
|
|
| 481 |
final_composition = gr.State([])
|
| 482 |
generated_batches = gr.State([])
|
| 483 |
block_lines = gr.State([])
|
|
|
|
| 484 |
|
| 485 |
#==================================================================================
|
| 486 |
|
|
@@ -529,7 +540,8 @@ with gr.Blocks() as demo:
|
|
| 529 |
final_composition,
|
| 530 |
generated_batches,
|
| 531 |
block_lines,
|
| 532 |
-
model_selector
|
|
|
|
| 533 |
],
|
| 534 |
outputs
|
| 535 |
)
|
|
|
|
| 111 |
print('Model will use', dtype, 'precision...')
|
| 112 |
print('=' * 70)
|
| 113 |
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
#==================================================================================
|
| 117 |
|
| 118 |
def load_midi(input_midi):
|
|
|
|
| 197 |
num_mem_tokens,
|
| 198 |
num_gen_batches,
|
| 199 |
model_temperature,
|
| 200 |
+
# model_sampling_top_p,
|
| 201 |
+
model_state
|
| 202 |
):
|
| 203 |
|
| 204 |
if not prime:
|
|
|
|
| 206 |
|
| 207 |
else:
|
| 208 |
inputs = prime[-num_mem_tokens:]
|
| 209 |
+
|
| 210 |
+
model = model_state
|
| 211 |
|
| 212 |
model.cuda()
|
| 213 |
model.eval()
|
|
|
|
| 245 |
# model_sampling_top_p,
|
| 246 |
final_composition,
|
| 247 |
generated_batches,
|
| 248 |
+
block_lines,
|
| 249 |
+
model_state
|
| 250 |
):
|
| 251 |
|
| 252 |
generated_batches = []
|
|
|
|
| 261 |
num_mem_tokens,
|
| 262 |
NUM_OUT_BATCHES,
|
| 263 |
model_temperature,
|
| 264 |
+
# model_sampling_top_p,
|
| 265 |
+
model_state
|
| 266 |
)
|
| 267 |
|
| 268 |
outputs = []
|
|
|
|
| 317 |
final_composition,
|
| 318 |
generated_batches,
|
| 319 |
block_lines,
|
| 320 |
+
model_selector,
|
| 321 |
+
model_state
|
| 322 |
):
|
| 323 |
|
| 324 |
print('=' * 70)
|
|
|
|
| 333 |
|
| 334 |
print('Selected model type:', model_selector)
|
| 335 |
|
| 336 |
+
if not model_State:
|
| 337 |
+
model_state = load_model(model_selector)
|
| 338 |
|
| 339 |
print('Num prime tokens:', num_prime_tokens)
|
| 340 |
print('Num gen tokens:', num_gen_tokens)
|
|
|
|
| 352 |
# model_sampling_top_p,
|
| 353 |
final_composition,
|
| 354 |
generated_batches,
|
| 355 |
+
block_lines,
|
| 356 |
+
model_state
|
| 357 |
)
|
| 358 |
|
| 359 |
generated_batches = [sublist[-1] for sublist in result[0]]
|
|
|
|
| 491 |
final_composition = gr.State([])
|
| 492 |
generated_batches = gr.State([])
|
| 493 |
block_lines = gr.State([])
|
| 494 |
+
model_state = gr.State([])
|
| 495 |
|
| 496 |
#==================================================================================
|
| 497 |
|
|
|
|
| 540 |
final_composition,
|
| 541 |
generated_batches,
|
| 542 |
block_lines,
|
| 543 |
+
model_selector,
|
| 544 |
+
model_state
|
| 545 |
],
|
| 546 |
outputs
|
| 547 |
)
|