yadavkapil23 commited on
Commit
8959aae
·
1 Parent(s): 70215f2

OCR implememntation

Browse files
Files changed (1) hide show
  1. ocr_service.py +277 -0
ocr_service.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek OCR Service Module
3
+ Handles OCR text extraction using DeepSeek-OCR model
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from typing import Optional, Dict, Any
11
+ import logging
12
+ from pathlib import Path
13
+ from dotenv import load_dotenv
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class DeepSeekOCRService:
23
+ """
24
+ Service class for DeepSeek OCR text extraction
25
+ """
26
+
27
+ def __init__(self, model_name: str = None):
28
+ """
29
+ Initialize the DeepSeek OCR service
30
+
31
+ Args:
32
+ model_name (str): Hugging Face model name for DeepSeek OCR
33
+ """
34
+ self.model_name = model_name or os.getenv('DEEPSEEK_OCR_MODEL', 'deepseek-ai/DeepSeek-OCR')
35
+ self.model = None
36
+ self.tokenizer = None
37
+
38
+ # Device configuration - optimized for CPU
39
+ device_config = os.getenv('DEEPSEEK_OCR_DEVICE', 'cpu')
40
+ if device_config == 'auto':
41
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ else:
43
+ self.device = device_config
44
+
45
+ logger.info(f"Initializing DeepSeek OCR on device: {self.device}")
46
+
47
+ def load_model(self):
48
+ """
49
+ Load the DeepSeek OCR model and tokenizer
50
+ """
51
+ try:
52
+ logger.info(f"Loading DeepSeek OCR model: {self.model_name}")
53
+ self.tokenizer = AutoTokenizer.from_pretrained(
54
+ self.model_name,
55
+ trust_remote_code=True
56
+ )
57
+ # CPU-optimized model loading
58
+ if self.device == "cpu":
59
+ self.model = AutoModelForCausalLM.from_pretrained(
60
+ self.model_name,
61
+ trust_remote_code=True,
62
+ torch_dtype=torch.float32, # Use float32 for CPU
63
+ low_cpu_mem_usage=True, # Reduce memory usage
64
+ device_map="cpu" # Force CPU usage
65
+ )
66
+ else:
67
+ self.model = AutoModelForCausalLM.from_pretrained(
68
+ self.model_name,
69
+ trust_remote_code=True,
70
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
71
+ )
72
+ self.model.to(self.device)
73
+ logger.info("DeepSeek OCR model loaded successfully")
74
+ except Exception as e:
75
+ logger.error(f"Failed to load DeepSeek OCR model: {str(e)}")
76
+ raise e
77
+
78
+ def extract_text_from_image(self, image_path: str, prompt: str = None) -> Dict[str, Any]:
79
+ """
80
+ Extract text from an image using DeepSeek OCR
81
+
82
+ Args:
83
+ image_path (str): Path to the image file
84
+ prompt (str, optional): Custom prompt for OCR processing
85
+
86
+ Returns:
87
+ Dict containing extracted text and metadata
88
+ """
89
+ if self.model is None or self.tokenizer is None:
90
+ self.load_model()
91
+
92
+ try:
93
+ # Load and preprocess the image
94
+ image = Image.open(image_path)
95
+ if image.mode != 'RGB':
96
+ image = image.convert('RGB')
97
+
98
+ # Use default prompt if none provided
99
+ if prompt is None:
100
+ prompt = "<|grounding|>Extract all text from this image."
101
+
102
+ # Prepare inputs
103
+ inputs = self.tokenizer(
104
+ prompt,
105
+ image,
106
+ return_tensors="pt"
107
+ ).to(self.device)
108
+
109
+ # Get configuration from environment - CPU optimized defaults
110
+ max_tokens = int(os.getenv('DEEPSEEK_OCR_MAX_TOKENS', '256')) # Reduced for CPU
111
+ temperature = float(os.getenv('DEEPSEEK_OCR_TEMPERATURE', '0.1'))
112
+
113
+ # Generate text extraction
114
+ with torch.no_grad():
115
+ outputs = self.model.generate(
116
+ **inputs,
117
+ max_new_tokens=max_tokens,
118
+ do_sample=False,
119
+ temperature=temperature,
120
+ pad_token_id=self.tokenizer.eos_token_id
121
+ )
122
+
123
+ # Decode the output
124
+ extracted_text = self.tokenizer.decode(
125
+ outputs[0],
126
+ skip_special_tokens=True
127
+ )
128
+
129
+ # Clean up the extracted text
130
+ extracted_text = extracted_text.replace(prompt, "").strip()
131
+
132
+ return {
133
+ "success": True,
134
+ "extracted_text": extracted_text,
135
+ "image_path": image_path,
136
+ "model_used": self.model_name,
137
+ "device": self.device
138
+ }
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error extracting text from image {image_path}: {str(e)}")
142
+ return {
143
+ "success": False,
144
+ "error": str(e),
145
+ "image_path": image_path
146
+ }
147
+
148
+ def extract_text_with_grounding(self, image_path: str, target_text: str = None) -> Dict[str, Any]:
149
+ """
150
+ Extract text with grounding capabilities (locate specific text)
151
+
152
+ Args:
153
+ image_path (str): Path to the image file
154
+ target_text (str, optional): Specific text to locate in the image
155
+
156
+ Returns:
157
+ Dict containing extracted text and location information
158
+ """
159
+ if self.model is None or self.tokenizer is None:
160
+ self.load_model()
161
+
162
+ try:
163
+ image = Image.open(image_path)
164
+ if image.mode != 'RGB':
165
+ image = image.convert('RGB')
166
+
167
+ if target_text:
168
+ prompt = f"<|grounding|>Locate <|ref|>{target_text}<|/ref|> in the image."
169
+ else:
170
+ prompt = "<|grounding|>Extract all text from this image with location information."
171
+
172
+ inputs = self.tokenizer(
173
+ prompt,
174
+ image,
175
+ return_tensors="pt"
176
+ ).to(self.device)
177
+
178
+ with torch.no_grad():
179
+ outputs = self.model.generate(
180
+ **inputs,
181
+ max_new_tokens=512,
182
+ do_sample=False,
183
+ temperature=0.1,
184
+ pad_token_id=self.tokenizer.eos_token_id
185
+ )
186
+
187
+ extracted_text = self.tokenizer.decode(
188
+ outputs[0],
189
+ skip_special_tokens=True
190
+ )
191
+
192
+ extracted_text = extracted_text.replace(prompt, "").strip()
193
+
194
+ return {
195
+ "success": True,
196
+ "extracted_text": extracted_text,
197
+ "grounding_info": target_text if target_text else "all_text",
198
+ "image_path": image_path,
199
+ "model_used": self.model_name
200
+ }
201
+
202
+ except Exception as e:
203
+ logger.error(f"Error in grounding extraction from {image_path}: {str(e)}")
204
+ return {
205
+ "success": False,
206
+ "error": str(e),
207
+ "image_path": image_path
208
+ }
209
+
210
+ def convert_to_markdown(self, image_path: str) -> Dict[str, Any]:
211
+ """
212
+ Convert document image to markdown format
213
+
214
+ Args:
215
+ image_path (str): Path to the image file
216
+
217
+ Returns:
218
+ Dict containing markdown formatted text
219
+ """
220
+ if self.model is None or self.tokenizer is None:
221
+ self.load_model()
222
+
223
+ try:
224
+ image = Image.open(image_path)
225
+ if image.mode != 'RGB':
226
+ image = image.convert('RGB')
227
+
228
+ prompt = "<|grounding|>Convert the document to markdown format."
229
+
230
+ inputs = self.tokenizer(
231
+ prompt,
232
+ image,
233
+ return_tensors="pt"
234
+ ).to(self.device)
235
+
236
+ with torch.no_grad():
237
+ outputs = self.model.generate(
238
+ **inputs,
239
+ max_new_tokens=1024,
240
+ do_sample=False,
241
+ temperature=0.1,
242
+ pad_token_id=self.tokenizer.eos_token_id
243
+ )
244
+
245
+ markdown_text = self.tokenizer.decode(
246
+ outputs[0],
247
+ skip_special_tokens=True
248
+ )
249
+
250
+ markdown_text = markdown_text.replace(prompt, "").strip()
251
+
252
+ return {
253
+ "success": True,
254
+ "markdown_text": markdown_text,
255
+ "image_path": image_path,
256
+ "model_used": self.model_name
257
+ }
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error converting to markdown from {image_path}: {str(e)}")
261
+ return {
262
+ "success": False,
263
+ "error": str(e),
264
+ "image_path": image_path
265
+ }
266
+
267
+ # Global OCR service instance
268
+ ocr_service = DeepSeekOCRService()
269
+
270
+ def get_ocr_service() -> DeepSeekOCRService:
271
+ """
272
+ Get the global OCR service instance
273
+
274
+ Returns:
275
+ DeepSeekOCRService: The OCR service instance
276
+ """
277
+ return ocr_service