Update modeling_dots_ocr_vllm.py
Browse files- modeling_dots_ocr_vllm.py +11 -0
 
    	
        modeling_dots_ocr_vllm.py
    CHANGED
    
    | 
         @@ -99,6 +99,7 @@ class DotsOCRProcessingInfo(Qwen2_5_VLProcessingInfo): 
     | 
|
| 99 | 
         
             
                    size: Optional[dict[str, int]] = None,
         
     | 
| 100 | 
         
             
                    **kwargs: object,
         
     | 
| 101 | 
         
             
                ) -> Qwen2VLProcessor:
         
     | 
| 
         | 
|
| 102 | 
         
             
                    processor = self.ctx.get_hf_processor(
         
     | 
| 103 | 
         
             
                        Qwen2VLProcessor,
         
     | 
| 104 | 
         
             
                        image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
         
     | 
| 
         @@ -166,6 +167,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal): 
     | 
|
| 166 | 
         
             
                )
         
     | 
| 167 | 
         
             
                _tp_plan = {}
         
     | 
| 168 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 169 | 
         
             
                def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         
     | 
| 170 | 
         
             
                    super().__init__()
         
     | 
| 171 | 
         | 
| 
         @@ -409,6 +415,10 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal): 
     | 
|
| 409 | 
         | 
| 410 | 
         | 
| 411 | 
         
             
            def patch_vllm_chat_placeholder():
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 412 | 
         
             
                from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
         
     | 
| 413 | 
         | 
| 414 | 
         
             
                ori = BaseMultiModalItemTracker._placeholder_str
         
     | 
| 
         @@ -426,4 +436,5 @@ ModelRegistry.register_model( 
     | 
|
| 426 | 
         
             
                "DotsOCRForCausalLM", DotsOCRForCausalLM,
         
     | 
| 427 | 
         
             
            )
         
     | 
| 428 | 
         | 
| 
         | 
|
| 429 | 
         
             
            patch_vllm_chat_placeholder()
         
     | 
| 
         | 
|
| 99 | 
         
             
                    size: Optional[dict[str, int]] = None,
         
     | 
| 100 | 
         
             
                    **kwargs: object,
         
     | 
| 101 | 
         
             
                ) -> Qwen2VLProcessor:
         
     | 
| 102 | 
         
            +
                    self.get_tokenizer().image_token = "<|imgpad|>" # Ensure image token is set
         
     | 
| 103 | 
         
             
                    processor = self.ctx.get_hf_processor(
         
     | 
| 104 | 
         
             
                        Qwen2VLProcessor,
         
     | 
| 105 | 
         
             
                        image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size),
         
     | 
| 
         | 
|
| 167 | 
         
             
                )
         
     | 
| 168 | 
         
             
                _tp_plan = {}
         
     | 
| 169 | 
         | 
| 170 | 
         
            +
                @classmethod
         
     | 
| 171 | 
         
            +
                def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
         
     | 
| 172 | 
         
            +
                    if modality in ("image",):
         
     | 
| 173 | 
         
            +
                        return "<|img|><|imgpad|><|endofimg|>"
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
             
                def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         
     | 
| 176 | 
         
             
                    super().__init__()
         
     | 
| 177 | 
         | 
| 
         | 
|
| 415 | 
         | 
| 416 | 
         | 
| 417 | 
         
             
            def patch_vllm_chat_placeholder():
         
     | 
| 418 | 
         
            +
                import vllm
         
     | 
| 419 | 
         
            +
                # return when vllm version > 0.9.1
         
     | 
| 420 | 
         
            +
                if not (vllm.__version_tuple__[0]==0 and vllm.__version_tuple__[1] <= 9 and vllm.__version_tuple__[2] <= 1):
         
     | 
| 421 | 
         
            +
                    return
         
     | 
| 422 | 
         
             
                from vllm.entrypoints.chat_utils import BaseMultiModalItemTracker
         
     | 
| 423 | 
         | 
| 424 | 
         
             
                ori = BaseMultiModalItemTracker._placeholder_str
         
     | 
| 
         | 
|
| 436 | 
         
             
                "DotsOCRForCausalLM", DotsOCRForCausalLM,
         
     | 
| 437 | 
         
             
            )
         
     | 
| 438 | 
         | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
             
            patch_vllm_chat_placeholder()
         
     |