"""Inference-only Deepseek-OCR model compatible with HuggingFace weights.""" import math from collections.abc import Iterable, Mapping, Sequence from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from transformers import BatchFeature from vllm.config import VllmConfig from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, MlpProjectorConfig, VisionEncoderConfig) from process.image_process import ( DeepseekOCRProcessor, count_tiles) from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config # from vllm.utils import is_list_of from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) from deepencoder.sam_vary_sdpa import build_sam_vit_b from deepencoder.clip_sdpa import build_clip_l from deepencoder.build_linear import MlpProjector from addict import Dict # import time from config import IMAGE_SIZE, BASE_SIZE, CROP_MODE, PRINT_NUM_VIS_TOKENS, PROMPT # The image token id may be various _IMAGE_TOKEN = "" class DeepseekOCRProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(DeepseekVLV2Config) def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} def get_num_image_tokens(self, *, image_width: int, image_height: int, cropping: bool = True) -> int: hf_processor = self.get_hf_processor() # image_size = hf_processor.image_size # patch_size = hf_processor.patch_size # downsample_ratio = hf_processor.downsample_ratio image_size = IMAGE_SIZE base_size = BASE_SIZE patch_size = 16 downsample_ratio = 4 if CROP_MODE: if image_width <= 640 and image_height <= 640: crop_ratio = [1, 1] else: # images_crop_raw, crop_ratio = hf_processor.dynamic_preprocess(image) # find the closest aspect ratio to the target crop_ratio = count_tiles(image_width, image_height, image_size=IMAGE_SIZE) # print('===========') # print('crop_ratio ', crop_ratio) # print('============') num_width_tiles, num_height_tiles = crop_ratio else: num_width_tiles = num_height_tiles = 1 h = w = math.ceil((base_size // patch_size) / downsample_ratio) h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio) global_views_tokens = h * (w + 1) if num_width_tiles >1 or num_height_tiles>1: local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1) else: local_views_tokens = 0 return global_views_tokens + local_views_tokens + 1 def get_image_size_with_most_features(self) -> ImageSize: if IMAGE_SIZE == 1024 and BASE_SIZE == 1280: return ImageSize(width=1024*2, height=1024*2) return ImageSize(width=640*2, height=640*2) class DeepseekOCRDummyInputsBuilder( BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) processor = self.info.get_hf_processor() image_token = processor.image_token return image_token * num_images def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) max_image_size = self.info.get_image_size_with_most_features() if '' in PROMPT: return { "image": DeepseekOCRProcessor().tokenize_with_images(images = self._get_dummy_images(width=max_image_size.width, height=max_image_size.height, num_images=num_images), bos=True, eos=True, cropping=CROP_MODE) } else: return { "image": [] } class DeepseekOCRMultiModalProcessor( BaseMultiModalProcessor[DeepseekOCRProcessingInfo]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: # print(mm_data) if mm_data: processed_outputs = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(prompt=prompt, **mm_data), mm_kwargs, ) else: tokenizer = self.info.get_tokenizer() processed_outputs = tokenizer(prompt, add_special_tokens=True, return_tensors="pt") return processed_outputs def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), images_spatial_crop=MultiModalFieldConfig.batched("image"), # image_embeds=MultiModalFieldConfig.batched("image2"), images_crop=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = hf_processor.image_token_id assert isinstance(image_token_id, int) def get_replacement_deepseek_vl2(item_idx: int): images = mm_items.get_items( "image", (ImageEmbeddingItems, ImageProcessorItems)) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) else: width = images[0][-1][0][0] height = images[0][-1][0][1] num_image_tokens = self.info.get_num_image_tokens( image_width=width, image_height=height, # flag = True, cropping=CROP_MODE, ) return [image_token_id] * num_image_tokens return [ PromptReplacement( modality="image", target=[image_token_id], replacement=get_replacement_deepseek_vl2, ) ] def _cached_apply_hf_processor( self, prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], ) -> tuple[list[int], MultiModalKwargs, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only # perform caching for the most common case if mm_data_items.get_count("image", strict=False) > 2: # This code path corresponds to the cache being disabled return self._apply_hf_processor_main( prompt=prompt, mm_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, enable_hf_prompt_update=True, ) return super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, ) @MULTIMODAL_REGISTRY.register_processor( DeepseekOCRMultiModalProcessor, info=DeepseekOCRProcessingInfo, dummy_inputs=DeepseekOCRDummyInputsBuilder) class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "language.": "language_model.", }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: DeepseekVLV2Config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config # config.model_type ='deepseek_vl_v2' self.config = config self.multimodal_config = multimodal_config self.vision_config = config.vision_config self.projector_config = config.projector_config self.text_config = config.text_config model_config = vllm_config.model_config tokenizer = cached_tokenizer_from_config(model_config) self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] self.sam_model = build_sam_vit_b() self.vision_model = build_clip_l() n_embed = 1280 self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) self.tile_tag = config.tile_tag self.global_view_pos = config.global_view_pos # self.sam_model = torch.compile(self.sam_model, mode="reduce-overhead") # self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead") # self.projector = torch.compile(self.projector, mode="max-autotune") # special token for image token sequence format embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) if self.tile_tag == "2D": # <|view_separator|>, <|\n|> self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) else: raise ValueError( f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) if self.text_config.topk_method == "noaux_tc": architectures = ["DeepseekV3ForCausalLM"] elif not self.text_config.use_mla: architectures = ["DeepseekForCausalLM"] else: architectures = ["DeepseekV2ForCausalLM"] self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.text_config, prefix=maybe_prefix(prefix, "language"), architectures=architectures, ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_image_input( self, **kwargs: object): pixel_values = kwargs.pop("pixel_values", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None) images_crop = kwargs.pop("images_crop", None) if pixel_values is None or torch.sum(pixel_values).item() == 0: return None if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") if not isinstance(images_spatial_crop, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(images_spatial_crop)}") if not isinstance(images_crop, (torch.Tensor, list)): raise ValueError("Incorrect type of image crop. " f"Got type: {type(images_crop)}") return [pixel_values, images_crop, images_spatial_crop] raise AssertionError("This line should be unreachable.") def _pixel_values_to_embedding( self, pixel_values: torch.Tensor, images_crop: torch.Tensor, images_spatial_crop: torch.Tensor, ) -> NestedTensors: # Pixel_values (global view): [n_image, batch_size, 3, height, width] # images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]] # images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w] # split the pixel and image_crop, all batch_size = 1 images_in_this_batch = [] # print(type(images_crop)) # print(pixel_values.shape) with torch.no_grad(): for jdx in range(images_spatial_crop.size(0)): # with torch.set_grad_enabled(False): patches = images_crop[jdx][0].to(torch.bfloat16) # batch_size = 1 image_ori = pixel_values[jdx] crop_shape = images_spatial_crop[jdx][0] if torch.sum(patches).item() != 0: # if all values = 0, no crop # P, C, H, W = patches.shape # crop_flag = 1 local_features_1 = self.sam_model(patches) #TODO del patches # torch.compiler.cudagraph_mark_step_begin() local_features_2 = self.vision_model(patches, local_features_1) local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) local_features = self.projector(local_features) global_features_1 = self.sam_model(image_ori) global_features_2 = self.vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = self.projector(global_features) if PRINT_NUM_VIS_TOKENS: print('=====================') print('BASE: ', global_features.shape) print('PATCHES: ', local_features.shape) print('=====================') _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) _2, hw2, n_dim2 = local_features.shape h2 = w2 = int(hw2 ** 0.5) width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) local_features = torch.cat( [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 ) local_features = local_features.view(-1, n_dim2) global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) else: global_features_1 = self.sam_model(image_ori) global_features_2 = self.vision_model(image_ori, global_features_1) global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features = self.projector(global_features) if PRINT_NUM_VIS_TOKENS: print('=====================') print('BASE: ', global_features.shape) print('NO PATCHES') print('=====================') _, hw, n_dim = global_features.shape h = w = int(hw ** 0.5) global_features = global_features.view(h, w, n_dim) global_features = torch.cat( [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 ) global_features = global_features.view(-1, n_dim) global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) images_in_this_batch.append(global_local_features) return images_in_this_batch def _process_image_input( self, image_input) -> torch.Tensor: # image_input: [pixel_values, images_crop, images_spatial_crop] pixel_values = image_input[0].to(torch.bfloat16) # print(image_input[1][0].shape) # print(type(image_input[1])) # exit() # images_crop = image_input[1].to(torch.bfloat16) images_crop = image_input[1] # images_crop = image_input[1] images_spatial_crop = image_input[2].to(dtype=torch.long) # local_start = time.time() vision_features = self._pixel_values_to_embedding( pixel_values=pixel_values, images_crop = images_crop, images_spatial_crop=images_spatial_crop) # local_total_time = time.time() - local_start # print('encoder_time: ', local_total_time) # exit() return vision_features def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None vision_embeddings = self._process_image_input(image_input) return vision_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id) # print(len(multimodal_embeddings)) # print(input_ids.shape) # print(type(inputs_embeds)) # print(inputs_embeds.shape) return inputs_embeds def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None hidden_states = self.language_model(input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: processed_weights = [] for name, tensor in weights: if 'sam_model' in name or 'vision_model' in name or 'projector' in name or 'image_newline' in name or 'view_seperator' in name: new_name = name.replace('model.', '', 1) else: new_name = 'language.' + name processed_weights.append((new_name, tensor)) loader = AutoWeightsLoader(self) autoloaded_weights = loader.load_weights(processed_weights, mapper=self.hf_to_vllm_mapper) return autoloaded_weights