Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,6 +16,8 @@ import MIDI
|
|
| 16 |
from midi_synthesizer import synthesis
|
| 17 |
from midi_tokenizer import MIDITokenizer
|
| 18 |
|
|
|
|
|
|
|
| 19 |
def softmax(x, axis):
|
| 20 |
x_max = np.amax(x, axis=axis, keepdims=True)
|
| 21 |
exp_x_shifted = np.exp(x - x_max)
|
|
@@ -58,7 +60,7 @@ def generate(prompt=None, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
|
| 58 |
input_tensor = prompt
|
| 59 |
input_tensor = input_tensor[None, :, :]
|
| 60 |
cur_len = input_tensor.shape[1]
|
| 61 |
-
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len)
|
| 62 |
with bar:
|
| 63 |
while cur_len < max_len:
|
| 64 |
end = False
|
|
@@ -204,7 +206,7 @@ if __name__ == "__main__":
|
|
| 204 |
parser = argparse.ArgumentParser()
|
| 205 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 206 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 207 |
-
parser.add_argument("--max-gen", type=int, default=
|
| 208 |
opt = parser.parse_args()
|
| 209 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 210 |
model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
|
|
|
|
| 16 |
from midi_synthesizer import synthesis
|
| 17 |
from midi_tokenizer import MIDITokenizer
|
| 18 |
|
| 19 |
+
in_space = os.getenv("SYSTEM") == "spaces"
|
| 20 |
+
|
| 21 |
def softmax(x, axis):
|
| 22 |
x_max = np.amax(x, axis=axis, keepdims=True)
|
| 23 |
exp_x_shifted = np.exp(x - x_max)
|
|
|
|
| 60 |
input_tensor = prompt
|
| 61 |
input_tensor = input_tensor[None, :, :]
|
| 62 |
cur_len = input_tensor.shape[1]
|
| 63 |
+
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
| 64 |
with bar:
|
| 65 |
while cur_len < max_len:
|
| 66 |
end = False
|
|
|
|
| 206 |
parser = argparse.ArgumentParser()
|
| 207 |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
|
| 208 |
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
|
| 209 |
+
parser.add_argument("--max-gen", type=int, default=512, help="max")
|
| 210 |
opt = parser.parse_args()
|
| 211 |
soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
|
| 212 |
model_base_path = hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx")
|