|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
|
|
|
from einops import rearrange |
|
|
from timm.models.layers import DropPath |
|
|
from torch import nn |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.modeling_outputs import (BaseModelOutput, |
|
|
BaseModelOutputWithPooling) |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_navil_vit import NaViLVisionConfig |
|
|
from .modular_intern_vit import ( |
|
|
InternVisionFlashAttention2, |
|
|
InternVisionSdpaAttention, |
|
|
InternMLP, |
|
|
NORM2FN, |
|
|
InternVisionRotaryEmbedding, |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
from flash_attn import flash_attn_varlen_func |
|
|
from flash_attn.layers.rotary import apply_rotary_emb |
|
|
has_flash_attn = True |
|
|
except: |
|
|
print('FlashAttention is not installed.') |
|
|
has_flash_attn = False |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class NaViLVisionEmbeddingsAnyRes(nn.Module): |
|
|
def __init__(self, config: NaViLVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.image_size = config.image_size |
|
|
self.patch_size = config.patch_size |
|
|
self.merge_size = int(1.0 / config.downsample_ratio) |
|
|
|
|
|
self.patch_embedding = nn.Conv2d( |
|
|
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size |
|
|
) |
|
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
|
|
self.num_positions = self.num_patches + 1 |
|
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
|
|
target_dtype = self.patch_embedding.weight.dtype |
|
|
patch_embeds = self.patch_embedding(pixel_values) |
|
|
batch_size, _, height, width = patch_embeds.shape |
|
|
|
|
|
return patch_embeds.flatten(1) |
|
|
|
|
|
|
|
|
class NaViLVisionEncoderLayerAnyRes(nn.Module): |
|
|
def __init__(self, config: NaViLVisionConfig, drop_path_rate: float): |
|
|
super().__init__() |
|
|
self.embed_dim = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.norm_type = config.norm_type |
|
|
|
|
|
if has_flash_attn: |
|
|
self.attn = InternVisionFlashAttention2(config) |
|
|
else: |
|
|
self.attn = InternVisionSdpaAttention(config) |
|
|
self.mlp = InternMLP(config) |
|
|
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
|
|
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
|
|
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) |
|
|
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens, |
|
|
rotary_pos_emb |
|
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
""" |
|
|
hidden_states = hidden_states + self.drop_path1( |
|
|
self.attn( |
|
|
self.norm1(hidden_states), |
|
|
cu_seqlens=cu_seqlens, |
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
) * self.ls1) |
|
|
|
|
|
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class NaViLVisionEncoderAnyRes(nn.Module): |
|
|
""" |
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
|
|
[`InternEncoderLayer`]. |
|
|
|
|
|
Args: |
|
|
config (`InternConfig`): |
|
|
The corresponding vision configuration for the `InternEncoder`. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: NaViLVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] |
|
|
self.layers = nn.ModuleList([ |
|
|
NaViLVisionEncoderLayerAnyRes(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) |
|
|
self.gradient_checkpointing = True |
|
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads |
|
|
self.rotary_pos_emb = InternVisionRotaryEmbedding(head_dim // 2) |
|
|
|
|
|
self.merge_size = int(1.0 / config.downsample_ratio) |
|
|
self.merge_unit = self.merge_size * self.merge_size |
|
|
self.patch_size = config.patch_size |
|
|
self.fullatt_block_indexes = config.fullatt_block_indexes |
|
|
self.window_size = config.window_size |
|
|
|
|
|
def rot_pos_emb(self, grid_thw): |
|
|
pos_ids = [] |
|
|
for t, h, w in grid_thw: |
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
|
|
hpos_ids = hpos_ids.reshape( |
|
|
h // self.merge_size, |
|
|
self.merge_size, |
|
|
w // self.merge_size, |
|
|
self.merge_size, |
|
|
) |
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
|
|
hpos_ids = hpos_ids.flatten() |
|
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
|
|
wpos_ids = wpos_ids.reshape( |
|
|
h // self.merge_size, |
|
|
self.merge_size, |
|
|
w // self.merge_size, |
|
|
self.merge_size, |
|
|
) |
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
|
|
wpos_ids = wpos_ids.flatten() |
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
|
|
pos_ids = torch.cat(pos_ids, dim=0) |
|
|
max_grid_size = grid_thw[:, 1:].max() |
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
|
|
return rotary_pos_emb |
|
|
|
|
|
def get_window_index(self, grid_thw): |
|
|
window_index: list = [] |
|
|
cu_window_seqlens: list = [0] |
|
|
window_index_id = 0 |
|
|
vit_merger_window_size = self.window_size // self.merge_size |
|
|
assert vit_merger_window_size > 0 |
|
|
|
|
|
for grid_t, grid_h, grid_w in grid_thw: |
|
|
llm_grid_h, llm_grid_w = ( |
|
|
grid_h // self.merge_size, |
|
|
grid_w // self.merge_size, |
|
|
) |
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
|
|
index_padded = index_padded.reshape( |
|
|
grid_t, |
|
|
num_windows_h, |
|
|
vit_merger_window_size, |
|
|
num_windows_w, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
|
|
grid_t, |
|
|
num_windows_h * num_windows_w, |
|
|
vit_merger_window_size, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
|
|
index_padded = index_padded.reshape(-1) |
|
|
index_new = index_padded[index_padded != -100] |
|
|
window_index.append(index_new + window_index_id) |
|
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1] |
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
|
|
window_index = torch.cat(window_index, dim=0) |
|
|
|
|
|
return window_index, cu_window_seqlens |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
inputs_embeds, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
grid_thw: Optional[torch.Tensor] = None, |
|
|
) -> Union[Tuple, BaseModelOutput]: |
|
|
r""" |
|
|
Args: |
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
|
Embedded representation of the inputs. Should be float, not int tokens. |
|
|
output_hidden_states (`bool`, *optional*): |
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
|
for more detail. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
""" |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
encoder_states = () if output_hidden_states else None |
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
|
|
cu_window_seqlens = torch.tensor( |
|
|
cu_window_seqlens, |
|
|
device=hidden_states.device, |
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
|
|
|
seq_len, _ = hidden_states.size() |
|
|
hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
|
|
hidden_states = hidden_states[window_index, :, :] |
|
|
hidden_states = hidden_states.reshape(seq_len, -1) |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.merge_unit, self.merge_unit, -1) |
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
|
|
dim=0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
|
|
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
|
if (self.fullatt_block_indexes is None) or (idx in self.fullatt_block_indexes): |
|
|
cu_seqlens_now = cu_seqlens |
|
|
else: |
|
|
cu_seqlens_now = cu_window_seqlens |
|
|
if output_hidden_states: |
|
|
encoder_states = encoder_states + (hidden_states,) |
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
|
partial(encoder_layer, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb), |
|
|
hidden_states) |
|
|
else: |
|
|
layer_outputs = encoder_layer( |
|
|
hidden_states, |
|
|
cu_seqlens=cu_seqlens_now, |
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
) |
|
|
hidden_states = layer_outputs |
|
|
|
|
|
if output_hidden_states: |
|
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, encoder_states] if v is not None) |
|
|
return BaseModelOutput( |
|
|
last_hidden_state=hidden_states, hidden_states=encoder_states |
|
|
) |
|
|
|
|
|
|
|
|
class NaViLVisionModelAnyRes(PreTrainedModel): |
|
|
main_input_name = 'pixel_values' |
|
|
config_class = NaViLVisionConfig |
|
|
_no_split_modules = ['NaViLVisionEncoderLayerAnyRes'] |
|
|
|
|
|
def __init__(self, config: NaViLVisionConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.merge_size = int(1.0 / config.downsample_ratio) |
|
|
self.embeddings = NaViLVisionEmbeddingsAnyRes(config) |
|
|
self.encoder = NaViLVisionEncoderAnyRes(config) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_embeds: Optional[torch.FloatTensor] = None, |
|
|
grid_thw: Optional[torch.Tensor] = None, |
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None and pixel_embeds is None: |
|
|
raise ValueError('You have to specify pixel_values or pixel_embeds') |
|
|
|
|
|
if pixel_embeds is not None: |
|
|
hidden_states = pixel_embeds |
|
|
else: |
|
|
if len(pixel_values.shape) == 4: |
|
|
hidden_states = self.embeddings(pixel_values) |
|
|
else: |
|
|
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') |
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
inputs_embeds=hidden_states, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
grid_thw=grid_thw |
|
|
) |
|
|
last_hidden_state = encoder_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
last_hidden_state = last_hidden_state.unsqueeze(1).reshape(-1, self.merge_size, self.merge_size, last_hidden_state.shape[-1]) |
|
|
|
|
|
if not return_dict: |
|
|
return (last_hidden_state, ) + encoder_outputs[1:] |
|
|
|
|
|
return BaseModelOutputWithPooling( |
|
|
last_hidden_state=last_hidden_state, |
|
|
pooler_output=None, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
) |
|
|
|