Spaces:
Sleeping
Sleeping
| import warnings | |
| from typing import Any, List, Optional, Tuple, Union | |
| import torch.utils.checkpoint | |
| import transformers | |
| from torch import nn | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import GenerationConfig | |
| from transformers import Qwen2ForCausalLM | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import ModelOutput, logging | |
| import torch.nn.functional as F | |
| from .configuration_llavaqw import LlavaQwConfig | |
| from .conversation import get_conv_template | |
| from .modeling_intern_vit import InternVisionModel, has_flash_attn | |
| logger = logging.get_logger(__name__) | |
| def version_cmp(v1, v2, op='eq'): | |
| import operator | |
| from packaging import version | |
| op_func = getattr(operator, op) | |
| return op_func(version.parse(v1), version.parse(v2)) | |
| class LlavaQwModel(PreTrainedModel): | |
| config_class = LlavaQwConfig | |
| main_input_name = 'pixel_values' | |
| _supports_flash_attn_2 = True | |
| _no_split_modules = ['InternVisionModel', 'Qwen2DecoderLayer'] | |
| def __init__(self, config: LlavaQwConfig, vision_model=None, language_model=None, use_flash_attn=True): | |
| super().__init__(config) | |
| assert version_cmp(transformers.__version__, '4.44.2', 'ge') | |
| image_size = config.force_image_size or config.vision_config.image_size | |
| patch_size = config.vision_config.patch_size | |
| self.patch_size = patch_size | |
| self.select_layer = config.select_layer | |
| self.llm_arch_name = config.llm_config.architectures[0] | |
| self.template = config.template | |
| self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) | |
| self.downsample_ratio = config.downsample_ratio | |
| self.ps_version = config.ps_version | |
| use_flash_attn = use_flash_attn if has_flash_attn else False | |
| config.vision_config.use_flash_attn = True if use_flash_attn else False | |
| config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' | |
| logger.info(f'num_image_token: {self.num_image_token}') | |
| logger.info(f'ps_version: {self.ps_version}') | |
| if vision_model is not None: | |
| self.vision_model = vision_model | |
| else: | |
| self.vision_model = InternVisionModel(config.vision_config) | |
| if language_model is not None: | |
| self.language_model = language_model | |
| else: | |
| if config.llm_config.architectures[0] == 'Qwen2ForCausalLM': | |
| self.language_model = Qwen2ForCausalLM(config.llm_config) | |
| else: | |
| raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') | |
| vit_hidden_size = config.vision_config.hidden_size | |
| llm_intermediate_size = config.llm_config.intermediate_size | |
| llm_hidden_size = config.llm_config.hidden_size | |
| self.mlp1 = nn.Sequential( | |
| nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), | |
| nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_intermediate_size, bias=False), | |
| nn.GELU(), | |
| nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False) | |
| ) | |
| self.img_context_token_id = 151654 | |
| self.conv_template = get_conv_template(self.template) | |
| self.system_message = self.conv_template.system_message | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| image_flags: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| vit_embeds = self.extract_feature(pixel_values) | |
| vit_batch_size = pixel_values.shape[0] | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.img_context_token_id) | |
| try: | |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) | |
| except Exception as e: | |
| vit_embeds = vit_embeds.reshape(-1, C) | |
| print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' | |
| f'vit_embeds.shape={vit_embeds.shape}') | |
| n_token = selected.sum() | |
| input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| outputs = self.language_model( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| logits = outputs.logits | |
| loss = None | |
| if labels is not None: | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def pixel_shuffle(self, x, scale_factor=0.5): | |
| n, w, h, c = x.size() | |
| # N, W, H, C --> N, W, H * scale, C // scale | |
| x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) | |
| # N, W, H * scale, C // scale --> N, H * scale, W, C // scale | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) | |
| x = x.view(n, int(h * scale_factor), int(w * scale_factor), | |
| int(c / (scale_factor * scale_factor))) | |
| if self.ps_version == 'v1': | |
| warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " | |
| 'which results in a transposed image.') | |
| else: | |
| x = x.permute(0, 2, 1, 3).contiguous() | |
| return x | |
| def extract_feature(self, pixel_values): | |
| if self.select_layer == -1: | |
| vit_embeds = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_hidden_states=False, | |
| return_dict=True).last_hidden_state | |
| else: | |
| vit_embeds = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_hidden_states=True, | |
| return_dict=True).hidden_states[self.select_layer] | |
| vit_embeds = vit_embeds[:, 1:, :] | |
| h = w = int(vit_embeds.shape[1] ** 0.5) | |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) | |
| vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) | |
| vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) | |
| vit_embeds = self.mlp1(vit_embeds) | |
| return vit_embeds | |
| def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, | |
| num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', IMG_END_TOKEN='<|vision_end|>', | |
| IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): | |
| if history is None and pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| if num_patches_list is None: | |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] | |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| template = get_conv_template(self.template) | |
| template.system_message = self.system_message | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| history = [] if history is None else history | |
| for (old_question, old_answer) in history: | |
| template.append_message(template.roles[0], old_question) | |
| template.append_message(template.roles[1], old_answer) | |
| template.append_message(template.roles[0], question) | |
| template.append_message(template.roles[1], None) | |
| query = template.get_prompt() | |
| # print('query:', query) | |
| if verbose and pixel_values is not None: | |
| image_bs = pixel_values.shape[0] | |
| print(f'dynamic ViT batch size: {image_bs}') | |
| for num_patches in num_patches_list: | |
| tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)] + ["<tile_global_thumbnail>"] | |
| image_tokens = '' | |
| for tile_pos_identifier in tile_pos_identifiers: | |
| image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token | |
| image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| model_inputs = tokenizer(query, return_tensors='pt') | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| generation_config['eos_token_id'] = eos_token_id | |
| generation_output = self.generate( | |
| pixel_values=pixel_values, | |
| visual_features=visual_features, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **generation_config | |
| ) | |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] | |
| response = response.split(template.sep)[0].strip() | |
| history.append((question, response)) | |
| if return_history: | |
| return response, history | |
| else: | |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') | |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') | |
| if verbose: | |
| print(query_to_print, response) | |
| return response | |
| def chat_without_sys_prompt(self, tokenizer, pixel_values, question, generation_config, history=None, | |
| return_history=False, | |
| num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', | |
| IMG_END_TOKEN='<|vision_end|>', | |
| IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): | |
| if history is None and pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| if num_patches_list is None: | |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] | |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| template = get_conv_template(self.template) | |
| system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" # override dummy system prompt | |
| template.system_message = system_prompt | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| history = [] if history is None else history | |
| for (old_question, old_answer) in history: | |
| template.append_message(template.roles[0], old_question) | |
| template.append_message(template.roles[1], old_answer) | |
| template.append_message(template.roles[0], question) | |
| template.append_message(template.roles[1], None) | |
| query = template.get_prompt() | |
| if verbose and pixel_values is not None: | |
| image_bs = pixel_values.shape[0] | |
| print(f'dynamic ViT batch size: {image_bs}') | |
| query = query[len(system_prompt):] | |
| for num_patches in num_patches_list: | |
| tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)] + ["<tile_global_thumbnail>"] | |
| image_tokens = '' | |
| for tile_pos_identifier in tile_pos_identifiers: | |
| image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token | |
| image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| model_inputs = tokenizer(query, return_tensors='pt') | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| generation_config['eos_token_id'] = eos_token_id | |
| generation_output = self.generate( | |
| pixel_values=pixel_values, | |
| visual_features=visual_features, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **generation_config | |
| ) | |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] | |
| response = response.split(template.sep)[0].strip() | |
| history.append((question, response)) | |
| if return_history: | |
| return response, history | |
| else: | |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') | |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') | |
| if verbose: | |
| print(query_to_print, response) | |
| return response | |
| def chat_without_chat_prompt(self, tokenizer, pixel_values, question, generation_config, | |
| num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', | |
| IMG_END_TOKEN='<|vision_end|>', | |
| IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): | |
| if pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| if num_patches_list is None: | |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] | |
| assert pixel_values is None or len(pixel_values) == sum(num_patches_list) | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| template = get_conv_template(self.template) | |
| template.system_message = self.system_message | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| if verbose and pixel_values is not None: | |
| image_bs = pixel_values.shape[0] | |
| print(f'dynamic ViT batch size: {image_bs}') | |
| query = question | |
| for num_patches in num_patches_list: | |
| tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)] + ["<tile_global_thumbnail>"] | |
| image_tokens = '' | |
| for tile_pos_identifier in tile_pos_identifiers: | |
| image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token | |
| image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| model_inputs = tokenizer(query, return_tensors='pt') | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| generation_config['eos_token_id'] = eos_token_id | |
| generation_output = self.generate( | |
| pixel_values=pixel_values, | |
| visual_features=visual_features, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **generation_config | |
| ) | |
| response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] | |
| response = response.split(template.sep)[0].strip() | |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') | |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') | |
| if verbose: | |
| print(query_to_print, response) | |
| return response | |
| def generate( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| input_ids: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| visual_features: Optional[torch.FloatTensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| # assert self.img_context_token_id is not None | |
| if pixel_values is not None: | |
| if visual_features is not None: | |
| vit_embeds = visual_features.cuda() | |
| vit_embeds = self.mlp1(vit_embeds) | |
| else: | |
| vit_embeds = self.extract_feature(pixel_values) | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.img_context_token_id) | |
| assert selected.sum() != 0 | |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| else: | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| outputs = self.language_model.generate( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| generation_config=generation_config, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| use_cache=True, | |
| **generate_kwargs, | |
| ) | |
| return outputs | |
| def chat_batch( | |
| self, | |
| tokenizer, | |
| pixel_values_list, | |
| questions, | |
| generation_config, | |
| histories=None, | |
| return_histories=False, | |
| num_patches_lists=None, | |
| IMG_START_TOKEN='<|vision_start|>', | |
| IMG_END_TOKEN='<|vision_end|>', | |
| IMG_CONTEXT_TOKEN='<|vision_pad|>', | |
| verbose=False, | |
| visual_features_list=None | |
| ): | |
| if histories is None: | |
| histories = [[] for _ in questions] | |
| img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_context_token_id = img_context_token_id | |
| # Get eos_token_id from the template | |
| template = get_conv_template(self.template) | |
| template.system_message = self.system_message | |
| eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) | |
| generation_config['eos_token_id'] = eos_token_id | |
| queries = [] | |
| input_ids_list = [] | |
| attention_mask_list = [] | |
| for idx in range(len(questions)): | |
| question = questions[idx] | |
| history = histories[idx] | |
| pixel_values = pixel_values_list[idx] if pixel_values_list[idx] is not None else None | |
| num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] | |
| if not history and pixel_values is not None and '<image>' not in question: | |
| question = '<image>\n' + question | |
| template_i = get_conv_template(self.template) | |
| template_i.system_message = self.system_message | |
| for (old_question, old_answer) in history: | |
| template_i.append_message(template_i.roles[0], old_question) | |
| template_i.append_message(template_i.roles[1], old_answer) | |
| template_i.append_message(template_i.roles[0], question) | |
| template_i.append_message(template_i.roles[1], None) | |
| query = template_i.get_prompt() | |
| # Handle image tokens | |
| if pixel_values is not None: | |
| for num_patches in num_patches_list: | |
| tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)] + ["<tile_global_thumbnail>"] | |
| image_tokens = '' | |
| for tile_pos_identifier in tile_pos_identifiers: | |
| image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token | |
| image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN | |
| query = query.replace('<image>', image_tokens, 1) | |
| model_inputs = tokenizer( | |
| query, | |
| return_tensors='pt', | |
| padding=True, | |
| truncation=True | |
| ) | |
| input_ids = model_inputs['input_ids'].cuda() | |
| attention_mask = model_inputs['attention_mask'].cuda() | |
| input_ids_list.append(input_ids) | |
| attention_mask_list.append(attention_mask) | |
| # Call the generate function | |
| generation_output = self.generate_batch( | |
| pixel_values_list=pixel_values_list, | |
| input_ids_list=input_ids_list, | |
| attention_mask_list=attention_mask_list, | |
| **generation_config | |
| ) | |
| responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) | |
| outputs = [] | |
| for idx, response in enumerate(responses): | |
| response = response.split(template.sep)[0].strip() | |
| histories[idx].append((questions[idx], response)) | |
| outputs.append(response) | |
| if return_histories: | |
| return outputs, histories | |
| else: | |
| if verbose: | |
| for idx, query in enumerate(queries): | |
| query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') | |
| query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>') | |
| print(query_to_print, outputs[idx]) | |
| return outputs | |
| def generate_batch( | |
| self, | |
| pixel_values_list: Optional[List[torch.FloatTensor]] = None, | |
| input_ids_list: Optional[List[torch.FloatTensor]] = None, | |
| attention_mask_list: Optional[List[torch.LongTensor]] = None, | |
| visual_features: Optional[torch.FloatTensor] = None, | |
| generation_config: Optional[GenerationConfig] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **generate_kwargs, | |
| ) -> torch.LongTensor: | |
| input_embeds_list = [] | |
| attention_mask_padded_list = [] | |
| max_seq_length = max(input_ids.shape[1] for input_ids in input_ids_list) | |
| for pixel_values, input_ids, attention_mask in zip(pixel_values_list, input_ids_list, attention_mask_list): | |
| if pixel_values is not None: | |
| if visual_features is not None: | |
| vit_embeds = visual_features.cuda() | |
| vit_embeds = self.mlp1(vit_embeds) | |
| else: | |
| vit_embeds = self.extract_feature(pixel_values) | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| B, N, C = input_embeds.shape | |
| input_embeds = input_embeds.reshape(B * N, C) | |
| input_ids = input_ids.reshape(B * N) | |
| selected = (input_ids == self.img_context_token_id) | |
| assert selected.sum() != 0, "No valid image context token IDs found." | |
| input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) | |
| input_embeds = input_embeds.reshape(B, N, C) | |
| else: | |
| input_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| seq_length = input_embeds.shape[1] | |
| if seq_length < max_seq_length: | |
| pad_size = max_seq_length - seq_length | |
| input_embeds = F.pad(input_embeds, (0, 0, 0, pad_size)) | |
| attention_mask = F.pad(attention_mask, (0, pad_size)) | |
| input_embeds_list.append(input_embeds) | |
| attention_mask_padded_list.append(attention_mask) | |
| input_embeds = torch.cat(input_embeds_list, dim=0) | |
| attention_mask = torch.cat(attention_mask_padded_list, dim=0) | |
| outputs = self.language_model.generate( | |
| inputs_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| generation_config=generation_config, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| use_cache=True, | |
| **generate_kwargs, | |
| ) | |
| return outputs |