# -------------------------------------------------------- # InternVL # Copyright (c) 2024 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import torch import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.activations import ACT2FN from transformers.utils import logging from .configuration_navil_vit import NaViLVisionConfig try: # from .flash_attention import FlashAttention from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb has_flash_attn = True except: print('FlashAttention is not installed.') has_flash_attn = False logger = logging.get_logger(__name__) class InternRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) try: from apex.normalization import FusedRMSNorm InternRMSNorm = FusedRMSNorm # noqa logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') except ImportError: # using the normal InternRMSNorm pass except Exception: logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') pass NORM2FN = { 'rms_norm': InternRMSNorm, 'layer_norm': nn.LayerNorm, } class InternVisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs class InternAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: NaViLVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.use_flash_attn = config.use_flash_attn and has_flash_attn if config.use_flash_attn and not has_flash_attn: print('Warning: Flash Attention is not available, use_flash_attn is set to False.') self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' f' {self.num_heads}).' ) self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) self.attn_drop = nn.Dropout(config.attention_dropout) self.proj_drop = nn.Dropout(config.dropout) self.qk_normalization = config.qk_normalization if self.qk_normalization: self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) if self.use_flash_attn: self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) self.proj = nn.Linear(self.embed_dim, self.embed_dim) def _naive_attn(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.qk_normalization: B_, H_, N_, D_ = q.shape q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) attn = ((q * self.scale) @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _flash_attn(self, x, key_padding_mask=None, need_weights=False): qkv = self.qkv(x) qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) if self.qk_normalization: q, k, v = qkv.unbind(2) q = self.q_norm(q.flatten(-2, -1)).view(q.shape) k = self.k_norm(k.flatten(-2, -1)).view(k.shape) qkv = torch.stack([q, k, v], dim=2) context, _ = self.inner_attn( qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False ) outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) outs = self.proj_drop(outs) return outs def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) return x def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: orig_dtype = tensor.dtype tensor = tensor.float() cos = freqs.cos() sin = freqs.sin() cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() output = (tensor * cos) + (rotate_half(tensor) * sin) output = output.to(orig_dtype) return output class InternVisionSdpaAttention(nn.Module): def __init__(self, config: NaViLVisionConfig) -> None: super().__init__() self.config = config dim = config.hidden_size num_heads = config.num_attention_heads self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) self.proj = nn.Linear(dim, dim) self.qk_normalization = config.qk_normalization if self.qk_normalization: self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) self.proj_drop = nn.Dropout(config.dropout) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if self.qk_normalization: q = self.q_norm(q.flatten(1).view(q.shape)) k = self.k_norm(k.flatten(1).view(k.shape)) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) attn_output = self.proj_drop(attn_output) return attn_output def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: tensor_ = tensor.float() cos = freqs.cos().float() sin = freqs.sin().float() output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) return output class InternVisionFlashAttention2(nn.Module): def __init__(self, config: NaViLVisionConfig) -> None: super().__init__() self.config = config dim = config.hidden_size num_heads = config.num_attention_heads self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) self.proj = nn.Linear(dim, dim) self.qk_normalization = config.qk_normalization if self.qk_normalization: self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) self.proj_drop = nn.Dropout(config.dropout) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if self.qk_normalization: q = self.q_norm(q.flatten(1).view(q.shape)) k = self.k_norm(k.flatten(1).view(k.shape)) q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) attn_output = self.proj_drop(attn_output) return attn_output class InternMLP(nn.Module): def __init__(self, config: NaViLVisionConfig): super().__init__() self.config = config self.act = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states