Add Apple Silicon (MPS) backend support
Browse filesEnables DeepSeek-OCR to run on Apple Silicon (M1/M2/M3/M4) using the MPS backend with proper OCR output quality.
Key changes:
- Replace masked_scatter_ with row-wise boolean assignment on MPS (fixes silent embedding injection failure)
- Use fp32 precision for images and inference on MPS (bfloat16 causes numerical issues)
- Disable autocast on MPS backend
- Make tensor placement device-agnostic (.to(self.device) instead of .cuda())
- Add NaN guards for vision tower outputs on MPS
All changes are conditionally applied based on self.device.type == "mps".
CUDA code path remains completely unchanged for full backwards compatibility.
Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0, Transformers 4.46.3
- modeling_deepseekocr.py +42 -17
    	
        modeling_deepseekocr.py
    CHANGED
    
    | @@ -3,6 +3,7 @@ from .configuration_deepseek_v2 import DeepseekV2Config | |
| 3 | 
             
            from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
         | 
| 4 | 
             
            from typing import List, Optional, Tuple, Union
         | 
| 5 | 
             
            from transformers.cache_utils import Cache
         | 
|  | |
| 6 | 
             
            import requests
         | 
| 7 | 
             
            from PIL import Image, ImageOps, ImageDraw, ImageFont
         | 
| 8 | 
             
            from io import BytesIO
         | 
| @@ -502,7 +503,23 @@ class DeepseekOCRModel(DeepseekV2Model): | |
| 502 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 503 | 
             
                                # exit()
         | 
| 504 |  | 
| 505 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 506 |  | 
| 507 | 
             
                            idx += 1
         | 
| 508 |  | 
| @@ -799,7 +816,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 799 |  | 
| 800 |  | 
| 801 |  | 
| 802 | 
            -
                             | 
|  | |
|  | |
| 803 |  | 
| 804 | 
             
                            # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
         | 
| 805 |  | 
| @@ -810,9 +829,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 810 |  | 
| 811 | 
             
                            if width_crop_num > 1 or height_crop_num > 1:
         | 
| 812 | 
             
                                """process the local views"""
         | 
| 813 | 
            -
             | 
| 814 | 
             
                                for i in range(len(images_crop_raw)):
         | 
| 815 | 
            -
                                    images_crop_list.append(image_transform(images_crop_raw[i]).to( | 
| 816 |  | 
| 817 | 
             
                            if image_size == 640:
         | 
| 818 | 
             
                                valid_img_tokens += len(images_crop_list) * 100
         | 
| @@ -846,7 +865,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 846 | 
             
                            # else:
         | 
| 847 | 
             
                            global_view = ImageOps.pad(image, (image_size, image_size),
         | 
| 848 | 
             
                                                    color=tuple(int(x * 255) for x in image_transform.mean))
         | 
| 849 | 
            -
                             | 
|  | |
|  | |
| 850 |  | 
| 851 | 
             
                            if base_size == 1024:
         | 
| 852 | 
             
                                valid_img_tokens += int(256 * ratio)
         | 
| @@ -911,12 +932,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 911 |  | 
| 912 | 
             
                    if not eval_mode:
         | 
| 913 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 914 | 
            -
                         | 
|  | |
|  | |
| 915 | 
             
                            with torch.no_grad():
         | 
| 916 | 
             
                                output_ids = self.generate(
         | 
| 917 | 
            -
                                    input_ids.unsqueeze(0). | 
| 918 | 
            -
                                    images=[(images_crop. | 
| 919 | 
            -
                                    images_seq_mask = images_seq_mask.unsqueeze(0). | 
| 920 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 921 | 
             
                                    # do_sample=False,
         | 
| 922 | 
             
                                    # num_beams = 1,
         | 
| @@ -929,12 +952,14 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 929 | 
             
                                    )
         | 
| 930 |  | 
| 931 | 
             
                    else:
         | 
| 932 | 
            -
                         | 
|  | |
|  | |
| 933 | 
             
                            with torch.no_grad():
         | 
| 934 | 
             
                                output_ids = self.generate(
         | 
| 935 | 
            -
                                    input_ids.unsqueeze(0). | 
| 936 | 
            -
                                    images=[(images_crop. | 
| 937 | 
            -
                                    images_seq_mask = images_seq_mask.unsqueeze(0). | 
| 938 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 939 | 
             
                                    # do_sample=False,
         | 
| 940 | 
             
                                    # num_beams = 1,
         | 
| @@ -944,10 +969,10 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 944 | 
             
                                    no_repeat_ngram_size = 35,
         | 
| 945 | 
             
                                    use_cache = True
         | 
| 946 | 
             
                                    )
         | 
| 947 | 
            -
             | 
| 948 |  | 
| 949 | 
             
                    if '<image>' in conversation[0]['content'] and eval_mode:
         | 
| 950 | 
            -
                            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 951 | 
             
                            stop_str = '<|end▁of▁sentence|>'
         | 
| 952 | 
             
                            if outputs.endswith(stop_str):
         | 
| 953 | 
             
                                outputs = outputs[:-len(stop_str)]
         | 
| @@ -957,7 +982,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 957 | 
             
                            return outputs
         | 
| 958 |  | 
| 959 | 
             
                    if '<image>' in conversation[0]['content'] and test_compress:
         | 
| 960 | 
            -
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 961 | 
             
                        pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
         | 
| 962 | 
             
                        print('='*50)
         | 
| 963 | 
             
                        print('image size: ', (w, h))
         | 
| @@ -968,7 +993,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): | |
| 968 |  | 
| 969 |  | 
| 970 | 
             
                    if '<image>' in conversation[0]['content'] and save_results:
         | 
| 971 | 
            -
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0). | 
| 972 | 
             
                        stop_str = '<|end▁of▁sentence|>'
         | 
| 973 |  | 
| 974 | 
             
                        print('='*15 + 'save results:' + '='*15)
         | 
|  | |
| 3 | 
             
            from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
         | 
| 4 | 
             
            from typing import List, Optional, Tuple, Union
         | 
| 5 | 
             
            from transformers.cache_utils import Cache
         | 
| 6 | 
            +
            from contextlib import nullcontext
         | 
| 7 | 
             
            import requests
         | 
| 8 | 
             
            from PIL import Image, ImageOps, ImageDraw, ImageFont
         | 
| 9 | 
             
            from io import BytesIO
         | 
|  | |
| 503 | 
             
                                images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
         | 
| 504 | 
             
                                # exit()
         | 
| 505 |  | 
| 506 | 
            +
                                # MPS compatibility: use row-wise assignment; CUDA: keep original masked_scatter_
         | 
| 507 | 
            +
                                if self.device.type == "mps":
         | 
| 508 | 
            +
                                    # MPS-safe: row-wise boolean assignment instead of broadcasted masked_scatter_
         | 
| 509 | 
            +
                                    mask = images_seq_mask[idx].to(self.device)
         | 
| 510 | 
            +
                                    feats = images_in_this_batch.to(dtype=inputs_embeds.dtype, device=self.device)
         | 
| 511 | 
            +
                                    # Basic sanity: number of rows must match
         | 
| 512 | 
            +
                                    if mask.sum().item() != feats.shape[0]:
         | 
| 513 | 
            +
                                        raise RuntimeError(
         | 
| 514 | 
            +
                                            f"image token count mismatch: mask={mask.sum().item()} vs feats={feats.shape[0]}"
         | 
| 515 | 
            +
                                        )
         | 
| 516 | 
            +
                                    # Guard against NaNs from upstream vision tower (seen on some MPS builds)
         | 
| 517 | 
            +
                                    feats = torch.nan_to_num(feats)
         | 
| 518 | 
            +
                                    # Deterministic row write
         | 
| 519 | 
            +
                                    inputs_embeds[idx][mask] = feats
         | 
| 520 | 
            +
                                else:
         | 
| 521 | 
            +
                                    # Original CUDA path (unchanged)
         | 
| 522 | 
            +
                                    inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
         | 
| 523 |  | 
| 524 | 
             
                            idx += 1
         | 
| 525 |  | 
|  | |
| 816 |  | 
| 817 |  | 
| 818 |  | 
| 819 | 
            +
                            # MPS needs fp32, CUDA can use bfloat16
         | 
| 820 | 
            +
                            image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
         | 
| 821 | 
            +
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 822 |  | 
| 823 | 
             
                            # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
         | 
| 824 |  | 
|  | |
| 829 |  | 
| 830 | 
             
                            if width_crop_num > 1 or height_crop_num > 1:
         | 
| 831 | 
             
                                """process the local views"""
         | 
| 832 | 
            +
             | 
| 833 | 
             
                                for i in range(len(images_crop_raw)):
         | 
| 834 | 
            +
                                    images_crop_list.append(image_transform(images_crop_raw[i]).to(image_dtype))
         | 
| 835 |  | 
| 836 | 
             
                            if image_size == 640:
         | 
| 837 | 
             
                                valid_img_tokens += len(images_crop_list) * 100
         | 
|  | |
| 865 | 
             
                            # else:
         | 
| 866 | 
             
                            global_view = ImageOps.pad(image, (image_size, image_size),
         | 
| 867 | 
             
                                                    color=tuple(int(x * 255) for x in image_transform.mean))
         | 
| 868 | 
            +
                            # MPS needs fp32, CUDA can use bfloat16
         | 
| 869 | 
            +
                            image_dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
         | 
| 870 | 
            +
                            images_list.append(image_transform(global_view).to(image_dtype))
         | 
| 871 |  | 
| 872 | 
             
                            if base_size == 1024:
         | 
| 873 | 
             
                                valid_img_tokens += int(256 * ratio)
         | 
|  | |
| 932 |  | 
| 933 | 
             
                    if not eval_mode:
         | 
| 934 | 
             
                        streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
         | 
| 935 | 
            +
                        # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
         | 
| 936 | 
            +
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 937 | 
            +
                        with autocast_ctx:
         | 
| 938 | 
             
                            with torch.no_grad():
         | 
| 939 | 
             
                                output_ids = self.generate(
         | 
| 940 | 
            +
                                    input_ids.unsqueeze(0).to(self.device),
         | 
| 941 | 
            +
                                    images=[(images_crop.to(self.device), images_ori.to(self.device))],
         | 
| 942 | 
            +
                                    images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
         | 
| 943 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 944 | 
             
                                    # do_sample=False,
         | 
| 945 | 
             
                                    # num_beams = 1,
         | 
|  | |
| 952 | 
             
                                    )
         | 
| 953 |  | 
| 954 | 
             
                    else:
         | 
| 955 | 
            +
                        # MPS: no autocast (pure fp32); CUDA: keep original bfloat16 autocast
         | 
| 956 | 
            +
                        autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
         | 
| 957 | 
            +
                        with autocast_ctx:
         | 
| 958 | 
             
                            with torch.no_grad():
         | 
| 959 | 
             
                                output_ids = self.generate(
         | 
| 960 | 
            +
                                    input_ids.unsqueeze(0).to(self.device),
         | 
| 961 | 
            +
                                    images=[(images_crop.to(self.device), images_ori.to(self.device))],
         | 
| 962 | 
            +
                                    images_seq_mask = images_seq_mask.unsqueeze(0).to(self.device),
         | 
| 963 | 
             
                                    images_spatial_crop = images_spatial_crop,
         | 
| 964 | 
             
                                    # do_sample=False,
         | 
| 965 | 
             
                                    # num_beams = 1,
         | 
|  | |
| 969 | 
             
                                    no_repeat_ngram_size = 35,
         | 
| 970 | 
             
                                    use_cache = True
         | 
| 971 | 
             
                                    )
         | 
| 972 | 
            +
             | 
| 973 |  | 
| 974 | 
             
                    if '<image>' in conversation[0]['content'] and eval_mode:
         | 
| 975 | 
            +
                            outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
         | 
| 976 | 
             
                            stop_str = '<|end▁of▁sentence|>'
         | 
| 977 | 
             
                            if outputs.endswith(stop_str):
         | 
| 978 | 
             
                                outputs = outputs[:-len(stop_str)]
         | 
|  | |
| 982 | 
             
                            return outputs
         | 
| 983 |  | 
| 984 | 
             
                    if '<image>' in conversation[0]['content'] and test_compress:
         | 
| 985 | 
            +
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
         | 
| 986 | 
             
                        pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
         | 
| 987 | 
             
                        print('='*50)
         | 
| 988 | 
             
                        print('image size: ', (w, h))
         | 
|  | |
| 993 |  | 
| 994 |  | 
| 995 | 
             
                    if '<image>' in conversation[0]['content'] and save_results:
         | 
| 996 | 
            +
                        outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
         | 
| 997 | 
             
                        stop_str = '<|end▁of▁sentence|>'
         | 
| 998 |  | 
| 999 | 
             
                        print('='*15 + 'save results:' + '='*15)
         | 
