Spaces:
Sleeping
Sleeping
add lora
Browse files- app.py +32 -12
- midi_model.py +5 -1
app.py
CHANGED
|
@@ -142,7 +142,12 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
|
|
| 142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
| 143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 145 |
-
model = models[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
model.to(device=opt.device)
|
| 147 |
tokenizer = model.tokenizer
|
| 148 |
bpm = int(bpm)
|
|
@@ -253,7 +258,7 @@ def finish_run(model_name, mid_seq):
|
|
| 253 |
if mid_seq is None:
|
| 254 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 255 |
return *outputs, []
|
| 256 |
-
tokenizer = models[model_name].tokenizer
|
| 257 |
outputs = []
|
| 258 |
end_msgs = [create_msg("progress", [0, 0])]
|
| 259 |
if not os.path.exists("outputs"):
|
|
@@ -277,7 +282,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
| 277 |
if (not should_render_audio) or mid_seq is None:
|
| 278 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 279 |
return tuple(outputs)
|
| 280 |
-
tokenizer = models[model_name].tokenizer
|
| 281 |
outputs = []
|
| 282 |
if not os.path.exists("outputs"):
|
| 283 |
os.mkdir("outputs")
|
|
@@ -294,7 +299,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
| 294 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
| 295 |
if mid_seq is None or len(continuation_state) < 2:
|
| 296 |
return mid_seq, continuation_state, send_msgs([])
|
| 297 |
-
tokenizer = models[model_name].tokenizer
|
| 298 |
if isinstance(continuation_state[-1], list):
|
| 299 |
mid_seq = continuation_state[-1]
|
| 300 |
else:
|
|
@@ -364,12 +369,21 @@ if __name__ == "__main__":
|
|
| 364 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
| 365 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 366 |
models_info = {
|
| 367 |
-
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
}
|
| 374 |
models = {}
|
| 375 |
if opt.device == "cuda":
|
|
@@ -379,14 +393,20 @@ if __name__ == "__main__":
|
|
| 379 |
torch.backends.cudnn.allow_tf32 = True
|
| 380 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 381 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 382 |
-
for name, (repo_id, path, config) in models_info.items():
|
| 383 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 384 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
| 385 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 386 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 387 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
model.to(device="cpu", dtype=torch.float32).eval()
|
| 389 |
-
models[name] = model
|
|
|
|
|
|
|
| 390 |
|
| 391 |
load_javascript()
|
| 392 |
app = gr.Blocks()
|
|
|
|
| 142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
| 143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 145 |
+
model, lora_name = models[model_name]
|
| 146 |
+
if lora_name is None and model.peft_loaded():
|
| 147 |
+
model.disable_adapters()
|
| 148 |
+
elif lora_name is not None:
|
| 149 |
+
model.enable_adapters()
|
| 150 |
+
model.set_adapter(lora_name)
|
| 151 |
model.to(device=opt.device)
|
| 152 |
tokenizer = model.tokenizer
|
| 153 |
bpm = int(bpm)
|
|
|
|
| 258 |
if mid_seq is None:
|
| 259 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 260 |
return *outputs, []
|
| 261 |
+
tokenizer = models[model_name][0].tokenizer
|
| 262 |
outputs = []
|
| 263 |
end_msgs = [create_msg("progress", [0, 0])]
|
| 264 |
if not os.path.exists("outputs"):
|
|
|
|
| 282 |
if (not should_render_audio) or mid_seq is None:
|
| 283 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 284 |
return tuple(outputs)
|
| 285 |
+
tokenizer = models[model_name][0].tokenizer
|
| 286 |
outputs = []
|
| 287 |
if not os.path.exists("outputs"):
|
| 288 |
os.mkdir("outputs")
|
|
|
|
| 299 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
| 300 |
if mid_seq is None or len(continuation_state) < 2:
|
| 301 |
return mid_seq, continuation_state, send_msgs([])
|
| 302 |
+
tokenizer = models[model_name][0].tokenizer
|
| 303 |
if isinstance(continuation_state[-1], list):
|
| 304 |
mid_seq = continuation_state[-1]
|
| 305 |
else:
|
|
|
|
| 369 |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
|
| 370 |
synthesizer = MidiSynthesizer(soundfont_path)
|
| 371 |
models_info = {
|
| 372 |
+
"generic pretrain model (tv2o-medium) by skytnt": [
|
| 373 |
+
"skytnt/midi-model-tv2o-medium", "", "tv2o-medium", {
|
| 374 |
+
"jpop": "skytnt/midi-model-tv2om-jpop-lora",
|
| 375 |
+
"touhou": "skytnt/midi-model-tv2om-touhou-lora"
|
| 376 |
+
}
|
| 377 |
+
],
|
| 378 |
+
"generic pretrain model (tv2o-large) by asigalov61": [
|
| 379 |
+
"asigalov61/Music-Llama", "", "tv2o-large", {}
|
| 380 |
+
],
|
| 381 |
+
"generic pretrain model (tv2o-medium) by asigalov61": [
|
| 382 |
+
"asigalov61/Music-Llama-Medium", "", "tv2o-medium", {}
|
| 383 |
+
],
|
| 384 |
+
"generic pretrain model (tv1-medium) by skytnt": [
|
| 385 |
+
"skytnt/midi-model", "", "tv1-medium", {}
|
| 386 |
+
]
|
| 387 |
}
|
| 388 |
models = {}
|
| 389 |
if opt.device == "cuda":
|
|
|
|
| 393 |
torch.backends.cudnn.allow_tf32 = True
|
| 394 |
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 395 |
torch.backends.cuda.enable_flash_sdp(True)
|
| 396 |
+
for name, (repo_id, path, config, loras) in models_info.items():
|
| 397 |
model_path = hf_hub_download_retry(repo_id=repo_id, filename=f"{path}model.ckpt")
|
| 398 |
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
| 399 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 400 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 401 |
model.load_state_dict(state_dict, strict=False)
|
| 402 |
+
for lora_name, lora_repo in loras.items():
|
| 403 |
+
model.load_adapter(lora_repo, lora_name)
|
| 404 |
+
if loras:
|
| 405 |
+
model.disable_adapters()
|
| 406 |
model.to(device="cpu", dtype=torch.float32).eval()
|
| 407 |
+
models[name] = model, None
|
| 408 |
+
for lora_name, lora_repo in loras.items():
|
| 409 |
+
models[f"{name} with {lora_name} lora"] = model, lora_name
|
| 410 |
|
| 411 |
load_javascript()
|
| 412 |
app = gr.Blocks()
|
midi_model.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tqdm
|
| 8 |
from transformers import LlamaModel, LlamaConfig
|
|
|
|
| 9 |
|
| 10 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
| 11 |
|
|
@@ -55,7 +56,7 @@ class MIDIModelConfig:
|
|
| 55 |
raise ValueError(f"Unknown model size {size}")
|
| 56 |
|
| 57 |
|
| 58 |
-
class MIDIModel(nn.Module):
|
| 59 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
| 60 |
super(MIDIModel, self).__init__()
|
| 61 |
self.tokenizer = config.tokenizer
|
|
@@ -69,6 +70,9 @@ class MIDIModel(nn.Module):
|
|
| 69 |
self.device = kwargs["device"]
|
| 70 |
return super(MIDIModel, self).to(*args, **kwargs)
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
def forward_token(self, hidden_state, x=None):
|
| 73 |
"""
|
| 74 |
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tqdm
|
| 8 |
from transformers import LlamaModel, LlamaConfig
|
| 9 |
+
from transformers.integrations import PeftAdapterMixin
|
| 10 |
|
| 11 |
from midi_tokenizer import MIDITokenizerV1, MIDITokenizerV2, MIDITokenizer
|
| 12 |
|
|
|
|
| 56 |
raise ValueError(f"Unknown model size {size}")
|
| 57 |
|
| 58 |
|
| 59 |
+
class MIDIModel(nn.Module, PeftAdapterMixin):
|
| 60 |
def __init__(self, config: MIDIModelConfig, *args, **kwargs):
|
| 61 |
super(MIDIModel, self).__init__()
|
| 62 |
self.tokenizer = config.tokenizer
|
|
|
|
| 70 |
self.device = kwargs["device"]
|
| 71 |
return super(MIDIModel, self).to(*args, **kwargs)
|
| 72 |
|
| 73 |
+
def peft_loaded(self):
|
| 74 |
+
return self._hf_peft_config_loaded
|
| 75 |
+
|
| 76 |
def forward_token(self, hidden_state, x=None):
|
| 77 |
"""
|
| 78 |
|