|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from einops import rearrange |
|
|
from torch import nn |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_navil_vit import NaViLVisionConfig |
|
|
|
|
|
try: |
|
|
|
|
|
from flash_attn import flash_attn_varlen_func |
|
|
from flash_attn.layers.rotary import apply_rotary_emb |
|
|
has_flash_attn = True |
|
|
except: |
|
|
print('FlashAttention is not installed.') |
|
|
has_flash_attn = False |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class InternRMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
|
|
|
try: |
|
|
from apex.normalization import FusedRMSNorm |
|
|
|
|
|
InternRMSNorm = FusedRMSNorm |
|
|
|
|
|
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') |
|
|
except ImportError: |
|
|
|
|
|
pass |
|
|
except Exception: |
|
|
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') |
|
|
pass |
|
|
|
|
|
|
|
|
NORM2FN = { |
|
|
'rms_norm': InternRMSNorm, |
|
|
'layer_norm': nn.LayerNorm, |
|
|
} |
|
|
|
|
|
|
|
|
class InternVisionRotaryEmbedding(nn.Module): |
|
|
def __init__(self, dim: int, theta: float = 10000.0) -> None: |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, seqlen: int) -> torch.Tensor: |
|
|
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
|
freqs = torch.outer(seq, self.inv_freq) |
|
|
return freqs |
|
|
|
|
|
|
|
|
class InternAttention(nn.Module): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config: NaViLVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.use_flash_attn = config.use_flash_attn and has_flash_attn |
|
|
if config.use_flash_attn and not has_flash_attn: |
|
|
print('Warning: Flash Attention is not available, use_flash_attn is set to False.') |
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
|
raise ValueError( |
|
|
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' |
|
|
f' {self.num_heads}).' |
|
|
) |
|
|
|
|
|
self.scale = self.head_dim ** -0.5 |
|
|
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) |
|
|
self.attn_drop = nn.Dropout(config.attention_dropout) |
|
|
self.proj_drop = nn.Dropout(config.dropout) |
|
|
|
|
|
self.qk_normalization = config.qk_normalization |
|
|
|
|
|
if self.qk_normalization: |
|
|
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
if self.use_flash_attn: |
|
|
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) |
|
|
self.proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
def _naive_attn(self, x): |
|
|
B, N, C = x.shape |
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv.unbind(0) |
|
|
|
|
|
if self.qk_normalization: |
|
|
B_, H_, N_, D_ = q.shape |
|
|
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) |
|
|
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) |
|
|
|
|
|
attn = ((q * self.scale) @ k.transpose(-2, -1)) |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
def _flash_attn(self, x, key_padding_mask=None, need_weights=False): |
|
|
qkv = self.qkv(x) |
|
|
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) |
|
|
|
|
|
if self.qk_normalization: |
|
|
q, k, v = qkv.unbind(2) |
|
|
q = self.q_norm(q.flatten(-2, -1)).view(q.shape) |
|
|
k = self.k_norm(k.flatten(-2, -1)).view(k.shape) |
|
|
qkv = torch.stack([q, k, v], dim=2) |
|
|
|
|
|
context, _ = self.inner_attn( |
|
|
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False |
|
|
) |
|
|
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) |
|
|
outs = self.proj_drop(outs) |
|
|
return outs |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) |
|
|
return x |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
|
|
orig_dtype = tensor.dtype |
|
|
tensor = tensor.float() |
|
|
cos = freqs.cos() |
|
|
sin = freqs.sin() |
|
|
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
|
|
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() |
|
|
output = (tensor * cos) + (rotate_half(tensor) * sin) |
|
|
output = output.to(orig_dtype) |
|
|
return output |
|
|
|
|
|
|
|
|
class InternVisionSdpaAttention(nn.Module): |
|
|
def __init__(self, config: NaViLVisionConfig) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
|
|
|
dim = config.hidden_size |
|
|
num_heads = config.num_attention_heads |
|
|
self.num_heads = num_heads |
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
self.qk_normalization = config.qk_normalization |
|
|
|
|
|
if self.qk_normalization: |
|
|
self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
|
|
self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
|
|
|
|
|
self.proj_drop = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None |
|
|
) -> torch.Tensor: |
|
|
seq_length = hidden_states.shape[0] |
|
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
|
|
|
|
|
if self.qk_normalization: |
|
|
q = self.q_norm(q.flatten(1).view(q.shape)) |
|
|
k = self.k_norm(k.flatten(1).view(k.shape)) |
|
|
|
|
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
|
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
|
|
|
|
|
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) |
|
|
for i in range(1, len(cu_seqlens)): |
|
|
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True |
|
|
q = q.transpose(0, 1) |
|
|
k = k.transpose(0, 1) |
|
|
v = v.transpose(0, 1) |
|
|
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) |
|
|
attn_output = attn_output.transpose(0, 1) |
|
|
attn_output = attn_output.reshape(seq_length, -1) |
|
|
attn_output = self.proj(attn_output) |
|
|
attn_output = self.proj_drop(attn_output) |
|
|
return attn_output |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
|
|
tensor_ = tensor.float() |
|
|
cos = freqs.cos().float() |
|
|
sin = freqs.sin().float() |
|
|
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) |
|
|
return output |
|
|
|
|
|
|
|
|
class InternVisionFlashAttention2(nn.Module): |
|
|
def __init__(self, config: NaViLVisionConfig) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
dim = config.hidden_size |
|
|
num_heads = config.num_attention_heads |
|
|
|
|
|
self.num_heads = num_heads |
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
self.qk_normalization = config.qk_normalization |
|
|
|
|
|
if self.qk_normalization: |
|
|
self.q_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
|
|
self.k_norm = InternRMSNorm(dim, eps=config.layer_norm_eps) |
|
|
|
|
|
self.proj_drop = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rotary_pos_emb: torch.Tensor = None, |
|
|
) -> torch.Tensor: |
|
|
seq_length = hidden_states.shape[0] |
|
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) |
|
|
|
|
|
if self.qk_normalization: |
|
|
q = self.q_norm(q.flatten(1).view(q.shape)) |
|
|
k = self.k_norm(k.flatten(1).view(k.shape)) |
|
|
|
|
|
q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) |
|
|
k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) |
|
|
|
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
|
|
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( |
|
|
seq_length, -1 |
|
|
) |
|
|
attn_output = self.proj(attn_output) |
|
|
attn_output = self.proj_drop(attn_output) |
|
|
return attn_output |
|
|
|
|
|
|
|
|
class InternMLP(nn.Module): |
|
|
def __init__(self, config: NaViLVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.act = ACT2FN[config.hidden_act] |
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.fc1(hidden_states) |
|
|
hidden_states = self.act(hidden_states) |
|
|
hidden_states = self.fc2(hidden_states) |
|
|
return hidden_states |
|
|
|