Spaces:
Runtime error
Runtime error
tf32
Browse files
app.py
CHANGED
|
@@ -189,24 +189,22 @@ def run(model_name, tab, mid_seq, continuation_state, instruments, drum_kit, bpm
|
|
| 189 |
init_msgs += [create_msg("visualizer_clear", tokenizer.version),
|
| 190 |
create_msg("visualizer_append", events)]
|
| 191 |
yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
t = ct
|
| 209 |
-
events = []
|
| 210 |
|
| 211 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
| 212 |
mid = tokenizer.detokenize(mid_seq)
|
|
@@ -307,6 +305,10 @@ if __name__ == "__main__":
|
|
| 307 |
}
|
| 308 |
models = {}
|
| 309 |
if opt.device == "cuda":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 311 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 312 |
for name, (repo_id, path, config) in models_info.items():
|
|
@@ -315,7 +317,7 @@ if __name__ == "__main__":
|
|
| 315 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 316 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 317 |
model.load_state_dict(state_dict, strict=False)
|
| 318 |
-
model.to(device="cpu", dtype=torch.
|
| 319 |
models[name] = model
|
| 320 |
|
| 321 |
load_javascript()
|
|
|
|
| 189 |
init_msgs += [create_msg("visualizer_clear", tokenizer.version),
|
| 190 |
create_msg("visualizer_append", events)]
|
| 191 |
yield mid_seq, continuation_state, None, None, seed, send_msgs(init_msgs)
|
| 192 |
+
midi_generator = generate(model, mid, max_len=max_len, temp=temp, top_p=top_p, top_k=top_k,
|
| 193 |
+
disable_patch_change=disable_patch_change, disable_control_change=not allow_cc,
|
| 194 |
+
disable_channels=disable_channels, generator=generator)
|
| 195 |
+
events = []
|
| 196 |
+
t = time.time() + 1
|
| 197 |
+
for i, token_seq in enumerate(midi_generator):
|
| 198 |
+
token_seq = token_seq.tolist()
|
| 199 |
+
mid_seq.append(token_seq)
|
| 200 |
+
events.append(tokenizer.tokens2event(token_seq))
|
| 201 |
+
ct = time.time()
|
| 202 |
+
if ct - t > 0.5:
|
| 203 |
+
yield (mid_seq, continuation_state, None, None, seed,
|
| 204 |
+
send_msgs([create_msg("visualizer_append", events),
|
| 205 |
+
create_msg("progress", [i + 1, gen_events])]))
|
| 206 |
+
t = ct
|
| 207 |
+
events = []
|
|
|
|
|
|
|
| 208 |
|
| 209 |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq]
|
| 210 |
mid = tokenizer.detokenize(mid_seq)
|
|
|
|
| 305 |
}
|
| 306 |
models = {}
|
| 307 |
if opt.device == "cuda":
|
| 308 |
+
torch.backends.cudnn.deterministic = True
|
| 309 |
+
torch.backends.cudnn.benchmark = False
|
| 310 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 311 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 312 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 313 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 314 |
for name, (repo_id, path, config) in models_info.items():
|
|
|
|
| 317 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 318 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 319 |
model.load_state_dict(state_dict, strict=False)
|
| 320 |
+
model.to(device="cpu", dtype=torch.float32)
|
| 321 |
models[name] = model
|
| 322 |
|
| 323 |
load_javascript()
|