NaViL-2B / modeling_navil_vit_anyres.py
Changyao's picture
Upload folder using huggingface_hub
20f5bc9 verified
# --------------------------------------------------------
# NaViL
# Copyright (c) 2025 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional, Tuple, Union
from functools import partial
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from timm.models.layers import DropPath
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_navil_vit import NaViLVisionConfig
from .modular_intern_vit import (
InternVisionFlashAttention2,
InternVisionSdpaAttention,
InternMLP,
NORM2FN,
InternVisionRotaryEmbedding,
)
try:
# from .flash_attention import FlashAttention
from flash_attn import flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
has_flash_attn = True
except:
print('FlashAttention is not installed.')
has_flash_attn = False
logger = logging.get_logger(__name__)
class NaViLVisionEmbeddingsAnyRes(nn.Module):
def __init__(self, config: NaViLVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.merge_size = int(1.0 / config.downsample_ratio)
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
return patch_embeds.flatten(1)
class NaViLVisionEncoderLayerAnyRes(nn.Module):
def __init__(self, config: NaViLVisionConfig, drop_path_rate: float):
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
if has_flash_attn:
self.attn = InternVisionFlashAttention2(config)
else:
self.attn = InternVisionSdpaAttention(config)
self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens,
rotary_pos_emb
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
"""
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states = hidden_states + self.drop_path1(
self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
) * self.ls1)
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
return hidden_states
class NaViLVisionEncoderAnyRes(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`InternEncoderLayer`].
Args:
config (`InternConfig`):
The corresponding vision configuration for the `InternEncoder`.
"""
def __init__(self, config: NaViLVisionConfig):
super().__init__()
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
self.layers = nn.ModuleList([
NaViLVisionEncoderLayerAnyRes(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
head_dim = config.hidden_size // config.num_attention_heads
self.rotary_pos_emb = InternVisionRotaryEmbedding(head_dim // 2)
self.merge_size = int(1.0 / config.downsample_ratio)
self.merge_unit = self.merge_size * self.merge_size
self.patch_size = config.patch_size
self.fullatt_block_indexes = config.fullatt_block_indexes
self.window_size = config.window_size
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.merge_size,
self.merge_size,
w // self.merge_size,
self.merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.merge_size,
self.merge_size,
w // self.merge_size,
self.merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.merge_size
assert vit_merger_window_size > 0
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.merge_size,
grid_w // self.merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
inputs_embeds,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
grid_thw: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.merge_unit, self.merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for idx, encoder_layer in enumerate(self.layers):
if (self.fullatt_block_indexes is None) or (idx in self.fullatt_block_indexes):
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
partial(encoder_layer, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb),
hidden_states)
else:
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states
)
class NaViLVisionModelAnyRes(PreTrainedModel):
main_input_name = 'pixel_values'
config_class = NaViLVisionConfig
_no_split_modules = ['NaViLVisionEncoderLayerAnyRes']
def __init__(self, config: NaViLVisionConfig):
super().__init__(config)
self.config = config
self.merge_size = int(1.0 / config.downsample_ratio)
self.embeddings = NaViLVisionEmbeddingsAnyRes(config)
self.encoder = NaViLVisionEncoderAnyRes(config)
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
grid_thw: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
if len(pixel_values.shape) == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
grid_thw=grid_thw
)
last_hidden_state = encoder_outputs.last_hidden_state
# pooled_output = last_hidden_state[:, 0, :]
last_hidden_state = last_hidden_state.unsqueeze(1).reshape(-1, self.merge_size, self.merge_size, last_hidden_state.shape[-1])
if not return_dict:
return (last_hidden_state, ) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=None,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)