Spaces:
Paused
Paused
File size: 15,128 Bytes
211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 420a04f d405999 420a04f d405999 420a04f d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 420a04f d405999 420a04f 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 d405999 211e423 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
"""Dots.OCR Model Loader
This module handles downloading and loading the Dots.OCR model using
Hugging Face's `snapshot_download`. It centralizes device selection,
dtype configuration, model initialization, and safe fallbacks.
Why this exists:
- Keep model lifecycle and I/O concerns isolated from API/business logic.
- Provide safe CPU defaults, optional CUDA acceleration, and optional
FlashAttention2 when compatible and explicitly enabled.
Key environment variables:
- DOTS_OCR_REPO_ID: HF repo to download (default: "rednote-hilab/dots.ocr").
- DOTS_OCR_LOCAL_DIR: Local cache directory for `snapshot_download`.
- DOTS_OCR_DEVICE: One of {"cpu", "cuda", "auto"}. "auto" prefers CUDA.
- DOTS_OCR_MAX_NEW_TOKENS: Max generated tokens per request.
- DOTS_OCR_FLASH_ATTENTION: "1" to attempt FlashAttention2 when compatible.
- DOTS_OCR_MIN_PIXELS / DOTS_OCR_MAX_PIXELS: Image size bounds pre-inference.
- DOTS_OCR_PROMPT: Optional default transcription prompt.
Usage: call `load_model()` once, then `extract_text(image)` per request.
"""
import os
import logging
import torch
from typing import Optional, Tuple, Dict, Any
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
# Configure logging
logger = logging.getLogger(__name__)
# Environment variable configuration
#
# These env vars make runtime behavior tunable without code changes. Defaults are
# conservative to favor stability on CPU-only platforms; performance features
# are opt-in and gated by compatibility checks.
REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto") # "auto" prefers CUDA if available
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1" # opt-in
MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56 lower bound
MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360 upper bound
CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
# Default transcription prompt for faithful text extraction.
# Keep terse to reduce bias; we want faithful extraction, not translation or formatting.
DEFAULT_PROMPT = (
"Transcribe all visible text in the image in the original language. "
"Do not translate. Preserve natural reading order. Output plain text only."
)
class DotsOCRModelLoader:
"""Handles Dots.OCR model downloading, loading, and inference.
Encapsulates model lifecycle (download, init, device placement), preprocessing,
and a narrow inference surface for OCR. Exposes a minimal API and maintains a
single global instance via helpers below.
"""
def __init__(self):
"""Initialize the model loader.
Heavyweight work is deferred until `load_model()` so that constructing this
class is cheap. The default prompt is captured from env, if provided.
"""
self.model = None
self.processor = None
self.device = None
self.dtype = None
self.local_dir = None
self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT
def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]:
"""Pick device and dtype based on availability and configuration.
Rules:
- Respect explicit "cpu" or "cuda" when valid.
- "auto" selects CUDA when available, else CPU.
- Use bfloat16 on CUDA for throughput; float32 on CPU for correctness.
"""
if DEVICE_CONFIG == "cpu":
device = "cpu"
dtype = torch.float32
elif DEVICE_CONFIG == "cuda" and torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
elif DEVICE_CONFIG == "auto":
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
else:
device = "cpu"
dtype = torch.float32
else:
# Fallback to CPU if CUDA requested but not available
logger.warning(f"CUDA requested but not available, falling back to CPU")
device = "cpu"
dtype = torch.float32
logger.info(f"Selected device: {device}, dtype: {dtype}")
return device, dtype
def _download_model(self) -> str:
"""Download the model using `snapshot_download` and ensure cache dir exists.
Returns the resolved local path for deterministic, offline-friendly loading.
Raises `RuntimeError` on failure.
"""
logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}")
try:
# Ensure local directory exists
Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True)
# Download model snapshot
local_path = snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR,
)
logger.info(f"Model downloaded successfully to {local_path}")
return local_path
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise RuntimeError(f"Model download failed: {e}")
def _can_use_flash_attn(self) -> bool:
"""Check whether FlashAttention2 can be enabled safely.
Requires all of:
- DOTS_OCR_FLASH_ATTENTION toggle is set.
- `flash_attn` is importable.
- dtype is fp16/bf16 per library support.
"""
if not USE_FLASH_ATTENTION:
return False
try:
# Import check avoids runtime error from Transformers if not installed
import flash_attn # type: ignore # noqa: F401
except Exception:
logger.warning(
"flash_attn package not installed; disabling FlashAttention2"
)
return False
# FlashAttention2 supports fp16/bf16 only (see HF docs)
return self.dtype in (torch.float16, torch.bfloat16)
def load_model(self) -> None:
"""Load the Dots.OCR model and processor.
Steps:
1) Determine device/dtype
2) Download snapshot if missing
3) Load `AutoProcessor`
4) Configure attention/device mapping
5) Instantiate model and place on target device
"""
try:
# Determine device and dtype
self.device, self.dtype = self._determine_device_and_dtype()
# Download model if not already present
self.local_dir = self._download_model()
# Load processor
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
self.local_dir, trust_remote_code=True
)
# Load model with appropriate configuration
model_kwargs = {
"dtype": self.dtype, # NOTE: `torch_dtype` is deprecated upstream
"trust_remote_code": True,
}
# Add device-specific configurations
if self.device == "cuda":
# Prefer FlashAttention2 when truly available; otherwise SDPA
if self._can_use_flash_attn():
model_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using flash attention 2")
else:
model_kwargs["attn_implementation"] = "sdpa"
logger.info(
"Using SDPA attention (flash-attn unavailable or disabled)"
)
# Use device_map for automatic GPU memory management
model_kwargs["device_map"] = "auto"
else:
# For CPU, don't use device_map
model_kwargs["device_map"] = None
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.local_dir, **model_kwargs
)
# Move model to device if not using device_map
if self.device == "cpu" or model_kwargs.get("device_map") is None:
self.model = self.model.to(self.device)
logger.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model loading failed: {e}")
def _preprocess_image(self, image: Image.Image) -> Image.Image:
"""Preprocess image to meet model requirements.
- Normalize to RGB
- Constrain pixel count within [MIN_PIXELS, MAX_PIXELS]
- Snap dimensions to multiples of 28 to satisfy backbone constraints
"""
# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")
# Calculate current pixel count
width, height = image.size
current_pixels = width * height
# Resize if necessary to meet pixel requirements
if current_pixels < MIN_PIXELS:
# Scale up to meet minimum pixel requirement
scale_factor = (MIN_PIXELS / current_pixels) ** 0.5
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Scaled up image from {width}x{height} to {new_width}x{new_height}"
)
elif current_pixels > MAX_PIXELS:
# Scale down to meet maximum pixel requirement
scale_factor = (MAX_PIXELS / current_pixels) ** 0.5
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Scaled down image from {width}x{height} to {new_width}x{new_height}"
)
# Ensure dimensions are divisible by 28 (common requirement for vision models)
width, height = image.size
new_width = ((width + 27) // 28) * 28
new_height = ((height + 27) // 28) * 28
if new_width != width or new_height != height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
logger.info(
f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}"
)
return image
@torch.inference_mode()
def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
"""Extract text from an image using the loaded model.
Builds a single-turn chat message with the image and a transcription prompt,
applies the model's chat template, and decodes deterministically (no sampling).
"""
if self.model is None or self.processor is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
try:
# Preprocess image
processed_image = self._preprocess_image(image)
# Use provided prompt or default
text_prompt = prompt or self.prompt
# Prepare messages for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": processed_image},
{"type": "text", "text": text_prompt},
],
}
]
# Apply chat template (preserves special tokens/formatting expected by model)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process vision information (required for some models)
try:
from qwen_vl_utils import process_vision_info
image_inputs, video_inputs = process_vision_info(messages)
except ImportError:
# Fallback if qwen_vl_utils not available
logger.warning("qwen_vl_utils not available, using basic processing")
image_inputs = [processed_image]
video_inputs = []
# Prepare inputs
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Generate text deterministically (temperature=0, do_sample=False)
output_ids = self.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=0.0,
pad_token_id=self.processor.tokenizer.eos_token_id,
)
# Decode only newly generated tokens (strip prompt tokens)
trimmed = [
out[len(inp) :] for inp, out in zip(inputs.input_ids, output_ids)
]
decoded = self.processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return decoded[0] if decoded else ""
except Exception as e:
logger.error(f"Text extraction failed: {e}")
raise RuntimeError(f"Text extraction failed: {e}")
def is_loaded(self) -> bool:
"""Return True when both model and processor are initialized."""
return self.model is not None and self.processor is not None
def get_model_info(self) -> Dict[str, Any]:
"""Get diagnostic information about the loaded model and configuration."""
return {
"device": self.device,
"dtype": str(self.dtype),
"local_dir": self.local_dir,
"repo_id": REPO_ID,
"max_new_tokens": MAX_NEW_TOKENS,
"use_flash_attention": USE_FLASH_ATTENTION,
"prompt": self.prompt,
"is_loaded": self.is_loaded(),
}
# Global model instance
_model_loader: Optional[DotsOCRModelLoader] = None
def get_model_loader() -> DotsOCRModelLoader:
"""Get the global model loader instance."""
global _model_loader
if _model_loader is None:
_model_loader = DotsOCRModelLoader()
return _model_loader
def load_model() -> None:
"""Load the Dots.OCR model."""
loader = get_model_loader()
loader.load_model()
def extract_text(image: Image.Image, prompt: Optional[str] = None) -> str:
"""Extract text from an image using the loaded model."""
loader = get_model_loader()
if not loader.is_loaded():
raise RuntimeError("Model not loaded. Call load_model() first.")
return loader.extract_text(image, prompt)
def is_model_loaded() -> bool:
"""Check if the model is loaded and ready."""
loader = get_model_loader()
return loader.is_loaded()
def get_model_info() -> Dict[str, Any]:
"""Get information about the loaded model."""
loader = get_model_loader()
return loader.get_model_info()
|