Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from torch import nn | |
| from transformers import AutoConfig | |
| from flashcosyvoice.config import CosyVoice2LLMConfig | |
| from flashcosyvoice.modules.qwen2_components.layers import ( | |
| ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding) | |
| class Qwen2Model(nn.Module): | |
| def __init__( | |
| self, | |
| config: CosyVoice2LLMConfig, | |
| ): | |
| super().__init__() | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) | |
| self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| ) -> torch.Tensor: | |
| hidden_states = self.embed_tokens(input_ids) | |
| residual = None | |
| for layer in self.layers: | |
| hidden_states, residual = layer( | |
| positions, | |
| hidden_states, | |
| residual, | |
| ) | |
| hidden_states, _ = self.norm(hidden_states, residual) | |
| return hidden_states | |
| class Qwen2ForCausalLM(nn.Module): | |
| packed_modules_mapping = { | |
| "q_proj": ("qkv_proj", "q"), | |
| "k_proj": ("qkv_proj", "k"), | |
| "v_proj": ("qkv_proj", "v"), | |
| "gate_proj": ("gate_up_proj", 0), | |
| "up_proj": ("gate_up_proj", 1), | |
| } | |
| def __init__( | |
| self, | |
| config: CosyVoice2LLMConfig | AutoConfig | |
| ): | |
| super().__init__() | |
| self.model = Qwen2Model(config) | |
| if hasattr(config, "speech_vocab_size"): | |
| self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True)) | |
| self.model_type = "speech_llm" | |
| else: | |
| self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False) | |
| self.model_type = "text_llm" | |
| self.tie_word_embeddings = config.tie_word_embeddings | |
| if self.tie_word_embeddings: | |
| if self.model_type == "speech_llm": | |
| assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True" | |
| self.lm_head.weight.data = self.model.embed_tokens.weight.data | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| ) -> torch.Tensor: | |
| hidden_states = self.model(input_ids, positions) | |
| return hidden_states | |
| def compute_logits( | |
| self, | |
| hidden_states: torch.Tensor, | |
| ) -> torch.Tensor: | |
| logits = self.lm_head(hidden_states) | |
| return logits | |