"""Dots.OCR Model Loader This module handles downloading and loading the Dots.OCR model using Hugging Face's `snapshot_download`. It centralizes device selection, dtype configuration, model initialization, and safe fallbacks. Why this exists: - Keep model lifecycle and I/O concerns isolated from API/business logic. - Provide safe CPU defaults, optional CUDA acceleration, and optional FlashAttention2 when compatible and explicitly enabled. Key environment variables: - DOTS_OCR_REPO_ID: HF repo to download (default: "rednote-hilab/dots.ocr"). - DOTS_OCR_LOCAL_DIR: Local cache directory for `snapshot_download`. - DOTS_OCR_DEVICE: One of {"cpu", "cuda", "auto"}. "auto" prefers CUDA. - DOTS_OCR_MAX_NEW_TOKENS: Max generated tokens per request. - DOTS_OCR_FLASH_ATTENTION: "1" to attempt FlashAttention2 when compatible. - DOTS_OCR_MIN_PIXELS / DOTS_OCR_MAX_PIXELS: Image size bounds pre-inference. - DOTS_OCR_PROMPT: Optional default transcription prompt. Usage: call `load_model()` once, then `extract_text(image)` per request. """ import os import logging import torch from typing import Optional, Tuple, Dict, Any from pathlib import Path from huggingface_hub import snapshot_download from transformers import AutoModelForCausalLM, AutoProcessor from PIL import Image # Configure logging logger = logging.getLogger(__name__) # Environment variable configuration # # These env vars make runtime behavior tunable without code changes. Defaults are # conservative to favor stability on CPU-only platforms; performance features # are opt-in and gated by compatibility checks. REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr") LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr") DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto") # "auto" prefers CUDA if available MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048")) USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1" # opt-in MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56 lower bound MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360 upper bound CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT") # Default transcription prompt for faithful text extraction. # Keep terse to reduce bias; we want faithful extraction, not translation or formatting. DEFAULT_PROMPT = ( "Transcribe all visible text in the image in the original language. " "Do not translate. Preserve natural reading order. Output plain text only." ) class DotsOCRModelLoader: """Handles Dots.OCR model downloading, loading, and inference. Encapsulates model lifecycle (download, init, device placement), preprocessing, and a narrow inference surface for OCR. Exposes a minimal API and maintains a single global instance via helpers below. """ def __init__(self): """Initialize the model loader. Heavyweight work is deferred until `load_model()` so that constructing this class is cheap. The default prompt is captured from env, if provided. """ self.model = None self.processor = None self.device = None self.dtype = None self.local_dir = None self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]: """Pick device and dtype based on availability and configuration. Rules: - Respect explicit "cpu" or "cuda" when valid. - "auto" selects CUDA when available, else CPU. - Use bfloat16 on CUDA for throughput; float32 on CPU for correctness. """ if DEVICE_CONFIG == "cpu": device = "cpu" dtype = torch.float32 elif DEVICE_CONFIG == "cuda" and torch.cuda.is_available(): device = "cuda" dtype = torch.bfloat16 elif DEVICE_CONFIG == "auto": if torch.cuda.is_available(): device = "cuda" dtype = torch.bfloat16 else: device = "cpu" dtype = torch.float32 else: # Fallback to CPU if CUDA requested but not available logger.warning(f"CUDA requested but not available, falling back to CPU") device = "cpu" dtype = torch.float32 logger.info(f"Selected device: {device}, dtype: {dtype}") return device, dtype def _download_model(self) -> str: """Download the model using `snapshot_download` and ensure cache dir exists. Returns the resolved local path for deterministic, offline-friendly loading. Raises `RuntimeError` on failure. """ logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}") try: # Ensure local directory exists Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True) # Download model snapshot local_path = snapshot_download( repo_id=REPO_ID, local_dir=LOCAL_DIR, ) logger.info(f"Model downloaded successfully to {local_path}") return local_path except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError(f"Model download failed: {e}") def _can_use_flash_attn(self) -> bool: """Check whether FlashAttention2 can be enabled safely. Requires all of: - DOTS_OCR_FLASH_ATTENTION toggle is set. - `flash_attn` is importable. - dtype is fp16/bf16 per library support. """ if not USE_FLASH_ATTENTION: return False try: # Import check avoids runtime error from Transformers if not installed import flash_attn # type: ignore # noqa: F401 except Exception: logger.warning( "flash_attn package not installed; disabling FlashAttention2" ) return False # FlashAttention2 supports fp16/bf16 only (see HF docs) return self.dtype in (torch.float16, torch.bfloat16) def load_model(self) -> None: """Load the Dots.OCR model and processor. Steps: 1) Determine device/dtype 2) Download snapshot if missing 3) Load `AutoProcessor` 4) Configure attention/device mapping 5) Instantiate model and place on target device """ try: # Determine device and dtype self.device, self.dtype = self._determine_device_and_dtype() # Download model if not already present self.local_dir = self._download_model() # Load processor logger.info("Loading processor...") self.processor = AutoProcessor.from_pretrained( self.local_dir, trust_remote_code=True ) # Load model with appropriate configuration model_kwargs = { "dtype": self.dtype, # NOTE: `torch_dtype` is deprecated upstream "trust_remote_code": True, } # Add device-specific configurations if self.device == "cuda": # Prefer FlashAttention2 when truly available; otherwise SDPA if self._can_use_flash_attn(): model_kwargs["attn_implementation"] = "flash_attention_2" logger.info("Using flash attention 2") else: model_kwargs["attn_implementation"] = "sdpa" logger.info( "Using SDPA attention (flash-attn unavailable or disabled)" ) # Use device_map for automatic GPU memory management model_kwargs["device_map"] = "auto" else: # For CPU, don't use device_map model_kwargs["device_map"] = None logger.info("Loading model...") self.model = AutoModelForCausalLM.from_pretrained( self.local_dir, **model_kwargs ) # Move model to device if not using device_map if self.device == "cpu" or model_kwargs.get("device_map") is None: self.model = self.model.to(self.device) logger.info(f"Model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load model: {e}") raise RuntimeError(f"Model loading failed: {e}") def _preprocess_image(self, image: Image.Image) -> Image.Image: """Preprocess image to meet model requirements. - Normalize to RGB - Constrain pixel count within [MIN_PIXELS, MAX_PIXELS] - Snap dimensions to multiples of 28 to satisfy backbone constraints """ # Convert to RGB if necessary if image.mode != "RGB": image = image.convert("RGB") # Calculate current pixel count width, height = image.size current_pixels = width * height # Resize if necessary to meet pixel requirements if current_pixels < MIN_PIXELS: # Scale up to meet minimum pixel requirement scale_factor = (MIN_PIXELS / current_pixels) ** 0.5 new_width = int(width * scale_factor) new_height = int(height * scale_factor) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) logger.info( f"Scaled up image from {width}x{height} to {new_width}x{new_height}" ) elif current_pixels > MAX_PIXELS: # Scale down to meet maximum pixel requirement scale_factor = (MAX_PIXELS / current_pixels) ** 0.5 new_width = int(width * scale_factor) new_height = int(height * scale_factor) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) logger.info( f"Scaled down image from {width}x{height} to {new_width}x{new_height}" ) # Ensure dimensions are divisible by 28 (common requirement for vision models) width, height = image.size new_width = ((width + 27) // 28) * 28 new_height = ((height + 27) // 28) * 28 if new_width != width or new_height != height: image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) logger.info( f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}" ) return image @torch.inference_mode() def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: """Extract text from an image using the loaded model. Builds a single-turn chat message with the image and a transcription prompt, applies the model's chat template, and decodes deterministically (no sampling). """ if self.model is None or self.processor is None: raise RuntimeError("Model not loaded. Call load_model() first.") try: # Preprocess image processed_image = self._preprocess_image(image) # Use provided prompt or default text_prompt = prompt or self.prompt # Prepare messages for the model messages = [ { "role": "user", "content": [ {"type": "image", "image": processed_image}, {"type": "text", "text": text_prompt}, ], } ] # Apply chat template (preserves special tokens/formatting expected by model) text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process vision information (required for some models) try: from qwen_vl_utils import process_vision_info image_inputs, video_inputs = process_vision_info(messages) except ImportError: # Fallback if qwen_vl_utils not available logger.warning("qwen_vl_utils not available, using basic processing") image_inputs = [processed_image] video_inputs = [] # Prepare inputs inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(self.device) # Generate text deterministically (temperature=0, do_sample=False) output_ids = self.model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, temperature=0.0, pad_token_id=self.processor.tokenizer.eos_token_id, ) # Decode only newly generated tokens (strip prompt tokens) trimmed = [ out[len(inp) :] for inp, out in zip(inputs.input_ids, output_ids) ] decoded = self.processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return decoded[0] if decoded else "" except Exception as e: logger.error(f"Text extraction failed: {e}") raise RuntimeError(f"Text extraction failed: {e}") def is_loaded(self) -> bool: """Return True when both model and processor are initialized.""" return self.model is not None and self.processor is not None def get_model_info(self) -> Dict[str, Any]: """Get diagnostic information about the loaded model and configuration.""" return { "device": self.device, "dtype": str(self.dtype), "local_dir": self.local_dir, "repo_id": REPO_ID, "max_new_tokens": MAX_NEW_TOKENS, "use_flash_attention": USE_FLASH_ATTENTION, "prompt": self.prompt, "is_loaded": self.is_loaded(), } # Global model instance _model_loader: Optional[DotsOCRModelLoader] = None def get_model_loader() -> DotsOCRModelLoader: """Get the global model loader instance.""" global _model_loader if _model_loader is None: _model_loader = DotsOCRModelLoader() return _model_loader def load_model() -> None: """Load the Dots.OCR model.""" loader = get_model_loader() loader.load_model() def extract_text(image: Image.Image, prompt: Optional[str] = None) -> str: """Extract text from an image using the loaded model.""" loader = get_model_loader() if not loader.is_loaded(): raise RuntimeError("Model not loaded. Call load_model() first.") return loader.extract_text(image, prompt) def is_model_loaded() -> bool: """Check if the model is loaded and ready.""" loader = get_model_loader() return loader.is_loaded() def get_model_info() -> Dict[str, Any]: """Get information about the loaded model.""" loader = get_model_loader() return loader.get_model_info()