Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import CrossEntropyLoss, CTCLoss | |
| from transformers import AutoConfig, AutoModelForCausalLM, \ | |
| LlamaConfig, LlamaModel, LlamaForCausalLM | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from transformers import ( | |
| WhisperProcessor, | |
| WhisperModel, | |
| ) | |
| from T2ULlama_CR_online import T2ULlamaForCausalLM | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| class ACLlamaConfig(LlamaConfig): | |
| model_type = "ACLlama" | |
| def load_whisper(audio_tower_name, device="cuda"): | |
| model = WhisperModel.from_pretrained( | |
| "openai/whisper-large-v3",torch_dtype=torch.float16,low_cpu_mem_usage=True).to(device) | |
| model.config.forced_decoder_ids = None | |
| return model | |
| class LookBackModule(nn.Module): | |
| def __init__(self, cfg: LlamaConfig): | |
| super().__init__() | |
| self.encoder_attn = nn.MultiheadAttention( | |
| cfg.hidden_size, | |
| cfg.num_attention_heads, | |
| dropout=0.1, | |
| batch_first=True | |
| ) | |
| self.atten_layer_norm = nn.LayerNorm(cfg.hidden_size) | |
| def forward(self, x, wav_feature, bf_shrink_padding_mask): | |
| residual = x | |
| x, _ = self.encoder_attn( | |
| query=x, | |
| key=wav_feature, | |
| value=wav_feature, | |
| key_padding_mask=bf_shrink_padding_mask, | |
| #attn_mask=padding_mask, | |
| ) | |
| x += residual | |
| x = self.atten_layer_norm(x) | |
| return x | |
| class ACLlamaModel(LlamaModel): | |
| config_class = ACLlamaConfig | |
| def __init__(self, config: LlamaConfig): | |
| super(ACLlamaModel, self).__init__(config) | |
| if hasattr(config, "audio_tower"): | |
| self.audio_tower = [load_whisper(config.audio_tower)] | |
| if hasattr(config, "adapter_size"): | |
| self.mm_projector1 = nn.Linear(config.adapter_size*2 , config.hidden_size) | |
| asr_encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=config.hidden_size, | |
| nhead=config.num_attention_heads, | |
| dim_feedforward=config.hidden_size*2, | |
| dropout=0.1, | |
| norm_first=True | |
| ) | |
| self.lbm = LookBackModule(config) | |
| self.out_norm = nn.LayerNorm(config.hidden_size) | |
| self.audio_feature_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.asr_transformer_encoder = nn.TransformerEncoder(asr_encoder_layer, num_layers=1) | |
| self.mask_tensor=(torch.ones([1, 2048])>0) | |
| self.length=-1 | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| audios: Optional[torch.FloatTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| # HACK: replace back original embeddings for LLaAA pretraining | |
| orig_embeds_params = getattr(self, 'orig_embeds_params', None) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| audio_tower = getattr(self, 'audio_tower', None) | |
| if audio_tower is not None and (input_ids.shape[1] != 1 or self.training) and audios is not None: | |
| audio_tower = audio_tower[0] # HACK: for FSDP | |
| audio_list=[] | |
| audio_config = audio_tower.config | |
| for audio in audios: | |
| with torch.no_grad(): | |
| audio_feature = audio_tower.encoder(audio).last_hidden_state | |
| audio_feature = audio_feature.view(audio_feature.shape[0], audio_feature.shape[1]//2, 2 * audio_feature.shape[2]) | |
| audio_feature = self.mm_projector1(audio_feature) | |
| audio_feature = self.asr_transformer_encoder(audio_feature) | |
| audio_feature = self.out_norm(audio_feature) | |
| audio_list.append(audio_feature) | |
| audio_features = torch.stack(audio_list, dim=0) | |
| batch = audio_features.shape[0] | |
| audio_turn = audio_features.shape[1] | |
| audio_features = audio_features.view((batch * audio_turn,)+audio_features.shape[2:]) | |
| predict_logits = self.audio_feature_head(audio_features) | |
| new_input_embeds = [] | |
| label_shift = [] | |
| speech_pos = [] | |
| label_extend = -1 | |
| new_input_ids = [] | |
| tokens = predict_logits.argmax(dim=-1) | |
| shrink_mask = tokens.roll(1) != tokens | |
| shrink_mask[:,0] = True | |
| lengths = shrink_mask.long().sum(-1) | |
| shrink_2d = audio_features[shrink_mask] | |
| #num_patches = audio_features.shape[1] | |
| num_patches = audio_config.audio_patch_size | |
| l_index=0 | |
| shrink_features_raw = [] | |
| for v, audio_feature, mask in zip(lengths, audio_features, ~shrink_mask): | |
| shrink_feature = shrink_2d[l_index:l_index+v] | |
| shrink_feature = self.lbm(shrink_feature, audio_feature, bf_shrink_padding_mask=mask) | |
| shrink_features_raw.append(shrink_feature) | |
| l_index += v | |
| shrink_features = [] | |
| for i in range(0, len(shrink_features_raw), audio_turn): | |
| shrink_features.append(shrink_features_raw[i:i+audio_turn]) | |
| if self.training: | |
| maxn_length = lengths.view(batch,audio_turn).sum(-1).max() | |
| label_extend = maxn_length - num_patches * audio_turn | |
| old_seq_length = inputs_embeds.shape[1] | |
| for cur_input_ids, cur_input_embeds, cur_shrink_features in zip(input_ids, inputs_embeds, shrink_features): | |
| pad_ids = torch.full(size=(maxn_length,), fill_value=audio_config.llm_pad_token_id, dtype=torch.long).to(attention_mask.device) | |
| pad_embeds = self.embed_tokens(pad_ids) | |
| audio_start_token_pos_all = torch.where(cur_input_ids == audio_config.audio_patch_token)[0] | |
| #print(cur_input_embeds.shape,cur_input_ids.shape) | |
| inner_label_shift = [] | |
| inner_speech_pos = [] | |
| for audio_start_token_pos, shrink_feature in reversed(list(zip(audio_start_token_pos_all, cur_shrink_features))): #zip(audio_start_token_pos_all, cur_shrink_features): | |
| cur_speech_length = shrink_feature.shape[0] | |
| cur_input_ids = torch.cat((cur_input_ids[:audio_start_token_pos], | |
| cur_input_ids[audio_start_token_pos: audio_start_token_pos+1].repeat(cur_speech_length), | |
| cur_input_ids[audio_start_token_pos + num_patches:]), dim=0) | |
| cur_input_embeds = torch.cat(( | |
| cur_input_embeds[:audio_start_token_pos], | |
| shrink_feature, | |
| cur_input_embeds[audio_start_token_pos + num_patches:]), dim=0) | |
| inner_label_shift.insert(0, cur_speech_length - num_patches) | |
| inner_speech_pos.insert(0, audio_start_token_pos) | |
| label_shift = label_shift + inner_label_shift | |
| speech_pos = speech_pos + inner_speech_pos | |
| cur_new_input_embeds = torch.cat((cur_input_embeds, pad_embeds[:old_seq_length + label_extend - cur_input_embeds.shape[0]]),dim=0) | |
| cur_new_input_ids = torch.cat((cur_input_ids, pad_ids[:old_seq_length + label_extend - cur_input_ids.shape[0]]),dim=0) | |
| new_input_embeds.append(cur_new_input_embeds) | |
| new_input_ids.append(cur_new_input_ids) | |
| input_ids = torch.stack(new_input_ids, dim=0) | |
| attention_mask=input_ids.ne(audio_config.llm_pad_token_id) | |
| inputs_embeds = torch.stack(new_input_embeds, dim=0) | |
| batch_label_shift = [] | |
| batch_speech_pos=[] | |
| for i in range(0, len(label_shift), audio_turn): | |
| batch_label_shift.append(label_shift[i:i+audio_turn]) | |
| batch_speech_pos.append(speech_pos[i:i+audio_turn]) | |
| else: | |
| # Inference mode with batch_size=1 | |
| assert input_ids.shape[0] == 1, "This implementation only supports batch_size=1 during inference" | |
| # Get all audio token positions in this sample | |
| audio_start_token_positions = torch.where(input_ids[0] == audio_config.audio_patch_token)[0] | |
| # Initialize with original embeddings | |
| current_embeds = inputs_embeds[0] # [seq_len, embed_dim] | |
| current_ids = input_ids[0] # [seq_len] | |
| # Process each audio token position sequentially | |
| position_shift = 0 # Track position changes due to expansions | |
| # Ensure shrink_features is properly formatted | |
| if isinstance(shrink_features[0], list): | |
| # If it's a list of lists (batch_size=1 but multiple turns), flatten it | |
| shrink_features = [item for sublist in shrink_features for item in sublist] | |
| for pos_idx, audio_pos in enumerate(audio_start_token_positions): | |
| adjusted_pos = audio_pos + position_shift | |
| # Get corresponding shrink feature (ensure it's a tensor) | |
| shrink_feature = shrink_features[pos_idx] | |
| if isinstance(shrink_feature, list): | |
| shrink_feature = torch.stack(shrink_feature, dim=0) | |
| v = shrink_feature.shape[0] # Now this should work | |
| # print('len: ', v) | |
| # Expand the input ids and embeddings | |
| current_ids = torch.cat([ | |
| current_ids[:adjusted_pos], | |
| current_ids[adjusted_pos:adjusted_pos+1].repeat(v), | |
| current_ids[adjusted_pos + num_patches:] | |
| ], dim=0) | |
| current_embeds = torch.cat([ | |
| current_embeds[:adjusted_pos], | |
| shrink_feature, | |
| current_embeds[adjusted_pos + num_patches:] | |
| ], dim=0) | |
| # Update position shift for next iteration | |
| position_shift += (v - num_patches) | |
| # Update the tensors (unsqueeze to restore batch dim) | |
| input_ids = current_ids.unsqueeze(0) # [1, new_seq_len] | |
| inputs_embeds = current_embeds.unsqueeze(0) # [1, new_seq_len, embed_dim] | |
| attention_mask = input_ids.ne(audio_config.llm_pad_token_id) | |
| # Update inference state tracking | |
| if not hasattr(self, 'mask_tensor'): | |
| # Initialize with current attention mask | |
| self.mask_tensor = attention_mask.clone() | |
| self.length = attention_mask.shape[1] | |
| else: | |
| # Ensure mask tensor is on correct device | |
| self.mask_tensor = self.mask_tensor.to(attention_mask.device) | |
| # Expand mask tensor if needed | |
| if self.mask_tensor.shape[1] < attention_mask.shape[1]: | |
| new_mask = torch.zeros(1, attention_mask.shape[1], | |
| dtype=torch.bool, | |
| device=attention_mask.device) | |
| new_mask[0, :self.mask_tensor.shape[1]] = self.mask_tensor | |
| self.mask_tensor = new_mask | |
| # Update mask tensor | |
| self.mask_tensor[0, :attention_mask.shape[1]] = attention_mask[0] | |
| self.length = attention_mask.shape[1] | |
| attention_mask=self.mask_tensor[:,:self.length] | |
| self.length+=1 | |
| return_state=super(ACLlamaModel, self).forward( | |
| input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, use_cache=use_cache, | |
| output_attentions=output_attentions, output_hidden_states=output_hidden_states, | |
| return_dict=return_dict | |
| ) | |
| if self.training and audios is not None: | |
| return_state["audio_features"] = predict_logits | |
| return_state["label_shift"] = batch_label_shift | |
| return_state["label_extend"] = label_extend | |
| return_state["speech_pos"] = batch_speech_pos | |
| #return_state = {"audio_features":predict_logits} | |
| return return_state | |
| class ACLlamaForCausalLM(LlamaForCausalLM): | |
| config_class = ACLlamaConfig | |
| def __init__(self, config): | |
| super(LlamaForCausalLM, self).__init__(config) | |
| self.model = ACLlamaModel(config) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # t2u by kkq | |
| if hasattr(config, "unit_output"): | |
| self.unit_output = config.unit_output | |
| self.unit_translator = T2ULlamaForCausalLM(config, self.lm_head.weight) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_model(self): | |
| return self.model | |
| def get_unit_translator(self): | |
| return self.unit_translator | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| t2u_input_ids: Optional[torch.LongTensor] = None, | |
| t2u_labels: Optional[torch.LongTensor] = None, | |
| t2u_attention_mask: Optional[torch.Tensor] = None, | |
| unit_targets: Optional[torch.Tensor] = None, | |
| sub_lengths: Optional[torch.Tensor] = None, | |
| asr_targets: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| audios: Optional[torch.FloatTensor] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| do_task: str = None, | |
| assistant_after_audio_shifts: Optional[torch.Tensor] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| # t2u by kkq | |
| # pretrain(t2u only) finetune(s2t&e2u) | |
| do_task = do_task if do_task != None else getattr(self, 'unit_output', None) | |
| outputs = None | |
| hidden_states = None | |
| new_shift_labels = None | |
| if do_task != "pretrain": | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| audios=audios | |
| ) | |
| hidden_states = outputs[0] | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None and do_task != "pretrain" and do_task != "finetune_kd": | |
| if asr_targets is not None: | |
| asr_logits = outputs["audio_features"] | |
| asr_targets = asr_targets.view(asr_targets.shape[0] * asr_targets.shape[1], asr_targets.shape[2]) | |
| mask_asr_targets = (asr_targets != IGNORE_TOKEN_ID) | |
| target_lengths = mask_asr_targets.sum(1) | |
| input_lengths = torch.full(size=(asr_logits.shape[0],), fill_value=asr_logits.shape[1], dtype=torch.long) | |
| loss_ctc = CTCLoss() | |
| log_probs = F.log_softmax(asr_logits, dim=-1).transpose(0, 1) | |
| #print(asr_targets.shape) | |
| #print(input_lengths, target_lengths) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss_asr = F.ctc_loss( | |
| log_probs, | |
| asr_targets, | |
| input_lengths, | |
| target_lengths, | |
| blank=self.model.audio_tower[0].config.audio_patch_token, | |
| reduction='mean', | |
| zero_infinity=True, | |
| ) | |
| else: | |
| loss_asr=0 | |
| shift_labels = labels | |
| if "label_shift" in outputs.keys() and len(outputs["label_shift"]) >0: | |
| if outputs["label_extend"] != -1: | |
| new_shift_labels = torch.full(size=(shift_labels.shape[0], outputs["label_extend"]+shift_labels.shape[1]), fill_value=IGNORE_TOKEN_ID, dtype=torch.long).to(shift_labels.device) | |
| for batch in range(len(outputs["label_shift"])): | |
| it_lable_shift = outputs["label_shift"][batch] | |
| it_speech_pos = outputs["speech_pos"][batch] | |
| prefix = 0 | |
| for i in range(len(it_lable_shift)): | |
| if i == len(it_lable_shift) - 1: | |
| length = shift_labels.shape[1] - it_speech_pos[i] #len(shift_labels[batch]) - it_speech_pos[i] | |
| else: | |
| length = it_speech_pos[i + 1] - it_speech_pos[i] | |
| prefix += it_lable_shift[i] | |
| new_shift_labels[batch][it_speech_pos[i] + prefix: it_speech_pos[i] + length + prefix]= shift_labels[batch][it_speech_pos[i]:it_speech_pos[i]+length] | |
| shift_labels = new_shift_labels | |
| else: | |
| raise NotImplementedError | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = shift_labels[..., 1:].contiguous() | |
| #print(shift_labels[:,:50]) | |
| #print(shift_labels[:,:150]) | |
| loss_fct = CrossEntropyLoss() | |
| # Flatten the tokens | |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| loss = loss + 0.3 * loss_asr | |
| t2u_output = None | |
| if do_task != None and do_task != "skip": | |
| if do_task == "finetune_kd": | |
| text_start_index = [] | |
| for batch in range(len(outputs["label_shift"])): | |
| text_start_index.append(outputs["speech_pos"][batch][0] + outputs["label_shift"][batch][0]+assistant_after_audio_shifts[batch]) | |
| t2u_embeds_output = self.unit_translator.insert_text_embedding( | |
| input_ids=t2u_input_ids, | |
| attention_mask=t2u_attention_mask, | |
| inputs_embeds=None, | |
| labels=t2u_labels, | |
| text_labels=labels, | |
| shift_text_labels=new_shift_labels, | |
| shift_text_hidden_states=hidden_states, | |
| unit_targets=unit_targets, | |
| sub_lengths=sub_lengths, | |
| text_start_index=text_start_index, | |
| do_task=do_task, | |
| ) | |
| vae_loss, t2u_inputs_embeds, unit_targets, t2u_attention_mask = t2u_embeds_output | |
| t2u_output = self.unit_translator( | |
| input_ids=None, | |
| attention_mask=t2u_attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=t2u_inputs_embeds, | |
| use_cache=use_cache, | |
| labels=unit_targets, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| else: | |
| t2u_embeds_output = self.unit_translator.insert_text_embedding( | |
| input_ids=t2u_input_ids, | |
| attention_mask=t2u_attention_mask, | |
| inputs_embeds=None, | |
| labels=t2u_labels, | |
| text_labels=labels, | |
| shift_text_labels=new_shift_labels, | |
| shift_text_hidden_states=hidden_states, | |
| do_task=do_task, | |
| ) | |
| vae_loss, t2u_inputs_embeds = t2u_embeds_output | |
| t2u_output = self.unit_translator( | |
| input_ids=None, | |
| attention_mask=t2u_attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=t2u_inputs_embeds, | |
| use_cache=use_cache, | |
| labels=t2u_labels, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| t2u_loss = t2u_output[0] | |
| # print(do_task, t2u_loss, vae_loss) | |
| if vae_loss != None: | |
| target_scale = t2u_loss.item() * 0.2 | |
| vae_loss_weight = target_scale / vae_loss.item() if vae_loss > target_scale else 1.0 | |
| t2u_loss = t2u_loss + vae_loss_weight * vae_loss | |
| #print(vae_loss) | |
| if loss != None: # S2T + T2U loss | |
| # ignore LLM loss | |
| # t2u_output["loss"] = t2u_loss | |
| # return t2u_output | |
| # original version | |
| assert do_task in ["finetune"] | |
| if loss.item() < 1.0: # 1.7 | |
| loss = 0.2 * loss + t2u_loss * 2.0 | |
| else: | |
| loss = loss + t2u_loss | |
| else: | |
| assert do_task in ["pretrain", "finetune_kd"] | |
| t2u_output["loss"] = t2u_loss | |
| return t2u_output | |
| #return CausalLMOutputWithPast( | |
| # loss=loss, | |
| # logits=outputs["audio_features"], | |
| #) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| cache_position=None, | |
| position_ids=None, | |
| use_cache=True, | |
| **kwargs, | |
| ): | |
| # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
| # Exception 1: when passing input_embeds, input_ids may be missing entries | |
| # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
| if past_key_values is not None: | |
| if inputs_embeds is not None: # Exception 1 | |
| input_ids = input_ids[:, -cache_position.shape[0] :] | |
| elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
| input_ids = input_ids[:, cache_position] | |
| if attention_mask is not None and position_ids is None: | |
| # create position_ids on the fly for batch generation | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| if past_key_values: | |
| position_ids = position_ids[:, -input_ids.shape[1] :] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and cache_position[0] == 0: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases | |
| model_inputs.update( | |
| { | |
| "position_ids": position_ids, | |
| "cache_position": cache_position, | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| "attention_mask": attention_mask, | |
| } | |
| ) | |
| model_inputs.update({"audios": kwargs["audios"]} if "audios" in kwargs.keys() else {}) | |
| model_inputs.update({"do_task": kwargs["do_task"]} if "do_task" in kwargs.keys() else {}) | |
| model_inputs.update({"return_dict": kwargs["return_dict_in_generate"]} if "return_dict_in_generate" in kwargs.keys() else {}) | |
| return model_inputs | |
| AutoConfig.register("ACLlama", ACLlamaConfig) | |
| AutoModelForCausalLM.register(ACLlamaConfig, ACLlamaForCausalLM) |