Spaces:
Running
Running
| """ | |
| DeepSeek OCR Service Module | |
| Handles OCR text extraction using DeepSeek-OCR model | |
| """ | |
| import os | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, Dict, Any | |
| import logging | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DeepSeekOCRService: | |
| """ | |
| Service class for DeepSeek OCR text extraction | |
| """ | |
| def __init__(self, model_name: str = None): | |
| """ | |
| Initialize the DeepSeek OCR service | |
| Args: | |
| model_name (str): Hugging Face model name for DeepSeek OCR | |
| """ | |
| self.model_name = model_name or os.getenv('DEEPSEEK_OCR_MODEL', 'deepseek-ai/DeepSeek-OCR') | |
| self.model = None | |
| self.tokenizer = None | |
| # Device configuration - optimized for CPU | |
| device_config = os.getenv('DEEPSEEK_OCR_DEVICE', 'cpu') | |
| if device_config == 'auto': | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device_config | |
| logger.info(f"Initializing DeepSeek OCR on device: {self.device}") | |
| def load_model(self): | |
| """ | |
| Load the DeepSeek OCR model and tokenizer | |
| """ | |
| try: | |
| logger.info(f"Loading DeepSeek OCR model: {self.model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True | |
| ) | |
| # CPU-optimized model loading | |
| if self.device == "cpu": | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, # Use float32 for CPU | |
| low_cpu_mem_usage=True, # Reduce memory usage | |
| device_map="cpu" # Force CPU usage | |
| ) | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| self.model.to(self.device) | |
| logger.info("DeepSeek OCR model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load DeepSeek OCR model: {str(e)}") | |
| raise e | |
| def extract_text_from_image(self, image_path: str, prompt: str = None) -> Dict[str, Any]: | |
| """ | |
| Extract text from an image using DeepSeek OCR | |
| Args: | |
| image_path (str): Path to the image file | |
| prompt (str, optional): Custom prompt for OCR processing | |
| Returns: | |
| Dict containing extracted text and metadata | |
| """ | |
| if self.model is None or self.tokenizer is None: | |
| self.load_model() | |
| try: | |
| # Load and preprocess the image | |
| image = Image.open(image_path) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Use default prompt if none provided | |
| if prompt is None: | |
| prompt = "<|grounding|>Extract all text from this image." | |
| # Prepare inputs | |
| inputs = self.tokenizer( | |
| prompt, | |
| image, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Get configuration from environment - CPU optimized defaults | |
| max_tokens = int(os.getenv('DEEPSEEK_OCR_MAX_TOKENS', '256')) # Reduced for CPU | |
| temperature = float(os.getenv('DEEPSEEK_OCR_TEMPERATURE', '0.1')) | |
| # Generate text extraction | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=temperature, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode the output | |
| extracted_text = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| # Clean up the extracted text | |
| extracted_text = extracted_text.replace(prompt, "").strip() | |
| return { | |
| "success": True, | |
| "extracted_text": extracted_text, | |
| "image_path": image_path, | |
| "model_used": self.model_name, | |
| "device": self.device | |
| } | |
| except Exception as e: | |
| logger.error(f"Error extracting text from image {image_path}: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "image_path": image_path | |
| } | |
| def extract_text_with_grounding(self, image_path: str, target_text: str = None) -> Dict[str, Any]: | |
| """ | |
| Extract text with grounding capabilities (locate specific text) | |
| Args: | |
| image_path (str): Path to the image file | |
| target_text (str, optional): Specific text to locate in the image | |
| Returns: | |
| Dict containing extracted text and location information | |
| """ | |
| if self.model is None or self.tokenizer is None: | |
| self.load_model() | |
| try: | |
| image = Image.open(image_path) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| if target_text: | |
| prompt = f"<|grounding|>Locate <|ref|>{target_text}<|/ref|> in the image." | |
| else: | |
| prompt = "<|grounding|>Extract all text from this image with location information." | |
| inputs = self.tokenizer( | |
| prompt, | |
| image, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| temperature=0.1, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| extracted_text = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| extracted_text = extracted_text.replace(prompt, "").strip() | |
| return { | |
| "success": True, | |
| "extracted_text": extracted_text, | |
| "grounding_info": target_text if target_text else "all_text", | |
| "image_path": image_path, | |
| "model_used": self.model_name | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in grounding extraction from {image_path}: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "image_path": image_path | |
| } | |
| def convert_to_markdown(self, image_path: str) -> Dict[str, Any]: | |
| """ | |
| Convert document image to markdown format | |
| Args: | |
| image_path (str): Path to the image file | |
| Returns: | |
| Dict containing markdown formatted text | |
| """ | |
| if self.model is None or self.tokenizer is None: | |
| self.load_model() | |
| try: | |
| image = Image.open(image_path) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| prompt = "<|grounding|>Convert the document to markdown format." | |
| inputs = self.tokenizer( | |
| prompt, | |
| image, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| temperature=0.1, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| markdown_text = self.tokenizer.decode( | |
| outputs[0], | |
| skip_special_tokens=True | |
| ) | |
| markdown_text = markdown_text.replace(prompt, "").strip() | |
| return { | |
| "success": True, | |
| "markdown_text": markdown_text, | |
| "image_path": image_path, | |
| "model_used": self.model_name | |
| } | |
| except Exception as e: | |
| logger.error(f"Error converting to markdown from {image_path}: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "image_path": image_path | |
| } | |
| # Global OCR service instance | |
| ocr_service = DeepSeekOCRService() | |
| def get_ocr_service() -> DeepSeekOCRService: | |
| """ | |
| Get the global OCR service instance | |
| Returns: | |
| DeepSeekOCRService: The OCR service instance | |
| """ | |
| return ocr_service | |