Spaces:
Paused
Paused
merge lora into model
Browse files- app.py +13 -15
- midi_model.py +8 -0
app.py
CHANGED
|
@@ -142,12 +142,7 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
|
|
| 142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
| 143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 145 |
-
model
|
| 146 |
-
if lora_name is None and model.peft_loaded():
|
| 147 |
-
model.disable_adapters()
|
| 148 |
-
elif lora_name is not None:
|
| 149 |
-
model.enable_adapters()
|
| 150 |
-
model.set_adapter(lora_name)
|
| 151 |
model.to(device=opt.device)
|
| 152 |
tokenizer = model.tokenizer
|
| 153 |
bpm = int(bpm)
|
|
@@ -258,7 +253,7 @@ def finish_run(model_name, mid_seq):
|
|
| 258 |
if mid_seq is None:
|
| 259 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 260 |
return *outputs, []
|
| 261 |
-
tokenizer = models[model_name]
|
| 262 |
outputs = []
|
| 263 |
end_msgs = [create_msg("progress", [0, 0])]
|
| 264 |
if not os.path.exists("outputs"):
|
|
@@ -282,7 +277,7 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
| 282 |
if (not should_render_audio) or mid_seq is None:
|
| 283 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 284 |
return tuple(outputs)
|
| 285 |
-
tokenizer = models[model_name]
|
| 286 |
outputs = []
|
| 287 |
if not os.path.exists("outputs"):
|
| 288 |
os.mkdir("outputs")
|
|
@@ -293,13 +288,15 @@ def render_audio(model_name, mid_seq, should_render_audio):
|
|
| 293 |
audio_futures.append(audio_future)
|
| 294 |
for future in audio_futures:
|
| 295 |
outputs.append((44100, future.result()))
|
|
|
|
|
|
|
| 296 |
return tuple(outputs)
|
| 297 |
|
| 298 |
|
| 299 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
| 300 |
if mid_seq is None or len(continuation_state) < 2:
|
| 301 |
return mid_seq, continuation_state, send_msgs([])
|
| 302 |
-
tokenizer = models[model_name]
|
| 303 |
if isinstance(continuation_state[-1], list):
|
| 304 |
mid_seq = continuation_state[-1]
|
| 305 |
else:
|
|
@@ -399,14 +396,15 @@ if __name__ == "__main__":
|
|
| 399 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 400 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 401 |
model.load_state_dict(state_dict, strict=False)
|
| 402 |
-
for lora_name, lora_repo in loras.items():
|
| 403 |
-
model.load_adapter(lora_repo, lora_name)
|
| 404 |
-
if loras:
|
| 405 |
-
model.disable_adapters()
|
| 406 |
model.to(device="cpu", dtype=torch.float32).eval()
|
| 407 |
-
models[name] = model
|
| 408 |
for lora_name, lora_repo in loras.items():
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
load_javascript()
|
| 412 |
app = gr.Blocks()
|
|
|
|
| 142 |
def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
|
| 143 |
key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
|
| 144 |
seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
|
| 145 |
+
model = models[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
model.to(device=opt.device)
|
| 147 |
tokenizer = model.tokenizer
|
| 148 |
bpm = int(bpm)
|
|
|
|
| 253 |
if mid_seq is None:
|
| 254 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 255 |
return *outputs, []
|
| 256 |
+
tokenizer = models[model_name].tokenizer
|
| 257 |
outputs = []
|
| 258 |
end_msgs = [create_msg("progress", [0, 0])]
|
| 259 |
if not os.path.exists("outputs"):
|
|
|
|
| 277 |
if (not should_render_audio) or mid_seq is None:
|
| 278 |
outputs = [None] * OUTPUT_BATCH_SIZE
|
| 279 |
return tuple(outputs)
|
| 280 |
+
tokenizer = models[model_name].tokenizer
|
| 281 |
outputs = []
|
| 282 |
if not os.path.exists("outputs"):
|
| 283 |
os.mkdir("outputs")
|
|
|
|
| 288 |
audio_futures.append(audio_future)
|
| 289 |
for future in audio_futures:
|
| 290 |
outputs.append((44100, future.result()))
|
| 291 |
+
if OUTPUT_BATCH_SIZE == 1:
|
| 292 |
+
return outputs[0]
|
| 293 |
return tuple(outputs)
|
| 294 |
|
| 295 |
|
| 296 |
def undo_continuation(model_name, mid_seq, continuation_state):
|
| 297 |
if mid_seq is None or len(continuation_state) < 2:
|
| 298 |
return mid_seq, continuation_state, send_msgs([])
|
| 299 |
+
tokenizer = models[model_name].tokenizer
|
| 300 |
if isinstance(continuation_state[-1], list):
|
| 301 |
mid_seq = continuation_state[-1]
|
| 302 |
else:
|
|
|
|
| 396 |
ckpt = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 397 |
state_dict = ckpt.get("state_dict", ckpt)
|
| 398 |
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
model.to(device="cpu", dtype=torch.float32).eval()
|
| 400 |
+
models[name] = model
|
| 401 |
for lora_name, lora_repo in loras.items():
|
| 402 |
+
model = MIDIModel(config=MIDIModelConfig.from_name(config))
|
| 403 |
+
model.load_state_dict(state_dict, strict=False)
|
| 404 |
+
print(f"loading lora {lora_repo} for {name}")
|
| 405 |
+
model = model.load_merge_lora(lora_repo)
|
| 406 |
+
model.to(device="cpu", dtype=torch.float32).eval()
|
| 407 |
+
models[f"{name} with {lora_name} lora"] = model
|
| 408 |
|
| 409 |
load_javascript()
|
| 410 |
app = gr.Blocks()
|
midi_model.py
CHANGED
|
@@ -5,6 +5,7 @@ import torch
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tqdm
|
|
|
|
| 8 |
from transformers import LlamaModel, LlamaConfig
|
| 9 |
from transformers.integrations import PeftAdapterMixin
|
| 10 |
|
|
@@ -75,6 +76,13 @@ class MIDIModel(nn.Module, PeftAdapterMixin):
|
|
| 75 |
def peft_loaded(self):
|
| 76 |
return self._hf_peft_config_loaded
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def forward_token(self, hidden_state, x=None):
|
| 79 |
"""
|
| 80 |
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tqdm
|
| 8 |
+
from peft import PeftConfig, LoraModel, load_peft_weights, set_peft_model_state_dict
|
| 9 |
from transformers import LlamaModel, LlamaConfig
|
| 10 |
from transformers.integrations import PeftAdapterMixin
|
| 11 |
|
|
|
|
| 76 |
def peft_loaded(self):
|
| 77 |
return self._hf_peft_config_loaded
|
| 78 |
|
| 79 |
+
def load_merge_lora(self, model_id):
|
| 80 |
+
peft_config = PeftConfig.from_pretrained(model_id)
|
| 81 |
+
model = LoraModel(self, peft_config, adapter_name="default")
|
| 82 |
+
adapter_state_dict = load_peft_weights(model_id, device=self.device)
|
| 83 |
+
set_peft_model_state_dict(self, adapter_state_dict, "default")
|
| 84 |
+
return model.merge_and_unload()
|
| 85 |
+
|
| 86 |
def forward_token(self, hidden_state, x=None):
|
| 87 |
"""
|
| 88 |
|