Spaces:
Paused
Paused
| """Dots.OCR Model Loader | |
| This module handles downloading and loading the Dots.OCR model using | |
| Hugging Face's `snapshot_download`. It centralizes device selection, | |
| dtype configuration, model initialization, and safe fallbacks. | |
| Why this exists: | |
| - Keep model lifecycle and I/O concerns isolated from API/business logic. | |
| - Provide safe CPU defaults, optional CUDA acceleration, and optional | |
| FlashAttention2 when compatible and explicitly enabled. | |
| Key environment variables: | |
| - DOTS_OCR_REPO_ID: HF repo to download (default: "rednote-hilab/dots.ocr"). | |
| - DOTS_OCR_LOCAL_DIR: Local cache directory for `snapshot_download`. | |
| - DOTS_OCR_DEVICE: One of {"cpu", "cuda", "auto"}. "auto" prefers CUDA. | |
| - DOTS_OCR_MAX_NEW_TOKENS: Max generated tokens per request. | |
| - DOTS_OCR_FLASH_ATTENTION: "1" to attempt FlashAttention2 when compatible. | |
| - DOTS_OCR_MIN_PIXELS / DOTS_OCR_MAX_PIXELS: Image size bounds pre-inference. | |
| - DOTS_OCR_PROMPT: Optional default transcription prompt. | |
| Usage: call `load_model()` once, then `extract_text(image)` per request. | |
| """ | |
| import os | |
| import logging | |
| import torch | |
| from typing import Optional, Tuple, Dict, Any | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from PIL import Image | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Environment variable configuration | |
| # | |
| # These env vars make runtime behavior tunable without code changes. Defaults are | |
| # conservative to favor stability on CPU-only platforms; performance features | |
| # are opt-in and gated by compatibility checks. | |
| REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr") | |
| LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr") | |
| DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto") # "auto" prefers CUDA if available | |
| MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048")) | |
| USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1" # opt-in | |
| MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56 lower bound | |
| MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360 upper bound | |
| CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT") | |
| # Default transcription prompt for faithful text extraction. | |
| # Keep terse to reduce bias; we want faithful extraction, not translation or formatting. | |
| DEFAULT_PROMPT = ( | |
| "Transcribe all visible text in the image in the original language. " | |
| "Do not translate. Preserve natural reading order. Output plain text only." | |
| ) | |
| class DotsOCRModelLoader: | |
| """Handles Dots.OCR model downloading, loading, and inference. | |
| Encapsulates model lifecycle (download, init, device placement), preprocessing, | |
| and a narrow inference surface for OCR. Exposes a minimal API and maintains a | |
| single global instance via helpers below. | |
| """ | |
| def __init__(self): | |
| """Initialize the model loader. | |
| Heavyweight work is deferred until `load_model()` so that constructing this | |
| class is cheap. The default prompt is captured from env, if provided. | |
| """ | |
| self.model = None | |
| self.processor = None | |
| self.device = None | |
| self.dtype = None | |
| self.local_dir = None | |
| self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT | |
| def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]: | |
| """Pick device and dtype based on availability and configuration. | |
| Rules: | |
| - Respect explicit "cpu" or "cuda" when valid. | |
| - "auto" selects CUDA when available, else CPU. | |
| - Use bfloat16 on CUDA for throughput; float32 on CPU for correctness. | |
| """ | |
| if DEVICE_CONFIG == "cpu": | |
| device = "cpu" | |
| dtype = torch.float32 | |
| elif DEVICE_CONFIG == "cuda" and torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| elif DEVICE_CONFIG == "auto": | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| else: | |
| # Fallback to CPU if CUDA requested but not available | |
| logger.warning(f"CUDA requested but not available, falling back to CPU") | |
| device = "cpu" | |
| dtype = torch.float32 | |
| logger.info(f"Selected device: {device}, dtype: {dtype}") | |
| return device, dtype | |
| def _download_model(self) -> str: | |
| """Download the model using `snapshot_download` and ensure cache dir exists. | |
| Returns the resolved local path for deterministic, offline-friendly loading. | |
| Raises `RuntimeError` on failure. | |
| """ | |
| logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}") | |
| try: | |
| # Ensure local directory exists | |
| Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
| # Download model snapshot | |
| local_path = snapshot_download( | |
| repo_id=REPO_ID, | |
| local_dir=LOCAL_DIR, | |
| ) | |
| logger.info(f"Model downloaded successfully to {local_path}") | |
| return local_path | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise RuntimeError(f"Model download failed: {e}") | |
| def _can_use_flash_attn(self) -> bool: | |
| """Check whether FlashAttention2 can be enabled safely. | |
| Requires all of: | |
| - DOTS_OCR_FLASH_ATTENTION toggle is set. | |
| - `flash_attn` is importable. | |
| - dtype is fp16/bf16 per library support. | |
| """ | |
| if not USE_FLASH_ATTENTION: | |
| return False | |
| try: | |
| # Import check avoids runtime error from Transformers if not installed | |
| import flash_attn # type: ignore # noqa: F401 | |
| except Exception: | |
| logger.warning( | |
| "flash_attn package not installed; disabling FlashAttention2" | |
| ) | |
| return False | |
| # FlashAttention2 supports fp16/bf16 only (see HF docs) | |
| return self.dtype in (torch.float16, torch.bfloat16) | |
| def load_model(self) -> None: | |
| """Load the Dots.OCR model and processor. | |
| Steps: | |
| 1) Determine device/dtype | |
| 2) Download snapshot if missing | |
| 3) Load `AutoProcessor` | |
| 4) Configure attention/device mapping | |
| 5) Instantiate model and place on target device | |
| """ | |
| try: | |
| # Determine device and dtype | |
| self.device, self.dtype = self._determine_device_and_dtype() | |
| # Download model if not already present | |
| self.local_dir = self._download_model() | |
| # Load processor | |
| logger.info("Loading processor...") | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.local_dir, trust_remote_code=True | |
| ) | |
| # Load model with appropriate configuration | |
| model_kwargs = { | |
| "dtype": self.dtype, # NOTE: `torch_dtype` is deprecated upstream | |
| "trust_remote_code": True, | |
| } | |
| # Add device-specific configurations | |
| if self.device == "cuda": | |
| # Prefer FlashAttention2 when truly available; otherwise SDPA | |
| if self._can_use_flash_attn(): | |
| model_kwargs["attn_implementation"] = "flash_attention_2" | |
| logger.info("Using flash attention 2") | |
| else: | |
| model_kwargs["attn_implementation"] = "sdpa" | |
| logger.info( | |
| "Using SDPA attention (flash-attn unavailable or disabled)" | |
| ) | |
| # Use device_map for automatic GPU memory management | |
| model_kwargs["device_map"] = "auto" | |
| else: | |
| # For CPU, don't use device_map | |
| model_kwargs["device_map"] = None | |
| logger.info("Loading model...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.local_dir, **model_kwargs | |
| ) | |
| # Move model to device if not using device_map | |
| if self.device == "cpu" or model_kwargs.get("device_map") is None: | |
| self.model = self.model.to(self.device) | |
| logger.info(f"Model loaded successfully on {self.device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| def _preprocess_image(self, image: Image.Image) -> Image.Image: | |
| """Preprocess image to meet model requirements. | |
| - Normalize to RGB | |
| - Constrain pixel count within [MIN_PIXELS, MAX_PIXELS] | |
| - Snap dimensions to multiples of 28 to satisfy backbone constraints | |
| """ | |
| # Convert to RGB if necessary | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Calculate current pixel count | |
| width, height = image.size | |
| current_pixels = width * height | |
| # Resize if necessary to meet pixel requirements | |
| if current_pixels < MIN_PIXELS: | |
| # Scale up to meet minimum pixel requirement | |
| scale_factor = (MIN_PIXELS / current_pixels) ** 0.5 | |
| new_width = int(width * scale_factor) | |
| new_height = int(height * scale_factor) | |
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| logger.info( | |
| f"Scaled up image from {width}x{height} to {new_width}x{new_height}" | |
| ) | |
| elif current_pixels > MAX_PIXELS: | |
| # Scale down to meet maximum pixel requirement | |
| scale_factor = (MAX_PIXELS / current_pixels) ** 0.5 | |
| new_width = int(width * scale_factor) | |
| new_height = int(height * scale_factor) | |
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| logger.info( | |
| f"Scaled down image from {width}x{height} to {new_width}x{new_height}" | |
| ) | |
| # Ensure dimensions are divisible by 28 (common requirement for vision models) | |
| width, height = image.size | |
| new_width = ((width + 27) // 28) * 28 | |
| new_height = ((height + 27) // 28) * 28 | |
| if new_width != width or new_height != height: | |
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| logger.info( | |
| f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}" | |
| ) | |
| return image | |
| def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str: | |
| """Extract text from an image using the loaded model. | |
| Builds a single-turn chat message with the image and a transcription prompt, | |
| applies the model's chat template, and decodes deterministically (no sampling). | |
| """ | |
| if self.model is None or self.processor is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| try: | |
| # Preprocess image | |
| processed_image = self._preprocess_image(image) | |
| # Use provided prompt or default | |
| text_prompt = prompt or self.prompt | |
| # Prepare messages for the model | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": processed_image}, | |
| {"type": "text", "text": text_prompt}, | |
| ], | |
| } | |
| ] | |
| # Apply chat template (preserves special tokens/formatting expected by model) | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Process vision information (required for some models) | |
| try: | |
| from qwen_vl_utils import process_vision_info | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| except ImportError: | |
| # Fallback if qwen_vl_utils not available | |
| logger.warning("qwen_vl_utils not available, using basic processing") | |
| image_inputs = [processed_image] | |
| video_inputs = [] | |
| # Prepare inputs | |
| inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| # Generate text deterministically (temperature=0, do_sample=False) | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| temperature=0.0, | |
| pad_token_id=self.processor.tokenizer.eos_token_id, | |
| ) | |
| # Decode only newly generated tokens (strip prompt tokens) | |
| trimmed = [ | |
| out[len(inp) :] for inp, out in zip(inputs.input_ids, output_ids) | |
| ] | |
| decoded = self.processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return decoded[0] if decoded else "" | |
| except Exception as e: | |
| logger.error(f"Text extraction failed: {e}") | |
| raise RuntimeError(f"Text extraction failed: {e}") | |
| def is_loaded(self) -> bool: | |
| """Return True when both model and processor are initialized.""" | |
| return self.model is not None and self.processor is not None | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get diagnostic information about the loaded model and configuration.""" | |
| return { | |
| "device": self.device, | |
| "dtype": str(self.dtype), | |
| "local_dir": self.local_dir, | |
| "repo_id": REPO_ID, | |
| "max_new_tokens": MAX_NEW_TOKENS, | |
| "use_flash_attention": USE_FLASH_ATTENTION, | |
| "prompt": self.prompt, | |
| "is_loaded": self.is_loaded(), | |
| } | |
| # Global model instance | |
| _model_loader: Optional[DotsOCRModelLoader] = None | |
| def get_model_loader() -> DotsOCRModelLoader: | |
| """Get the global model loader instance.""" | |
| global _model_loader | |
| if _model_loader is None: | |
| _model_loader = DotsOCRModelLoader() | |
| return _model_loader | |
| def load_model() -> None: | |
| """Load the Dots.OCR model.""" | |
| loader = get_model_loader() | |
| loader.load_model() | |
| def extract_text(image: Image.Image, prompt: Optional[str] = None) -> str: | |
| """Extract text from an image using the loaded model.""" | |
| loader = get_model_loader() | |
| if not loader.is_loaded(): | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| return loader.extract_text(image, prompt) | |
| def is_model_loaded() -> bool: | |
| """Check if the model is loaded and ready.""" | |
| loader = get_model_loader() | |
| return loader.is_loaded() | |
| def get_model_info() -> Dict[str, Any]: | |
| """Get information about the loaded model.""" | |
| loader = get_model_loader() | |
| return loader.get_model_info() | |