Steveeeeeeen's picture
Steveeeeeeen HF Staff
add model
7e6946d
import os
from glob import glob
import torch
from safetensors import safe_open
from torch import nn
from flashcosyvoice.config import CosyVoice2LLMConfig
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def load_text_llm(model: nn.Module, path: str):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = param.weight_loader
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))
def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig):
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
# NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head
embedding_weights = {}
tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True)
missed, missed_names = 0, []
for k, v in tmp_weights.items():
if k == "speech_embedding.weight": # torch.Size([6564, 896])
speech_embedding_size = hf_config.speech_vocab_size # 6562
# NOTE(xcsong): padding to 6592 for vllm tensor parallel
if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896]
assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}"
v = v[:speech_embedding_size, :]
embedding_weights["speech_embedding.weight"] = v
elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id
assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}"
embedding_weights["llm_embedding.weight"] = v
elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896])
embedding_weights["model.embed_tokens.weight"] = v
elif k == "llm_decoder.weight": # torch.Size([6564, 896])
lm_head_size = hf_config.speech_vocab_size # 6562
if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896]
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
v = v[:lm_head_size, :]
param = model.get_parameter("lm_head.weight")
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, v)
elif k == "llm_decoder.bias": # torch.Size([6564])
lm_head_size = hf_config.speech_vocab_size # 6562
if lm_head_size != v.shape[0]: # [6564] -> [6562]
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
v = v[:lm_head_size]
param = model.get_parameter("lm_head.bias")
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, v)
elif "llm.model." in k:
weight_name = k.replace("llm.model.", "")
for kk in packed_modules_mapping:
if kk in weight_name:
vv, shard_id = packed_modules_mapping[kk]
param_name = weight_name.replace(kk, vv)
try:
param = model.get_parameter(param_name)
weight_loader = param.weight_loader
weight_loader(param, v, shard_id)
break
except Exception as e:
print(e)
print(f"skip parameter (1): {weight_name}")
continue
else:
try:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, v)
except Exception as e:
print(e)
print(f"skip parameter (2): {weight_name}")
continue
else:
missed += 1
missed_names.append(weight_name)
continue
print(f"missed {missed} parameters: {missed_names}")
# NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding
text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896]
sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896]
speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896]
final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896]
param = model.get_parameter("model.embed_tokens.weight")
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, final_embedding_weight)
def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None):
if model.model_type == "speech_llm":
load_speech_llm(model, path, hf_config)
elif model.model_type == "text_llm":
load_text_llm(model, path)
else:
raise ValueError(f"Unsupported model type: {model.model_type}")