| 
							 | 
						from typing import List, Optional, Tuple, Union | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						from transformers.modeling_outputs import CausalLMOutputWithPast | 
					
					
						
						| 
							 | 
						from transformers.utils import logging | 
					
					
						
						| 
							 | 
						from typing import Optional, Union | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torchvision | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						from transformers.modeling_utils import PreTrainedModel | 
					
					
						
						| 
							 | 
						from transformers import GenerationConfig | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .configuration_gar import GARConfig | 
					
					
						
						| 
							 | 
						from .modeling_perception_lm import PerceptionLMForConditionalGeneration | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logger = logging.get_logger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class GARModel(PreTrainedModel): | 
					
					
						
						| 
							 | 
						    config_class = GARConfig | 
					
					
						
						| 
							 | 
						    main_input_name = 'pixel_values' | 
					
					
						
						| 
							 | 
						    base_model_prefix = 'language_model' | 
					
					
						
						| 
							 | 
						    _no_split_modules = ['LlamaDecoderLayer'] | 
					
					
						
						| 
							 | 
						    _supports_flash_attn_2 = True | 
					
					
						
						| 
							 | 
						    supports_gradient_checkpointing = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self,  | 
					
					
						
						| 
							 | 
						        config: GARConfig,  | 
					
					
						
						| 
							 | 
						        mllm=None, | 
					
					
						
						| 
							 | 
						        mask_patch_embedding=None, | 
					
					
						
						| 
							 | 
						        use_flash_attn=True, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__(config) | 
					
					
						
						| 
							 | 
						        use_flash_attn = use_flash_attn | 
					
					
						
						| 
							 | 
						        config.mllm_config.use_flash_attn = True if use_flash_attn else False | 
					
					
						
						| 
							 | 
						        config.mllm_config.text_config.use_flash_attn = True if use_flash_attn else False | 
					
					
						
						| 
							 | 
						        config.mllm_config.vision_config.use_flash_attn = False | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config.mllm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' | 
					
					
						
						| 
							 | 
						        config.mllm_config.vision_config._attn_implementation = 'eager' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.prompt_numbers = config.prompt_numbers | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if mllm is not None: | 
					
					
						
						| 
							 | 
						            self.mllm = mllm | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.mllm = PerceptionLMForConditionalGeneration(config.mllm_config) | 
					
					
						
						| 
							 | 
						        if mask_patch_embedding is not None: | 
					
					
						
						| 
							 | 
						            self.mask_patch_embedding = mask_patch_embedding | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.mask_patch_embedding = nn.Conv2d( | 
					
					
						
						| 
							 | 
						                in_channels=3, | 
					
					
						
						| 
							 | 
						                out_channels=config.mask_path_embedding_out_channels, | 
					
					
						
						| 
							 | 
						                kernel_size=config.kernel_size,  | 
					
					
						
						| 
							 | 
						                stride=config.kernel_size, | 
					
					
						
						| 
							 | 
						                bias=False, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.crop_tokens_ids = config.crop_tokens_ids | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @property | 
					
					
						
						| 
							 | 
						    def lm_head(self): | 
					
					
						
						| 
							 | 
						        return self.mllm.model.language_model.get_output_embeddings() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_input_embeddings(self): | 
					
					
						
						| 
							 | 
						        return self.mllm.model.language_model.get_input_embeddings() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def get_output_embeddings(self): | 
					
					
						
						| 
							 | 
						        return self.mllm.model.language_model.get_output_embeddings() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, data, data_samples=None, mode='loss'): | 
					
					
						
						| 
							 | 
						        crop_tokens = self.crop_tokens_ids | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        pixel_values = data['pixel_values'].to(self.mllm.device).to(self.mllm.dtype) | 
					
					
						
						| 
							 | 
						        mask_values = torch.round((data['global_mask_values'] + 1.) / 2. * 255.).long().to(self.mllm.device) | 
					
					
						
						| 
							 | 
						        mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers) | 
					
					
						
						| 
							 | 
						        assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype))      | 
					
					
						
						| 
							 | 
						        input_ids = data['input_ids'] | 
					
					
						
						| 
							 | 
						        aspect_ratios = data['aspect_ratios'] | 
					
					
						
						| 
							 | 
						        bboxes = data['bboxes'] | 
					
					
						
						| 
							 | 
						        assert input_ids.shape[0] == 1, "Currently only support batch_size=1"     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        inputs_embeds = self.mllm.get_input_embeddings()(input_ids) | 
					
					
						
						| 
							 | 
						        labels = data['labels'] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image_features = None | 
					
					
						
						| 
							 | 
						        if pixel_values is not None: | 
					
					
						
						| 
							 | 
						            image_features = self.mllm.get_image_features( | 
					
					
						
						| 
							 | 
						                pixel_values=pixel_values.unsqueeze(0), | 
					
					
						
						| 
							 | 
						                mask_embeds=mask_embeds, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) | 
					
					
						
						| 
							 | 
						            special_image_mask, _ = self.mllm.get_placeholder_mask( | 
					
					
						
						| 
							 | 
						                input_ids, inputs_embeds=inputs_embeds, image_features=image_features | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        new_inputs_embeds = [] | 
					
					
						
						| 
							 | 
						        new_labels = [] | 
					
					
						
						| 
							 | 
						        image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16) | 
					
					
						
						| 
							 | 
						        for batch_idx in range(inputs_embeds.shape[0]): | 
					
					
						
						| 
							 | 
						            curr_inputs_embeds = inputs_embeds[batch_idx] | 
					
					
						
						| 
							 | 
						            curr_labels = labels[batch_idx] | 
					
					
						
						| 
							 | 
						            for crop_token in crop_tokens: | 
					
					
						
						| 
							 | 
						                if crop_token in input_ids[batch_idx]: | 
					
					
						
						| 
							 | 
						                    target_mask = input_ids[batch_idx].eq(crop_token) | 
					
					
						
						| 
							 | 
						                    target_indices = target_mask.nonzero().squeeze() | 
					
					
						
						| 
							 | 
						                    head_idx = target_indices.min().item() | 
					
					
						
						| 
							 | 
						                    tail_idx = target_indices.max().item() | 
					
					
						
						| 
							 | 
						                    image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1]) | 
					
					
						
						| 
							 | 
						                    feat_h, feat_w = image_features_recover.shape[2:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)] | 
					
					
						
						| 
							 | 
						                    orig_h, orig_w = feat_h * 28, feat_w * 28 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    roi_orig_x1 = x1 * orig_w | 
					
					
						
						| 
							 | 
						                    roi_orig_y1 = y1 * orig_h | 
					
					
						
						| 
							 | 
						                    roi_orig_x2 = x2 * orig_w | 
					
					
						
						| 
							 | 
						                    roi_orig_y2 = y2 * orig_h | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    spatial_scale = feat_w / orig_w | 
					
					
						
						| 
							 | 
						                    roi_feat_x1 = roi_orig_x1 * spatial_scale | 
					
					
						
						| 
							 | 
						                    roi_feat_y1 = roi_orig_y1 * spatial_scale | 
					
					
						
						| 
							 | 
						                    roi_feat_x2 = roi_orig_x2 * spatial_scale | 
					
					
						
						| 
							 | 
						                    roi_feat_y2 = roi_orig_y2 * spatial_scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    roi = torch.tensor( | 
					
					
						
						| 
							 | 
						                        [0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2],  | 
					
					
						
						| 
							 | 
						                        dtype=torch.float32, device=image_features_recover.device, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    roi_features = torchvision.ops.roi_align( | 
					
					
						
						| 
							 | 
						                        input=image_features_recover.float(), | 
					
					
						
						| 
							 | 
						                        boxes=roi.unsqueeze(0), | 
					
					
						
						| 
							 | 
						                        output_size=(16, 16), | 
					
					
						
						| 
							 | 
						                        spatial_scale=spatial_scale, | 
					
					
						
						| 
							 | 
						                        sampling_ratio=2, | 
					
					
						
						| 
							 | 
						                        aligned=True, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze() | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    curr_inputs_embeds = torch.cat([ | 
					
					
						
						| 
							 | 
						                        curr_inputs_embeds[:head_idx],  | 
					
					
						
						| 
							 | 
						                        image_features_replay,  | 
					
					
						
						| 
							 | 
						                        curr_inputs_embeds[tail_idx+1:], | 
					
					
						
						| 
							 | 
						                    ]) | 
					
					
						
						| 
							 | 
						                    curr_labels = torch.cat([ | 
					
					
						
						| 
							 | 
						                        curr_labels[:head_idx], | 
					
					
						
						| 
							 | 
						                        -100 * torch.ones(image_features_replay.shape[0], dtype=torch.long, device=labels.device), | 
					
					
						
						| 
							 | 
						                        curr_labels[tail_idx+1:], | 
					
					
						
						| 
							 | 
						                    ]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                    assert curr_inputs_embeds.shape[0] == curr_labels.shape[0], f"shape mismatch, got {curr_inputs_embeds.shape[0]} != {curr_labels.shape[0]}" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0)) | 
					
					
						
						| 
							 | 
						            new_labels.append(curr_labels) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        inputs_embeds = torch.cat(new_inputs_embeds, dim=0) | 
					
					
						
						| 
							 | 
						        labels = torch.cat(new_labels, dim=0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        skip_this_batch = False | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if mode == "loss": | 
					
					
						
						| 
							 | 
						            position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device).unsqueeze(0).repeat(inputs_embeds.shape[0], 1) | 
					
					
						
						| 
							 | 
						            attention_mask = torch.ones(inputs_embeds.shape[0], inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device) | 
					
					
						
						| 
							 | 
						            use_cache = False | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            outputs, _skip_this_case = self._llm_forward( | 
					
					
						
						| 
							 | 
						                inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						                position_ids=position_ids, | 
					
					
						
						| 
							 | 
						                attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						                labels=labels, | 
					
					
						
						| 
							 | 
						                use_cache=use_cache | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if skip_this_batch or _skip_this_case: | 
					
					
						
						| 
							 | 
						                print("skip this batch!") | 
					
					
						
						| 
							 | 
						                loss_dict = {'loss': outputs.loss * 0.0} | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                loss_dict = {'loss': outputs.loss} | 
					
					
						
						| 
							 | 
						            return loss_dict | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        elif mode == "predict": | 
					
					
						
						| 
							 | 
						            pass | 
					
					
						
						| 
							 | 
						        elif mode == "tensor": | 
					
					
						
						| 
							 | 
						            pass | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise NotImplementedError | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return outputs | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def _merge(self, tiles: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        batch_size, num_tiles, num_channels, tile_height, tile_width = tiles.size() | 
					
					
						
						| 
							 | 
						        assert num_tiles == ncw * nch, f"{ncw * nch} != {num_tiles}" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        tiles = tiles.view(batch_size, nch, ncw, num_channels, tile_height, tile_width) | 
					
					
						
						| 
							 | 
						        tiles = tiles.permute(0, 3, 1, 4, 2, 5).contiguous() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        original_height = nch * tile_height | 
					
					
						
						| 
							 | 
						        original_width = ncw * tile_width | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image = tiles.view(batch_size, num_channels, original_height, original_width) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        return image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _llm_forward( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        inputs_embeds: torch.FloatTensor, | 
					
					
						
						| 
							 | 
						        input_ids: torch.LongTensor = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.Tensor] = None, | 
					
					
						
						| 
							 | 
						        position_ids: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        image_flags: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        past_key_values: Optional[List[torch.FloatTensor]] = None, | 
					
					
						
						| 
							 | 
						        labels: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        use_cache: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_attentions: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						    ) -> Union[Tuple, CausalLMOutputWithPast]: | 
					
					
						
						| 
							 | 
						        return_dict = return_dict if return_dict is not None \ | 
					
					
						
						| 
							 | 
						            else self.mllm.config.use_return_dict | 
					
					
						
						| 
							 | 
						        skip_this_case = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = self.mllm( | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            position_ids=position_ids, | 
					
					
						
						| 
							 | 
						            labels=labels, | 
					
					
						
						| 
							 | 
						            past_key_values=past_key_values, | 
					
					
						
						| 
							 | 
						            use_cache=use_cache, | 
					
					
						
						| 
							 | 
						            output_attentions=output_attentions, | 
					
					
						
						| 
							 | 
						            output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						            return_dict=return_dict, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return outputs, skip_this_case | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @torch.no_grad() | 
					
					
						
						| 
							 | 
						    def generate( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        pixel_values: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        global_mask_values: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        aspect_ratios: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        bboxes: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        input_ids: Optional[torch.FloatTensor] = None, | 
					
					
						
						| 
							 | 
						        attention_mask: Optional[torch.LongTensor] = None, | 
					
					
						
						| 
							 | 
						        generation_config: Optional[GenerationConfig] = None, | 
					
					
						
						| 
							 | 
						        output_hidden_states: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        return_dict: Optional[bool] = None, | 
					
					
						
						| 
							 | 
						        **generate_kwargs, | 
					
					
						
						| 
							 | 
						    ) -> torch.LongTensor: | 
					
					
						
						| 
							 | 
						        device = self.device | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if pixel_values is not None: | 
					
					
						
						| 
							 | 
						            pixel_values = pixel_values.to(device).to(self.mllm.dtype) | 
					
					
						
						| 
							 | 
						            if global_mask_values is not None: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                mask_values = torch.round((global_mask_values + 1.) / 2. * 255.).long().to(device) | 
					
					
						
						| 
							 | 
						                mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0, f"max: {mask_values.max()}, min: {mask_values.min()}" | 
					
					
						
						| 
							 | 
						                mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype))  | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                mask_embeds = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            inputs_embeds = self.mllm.get_input_embeddings()(input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            image_features = self.mllm.get_image_features( | 
					
					
						
						| 
							 | 
						                pixel_values=pixel_values.unsqueeze(0), | 
					
					
						
						| 
							 | 
						                mask_embeds=mask_embeds, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) | 
					
					
						
						| 
							 | 
						            special_image_mask, _ = self.mllm.get_placeholder_mask( | 
					
					
						
						| 
							 | 
						                input_ids, inputs_embeds=inputs_embeds, image_features=image_features | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            new_inputs_embeds = [] | 
					
					
						
						| 
							 | 
						            image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16) | 
					
					
						
						| 
							 | 
						            for batch_idx in range(inputs_embeds.shape[0]): | 
					
					
						
						| 
							 | 
						                curr_inputs_embeds = inputs_embeds[batch_idx] | 
					
					
						
						| 
							 | 
						                for crop_token in self.crop_tokens_ids: | 
					
					
						
						| 
							 | 
						                    if crop_token in input_ids[batch_idx]: | 
					
					
						
						| 
							 | 
						                        target_mask = input_ids[batch_idx].eq(crop_token) | 
					
					
						
						| 
							 | 
						                        target_indices = target_mask.nonzero().squeeze() | 
					
					
						
						| 
							 | 
						                        head_idx = target_indices.min().item() | 
					
					
						
						| 
							 | 
						                        tail_idx = target_indices.max().item() | 
					
					
						
						| 
							 | 
						                        image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1]) | 
					
					
						
						| 
							 | 
						                        feat_h, feat_w = image_features_recover.shape[2:] | 
					
					
						
						| 
							 | 
						                        x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)] | 
					
					
						
						| 
							 | 
						                        orig_h, orig_w = feat_h * 28, feat_w * 28 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        roi_orig_x1 = x1 * orig_w | 
					
					
						
						| 
							 | 
						                        roi_orig_y1 = y1 * orig_h | 
					
					
						
						| 
							 | 
						                        roi_orig_x2 = x2 * orig_w | 
					
					
						
						| 
							 | 
						                        roi_orig_y2 = y2 * orig_h | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        spatial_scale = feat_w / orig_w | 
					
					
						
						| 
							 | 
						                        roi_feat_x1 = roi_orig_x1 * spatial_scale | 
					
					
						
						| 
							 | 
						                        roi_feat_y1 = roi_orig_y1 * spatial_scale | 
					
					
						
						| 
							 | 
						                        roi_feat_x2 = roi_orig_x2 * spatial_scale | 
					
					
						
						| 
							 | 
						                        roi_feat_y2 = roi_orig_y2 * spatial_scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        roi = torch.tensor( | 
					
					
						
						| 
							 | 
						                            [0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2],  | 
					
					
						
						| 
							 | 
						                            dtype=torch.float32, device=image_features_recover.device, | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        roi_features = torchvision.ops.roi_align( | 
					
					
						
						| 
							 | 
						                            input=image_features_recover.float(), | 
					
					
						
						| 
							 | 
						                            boxes=roi.unsqueeze(0), | 
					
					
						
						| 
							 | 
						                            output_size=(16, 16), | 
					
					
						
						| 
							 | 
						                            spatial_scale=spatial_scale, | 
					
					
						
						| 
							 | 
						                            sampling_ratio=2, | 
					
					
						
						| 
							 | 
						                            aligned=True, | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                        image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze() | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                        curr_inputs_embeds = torch.cat([ | 
					
					
						
						| 
							 | 
						                            curr_inputs_embeds[:head_idx],  | 
					
					
						
						| 
							 | 
						                            image_features_replay,  | 
					
					
						
						| 
							 | 
						                            curr_inputs_embeds[tail_idx+1:], | 
					
					
						
						| 
							 | 
						                        ]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0)) | 
					
					
						
						| 
							 | 
						            inputs_embeds = torch.cat(new_inputs_embeds, dim=0) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            inputs_embeds = self.mllm.get_input_embeddings()(input_ids) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        outputs = self.mllm.generate( | 
					
					
						
						| 
							 | 
						            inputs_embeds=inputs_embeds, | 
					
					
						
						| 
							 | 
						            attention_mask=attention_mask, | 
					
					
						
						| 
							 | 
						            generation_config=generation_config, | 
					
					
						
						| 
							 | 
						            output_hidden_states=output_hidden_states, | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            use_cache=True, | 
					
					
						
						| 
							 | 
						            return_dict_in_generate=True, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return outputs |