PaddleOCR-VL / modeling_paddleocr_vl.py
Tingquan's picture
Move PaddleOCR-VL-0.9B files to root directory (#31)
4760b0e verified
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN, GELUActivation
from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
sdpa_attention_forward,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
can_return_tuple,
is_flash_attn_2_available,
torch_int,
)
from transformers.utils.generic import check_model_inputs
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
from .configuration_paddleocr_vl import PaddleOCRVisionConfig, PaddleOCRVLConfig
class RotaryEmbedding(nn.Module):
def __init__(self, config: PaddleOCRVLConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
):
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq_expanded = (
self.inv_freq[None, None, :, None]
.float()
.expand(3, position_ids.shape[1], -1, 1)
)
position_ids_expanded = position_ids[
:, :, None, :
].float() # shape (3, bs, 1, positions)
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rope_init(self):
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device=None, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
class Ernie4_5RotaryEmbedding(nn.Module):
def __init__(self, config: PaddleOCRVLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
# keeping it in full precision
return cos, sin
class Ernie4_5MLP(nn.Module):
def __init__(self, config: PaddleOCRVLConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.use_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.use_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.use_bias
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward_ernie(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
# glm rope style (with full dim) and full precision
original_dtype = q.dtype
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
q_embed = (q.float() * cos) + (rotate_half(q).float() * sin)
k_embed = (k.float() * cos) + (rotate_half(k).float() * sin)
return q_embed.to(original_dtype), k_embed.to(original_dtype)
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
Explanation:
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
difference with modern LLMs.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
mrope_section(`List(int)`):
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
mrope_section = mrope_section * 2
cos = torch.cat(
[m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1
).unsqueeze(unsqueeze_dim)
sin = torch.cat(
[m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1
).unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Ernie4_5Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PaddleOCRVLConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.rope_scaling = config.rope_scaling
self.attention_dropout = 0.0
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=config.use_bias,
)
self.k_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.use_bias,
)
self.v_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.use_bias,
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=config.use_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
position_ids = kwargs["position_ids"]
if position_ids.dim() == 3 and position_ids.shape[0] > 1:
kwargs["position_ids"] = position_ids[0:1]
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward_ernie
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class Ernie4_5RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Ernie4_5RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Ernie4_5DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: PaddleOCRVLConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx)
self.mlp = Ernie4_5MLP(config)
self.input_layernorm = Ernie4_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Ernie4_5RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Ernie4_5PreTrainedModel(PreTrainedModel):
config: PaddleOCRVLConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Ernie4_5DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Ernie4_5DecoderLayer,
"attentions": Ernie4_5Attention,
}
@auto_docstring
class Ernie4_5Model(Ernie4_5PreTrainedModel):
def __init__(self, config: PaddleOCRVLConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Ernie4_5DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = RotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position: torch.Tensor = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = (
attention_mask[:, -1].sum().item() != input_tensor.size()[0]
)
if is_padding_right:
raise ValueError
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: PaddleOCRVLConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`PaddleOCRVLConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
diagonal_attend_mask = torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
if config.sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if (
not isinstance(past_key_values, SlidingWindowCache)
or sequence_length > target_length
):
sliding_attend_mask = torch.arange(
target_length, device=device
) <= (cache_position.reshape(-1, 1) - config.sliding_window)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
:, None, None, :
].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Ernie4_5Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsequently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class Projector(nn.Module):
def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig):
super().__init__()
self.text_config = text_config
self.vision_config = vision_config
self.merge_kernel_size = (2, 2)
self.hidden_size = (
self.vision_config.hidden_size
* self.merge_kernel_size[0]
* self.merge_kernel_size[1]
)
self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = GELUActivation()
self.linear_2 = nn.Linear(
self.hidden_size, self.text_config.hidden_size, bias=True
)
def forward(
self, image_features: torch.Tensor, image_grid_thw: List[Tuple[int, int, int]]
) -> torch.Tensor:
m1, m2 = self.merge_kernel_size
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features, image_grid_thw):
image_feature = self.pre_norm(image_feature)
t, h, w = image_grid
from einops import rearrange
image_feature = rearrange(
image_feature,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
t=t,
h=h // m1,
p1=m1,
w=w // m2,
p2=m2,
)
hidden_states = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
processed_features.append(hidden_states)
return processed_features
dims = image_features.shape[:-1]
dim = image_features.shape[-1]
image_features = image_features.view(np.prod(dims), dim)
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states.view(*dims, -1)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.cache_position_embedding = dict()
self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def interpolate_pos_encoding(
self,
embeddings: torch.Tensor,
height: int,
width: int,
is_after_patchify: bool = False,
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and no class embeddings.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_positions = self.position_embedding.weight.shape[0]
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
if is_after_patchify:
new_height = height
new_width = width
else:
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bilinear",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache=20):
grid = (h, w)
if grid in self.cache_position_embedding:
self.cache_position_count[grid] += 1
return self.cache_position_embedding[grid]
if len(self.cache_position_embedding) >= max_cache:
min_hit_grid = min(
self.cache_position_count, key=self.cache_position_count.get
)
self.cache_position_count.pop(min_hit_grid)
self.cache_position_embedding.pop(min_hit_grid)
position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True)
self.cache_position_count[grid] = 1
self.cache_position_embedding[grid] = position_embedding
return position_embedding
def forward(
self,
pixel_values: torch.FloatTensor,
position_ids: Optional[torch.Tensor] = None,
image_grid_thw: Optional[
List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
] = None,
interpolate_pos_encoding=False,
) -> torch.Tensor:
if pixel_values.dim() == 5:
assert position_ids is not None
from einops import rearrange
batch_size, squence_len, channel, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w")
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(-2).squeeze(-1)
embeddings = rearrange(
embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len
)
# todo: not dubug
if interpolate_pos_encoding and image_grid_thw is not None:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
assert batch_size == 1
start = 0
image_embedding_list = list()
assert (
sum([np.prod(x) for x in flatten_image_grid_thw])
== embeddings.shape[1]
), (flatten_image_grid_thw, embeddings.shape)
embeddings = embeddings.squeeze(0)
tmp_embeddings = list()
for image_grid in image_grid_thw:
t, h, w = image_grid
end = start + t * h * w
image_embeddings = embeddings[start:end, :]
position_embedding = (
self.interpolate_pos_encoding(image_embeddings, h, w, True)
.squeeze(0)
.repeat(t, 1)
)
image_embeddings = image_embeddings + position_embedding
tmp_embeddings.append(image_embeddings)
start = end
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings
else:
raise NotImplementedError(str(pixel_values.shape))
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[torch.Tensor]] = None,
rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
use_flash_attn = (
cu_seqlens is not None
) and self.config._attn_implementation == "flash_attention_2"
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
if rope_emb is None:
queries = queries.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
keys = keys.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
values = values.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
else:
assert cu_seqlens is not None, "Rope support flash attn only."
cos, sin = rope_emb
queries = queries.view(
batch_size, seq_length, self.num_heads, self.head_dim
)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim)
if use_flash_attn:
queries, keys = apply_rotary_pos_emb_flashatt(queries, keys, cos, sin)
else:
queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin)
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
if not use_flash_attn:
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
warnings.warn(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
elif self.config._attn_implementation == "sdpa":
attention_interface = sdpa_attention_forward
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(
batch_size, seq_length, embed_dim
).contiguous()
else:
assert batch_size == 1, hidden_states.shape
queries = queries.transpose(1, 2).squeeze(0)
keys = keys.transpose(1, 2).squeeze(0)
values = values.transpose(1, 2).squeeze(0)
max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
max_seqlen_k = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
assert (
cu_seqlens[-1].item()
== queries.shape[0]
== keys.shape[0]
== values.shape[0]
), (cu_seqlens, queries.shape, keys.shape, values.shape)
attn_output = flash_attn_varlen_func(
queries,
keys,
values,
cu_seqlens,
cu_seqlens,
max_seqlen_q,
max_seqlen_k,
causal=False,
softmax_scale=self.scale,
)
attn_output = attn_output.flatten(-2).unsqueeze(0)
attn_weights = None
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = SiglipAttention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
cu_seqlens: Optional[List[torch.Tensor]] = None,
rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
rope_emb=rope_emb,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
config_class = PaddleOCRVLConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
_no_split_modules = [
"SiglipTextEmbeddings",
"SiglipEncoderLayer",
"SiglipVisionEmbeddings",
"SiglipMultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, PaddleOCRVLConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: PaddleOCRVLConfig
"""
def __init__(self, config: PaddleOCRVLConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
num_heads = config.num_attention_heads
head_dim = embed_dim // num_heads
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2)
self.gradient_checkpointing = False
@staticmethod
def flatten_list(image_grid_thw):
tmp_image_grid_thw = list()
for image_grid in image_grid_thw:
if isinstance(image_grid, list):
tmp_image_grid_thw.extend(image_grid)
else:
tmp_image_grid_thw.append(image_grid)
return tmp_image_grid_thw
def build_window_index(self, image_grid, window_size, device):
from einops import rearrange
window_indices = list()
pad_values = -100
start_window_index = 0
cu_seqlens_within_windows = list()
for t, h, w in image_grid:
window_index = torch.arange(t * h * w, device=device).reshape(t, h, w)
pad_h = (-h) % window_size
pad_w = (-w) % window_size
assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w)
window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values)
window_index = rearrange(
window_index,
"t (h p1) (w p2) -> t (h w) (p1 p2)",
p1=window_size,
p2=window_size,
)
window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1)
window_index = window_index.reshape(-1)
window_index = window_index[window_index != pad_values]
window_indices.append(window_index + start_window_index)
cu_seqlens_within_windows.append(
window_seqlens.cumsum(0) + start_window_index
)
start_window_index += t * h * w
window_indices = torch.concat(window_indices, dim=0)
cu_seqlens_within_windows = torch.concat(cu_seqlens_within_windows, dim=0)
cu_seqlens_within_windows = F.pad(
cu_seqlens_within_windows, (1, 0), value=0
).to(torch.int32)
return window_indices, cu_seqlens_within_windows
# Ignore copy
# @can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cu_seqlens: Optional[List[torch.Tensor]] = None,
image_grid_thw: Optional[
List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
] = None,
height_position_ids: Optional[torch.Tensor] = None,
width_position_ids: Optional[torch.Tensor] = None,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
vision_or_text: str = "vision",
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
vision_or_text = "vision"
assert vision_or_text in ["vision", "text"]
use_window_attn = window_size > 0 and vision_or_text == "vision"
use_rope = (use_rope is True) and (vision_or_text == "vision")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
device = inputs_embeds.device
hidden_states = inputs_embeds
attention_mask = (
attention_mask.to(inputs_embeds.dtype)
if attention_mask is not None
else None
)
if use_rope is True:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
assert (
sum([np.prod(x) for x in flatten_image_grid_thw])
== hidden_states.shape[1]
), (flatten_image_grid_thw, hidden_states.shape)
if width_position_ids is None or height_position_ids is None:
split_hids = list()
split_wids = list()
for t, h, w in flatten_image_grid_thw:
image_pids = torch.arange(t * h * w, device=device) % (h * w)
sample_hids = image_pids // w
sample_wids = image_pids % w
split_hids.append(sample_hids)
split_wids.append(sample_wids)
width_position_ids = torch.concat(split_wids, dim=0)
height_position_ids = torch.concat(split_hids, dim=0)
window_indices, cu_seqlens_within_windows = None, None
if use_window_attn:
window_indices, cu_seqlens_within_windows = self.build_window_index(
flatten_image_grid_thw, window_size, device
)
reversed_window_indices = window_indices.argsort()
height_position_ids = height_position_ids[window_indices]
width_position_ids = width_position_ids[window_indices]
pids = torch.stack([height_position_ids, width_position_ids], dim=-1)
max_grid_size = pids.max() + 1
rope_emb_max_grid = self.rotary_pos_emb(max_grid_size)
rope_emb = rope_emb_max_grid[pids].flatten(1)
rope_emb = rope_emb.repeat(1, 2)
rope_emb = (rope_emb.cos(), rope_emb.sin())
else:
rope_emb = None
window_indices, cu_seqlens_within_windows = None, None
if use_window_attn:
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
assert (
sum([np.prod(x) for x in flatten_image_grid_thw])
== hidden_states.shape[1]
), (flatten_image_grid_thw, hidden_states.shape)
window_indices, cu_seqlens_within_windows = self.build_window_index(
flatten_image_grid_thw, window_size, device
)
reversed_window_indices = window_indices.argsort()
if use_window_attn:
assert cu_seqlens_within_windows is not None
attn_cu_seqlens = cu_seqlens_within_windows
hidden_states = hidden_states[:, window_indices, :]
else:
attn_cu_seqlens = cu_seqlens
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (
(hidden_states[:, reversed_window_indices, :],)
if use_window_attn
else (hidden_states,)
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
attn_cu_seqlens,
rope_emb,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
cu_seqlens=attn_cu_seqlens,
rope_emb=rope_emb,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if use_window_attn:
hidden_states = hidden_states[:, reversed_window_indices, :]
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = (
True if not hasattr(config, "vision_use_head") else config.vision_use_head
)
if self.use_head:
self.head = SiglipMultiheadAttentionPoolingHead(config)
# @can_return_tuple
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
attention_mask: Optional[torch.Tensor] = None,
sample_indices: Optional[torch.Tensor] = None,
image_indices: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
height_position_ids: Optional[torch.Tensor] = None,
width_position_ids: Optional[torch.Tensor] = None,
cu_seqlens: Optional[List[torch.Tensor]] = None,
padding_mask: Optional[torch.Tensor] = None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[
List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
] = None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
image_grid_thw=image_grid_thw,
)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
image_grid_thw=image_grid_thw,
use_rope=use_rope,
height_position_ids=height_position_ids,
width_position_ids=width_position_ids,
window_size=window_size,
vision_or_text="vision",
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
if return_pooler_output is True:
if sample_indices is not None:
assert self.use_head is True
dim = last_hidden_state.shape[-1]
sample_hidden_state_list = list()
hidden_state = last_hidden_state.squeeze(0)
sample_index = sample_indices
unique_sample_index = torch.unique(sample_index).sort().values.unbind(0)
unique_sample_index = list(unique_sample_index)
if len(unique_sample_index) > 0 and unique_sample_index[0] == -1:
unique_sample_index = unique_sample_index[1:]
for sample_idx in unique_sample_index:
token_indices = (sample_index == sample_idx).nonzero().flatten()
sample_hidden_state = hidden_state[token_indices]
sample_hidden_state_list.append(sample_hidden_state)
if not vision_return_embed_list:
max_length = max(
[_state.shape[0] for _state in sample_hidden_state_list]
)
tmp_sample_hidden_state_list = list()
padding_mask = list()
for idx, _state in enumerate(sample_hidden_state_list):
padding_length = max_length - _state.shape[0]
mask = _state.new_zeros(size=(max_length,), dtype=torch.int64)
mask[-padding_length:] = 1
padding_mask.append(mask)
padding = _state.new_zeros(size=(padding_length, dim))
new_state = torch.concat([_state, padding], dim=0)
tmp_sample_hidden_state_list.append(new_state)
sample_hidden_state = torch.stack(
tmp_sample_hidden_state_list, dim=0
)
padding_mask = (
torch.stack(padding_mask, dim=0)
.float()
.to(last_hidden_state.dtype)
)
pooler_output = self.head(
sample_hidden_state, key_padding_mask=padding_mask
)
else:
pooler_output = list()
for state in sample_hidden_state_list:
sample_pooler_output = self.head(state.unsqueeze(0))
pooler_output.append(sample_pooler_output)
pooler_output = torch.concat(pooler_output, dim=0)
sample_hidden_state = sample_hidden_state_list
return BaseModelOutputWithPooling(
last_hidden_state=sample_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
else:
pooler_output = self.head(last_hidden_state) if self.use_head else None
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
sample_hidden_state = list()
assert cu_seqlens is not None
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
tensor = last_hidden_state[:, start:end, :].squeeze(0)
sample_hidden_state.append(tensor)
return BaseModelOutputWithPooling(
last_hidden_state=sample_hidden_state,
pooler_output=None,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state, key_padding_mask=None):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(
probe, hidden_state, hidden_state, key_padding_mask=key_padding_mask
)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = PaddleOCRVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: PaddleOCRVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
# @can_return_tuple
def forward(
self,
pixel_values,
sample_indices: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
position_ids: Optional[torch.Tensor] = None,
vision_return_embed_list: Optional[bool] = False,
image_grid_thw: Optional[
List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]]
] = None,
cu_seqlens: Optional[List[torch.Tensor]] = None,
return_pooler_output: Optional[bool] = True,
use_rope: Optional[bool] = False,
window_size: Optional[bool] = -1,
) -> BaseModelOutputWithPooling:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SiglipVisionModel
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
position_ids=position_ids,
vision_return_embed_list=vision_return_embed_list,
image_grid_thw=image_grid_thw,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
return_pooler_output=return_pooler_output,
use_rope=use_rope,
window_size=window_size,
)
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class SigLIPRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
self.rope_init()
def rope_init(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
@dataclass
class PaddleOCRVLCausalLMOutputWithPast(ModelOutput):
"""
Base class for PaddleOCRVL causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
config_class = PaddleOCRVLConfig
_no_split_modules = ["Ernie4_5_DecoderLayer", "SiglipEncoderLayer"]
def __init__(self, config):
super().__init__(config)
self.mlp_AR = Projector(config, config.vision_config)
self.visual = SiglipVisionModel(config.vision_config)
self.model = Ernie4_5Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
if torch.is_tensor(second_per_grid_t):
second_per_grid_t = second_per_grid_t.detach().item()
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = (
expanded_range
* second_per_grid_t
* self.config.vision_config.tokens_per_second
)
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0)
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, PaddleOCRVLCausalLMOutputWithPast]:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
pixel_values = pixel_values.unsqueeze(0)
siglip_position_ids = list()
image_grid_hws = list()
sample_indices = list()
cu_seqlens = [0]
pro = 0
for idx, thw in enumerate(image_grid_thw):
thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
numel = np.prod(thw_tuple)
image_grid_hws.append(thw_tuple)
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
siglip_position_ids.append(image_position_ids)
sample_indices.append(torch.full((numel,), idx, dtype=torch.int64))
cu_seqlens.append(cu_seqlens[-1] + numel)
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
pixel_values.device
)
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
pixel_values.device
)
sample_indices = torch.concat(sample_indices, dim=0).to(
pixel_values.device
)
vision_outputs = self.visual(
pixel_values=pixel_values,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=True,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
return_pooler_output=False,
use_rope=True,
window_size=-1,
)
image_embeds = vision_outputs.last_hidden_state
image_embeds = self.mlp_AR(image_embeds, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
# image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor
image_embeds = torch.cat(image_embeds, dim=0)
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# position_ids = None
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (
attention_mask is None or attention_mask.ndim == 2
):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return PaddleOCRVLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
use_cache=use_cache,
**kwargs,
)
model_inputs["position_ids"] = None
if cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
return model_inputs
def _get_image_nums_and_video_nums(
self,
input_ids: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
vision_start_mask = input_ids == vision_start_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
# Overwritten -- Support for expanding tensors without a batch size dimension
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
# pixel_values.shape[0] is sum(seqlen_images for samples)
# image_grid_thw.shape[0] is sum(num_images for samples)
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = [
"pixel_values",
"image_grid_thw",
"pixel_values_videos",
"video_grid_thw",
"second_per_grid_ts",
]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat(
[sample.repeat(*repeat_args) for sample in samples], dim=0
)
return result
for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw, list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
if not isinstance(dict_to_expand[key], list):
raise TypeError(
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
)
tensor = torch.tensor(dict_to_expand[key])
lengths = list(video_nums)
tensor = _repeat_interleave_samples(
tensor, lengths=lengths, repeat_times=expand_size
)
dict_to_expand[key] = tensor.tolist()
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(
expand_size, dim=0
)
return dict_to_expand
# input_ids is required for expanding visual inputs
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
if input_ids is not None and input_ids.numel() != 0:
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(
model_kwargs["encoder_outputs"]
)
return input_ids, model_kwargs