# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.nn.init import _calculate_fan_in_and_fan_out from transformers.activations import ACT2FN, GELUActivation from transformers.cache_utils import ( Cache, DynamicCache, SlidingWindowCache, StaticCache, ) from transformers.generation import GenerationMixin from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ( ALL_ATTENTION_FUNCTIONS, PreTrainedModel, sdpa_attention_forward, ) from transformers.processing_utils import Unpack from transformers.utils import ( ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, is_flash_attn_2_available, torch_int, ) from transformers.utils.generic import check_model_inputs if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func from flash_attn.layers.rotary import apply_rotary_emb else: flash_attn_varlen_func = None apply_rotary_emb = None from .configuration_paddleocr_vl import PaddleOCRVisionConfig, PaddleOCRVLConfig class RotaryEmbedding(nn.Module): def __init__(self, config: PaddleOCRVLConfig, device=None): super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) inv_freq_expanded = ( self.inv_freq[None, None, :, None] .float() .expand(3, position_ids.shape[1], -1, 1) ) position_ids_expanded = position_ids[ :, :, None, : ].float() # shape (3, bs, 1, positions) # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) def rope_init(self): inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device=None, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq class Ernie4_5RotaryEmbedding(nn.Module): def __init__(self, config: PaddleOCRVLConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = ( self.inv_freq[None, :, None] .float() .expand(position_ids.shape[0], -1, 1) .to(x.device) ) position_ids_expanded = position_ids[:, None, :].float() device_type = ( x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling # keeping it in full precision return cos, sin class Ernie4_5MLP(nn.Module): def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.use_bias ) self.up_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.use_bias ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=config.use_bias ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward_ernie( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query.dtype ) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ # glm rope style (with full dim) and full precision original_dtype = q.dtype cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) # Interleave them instead of usual shape cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) q_embed = (q.float() * cos) + (rotate_half(q).float() * sin) k_embed = (k.float() * cos) + (rotate_half(k).float() * sin) return q_embed.to(original_dtype), k_embed.to(original_dtype) def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, height and width) of text embedding is always the same, so the text embedding rotary position embedding has no difference with modern LLMs. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. mrope_section(`List(int)`): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ mrope_section = mrope_section * 2 cos = torch.cat( [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 ).unsqueeze(unsqueeze_dim) sin = torch.cat( [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 ).unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Ernie4_5Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.rope_scaling = config.rope_scaling self.attention_dropout = 0.0 self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias, ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias, ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias, ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias, ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if "position_ids" in kwargs and kwargs["position_ids"] is not None: position_ids = kwargs["position_ids"] if position_ids.dim() == 3 and position_ids.shape[0] > 1: kwargs["position_ids"] = position_ids[0:1] cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward_ernie if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[ self.config._attn_implementation ] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights @use_kernel_forward_from_hub("RMSNorm") class Ernie4_5RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Ernie4_5RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Ernie4_5DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PaddleOCRVLConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Ernie4_5Attention(config=config, layer_idx=layer_idx) self.mlp = Ernie4_5MLP(config) self.input_layernorm = Ernie4_5RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = Ernie4_5RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring class Ernie4_5PreTrainedModel(PreTrainedModel): config: PaddleOCRVLConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Ernie4_5DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Ernie4_5DecoderLayer, "attentions": Ernie4_5Attention, } @auto_docstring class Ernie4_5Model(Ernie4_5PreTrainedModel): def __init__(self, config: PaddleOCRVLConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ Ernie4_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You must specify exactly one of input_ids or inputs_embeds" ) if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) cache_position: torch.Tensor = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand( 3, inputs_embeds.shape[0], -1 ) elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = ( attention_mask[:, -1].sum().item() != input_tensor.size()[0] ) if is_padding_right: raise ValueError if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = ( past_key_values.get_seq_length() if past_key_values is not None else 0 ) using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended( causal_mask, min_dtype ) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: PaddleOCRVLConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`PaddleOCRVLConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device, ) diagonal_attend_mask = torch.arange( target_length, device=device ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if ( not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length ): sliding_attend_mask = torch.arange( target_length, device=device ) <= (cache_position.reshape(-1, 1) - config.sliding_window) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = ( causal_mask.clone() ) # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ :, None, None, : ].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[ :, :, :, :mask_length ].masked_fill(padding_mask, min_dtype) return causal_mask class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Ernie4_5Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") class Projector(nn.Module): def __init__(self, text_config: PaddleOCRVLConfig, vision_config: PaddleOCRVisionConfig): super().__init__() self.text_config = text_config self.vision_config = vision_config self.merge_kernel_size = (2, 2) self.hidden_size = ( self.vision_config.hidden_size * self.merge_kernel_size[0] * self.merge_kernel_size[1] ) self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.act = GELUActivation() self.linear_2 = nn.Linear( self.hidden_size, self.text_config.hidden_size, bias=True ) def forward( self, image_features: torch.Tensor, image_grid_thw: List[Tuple[int, int, int]] ) -> torch.Tensor: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid from einops import rearrange image_feature = rearrange( image_feature, "(t h p1 w p2) d -> (t h w) (p1 p2 d)", t=t, h=h // m1, p1=m1, w=w // m2, p2=m2, ) hidden_states = self.linear_1(image_feature) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) processed_features.append(hidden_states) return processed_features dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states.view(*dims, -1) class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int, is_after_patchify: bool = False, ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] if is_after_patchify: new_height = height new_width = width else: new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape( 1, sqrt_num_positions, sqrt_num_positions, dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bilinear", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed @staticmethod def flatten_list(image_grid_thw): tmp_image_grid_thw = list() for image_grid in image_grid_thw: if isinstance(image_grid, list): tmp_image_grid_thw.extend(image_grid) else: tmp_image_grid_thw.append(image_grid) return tmp_image_grid_thw def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache=20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 return self.cache_position_embedding[grid] if len(self.cache_position_embedding) >= max_cache: min_hit_grid = min( self.cache_position_count, key=self.cache_position_count.get ) self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding def forward( self, pixel_values: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None, image_grid_thw: Optional[ List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]] ] = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 5: assert position_ids is not None from einops import rearrange batch_size, squence_len, channel, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") patch_embeds = self.patch_embedding( pixel_values.to(dtype=target_dtype) ) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(-2).squeeze(-1) embeddings = rearrange( embeddings, "(b l) d -> b l d", b=batch_size, l=squence_len ) # todo: not dubug if interpolate_pos_encoding and image_grid_thw is not None: flatten_image_grid_thw = self.flatten_list(image_grid_thw) assert batch_size == 1 start = 0 image_embedding_list = list() assert ( sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1] ), (flatten_image_grid_thw, embeddings.shape) embeddings = embeddings.squeeze(0) tmp_embeddings = list() for image_grid in image_grid_thw: t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = ( self.interpolate_pos_encoding(image_embeddings, h, w, True) .squeeze(0) .repeat(t, 1) ) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: embeddings = embeddings + self.packing_position_embedding(position_ids) return embeddings else: raise NotImplementedError(str(pixel_values.shape)) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query.dtype ) attn_weights = nn.functional.dropout( attn_weights, p=dropout, training=module.training ) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, cu_seqlens: Optional[List[torch.Tensor]] = None, rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" use_flash_attn = ( cu_seqlens is not None ) and self.config._attn_implementation == "flash_attention_2" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) if rope_emb is None: queries = queries.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) keys = keys.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) values = values.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) else: assert cu_seqlens is not None, "Rope support flash attn only." cos, sin = rope_emb queries = queries.view( batch_size, seq_length, self.num_heads, self.head_dim ) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim) if use_flash_attn: queries, keys = apply_rotary_pos_emb_flashatt(queries, keys, cos, sin) else: queries, keys = apply_rotary_pos_emb_vision(queries, keys, cos, sin) queries = queries.transpose(1, 2) keys = keys.transpose(1, 2) values = values.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) if not use_flash_attn: attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: warnings.warn( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) elif self.config._attn_implementation == "sdpa": attention_interface = sdpa_attention_forward attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape( batch_size, seq_length, embed_dim ).contiguous() else: assert batch_size == 1, hidden_states.shape queries = queries.transpose(1, 2).squeeze(0) keys = keys.transpose(1, 2).squeeze(0) values = values.transpose(1, 2).squeeze(0) max_seqlen_q = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() max_seqlen_k = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() assert ( cu_seqlens[-1].item() == queries.shape[0] == keys.shape[0] == values.shape[0] ), (cu_seqlens, queries.shape, keys.shape, values.shape) attn_output = flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, causal=False, softmax_scale=self.scale, ) attn_output = attn_output.flatten(-2).unsqueeze(0) attn_weights = None attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class SiglipEncoderLayer(nn.Module): def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = SiglipAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, cu_seqlens: Optional[List[torch.Tensor]] = None, rope_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, cu_seqlens=cu_seqlens, rope_emb=rope_emb, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class SiglipPreTrainedModel(PreTrainedModel): config_class = PaddleOCRVLConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True _no_split_modules = [ "SiglipTextEmbeddings", "SiglipEncoderLayer", "SiglipVisionEmbeddings", "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): width = ( self.config.vision_config.hidden_size if isinstance(self.config, PaddleOCRVLConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, SiglipMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): nn.init.xavier_uniform_(module.probe.data) nn.init.xavier_uniform_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`SiglipEncoderLayer`]. Args: config: PaddleOCRVLConfig """ def __init__(self, config: PaddleOCRVLConfig): super().__init__() self.config = config embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads self.layers = nn.ModuleList( [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) self.gradient_checkpointing = False @staticmethod def flatten_list(image_grid_thw): tmp_image_grid_thw = list() for image_grid in image_grid_thw: if isinstance(image_grid, list): tmp_image_grid_thw.extend(image_grid) else: tmp_image_grid_thw.append(image_grid) return tmp_image_grid_thw def build_window_index(self, image_grid, window_size, device): from einops import rearrange window_indices = list() pad_values = -100 start_window_index = 0 cu_seqlens_within_windows = list() for t, h, w in image_grid: window_index = torch.arange(t * h * w, device=device).reshape(t, h, w) pad_h = (-h) % window_size pad_w = (-w) % window_size assert pad_h >= 0 and pad_w >= 0, (pad_h, pad_w) window_index = F.pad(window_index, (0, pad_w, 0, pad_h), value=pad_values) window_index = rearrange( window_index, "t (h p1) (w p2) -> t (h w) (p1 p2)", p1=window_size, p2=window_size, ) window_seqlens = (window_index != pad_values).long().sum(-1).reshape(-1) window_index = window_index.reshape(-1) window_index = window_index[window_index != pad_values] window_indices.append(window_index + start_window_index) cu_seqlens_within_windows.append( window_seqlens.cumsum(0) + start_window_index ) start_window_index += t * h * w window_indices = torch.concat(window_indices, dim=0) cu_seqlens_within_windows = torch.concat(cu_seqlens_within_windows, dim=0) cu_seqlens_within_windows = F.pad( cu_seqlens_within_windows, (1, 0), value=0 ).to(torch.int32) return window_indices, cu_seqlens_within_windows # Ignore copy # @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cu_seqlens: Optional[List[torch.Tensor]] = None, image_grid_thw: Optional[ List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]] ] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, vision_or_text: str = "vision", ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ vision_or_text = "vision" assert vision_or_text in ["vision", "text"] use_window_attn = window_size > 0 and vision_or_text == "vision" use_rope = (use_rope is True) and (vision_or_text == "vision") output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = ( attention_mask.to(inputs_embeds.dtype) if attention_mask is not None else None ) if use_rope is True: flatten_image_grid_thw = self.flatten_list(image_grid_thw) assert ( sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1] ), (flatten_image_grid_thw, hidden_states.shape) if width_position_ids is None or height_position_ids is None: split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) split_wids.append(sample_wids) width_position_ids = torch.concat(split_wids, dim=0) height_position_ids = torch.concat(split_hids, dim=0) window_indices, cu_seqlens_within_windows = None, None if use_window_attn: window_indices, cu_seqlens_within_windows = self.build_window_index( flatten_image_grid_thw, window_size, device ) reversed_window_indices = window_indices.argsort() height_position_ids = height_position_ids[window_indices] width_position_ids = width_position_ids[window_indices] pids = torch.stack([height_position_ids, width_position_ids], dim=-1) max_grid_size = pids.max() + 1 rope_emb_max_grid = self.rotary_pos_emb(max_grid_size) rope_emb = rope_emb_max_grid[pids].flatten(1) rope_emb = rope_emb.repeat(1, 2) rope_emb = (rope_emb.cos(), rope_emb.sin()) else: rope_emb = None window_indices, cu_seqlens_within_windows = None, None if use_window_attn: flatten_image_grid_thw = self.flatten_list(image_grid_thw) assert ( sum([np.prod(x) for x in flatten_image_grid_thw]) == hidden_states.shape[1] ), (flatten_image_grid_thw, hidden_states.shape) window_indices, cu_seqlens_within_windows = self.build_window_index( flatten_image_grid_thw, window_size, device ) reversed_window_indices = window_indices.argsort() if use_window_attn: assert cu_seqlens_within_windows is not None attn_cu_seqlens = cu_seqlens_within_windows hidden_states = hidden_states[:, window_indices, :] else: attn_cu_seqlens = cu_seqlens for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + ( (hidden_states[:, reversed_window_indices, :],) if use_window_attn else (hidden_states,) ) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, attn_cu_seqlens, rope_emb, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, cu_seqlens=attn_cu_seqlens, rope_emb=rope_emb, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if use_window_attn: hidden_states = hidden_states[:, reversed_window_indices, :] if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class SiglipVisionTransformer(nn.Module): def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.use_head = ( True if not hasattr(config, "vision_use_head") else config.vision_use_head ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead(config) # @can_return_tuple def forward( self, pixel_values, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, attention_mask: Optional[torch.Tensor] = None, sample_indices: Optional[torch.Tensor] = None, image_indices: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, cu_seqlens: Optional[List[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[ List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]] ] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: r""" Returns: """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, image_grid_thw=image_grid_thw, ) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, attention_mask=attention_mask, cu_seqlens=cu_seqlens, image_grid_thw=image_grid_thw, use_rope=use_rope, height_position_ids=height_position_ids, width_position_ids=width_position_ids, window_size=window_size, vision_or_text="vision", ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) if return_pooler_output is True: if sample_indices is not None: assert self.use_head is True dim = last_hidden_state.shape[-1] sample_hidden_state_list = list() hidden_state = last_hidden_state.squeeze(0) sample_index = sample_indices unique_sample_index = torch.unique(sample_index).sort().values.unbind(0) unique_sample_index = list(unique_sample_index) if len(unique_sample_index) > 0 and unique_sample_index[0] == -1: unique_sample_index = unique_sample_index[1:] for sample_idx in unique_sample_index: token_indices = (sample_index == sample_idx).nonzero().flatten() sample_hidden_state = hidden_state[token_indices] sample_hidden_state_list.append(sample_hidden_state) if not vision_return_embed_list: max_length = max( [_state.shape[0] for _state in sample_hidden_state_list] ) tmp_sample_hidden_state_list = list() padding_mask = list() for idx, _state in enumerate(sample_hidden_state_list): padding_length = max_length - _state.shape[0] mask = _state.new_zeros(size=(max_length,), dtype=torch.int64) mask[-padding_length:] = 1 padding_mask.append(mask) padding = _state.new_zeros(size=(padding_length, dim)) new_state = torch.concat([_state, padding], dim=0) tmp_sample_hidden_state_list.append(new_state) sample_hidden_state = torch.stack( tmp_sample_hidden_state_list, dim=0 ) padding_mask = ( torch.stack(padding_mask, dim=0) .float() .to(last_hidden_state.dtype) ) pooler_output = self.head( sample_hidden_state, key_padding_mask=padding_mask ) else: pooler_output = list() for state in sample_hidden_state_list: sample_pooler_output = self.head(state.unsqueeze(0)) pooler_output.append(sample_pooler_output) pooler_output = torch.concat(pooler_output, dim=0) sample_hidden_state = sample_hidden_state_list return BaseModelOutputWithPooling( last_hidden_state=sample_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) else: pooler_output = self.head(last_hidden_state) if self.use_head else None return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) sample_hidden_state = list() assert cu_seqlens is not None for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] tensor = last_hidden_state[:, start:end, :].squeeze(0) sample_hidden_state.append(tensor) return BaseModelOutputWithPooling( last_hidden_state=sample_hidden_state, pooler_output=None, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: PaddleOCRVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, batch_first=True ) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_state, key_padding_mask=None): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention( probe, hidden_state, hidden_state, key_padding_mask=key_padding_mask )[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] class SiglipVisionModel(SiglipPreTrainedModel): config_class = PaddleOCRVisionConfig main_input_name = "pixel_values" def __init__(self, config: PaddleOCRVisionConfig): super().__init__(config) self.vision_model = SiglipVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding # @can_return_tuple def forward( self, pixel_values, sample_indices: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, position_ids: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, image_grid_thw: Optional[ List[Union[Tuple[int, int, int], List[Tuple[int, int, int]]]] ] = None, cu_seqlens: Optional[List[torch.Tensor]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, SiglipVisionModel >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, position_ids=position_ids, vision_return_embed_list=vision_return_embed_list, image_grid_thw=image_grid_thw, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=return_pooler_output, use_rope=use_rope, window_size=window_size, ) def apply_rotary_pos_emb_flashatt( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: cos = cos.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous() q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed class SigLIPRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta self.rope_init() def rope_init(self): inv_freq = 1.0 / ( self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype ) freqs = torch.outer(seq, self.inv_freq) return freqs @dataclass class PaddleOCRVLCausalLMOutputWithPast(ModelOutput): """ Base class for PaddleOCRVL causal language model (or autoregressive) outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None class PaddleOCRVLForConditionalGeneration(Ernie4_5PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] config_class = PaddleOCRVLConfig _no_split_modules = ["Ernie4_5_DecoderLayer", "SiglipEncoderLayer"] def __init__(self, config): super().__init__(config) self.mlp_AR = Projector(config, config.vision_config) self.visual = SiglipVisionModel(config.vision_config) self.model = Ernie4_5Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.rope_deltas = None self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate the 3D rope index based on image and video's temporal, height and width in LLM. Explanation: Each embedding sequence contains vision embedding and text embedding or just contains text embedding. For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. Examples: input_ids: [T T T T T], here T is for text. temporal position_ids: [0, 1, 2, 3, 4] height position_ids: [0, 1, 2, 3, 4] width position_ids: [0, 1, 2, 3, 4] For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part and 1D rotary position embedding for text part. Examples: Temporal (Time): 3 patches, representing different segments of the video in time. Height: 2 patches, dividing each frame vertically. Width: 2 patches, dividing each frame horizontally. We also have some important parameters: fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] text temporal position_ids: [101, 102, 103, 104, 105] text height position_ids: [101, 102, 103, 104, 105] text width position_ids: [101, 102, 103, 104, 105] Here we calculate the text start position_ids as the max vision position_ids plus 1. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. Returns: position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) """ spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id mrope_position_deltas = [] if input_ids is not None and ( image_grid_thw is not None or video_grid_thw is not None ): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere( input_ids == vision_start_token_id ).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) if second_per_grid_ts is not None: second_per_grid_t = second_per_grid_ts[video_index] else: second_per_grid_t = 1.0 video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = ( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) if torch.is_tensor(second_per_grid_t): second_per_grid_t = second_per_grid_t.detach().item() range_tensor = torch.arange(llm_grid_t).view(-1, 1) expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) time_tensor = ( expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second ) time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() h_index = ( torch.arange(llm_grid_h) .view(1, -1, 1) .expand(llm_grid_t, -1, llm_grid_w) .flatten() ) w_index = ( torch.arange(llm_grid_w) .view(1, 1, -1) .expand(llm_grid_t, llm_grid_h, -1) .flatten() ) llm_pos_ids_list.append( torch.stack([t_index, h_index, w_index]) + text_len + st_idx ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = ( llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 ) text_len = len(input_tokens) - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( position_ids.device ) mrope_position_deltas.append( llm_positions.max() + 1 - len(total_input_ids[i]) ) mrope_position_deltas = torch.tensor( mrope_position_deltas, device=input_ids.device ).unsqueeze(1) return position_ids, mrope_position_deltas else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = ( position_ids.unsqueeze(0) .expand(3, -1, -1) .to(attention_mask.device) ) max_position_ids = position_ids.max(0, keepdim=False)[0].max( -1, keepdim=True )[0] mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] else: position_ids = ( torch.arange(input_ids.shape[1], device=input_ids.device) .view(1, 1, -1) .expand(3, input_ids.shape[0], -1) ) mrope_position_deltas = torch.zeros( [input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype, ) return position_ids, mrope_position_deltas def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, PaddleOCRVLCausalLMOutputWithPast]: r""" Returns: """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) pixel_values = pixel_values.unsqueeze(0) siglip_position_ids = list() image_grid_hws = list() sample_indices = list() cu_seqlens = [0] pro = 0 for idx, thw in enumerate(image_grid_thw): thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) numel = np.prod(thw_tuple) image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( pixel_values.device ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( pixel_values.device ) sample_indices = torch.concat(sample_indices, dim=0).to( pixel_values.device ) vision_outputs = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_hws, position_ids=siglip_position_ids, vision_return_embed_list=True, interpolate_pos_encoding=True, sample_indices=sample_indices, cu_seqlens=cu_seqlens, return_pooler_output=False, use_rope=True, window_size=-1, ) image_embeds = vision_outputs.last_hidden_state image_embeds = self.mlp_AR(image_embeds, image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() # image_embeds is a list of tensor, each tensor is a image feature,I want to concat them all into a tensor image_embeds = torch.cat(image_embeds, dim=0) n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to( inputs_embeds.device, inputs_embeds.dtype ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # position_ids = None # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and ( attention_mask is None or attention_mask.ndim == 2 ): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return PaddleOCRVLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, second_per_grid_ts=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, use_cache=use_cache, **kwargs, ) model_inputs["position_ids"] = None if cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None return model_inputs def _get_image_nums_and_video_nums( self, input_ids: Optional[torch.LongTensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the number of images and videos for each sample to calculate the separation length of the sample tensor. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Returns: image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) """ image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id vision_start_mask = input_ids == vision_start_token_id vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) image_mask = input_ids == image_token_id video_mask = input_ids == video_token_id image_nums = torch.sum(vision_first_mask & image_mask, dim=1) video_nums = torch.sum(vision_first_mask & video_mask, dim=1) return image_nums, video_nums def _expand_inputs_for_generation( self, expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: # Overwritten -- Support for expanding tensors without a batch size dimension # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t # pixel_values.shape[0] is sum(seqlen_images for samples) # image_grid_thw.shape[0] is sum(num_images for samples) if expand_size == 1: return input_ids, model_kwargs visual_keys = [ "pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts", ] def _expand_dict_for_generation_visual(dict_to_expand): image_grid_thw = model_kwargs.get("image_grid_thw", None) video_grid_thw = model_kwargs.get("video_grid_thw", None) image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) def _repeat_interleave_samples(x, lengths, repeat_times): samples = torch.split(x, lengths) repeat_args = [repeat_times] + [1] * (x.dim() - 1) result = torch.cat( [sample.repeat(*repeat_args) for sample in samples], dim=0 ) return result for key in dict_to_expand: if key == "pixel_values": # split images into samples samples = torch.split(image_grid_thw, list(image_nums)) # compute the sequence length of images for each sample lengths = [torch.prod(sample, dim=1).sum() for sample in samples] dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "image_grid_thw": # get the num of images for each sample lengths = list(image_nums) dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "pixel_values_videos": samples = torch.split(video_grid_thw, list(video_nums)) lengths = [torch.prod(sample, dim=1).sum() for sample in samples] dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "video_grid_thw": lengths = list(video_nums) dict_to_expand[key] = _repeat_interleave_samples( dict_to_expand[key], lengths=lengths, repeat_times=expand_size ) elif key == "second_per_grid_ts": if not isinstance(dict_to_expand[key], list): raise TypeError( f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." ) tensor = torch.tensor(dict_to_expand[key]) lengths = list(video_nums) tensor = _repeat_interleave_samples( tensor, lengths=lengths, repeat_times=expand_size ) dict_to_expand[key] = tensor.tolist() return dict_to_expand def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: if ( key != "cache_position" and dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) and key not in visual_keys ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave( expand_size, dim=0 ) return dict_to_expand # input_ids is required for expanding visual inputs # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. if input_ids is not None and input_ids.numel() != 0: model_kwargs = _expand_dict_for_generation_visual(model_kwargs) if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) model_kwargs = _expand_dict_for_generation(model_kwargs) if is_encoder_decoder: if model_kwargs.get("encoder_outputs") is None: raise ValueError( "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." ) model_kwargs["encoder_outputs"] = _expand_dict_for_generation( model_kwargs["encoder_outputs"] ) return input_ids, model_kwargs