Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss, CTCLoss | |
| import transformers | |
| from transformers import AutoConfig, AutoModelForCausalLM, \ | |
| LlamaConfig, LlamaModel, LlamaForCausalLM | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from transformers import ( | |
| WhisperProcessor, | |
| WhisperModel, | |
| ) | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| def padding_tensor(tensor, length, dim=0, pad=False): | |
| if length == 0: | |
| return tensor | |
| assert length > 0, f"Wrong padding length: {length}" | |
| shape = list(tensor.shape) | |
| assert dim < len(shape), f"dim {dim} out of shape {shape}" | |
| shape[dim] = length | |
| padding_tensor = torch.cat( | |
| ( | |
| tensor, | |
| torch.full(tuple(shape), pad, dtype=tensor.dtype, device=tensor.device) | |
| ), | |
| dim=dim | |
| ) | |
| return padding_tensor | |
| class T2ULlamaConfig(LlamaConfig): | |
| model_type = "T2ULlama" | |
| class T2ULlamaForCausalLM(LlamaForCausalLM): | |
| config_class = T2ULlamaConfig | |
| def __init__(self, config, embedding_weight=None): | |
| self.current_step = 0 | |
| self.log = {} | |
| super(LlamaForCausalLM, self).__init__(config) | |
| self.config = config | |
| self.training_stage = config.unit_output | |
| self.pad_token_id = 128009 | |
| llama_config = T2ULlamaConfig(**config.to_dict(), | |
| batch_first=True, | |
| norm_first=True | |
| ) | |
| llama_config.architectures = ["T2ULlamaForCausalLM"] | |
| llama_config.pad_token_id = self.pad_token_id | |
| llama_config.vocab_size += llama_config.unit_vocab_size | |
| ####################################################### | |
| llama_config.unit_model = "medium" | |
| llama_config.max_position_embeddings = 2048 # 1024 1536 2048 # origin 1024 reduced 512 | |
| ####################################################### | |
| if hasattr(llama_config, "unit_model"): | |
| if llama_config.unit_model == "large": | |
| llama_config.num_hidden_layers = 2 | |
| # llama_config.hidden_size = 4096 | |
| # llama_config.num_attention_heads = 32 | |
| # llama_config.intermediate_size = 14336 | |
| # llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads | |
| elif llama_config.unit_model == "tiny": | |
| llama_config.num_hidden_layers = 4 | |
| llama_config.hidden_size = 512 | |
| llama_config.num_attention_heads = 8 | |
| llama_config.intermediate_size = 2048 | |
| llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads | |
| else: | |
| llama_config.num_hidden_layers = 8 | |
| llama_config.hidden_size = 768 | |
| llama_config.num_attention_heads = 12 | |
| llama_config.num_key_value_heads = 12 | |
| llama_config.intermediate_size = 2048 | |
| llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads | |
| else: | |
| llama_config.num_hidden_layers = 6 | |
| llama_config.hidden_size = 512 | |
| llama_config.num_attention_heads = 8 | |
| llama_config.intermediate_size = 2048 | |
| llama_config.head_dim = llama_config.hidden_size // llama_config.num_attention_heads | |
| # print(llama_config) | |
| self.model = LlamaModel(llama_config) | |
| # share embedding 0501 by kkq | |
| self.model.embed_tokens = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.pad_token_id) # redefine | |
| self.unit_embedding = nn.Linear(config.hidden_size, llama_config.unit_vocab_size, bias=False) | |
| self.adapter = nn.Linear(config.hidden_size, llama_config.hidden_size, bias = True) | |
| self.lm_head = nn.Linear(llama_config.hidden_size, llama_config.vocab_size, bias=False) | |
| if self.training_stage == "pretrain": | |
| pass | |
| elif self.training_stage == "finetune" or self.training_stage == "finetune_kd" or self.training_stage == "finetune_kd_online": | |
| self.aligner_MLP = nn.Sequential( | |
| nn.Linear(config.hidden_size, config.intermediate_size), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(config.intermediate_size, config.hidden_size), | |
| ) | |
| torch.nn.init.ones_(self.aligner_MLP[0].weight) | |
| torch.nn.init.zeros_(self.aligner_MLP[0].bias) | |
| torch.nn.init.ones_(self.aligner_MLP[3].weight) | |
| torch.nn.init.zeros_(self.aligner_MLP[3].bias) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_model(self): | |
| return self.model | |
| def insert_text_embedding( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| text_labels: Optional[torch.LongTensor] = None, | |
| shift_text_labels: Optional[torch.LongTensor] = None, | |
| shift_text_hidden_states: Optional[torch.FloatTensor] = None, | |
| unit_targets: Optional[torch.LongTensor] = None, | |
| sub_lengths: Optional[torch.LongTensor] = None, | |
| text_start_index: Optional[torch.LongTensor] = None, | |
| do_task: str = None, | |
| **kwargs: dict, | |
| ): | |
| if inputs_embeds == None: | |
| # share embedding 0501 by kkq | |
| embed_tokens_weight = torch.cat( | |
| [ | |
| self.model.embed_tokens.weight.detach(), self.unit_embedding.weight | |
| ], | |
| dim = 0, | |
| ) | |
| # print(embed_tokens_weight, embed_tokens_weight.shape) | |
| inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| emb_loss = None | |
| if do_task == "pretrain": | |
| if self.training: | |
| if hasattr(self, "embedding_dropout"): | |
| emb_origin_mask = text_labels != -100 | |
| origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1] | |
| extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False) | |
| extend_emb_origin_mask = ~extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds) | |
| # Π-Model + noise | |
| log_var = self.perturb(inputs_embeds) | |
| perturbed_inputs_embeds_2 = inputs_embeds + torch.randn_like(inputs_embeds) * (torch.exp(0.5 * log_var) + 1e-6) | |
| # Π-Model + dropout | |
| perturbed_inputs_embeds_1 = self.embedding_dropout(inputs_embeds) | |
| perturbed_inputs_embeds_2 = self.embedding_dropout(perturbed_inputs_embeds_2) | |
| perturbed_inputs_embeds_1 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_1) | |
| perturbed_inputs_embeds_2 = torch.where(extend_emb_origin_mask, inputs_embeds, perturbed_inputs_embeds_2) | |
| inputs_embeds = torch.cat( | |
| (perturbed_inputs_embeds_1, perturbed_inputs_embeds_2), | |
| dim=0, | |
| ) | |
| kl_loss = -0.5 * (1 + log_var - log_var.exp()).mean(dim=-1).sum(dim=-1).mean() | |
| contrastive_loss = (1 - F.cosine_similarity(perturbed_inputs_embeds_1, perturbed_inputs_embeds_2, dim=-1)).sum(dim=-1).mean() | |
| emb_loss = kl_loss + contrastive_loss | |
| if kl_loss.device == torch.device("cuda:0"): | |
| self.log["kl_loss"] = kl_loss.item() | |
| self.log["std"] = torch.exp(0.5 * log_var).mean().item() | |
| self.log["contrastive_loss"] = contrastive_loss.item() | |
| pass | |
| elif do_task == "finetune": | |
| inputs_embeds = inputs_embeds.detach() | |
| inputs_embeds_refer = inputs_embeds.clone().detach() | |
| shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) | |
| emb_origin_mask = text_labels != -100 # get output text pos | |
| emb_shift_mask = shift_text_labels != -100 | |
| origin_padding_length = labels.shape[-1] - emb_origin_mask.shape[-1] | |
| shift_padding_length = labels.shape[-1] - emb_shift_mask.shape[-1] | |
| extend_emb_origin_mask = padding_tensor(emb_origin_mask, origin_padding_length, 1, False) | |
| extend_emb_shift_mask = padding_tensor(emb_shift_mask, shift_padding_length, 1, False) | |
| extend_shift_text_hidden_states = padding_tensor(shift_text_hidden_states, shift_padding_length, 1, 1e-9) | |
| # check | |
| extend_text_labels = padding_tensor(text_labels, origin_padding_length, 1, -100) | |
| extend_shift_text_labels = padding_tensor(shift_text_labels, shift_padding_length, 1, -100) | |
| assert torch.equal( | |
| extend_text_labels[extend_emb_origin_mask], | |
| extend_shift_text_labels[extend_emb_shift_mask] | |
| ), "{}\n{}\n{}\n{}".format(labels, extend_emb_origin_mask, extend_shift_text_labels, extend_emb_shift_mask) | |
| inputs_embeds[extend_emb_origin_mask.unsqueeze(-1).expand_as(inputs_embeds)] = \ | |
| extend_shift_text_hidden_states[extend_emb_shift_mask.unsqueeze(-1).expand_as(extend_shift_text_hidden_states)].to(dtype=inputs_embeds.dtype) | |
| if self.training: | |
| contrastive_loss = (1 - F.cosine_similarity(inputs_embeds, inputs_embeds_refer, dim=-1)).sum(-1).mean() | |
| emb_loss = contrastive_loss | |
| if emb_loss.device == torch.device("cuda:0"): | |
| self.log["contrastive_loss"] = contrastive_loss.item() | |
| pass | |
| elif do_task == "finetune_kd" : | |
| #inputs_embeds = inputs_embeds.detach() | |
| #inputs_embeds_refer = inputs_embeds.clone().detach() | |
| #print(text_labels) | |
| #print(sub_lengths.sum()) | |
| emb_origin_mask = text_labels != -100 | |
| fetch_lables_list = [] | |
| for batch in range(emb_origin_mask.shape[0]): | |
| fetch_lables_list.append(text_labels[batch][emb_origin_mask[batch]]) | |
| shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) | |
| #split the shift_text_hidden_states | |
| #[128006, 128000, 78191, 128007, 128000, 198, 128000] | |
| maxn_length = sub_lengths.max() + 8 | |
| pad_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device) | |
| pad_text_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=self.pad_token_id, dtype=torch.long).to(shift_text_hidden_states.device) | |
| atten_mask = pad_ids.ne(self.pad_token_id) | |
| #target_mask_part1 = pad_ids.ne(self.pad_token_id) | |
| shift_text_hidden_states_slice = F.embedding(pad_ids, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| #print(shift_text_hidden_states_slice.shape,shift_text_hidden_states.shape) | |
| for batch in range(sub_lengths.shape[0]): | |
| cot=0 | |
| start_index = text_start_index[batch] | |
| for index, sub_length in enumerate(sub_lengths[batch]): | |
| if sub_length==-1: | |
| break | |
| #print(shift_text_hidden_states_slice[batch][index][:sub_length].shape, shift_text_hidden_states[batch][cot:cot+sub_length].shape) | |
| eos_id = torch.IntTensor([128009]).to(inputs_embeds.device) | |
| eos = self.model.embed_tokens(eos_id) | |
| if index == 0: | |
| text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198]).to(inputs_embeds.device) | |
| preifx_embed = self.model.embed_tokens(text_prefix_ids) | |
| pad_text_ids[batch][index][:sub_length+7] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id],dim=0) | |
| atten_mask[batch][index][:sub_length+7]=True | |
| else: | |
| text_prefix_ids = torch.IntTensor([128006, 128000, 65576, 128007, 128000, 198, 12800]).to(inputs_embeds.device) | |
| preifx_embed = self.model.embed_tokens(text_prefix_ids) | |
| pad_text_ids[batch][index][:sub_length+8] = torch.cat([text_prefix_ids, fetch_lables_list[batch][cot:cot+sub_length], eos_id], dim=0) | |
| atten_mask[batch][index][:sub_length+8]=True | |
| new_shift_text_hidden_states = torch.cat([preifx_embed, shift_text_hidden_states[batch][start_index+cot:start_index+cot+sub_length], eos], dim = 0) | |
| shift_text_hidden_states_slice[batch][index][:new_shift_text_hidden_states.shape[0]] = new_shift_text_hidden_states | |
| cot+=sub_length | |
| shift_text_hidden_states_slice = shift_text_hidden_states_slice.reshape(shift_text_hidden_states_slice.shape[0]*shift_text_hidden_states_slice.shape[1],shift_text_hidden_states_slice.shape[2],shift_text_hidden_states_slice.shape[3]) | |
| padding_unit_targets = unit_targets.clone() | |
| padding_unit_targets = torch.where(padding_unit_targets == IGNORE_TOKEN_ID, self.pad_token_id, padding_unit_targets) | |
| target_mask_part = padding_unit_targets.ne(self.pad_token_id) | |
| atten_mask = torch.cat([atten_mask, target_mask_part], dim = -1) | |
| atten_mask = atten_mask.reshape(atten_mask.shape[0]*atten_mask.shape[1],atten_mask.shape[2]) | |
| pad_text_ids = pad_text_ids.reshape(pad_text_ids.shape[0]*pad_text_ids.shape[1],pad_text_ids.shape[2]) | |
| shift_text_embeddings = F.embedding(pad_text_ids, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| unit_target_slice = F.embedding(padding_unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| # unit_target_slice = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| unit_target_slice = unit_target_slice.reshape(unit_target_slice.shape[0]*unit_target_slice.shape[1],unit_target_slice.shape[2],unit_target_slice.shape[3]) | |
| inputs_embeds = torch.cat([shift_text_hidden_states_slice, unit_target_slice], dim = 1) | |
| ignore_ids = torch.full(size=(sub_lengths.shape[0], sub_lengths.shape[1], maxn_length), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_text_hidden_states.device) | |
| unit_targets = torch.cat([ignore_ids,unit_targets],dim=-1) | |
| unit_targets = unit_targets.reshape(unit_targets.shape[0]*unit_targets.shape[1],unit_targets.shape[2]) | |
| if self.training: | |
| #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape) | |
| contrastive_loss = (1 - F.cosine_similarity(shift_text_hidden_states_slice, shift_text_embeddings, dim=-1)).sum(-1).mean() | |
| emb_loss = contrastive_loss | |
| if emb_loss.device == torch.device("cuda:0"): | |
| self.log["contrastive_loss"] = contrastive_loss.item() | |
| elif do_task == "finetune_kd_online": | |
| shift_text_hidden_states = self.aligner_MLP(shift_text_hidden_states) | |
| gold_inputs_embeds = inputs_embeds.clone() | |
| for batch in range (inputs_embeds.shape[0]): | |
| start_index = text_start_index[batch] | |
| for slice_index in range (inputs_embeds.shape[1]): | |
| sub_length= sub_lengths[batch][slice_index] | |
| inputs_embeds[batch][slice_index][7:7+sub_length] = shift_text_hidden_states[batch][start_index+1:start_index+1+sub_length] | |
| start_index += sub_length | |
| if self.training: | |
| #print(shift_text_hidden_states_slice.shape, shift_text_embeddings.shape) | |
| contrastive_loss = ((1 - F.cosine_similarity(inputs_embeds, gold_inputs_embeds, dim=-1)) * attention_mask).sum(-1).mean() | |
| emb_loss = contrastive_loss | |
| if emb_loss.device == torch.device("cuda:0"): | |
| self.log["contrastive_loss"] = contrastive_loss.item() | |
| unit_embeds = F.embedding(unit_targets, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| inputs_embeds = torch.cat([inputs_embeds,unit_embeds], dim=2) | |
| else: | |
| inputs_embeds = self.aligner_MLP(inputs_embeds) | |
| #[start_header_id] + _speaker + [end_header_id] + nl_tokens only for batch one! | |
| units_ids = torch.IntTensor([[128009, 128006, 128000, 65576, 128007, 128000, 198]]).to(inputs_embeds.device) | |
| units_prefix = self.model.embed_tokens(units_ids) | |
| text_ids = torch.IntTensor([[128006, 128000, 65576, 128007, 128000, 198, 12800]]).to(inputs_embeds.device) | |
| text_prefix = self.model.embed_tokens(text_ids) | |
| inputs_embeds = torch.cat([text_prefix, inputs_embeds, units_prefix], dim = 1) | |
| inputs_embeds = self.adapter(inputs_embeds) | |
| if do_task == "finetune_kd": | |
| return (emb_loss, inputs_embeds, unit_targets, atten_mask,) | |
| else: | |
| return (emb_loss, inputs_embeds) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if inputs_embeds == None: | |
| # inputs_embeds = self.model.embed_tokens(input_ids) | |
| # share embedding 0501 by kkq | |
| embed_tokens_weight = torch.cat( | |
| [ | |
| self.model.embed_tokens.weight.detach(), self.unit_embedding.weight | |
| ], | |
| dim = 0, | |
| ) | |
| # print(embed_tokens_weight, embed_tokens_weight.shape) | |
| inputs_embeds = F.embedding(input_ids, embed_tokens_weight, padding_idx=self.pad_token_id) | |
| inputs_embeds = self.adapter(inputs_embeds) | |
| outputs = self.model( | |
| input_ids=None, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| cr_loss = None | |
| if labels != None: | |
| shift_labels = labels | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = shift_labels[..., 1:].contiguous() | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, (self.config.vocab_size + self.config.unit_vocab_size)) | |
| shift_labels = shift_labels.view(-1) | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if loss.device == torch.device("cuda:0"): | |
| self.log["unit_loss"] = loss.item() | |
| if cr_loss != None: | |
| target_scale = loss.item() * 0.2 | |
| cr_loss_weight = target_scale / cr_loss.item() if cr_loss > target_scale else 1.0 | |
| loss = loss + cr_loss_weight * cr_loss | |
| if loss.device == torch.device("cuda:0") and (self.current_step - 10) % 100 == 0: | |
| print(self.log, loss.device) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| AutoConfig.register("T2ULlama", T2ULlamaConfig) | |
| AutoModelForCausalLM.register(T2ULlamaConfig, T2ULlamaForCausalLM) |