Feature Extraction
Transformers
Safetensors
English
GAR
custom_code
GAR-8B / modeling_gar.py
HaochenWang's picture
Upload folder using huggingface_hub
14e3151 verified
from typing import List, Optional, Tuple, Union
from torch import nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import logging
from typing import Optional, Union
import torch
import torchvision
from torch import nn
from einops import rearrange
from transformers.modeling_utils import PreTrainedModel
from transformers import GenerationConfig
from .configuration_gar import GARConfig
from .modeling_perception_lm import PerceptionLMForConditionalGeneration
logger = logging.get_logger(__name__)
class GARModel(PreTrainedModel):
config_class = GARConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_no_split_modules = ['LlamaDecoderLayer']
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
def __init__(
self,
config: GARConfig,
mllm=None,
mask_patch_embedding=None,
use_flash_attn=True,
):
super().__init__(config)
use_flash_attn = use_flash_attn
config.mllm_config.use_flash_attn = True if use_flash_attn else False
config.mllm_config.text_config.use_flash_attn = True if use_flash_attn else False
config.mllm_config.vision_config.use_flash_attn = False
config.mllm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
config.mllm_config.vision_config._attn_implementation = 'eager'
self.prompt_numbers = config.prompt_numbers
if mllm is not None:
self.mllm = mllm
else:
self.mllm = PerceptionLMForConditionalGeneration(config.mllm_config)
if mask_patch_embedding is not None:
self.mask_patch_embedding = mask_patch_embedding
else:
self.mask_patch_embedding = nn.Conv2d(
in_channels=3,
out_channels=config.mask_path_embedding_out_channels,
kernel_size=config.kernel_size,
stride=config.kernel_size,
bias=False,
)
self.crop_tokens_ids = config.crop_tokens_ids
@property
def lm_head(self):
return self.mllm.model.language_model.get_output_embeddings()
def get_input_embeddings(self):
return self.mllm.model.language_model.get_input_embeddings()
def get_output_embeddings(self):
return self.mllm.model.language_model.get_output_embeddings()
def forward(self, data, data_samples=None, mode='loss'):
crop_tokens = self.crop_tokens_ids
# (batch_size, num_tiles, channels, height, width)
pixel_values = data['pixel_values'].to(self.mllm.device).to(self.mllm.dtype)
mask_values = torch.round((data['global_mask_values'] + 1.) / 2. * 255.).long().to(self.mllm.device)
mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers)
assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0
mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype)) # binary mask
input_ids = data['input_ids']
aspect_ratios = data['aspect_ratios']
bboxes = data['bboxes']
assert input_ids.shape[0] == 1, "Currently only support batch_size=1"
inputs_embeds = self.mllm.get_input_embeddings()(input_ids)
labels = data['labels']
image_features = None
if pixel_values is not None:
image_features = self.mllm.get_image_features(
pixel_values=pixel_values.unsqueeze(0),
mask_embeds=mask_embeds,
)
image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
special_image_mask, _ = self.mllm.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# feature replay
new_inputs_embeds = []
new_labels = []
image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16)
for batch_idx in range(inputs_embeds.shape[0]):
curr_inputs_embeds = inputs_embeds[batch_idx]
curr_labels = labels[batch_idx]
for crop_token in crop_tokens:
if crop_token in input_ids[batch_idx]:
target_mask = input_ids[batch_idx].eq(crop_token)
target_indices = target_mask.nonzero().squeeze()
head_idx = target_indices.min().item()
tail_idx = target_indices.max().item()
image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1])
feat_h, feat_w = image_features_recover.shape[2:]
x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)]
orig_h, orig_w = feat_h * 28, feat_w * 28
# origin box
roi_orig_x1 = x1 * orig_w
roi_orig_y1 = y1 * orig_h
roi_orig_x2 = x2 * orig_w
roi_orig_y2 = y2 * orig_h
# feat box
spatial_scale = feat_w / orig_w
roi_feat_x1 = roi_orig_x1 * spatial_scale
roi_feat_y1 = roi_orig_y1 * spatial_scale
roi_feat_x2 = roi_orig_x2 * spatial_scale
roi_feat_y2 = roi_orig_y2 * spatial_scale
roi = torch.tensor(
[0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2],
dtype=torch.float32, device=image_features_recover.device,
)
roi_features = torchvision.ops.roi_align(
input=image_features_recover.float(),
boxes=roi.unsqueeze(0),
output_size=(16, 16),
spatial_scale=spatial_scale,
sampling_ratio=2,
aligned=True,
)
image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze()
curr_inputs_embeds = torch.cat([
curr_inputs_embeds[:head_idx],
image_features_replay,
curr_inputs_embeds[tail_idx+1:],
])
curr_labels = torch.cat([
curr_labels[:head_idx],
-100 * torch.ones(image_features_replay.shape[0], dtype=torch.long, device=labels.device),
curr_labels[tail_idx+1:],
])
assert curr_inputs_embeds.shape[0] == curr_labels.shape[0], f"shape mismatch, got {curr_inputs_embeds.shape[0]} != {curr_labels.shape[0]}"
new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0))
new_labels.append(curr_labels)
inputs_embeds = torch.cat(new_inputs_embeds, dim=0)
labels = torch.cat(new_labels, dim=0)
skip_this_batch = False
if mode == "loss":
position_ids = torch.arange(0, inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device).unsqueeze(0).repeat(inputs_embeds.shape[0], 1)
attention_mask = torch.ones(inputs_embeds.shape[0], inputs_embeds.shape[1], dtype=torch.long, device=inputs_embeds.device)
use_cache = False
outputs, _skip_this_case = self._llm_forward(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
attention_mask=attention_mask,
labels=labels,
use_cache=use_cache
)
if skip_this_batch or _skip_this_case:
print("skip this batch!")
loss_dict = {'loss': outputs.loss * 0.0}
else:
loss_dict = {'loss': outputs.loss}
return loss_dict
elif mode == "predict":
pass
elif mode == "tensor":
pass
else:
raise NotImplementedError
return outputs
def _merge(self, tiles: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
batch_size, num_tiles, num_channels, tile_height, tile_width = tiles.size()
assert num_tiles == ncw * nch, f"{ncw * nch} != {num_tiles}"
tiles = tiles.view(batch_size, nch, ncw, num_channels, tile_height, tile_width)
tiles = tiles.permute(0, 3, 1, 4, 2, 5).contiguous()
original_height = nch * tile_height
original_width = ncw * tile_width
image = tiles.view(batch_size, num_channels, original_height, original_width)
return image
def _llm_forward(
self,
inputs_embeds: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None \
else self.mllm.config.use_return_dict
skip_this_case = False
outputs = self.mllm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs, skip_this_case
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
global_mask_values: Optional[torch.LongTensor] = None,
aspect_ratios: Optional[torch.FloatTensor] = None,
bboxes: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
device = self.device
if pixel_values is not None:
pixel_values = pixel_values.to(device).to(self.mllm.dtype)
if global_mask_values is not None:
mask_values = torch.round((global_mask_values + 1.) / 2. * 255.).long().to(device)
mask_values = torch.clamp(mask_values, min=0, max=self.prompt_numbers)
assert mask_values.max() < self.prompt_numbers + 1 and mask_values.min() >= 0, f"max: {mask_values.max()}, min: {mask_values.min()}"
mask_embeds = self.mask_patch_embedding((mask_values != self.prompt_numbers).to(self.mllm.dtype))
else:
mask_embeds = None
inputs_embeds = self.mllm.get_input_embeddings()(input_ids)
image_features = self.mllm.get_image_features(
pixel_values=pixel_values.unsqueeze(0),
mask_embeds=mask_embeds,
)
image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype)
special_image_mask, _ = self.mllm.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# feature replay
new_inputs_embeds = []
image_features_tiles = rearrange(image_features[1:].unsqueeze(0), 'b n (h w) c -> b n c h w', h=16, w=16)
for batch_idx in range(inputs_embeds.shape[0]):
curr_inputs_embeds = inputs_embeds[batch_idx]
for crop_token in self.crop_tokens_ids:
if crop_token in input_ids[batch_idx]:
target_mask = input_ids[batch_idx].eq(crop_token)
target_indices = target_mask.nonzero().squeeze()
head_idx = target_indices.min().item()
tail_idx = target_indices.max().item()
image_features_recover = self._merge(image_features_tiles, aspect_ratios[batch_idx][0], aspect_ratios[batch_idx][1])
feat_h, feat_w = image_features_recover.shape[2:]
x1, y1, x2, y2 = bboxes[batch_idx][str(crop_token)]
orig_h, orig_w = feat_h * 28, feat_w * 28
# origin box
roi_orig_x1 = x1 * orig_w
roi_orig_y1 = y1 * orig_h
roi_orig_x2 = x2 * orig_w
roi_orig_y2 = y2 * orig_h
# feat box
spatial_scale = feat_w / orig_w
roi_feat_x1 = roi_orig_x1 * spatial_scale
roi_feat_y1 = roi_orig_y1 * spatial_scale
roi_feat_x2 = roi_orig_x2 * spatial_scale
roi_feat_y2 = roi_orig_y2 * spatial_scale
roi = torch.tensor(
[0, roi_feat_x1, roi_feat_y1, roi_feat_x2, roi_feat_y2],
dtype=torch.float32, device=image_features_recover.device,
)
roi_features = torchvision.ops.roi_align(
input=image_features_recover.float(),
boxes=roi.unsqueeze(0),
output_size=(16, 16),
spatial_scale=spatial_scale,
sampling_ratio=2,
aligned=True,
)
image_features_replay = roi_features.permute(0, 2, 3, 1).flatten(1, 2).to(image_features_recover.dtype).squeeze()
curr_inputs_embeds = torch.cat([
curr_inputs_embeds[:head_idx],
image_features_replay,
curr_inputs_embeds[tail_idx+1:],
])
new_inputs_embeds.append(curr_inputs_embeds.unsqueeze(0))
inputs_embeds = torch.cat(new_inputs_embeds, dim=0)
else:
inputs_embeds = self.mllm.get_input_embeddings()(input_ids)
outputs = self.mllm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
# return_dict=return_dict,
use_cache=True,
return_dict_in_generate=True,
)
return outputs