tommulder's picture
docs(model_loader): add comprehensive docstrings and comments; no functional changes
d405999
"""Dots.OCR Model Loader
This module handles downloading and loading the Dots.OCR model using
Hugging Face's `snapshot_download`. It centralizes device selection,
dtype configuration, model initialization, and safe fallbacks.
Why this exists:
- Keep model lifecycle and I/O concerns isolated from API/business logic.
- Provide safe CPU defaults, optional CUDA acceleration, and optional
FlashAttention2 when compatible and explicitly enabled.
Key environment variables:
- DOTS_OCR_REPO_ID: HF repo to download (default: "rednote-hilab/dots.ocr").
- DOTS_OCR_LOCAL_DIR: Local cache directory for `snapshot_download`.
- DOTS_OCR_DEVICE: One of {"cpu", "cuda", "auto"}. "auto" prefers CUDA.
- DOTS_OCR_MAX_NEW_TOKENS: Max generated tokens per request.
- DOTS_OCR_FLASH_ATTENTION: "1" to attempt FlashAttention2 when compatible.
- DOTS_OCR_MIN_PIXELS / DOTS_OCR_MAX_PIXELS: Image size bounds pre-inference.
- DOTS_OCR_PROMPT: Optional default transcription prompt.
Usage: call `load_model()` once, then `extract_text(image)` per request.
"""
import os
import logging
import torch
from typing import Optional, Tuple, Dict, Any
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
# Configure logging
logger = logging.getLogger(__name__)
# Environment variable configuration
#
# These env vars make runtime behavior tunable without code changes. Defaults are
# conservative to favor stability on CPU-only platforms; performance features
# are opt-in and gated by compatibility checks.
REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto") # "auto" prefers CUDA if available
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1" # opt-in
MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56 lower bound
MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360 upper bound
CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
# Default transcription prompt for faithful text extraction.
# Keep terse to reduce bias; we want faithful extraction, not translation or formatting.
DEFAULT_PROMPT = (
"Transcribe all visible text in the image in the original language. "
"Do not translate. Preserve natural reading order. Output plain text only."
)
class DotsOCRModelLoader:
"""Handles Dots.OCR model downloading, loading, and inference.
Encapsulates model lifecycle (download, init, device placement), preprocessing,
and a narrow inference surface for OCR. Exposes a minimal API and maintains a
single global instance via helpers below.
"""
def __init__(self):
"""Initialize the model loader.
Heavyweight work is deferred until `load_model()` so that constructing this
class is cheap. The default prompt is captured from env, if provided.
"""
self.model = None
self.processor = None
self.device = None
self.dtype = None
self.local_dir = None
self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT
def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]:
"""Pick device and dtype based on availability and configuration.
Rules:
- Respect explicit "cpu" or "cuda" when valid.
- "auto" selects CUDA when available, else CPU.
- Use bfloat16 on CUDA for throughput; float32 on CPU for correctness.
"""
if DEVICE_CONFIG == "cpu":
device = "cpu"
dtype = torch.float32
elif DEVICE_CONFIG == "cuda" and torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
elif DEVICE_CONFIG == "auto":
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
else:
device = "cpu"
dtype = torch.float32
else:
# Fallback to CPU if CUDA requested but not available
logger.warning(f"CUDA requested but not available, falling back to CPU")
device = "cpu"
dtype = torch.float32
logger.info(f"Selected device: {device}, dtype: {dtype}")
return device, dtype
def _download_model(self) -> str:
"""Download the model using `snapshot_download` and ensure cache dir exists.
Returns the resolved local path for deterministic, offline-friendly loading.
Raises `RuntimeError` on failure.
"""
logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}")
try:
# Ensure local directory exists
Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True)
# Download model snapshot
local_path = snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR,
)
logger.info(f"Model downloaded successfully to {local_path}")
return local_path
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise RuntimeError(f"Model download failed: {e}")
def _can_use_flash_attn(self) -> bool:
"""Check whether FlashAttention2 can be enabled safely.
Requires all of:
- DOTS_OCR_FLASH_ATTENTION toggle is set.
- `flash_attn` is importable.
- dtype is fp16/bf16 per library support.
"""
if not USE_FLASH_ATTENTION:
return False
try:
# Import check avoids runtime error from Transformers if not installed
import flash_attn # type: ignore # noqa: F401
except Exception:
logger.warning(
"flash_attn package not installed; disabling FlashAttention2"
)
return False
# FlashAttention2 supports fp16/bf16 only (see HF docs)
return self.dtype in (torch.float16, torch.bfloat16)
def load_model(self) -> None:
"""Load the Dots.OCR model and processor.
Steps:
1) Determine device/dtype
2) Download snapshot if missing
3) Load `AutoProcessor`
4) Configure attention/device mapping
5) Instantiate model and place on target device
"""
try:
# Determine device and dtype
self.device, self.dtype = self._determine_device_and_dtype()
# Download model if not already present
self.local_dir = self._download_model()
# Load processor
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
self.local_dir, trust_remote_code=True
)
# Load model with appropriate configuration
model_kwargs = {
"dtype": self.dtype, # NOTE: `torch_dtype` is deprecated upstream
"trust_remote_code": True,
}
# Add device-specific configurations
if self.device == "cuda":
# Prefer FlashAttention2 when truly available; otherwise SDPA
if self._can_use_flash_attn():
model_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using flash attention 2")
else:
model_kwargs["attn_implementation"] = "sdpa"
logger.info(
"Using SDPA attention (flash-attn unavailable or disabled)"
)
# Use device_map for automatic GPU memory management
model_kwargs["device_map"] = "auto"
else:
# For CPU, don't use device_map
model_kwargs["device_map"] = None
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.local_dir, **model_kwargs
)
# Move model to device if not using device_map
if self.device == "cpu" or model_kwargs.get("device_map") is None:
self.model = self.model.to(self.device)
logger.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model loading failed: {e}")
def _preprocess_image(self, image: Image.Image) -> Image.Image:
"""Preprocess image to meet model requirements.
- Normalize to RGB
- Constrain pixel count within [MIN_PIXELS, MAX_PIXELS]
- Snap dimensions to multiples of 28 to satisfy backbone constraints
"""
# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")
# Calculate current pixel count
width, height = image.size
current_pixels = width * height
# Resize if necessary to meet pixel requirements
if current_pixels < MIN_PIXELS:
# Scale up to meet minimum pixel requirement
scale_factor = (MIN_PIXELS / current_pixels) ** 0.5
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Scaled up image from {width}x{height} to {new_width}x{new_height}"
)
elif current_pixels > MAX_PIXELS:
# Scale down to meet maximum pixel requirement
scale_factor = (MAX_PIXELS / current_pixels) ** 0.5
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Scaled down image from {width}x{height} to {new_width}x{new_height}"
)
# Ensure dimensions are divisible by 28 (common requirement for vision models)
width, height = image.size
new_width = ((width + 27) // 28) * 28
new_height = ((height + 27) // 28) * 28
if new_width != width or new_height != height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}"
)
return image
@torch.inference_mode()
def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
"""Extract text from an image using the loaded model.
Builds a single-turn chat message with the image and a transcription prompt,
applies the model's chat template, and decodes deterministically (no sampling).
"""
if self.model is None or self.processor is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
try:
# Preprocess image
processed_image = self._preprocess_image(image)
# Use provided prompt or default
text_prompt = prompt or self.prompt
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": processed_image},
{"type": "text", "text": text_prompt},
],
}
]
# Apply chat template (preserves special tokens/formatting expected by model)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process vision information (required for some models)
try:
from qwen_vl_utils import process_vision_info
image_inputs, video_inputs = process_vision_info(messages)
except ImportError:
# Fallback if qwen_vl_utils not available
logger.warning("qwen_vl_utils not available, using basic processing")
image_inputs = [processed_image]
video_inputs = []
# Prepare inputs
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Generate text deterministically (temperature=0, do_sample=False)
output_ids = self.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=0.0,
pad_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode only newly generated tokens (strip prompt tokens)
trimmed = [
out[len(inp) :] for inp, out in zip(inputs.input_ids, output_ids)
]
decoded = self.processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return decoded[0] if decoded else ""
except Exception as e:
logger.error(f"Text extraction failed: {e}")
raise RuntimeError(f"Text extraction failed: {e}")
def is_loaded(self) -> bool:
"""Return True when both model and processor are initialized."""
return self.model is not None and self.processor is not None
def get_model_info(self) -> Dict[str, Any]:
"""Get diagnostic information about the loaded model and configuration."""
return {
"device": self.device,
"dtype": str(self.dtype),
"local_dir": self.local_dir,
"repo_id": REPO_ID,
"max_new_tokens": MAX_NEW_TOKENS,
"use_flash_attention": USE_FLASH_ATTENTION,
"prompt": self.prompt,
"is_loaded": self.is_loaded(),
}
# Global model instance
_model_loader: Optional[DotsOCRModelLoader] = None
def get_model_loader() -> DotsOCRModelLoader:
"""Get the global model loader instance."""
global _model_loader
if _model_loader is None:
_model_loader = DotsOCRModelLoader()
return _model_loader
def load_model() -> None:
"""Load the Dots.OCR model."""
loader = get_model_loader()
loader.load_model()
def extract_text(image: Image.Image, prompt: Optional[str] = None) -> str:
"""Extract text from an image using the loaded model."""
loader = get_model_loader()
if not loader.is_loaded():
raise RuntimeError("Model not loaded. Call load_model() first.")
return loader.extract_text(image, prompt)
def is_model_loaded() -> bool:
"""Check if the model is loaded and ready."""
loader = get_model_loader()
return loader.is_loaded()
def get_model_info() -> Dict[str, Any]:
"""Get information about the loaded model."""
loader = get_model_loader()
return loader.get_model_info()