Shirochi commited on
Commit
0cc2912
·
verified ·
1 Parent(s): 0bf8c19

Delete ocr_manager.py

Browse files
Files changed (1) hide show
  1. ocr_manager.py +0 -1904
ocr_manager.py DELETED
@@ -1,1904 +0,0 @@
1
- # ocr_manager.py
2
- """
3
- OCR Manager for handling multiple OCR providers
4
- Handles installation, model downloading, and OCR processing
5
- Updated with HuggingFace donut model and proper bubble detection integration
6
- """
7
- import os
8
- import sys
9
- import cv2
10
- import json
11
- import subprocess
12
- import threading
13
- import traceback
14
- from typing import List, Dict, Optional, Tuple, Any
15
- import numpy as np
16
- from dataclasses import dataclass
17
- from PIL import Image
18
- import logging
19
- import time
20
- import random
21
- import base64
22
- import io
23
- import requests
24
-
25
- try:
26
- import gptqmodel
27
- HAS_GPTQ = True
28
- except ImportError:
29
- try:
30
- import auto_gptq
31
- HAS_GPTQ = True
32
- except ImportError:
33
- HAS_GPTQ = False
34
-
35
- try:
36
- import optimum
37
- HAS_OPTIMUM = True
38
- except ImportError:
39
- HAS_OPTIMUM = False
40
-
41
- try:
42
- import accelerate
43
- HAS_ACCELERATE = True
44
- except ImportError:
45
- HAS_ACCELERATE = False
46
-
47
- logger = logging.getLogger(__name__)
48
-
49
- @dataclass
50
- class OCRResult:
51
- """Unified OCR result format with built-in sanitization to prevent data corruption."""
52
- text: str
53
- bbox: Tuple[int, int, int, int] # x, y, w, h
54
- confidence: float
55
- vertices: Optional[List[Tuple[int, int]]] = None
56
-
57
- def __post_init__(self):
58
- """
59
- This special method is called automatically after the object is created.
60
- It acts as a final safeguard to ensure the 'text' attribute is ALWAYS a clean string.
61
- """
62
- # --- THIS IS THE DEFINITIVE FIX ---
63
- # If the text we received is a tuple, we extract the first element.
64
- # This makes it impossible for a tuple to exist in a finished object.
65
- if isinstance(self.text, tuple):
66
- # Log that we are fixing a critical data error.
67
- print(f"CRITICAL WARNING: Corrupted tuple detected in OCRResult. Sanitizing '{self.text}' to '{self.text[0]}'.")
68
- self.text = self.text[0]
69
-
70
- # Ensure the final result is always a stripped string.
71
- self.text = str(self.text).strip()
72
-
73
- class OCRProvider:
74
- """Base class for OCR providers"""
75
-
76
- def __init__(self, log_callback=None):
77
- # Set thread limits early if environment indicates single-threaded mode
78
- try:
79
- if os.environ.get('OMP_NUM_THREADS') == '1':
80
- # Already in single-threaded mode, ensure it's applied to this process
81
- try:
82
- import sys
83
- if 'torch' in sys.modules:
84
- import torch
85
- torch.set_num_threads(1)
86
- except (ImportError, RuntimeError, AttributeError):
87
- pass
88
- try:
89
- import cv2
90
- cv2.setNumThreads(1)
91
- except (ImportError, AttributeError):
92
- pass
93
- except Exception:
94
- pass
95
-
96
- self.log_callback = log_callback
97
- self.is_installed = False
98
- self.is_loaded = False
99
- self.model = None
100
- self.stop_flag = None
101
- self._stopped = False
102
-
103
- def _log(self, message: str, level: str = "info"):
104
- """Log message with stop suppression"""
105
- # Suppress logs when stopped (allow only essential stop confirmation messages)
106
- if self._check_stop():
107
- essential_stop_keywords = [
108
- "⏹️ Translation stopped by user",
109
- "⏹️ OCR processing stopped",
110
- "cleanup", "🧹"
111
- ]
112
- if not any(keyword in message for keyword in essential_stop_keywords):
113
- return
114
-
115
- if self.log_callback:
116
- self.log_callback(message, level)
117
- else:
118
- print(f"[{level.upper()}] {message}")
119
-
120
- def set_stop_flag(self, stop_flag):
121
- """Set the stop flag for checking interruptions"""
122
- self.stop_flag = stop_flag
123
- self._stopped = False
124
-
125
- def _check_stop(self) -> bool:
126
- """Check if stop has been requested"""
127
- if self._stopped:
128
- return True
129
- if self.stop_flag and self.stop_flag.is_set():
130
- self._stopped = True
131
- return True
132
- # Check global manga translator cancellation
133
- try:
134
- from manga_translator import MangaTranslator
135
- if MangaTranslator.is_globally_cancelled():
136
- self._stopped = True
137
- return True
138
- except Exception:
139
- pass
140
- return False
141
-
142
- def reset_stop_flags(self):
143
- """Reset stop flags when starting new processing"""
144
- self._stopped = False
145
-
146
- def check_installation(self) -> bool:
147
- """Check if provider is installed"""
148
- raise NotImplementedError
149
-
150
- def install(self, progress_callback=None) -> bool:
151
- """Install the provider"""
152
- raise NotImplementedError
153
-
154
- def load_model(self, **kwargs) -> bool:
155
- """Load the OCR model"""
156
- raise NotImplementedError
157
-
158
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
159
- """Detect text in image"""
160
- raise NotImplementedError
161
-
162
- class CustomAPIProvider(OCRProvider):
163
- """Custom API OCR provider that uses existing GUI variables"""
164
-
165
- def __init__(self, log_callback=None):
166
- super().__init__(log_callback)
167
-
168
- # Use EXISTING environment variables from TranslatorGUI
169
- self.api_url = os.environ.get('OPENAI_CUSTOM_BASE_URL', '')
170
- self.api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
171
- self.model_name = os.environ.get('MODEL', 'gpt-4o-mini')
172
-
173
- # OCR prompt - use system prompt or a dedicated OCR prompt variable
174
- self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
175
- os.environ.get('SYSTEM_PROMPT',
176
- "YOU ARE AN OCR SYSTEM. YOUR ONLY JOB IS TEXT EXTRACTION.\n\n"
177
- "CRITICAL RULES:\n"
178
- "1. DO NOT TRANSLATE ANYTHING\n"
179
- "2. DO NOT MODIFY THE TEXT\n"
180
- "3. DO NOT EXPLAIN OR COMMENT\n"
181
- "4. ONLY OUTPUT THE EXACT TEXT YOU SEE\n"
182
- "5. PRESERVE NATURAL TEXT FLOW - DO NOT ADD UNNECESSARY LINE BREAKS\n\n"
183
- "If you see Korean text, output it in Korean.\n"
184
- "If you see Japanese text, output it in Japanese.\n"
185
- "If you see Chinese text, output it in Chinese.\n"
186
- "If you see English text, output it in English.\n\n"
187
- "IMPORTANT: Only use line breaks where they naturally occur in the original text "
188
- "(e.g., between dialogue lines or paragraphs). Do not break text mid-sentence or "
189
- "between every word/character.\n\n"
190
- "For vertical text common in manga/comics, transcribe it as a continuous line unless "
191
- "there are clear visual breaks.\n\n"
192
- "NEVER translate. ONLY extract exactly what is written.\n"
193
- "Output ONLY the raw text, nothing else."
194
- ))
195
-
196
- # Use existing temperature and token settings
197
- self.temperature = float(os.environ.get('TRANSLATION_TEMPERATURE', '0.01'))
198
- # Don't hardcode to 8192 - get fresh value when actually used
199
- self.max_tokens = int(os.environ.get('MAX_OUTPUT_TOKENS', '4096'))
200
-
201
- # Image settings from existing compression variables
202
- self.image_format = 'jpeg' if os.environ.get('IMAGE_COMPRESSION_FORMAT', 'auto') != 'png' else 'png'
203
- self.image_quality = int(os.environ.get('JPEG_QUALITY', '100'))
204
-
205
- # Simple defaults
206
- self.api_format = 'openai' # Most custom endpoints are OpenAI-compatible
207
- self.timeout = int(os.environ.get('CHUNK_TIMEOUT', '30'))
208
- self.api_headers = {} # Additional custom headers
209
-
210
- # Retry configuration for Custom API OCR calls
211
- self.max_retries = int(os.environ.get('CUSTOM_OCR_MAX_RETRIES', '3'))
212
- self.retry_initial_delay = float(os.environ.get('CUSTOM_OCR_RETRY_INITIAL_DELAY', '0.8'))
213
- self.retry_backoff = float(os.environ.get('CUSTOM_OCR_RETRY_BACKOFF', '1.8'))
214
- self.retry_jitter = float(os.environ.get('CUSTOM_OCR_RETRY_JITTER', '0.4'))
215
- self.retry_on_empty = os.environ.get('CUSTOM_OCR_RETRY_ON_EMPTY', '1') == '1'
216
-
217
- def check_installation(self) -> bool:
218
- """Always installed - uses UnifiedClient"""
219
- self.is_installed = True
220
- return True
221
-
222
- def install(self, progress_callback=None) -> bool:
223
- """No installation needed for API-based provider"""
224
- return self.check_installation()
225
-
226
- def load_model(self, **kwargs) -> bool:
227
- """Initialize UnifiedClient with current settings"""
228
- try:
229
- from unified_api_client import UnifiedClient
230
-
231
- # Support passing API key from GUI if available
232
- if 'api_key' in kwargs:
233
- api_key = kwargs['api_key']
234
- else:
235
- api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
236
-
237
- if 'model' in kwargs:
238
- model = kwargs['model']
239
- else:
240
- model = os.environ.get('MODEL', 'gpt-4o-mini')
241
-
242
- if not api_key:
243
- self._log("❌ No API key configured", "error")
244
- return False
245
-
246
- # Create UnifiedClient just like translations do
247
- self.client = UnifiedClient(model=model, api_key=api_key)
248
-
249
- #self._log(f"✅ Using {model} for OCR via UnifiedClient")
250
- self.is_loaded = True
251
- return True
252
-
253
- except Exception as e:
254
- self._log(f"❌ Failed to initialize UnifiedClient: {str(e)}", "error")
255
- return False
256
-
257
- def _test_connection(self) -> bool:
258
- """Test API connection with a simple request"""
259
- try:
260
- # Create a small test image
261
- test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
262
- cv2.putText(test_image, "TEST", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
263
-
264
- # Encode image
265
- image_base64 = self._encode_image(test_image)
266
-
267
- # Prepare test request based on API format
268
- if self.api_format == 'openai':
269
- test_payload = {
270
- "model": self.model_name,
271
- "messages": [
272
- {
273
- "role": "user",
274
- "content": [
275
- {"type": "text", "text": "What text do you see?"},
276
- {"type": "image_url", "image_url": {"url": f"data:image/{self.image_format};base64,{image_base64}"}}
277
- ]
278
- }
279
- ],
280
- "max_tokens": 50
281
- }
282
- else:
283
- # For other formats, just try a basic health check
284
- return True
285
-
286
- headers = self._prepare_headers()
287
- response = requests.post(
288
- self.api_url,
289
- headers=headers,
290
- json=test_payload,
291
- timeout=10
292
- )
293
-
294
- return response.status_code == 200
295
-
296
- except Exception:
297
- return False
298
-
299
- def _encode_image(self, image: np.ndarray) -> str:
300
- """Encode numpy array to base64 string"""
301
- # Convert BGR to RGB if needed
302
- if len(image.shape) == 3 and image.shape[2] == 3:
303
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
304
- else:
305
- image_rgb = image
306
-
307
- # Convert to PIL Image
308
- pil_image = Image.fromarray(image_rgb)
309
-
310
- # Save to bytes buffer
311
- buffer = io.BytesIO()
312
- if self.image_format.lower() == 'png':
313
- pil_image.save(buffer, format='PNG')
314
- else:
315
- pil_image.save(buffer, format='JPEG', quality=self.image_quality)
316
-
317
- # Encode to base64
318
- buffer.seek(0)
319
- image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
320
-
321
- return image_base64
322
-
323
- def _prepare_headers(self) -> dict:
324
- """Prepare request headers"""
325
- headers = {
326
- "Content-Type": "application/json"
327
- }
328
-
329
- # Add API key if configured
330
- if self.api_key:
331
- if self.api_format == 'anthropic':
332
- headers["x-api-key"] = self.api_key
333
- else:
334
- headers["Authorization"] = f"Bearer {self.api_key}"
335
-
336
- # Add any custom headers
337
- headers.update(self.api_headers)
338
-
339
- return headers
340
-
341
- def _prepare_request_payload(self, image_base64: str) -> dict:
342
- """Prepare request payload based on API format"""
343
- if self.api_format == 'openai':
344
- return {
345
- "model": self.model_name,
346
- "messages": [
347
- {
348
- "role": "user",
349
- "content": [
350
- {"type": "text", "text": self.ocr_prompt},
351
- {
352
- "type": "image_url",
353
- "image_url": {
354
- "url": f"data:image/{self.image_format};base64,{image_base64}"
355
- }
356
- }
357
- ]
358
- }
359
- ],
360
- "max_tokens": self.max_tokens,
361
- "temperature": self.temperature
362
- }
363
-
364
- elif self.api_format == 'anthropic':
365
- return {
366
- "model": self.model_name,
367
- "max_tokens": self.max_tokens,
368
- "temperature": self.temperature,
369
- "messages": [
370
- {
371
- "role": "user",
372
- "content": [
373
- {
374
- "type": "text",
375
- "text": self.ocr_prompt
376
- },
377
- {
378
- "type": "image",
379
- "source": {
380
- "type": "base64",
381
- "media_type": f"image/{self.image_format}",
382
- "data": image_base64
383
- }
384
- }
385
- ]
386
- }
387
- ]
388
- }
389
-
390
- else:
391
- # Custom format - use environment variable for template
392
- template = os.environ.get('CUSTOM_OCR_REQUEST_TEMPLATE', '{}')
393
- payload = json.loads(template)
394
-
395
- # Replace placeholders
396
- payload_str = json.dumps(payload)
397
- payload_str = payload_str.replace('{{IMAGE_BASE64}}', image_base64)
398
- payload_str = payload_str.replace('{{PROMPT}}', self.ocr_prompt)
399
- payload_str = payload_str.replace('{{MODEL}}', self.model_name)
400
- payload_str = payload_str.replace('{{MAX_TOKENS}}', str(self.max_tokens))
401
- payload_str = payload_str.replace('{{TEMPERATURE}}', str(self.temperature))
402
-
403
- return json.loads(payload_str)
404
-
405
- def _extract_text_from_response(self, response_data: dict) -> str:
406
- """Extract text from API response based on format"""
407
- try:
408
- if self.api_format == 'openai':
409
- # OpenAI format: response.choices[0].message.content
410
- return response_data.get('choices', [{}])[0].get('message', {}).get('content', '')
411
-
412
- elif self.api_format == 'anthropic':
413
- # Anthropic format: response.content[0].text
414
- content = response_data.get('content', [])
415
- if content and isinstance(content, list):
416
- return content[0].get('text', '')
417
- return ''
418
-
419
- else:
420
- # Custom format - use environment variable for path
421
- response_path = os.environ.get('CUSTOM_OCR_RESPONSE_PATH', 'text')
422
-
423
- # Navigate through the response using the path
424
- result = response_data
425
- for key in response_path.split('.'):
426
- if isinstance(result, dict):
427
- result = result.get(key, '')
428
- elif isinstance(result, list) and key.isdigit():
429
- idx = int(key)
430
- result = result[idx] if idx < len(result) else ''
431
- else:
432
- result = ''
433
- break
434
-
435
- return str(result)
436
-
437
- except Exception as e:
438
- self._log(f"Failed to extract text from response: {e}", "error")
439
- return ''
440
-
441
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
442
- """Process image using UnifiedClient.send_image()"""
443
- results = []
444
-
445
- try:
446
- # Get fresh max_tokens from environment - GUI will have set this
447
- max_tokens = int(os.environ.get('MAX_OUTPUT_TOKENS', '4096'))
448
- if not self.is_loaded:
449
- if not self.load_model():
450
- return results
451
-
452
- import cv2
453
- from PIL import Image
454
- import base64
455
- import io
456
-
457
- # Convert numpy array to PIL Image
458
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
459
- pil_image = Image.fromarray(image_rgb)
460
- h, w = image.shape[:2]
461
-
462
- # Convert PIL Image to base64 string
463
- buffer = io.BytesIO()
464
-
465
- # Use the image format from settings
466
- if self.image_format.lower() == 'png':
467
- pil_image.save(buffer, format='PNG')
468
- else:
469
- pil_image.save(buffer, format='JPEG', quality=self.image_quality)
470
-
471
- buffer.seek(0)
472
- image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
473
-
474
- # For OpenAI vision models, we need BOTH:
475
- # 1. System prompt with instructions
476
- # 2. User message that includes the image
477
- messages = [
478
- {
479
- "role": "system",
480
- "content": self.ocr_prompt # The OCR instruction as system prompt
481
- },
482
- {
483
- "role": "user",
484
- "content": [
485
- {
486
- "type": "text",
487
- "text": "Image:" # Minimal text, just to have something
488
- },
489
- {
490
- "type": "image_url",
491
- "image_url": {
492
- "url": f"data:image/jpeg;base64,{image_base64}"
493
- }
494
- }
495
- ]
496
- }
497
- ]
498
-
499
- # Now send this properly formatted message
500
- # The UnifiedClient should handle this correctly
501
- # But we're NOT using send_image, we're using regular send
502
-
503
- # Retry-aware call
504
- from unified_api_client import UnifiedClientError # local import to avoid hard dependency at module import time
505
- max_attempts = max(1, self.max_retries)
506
- attempt = 0
507
- last_error = None
508
-
509
- # Common refusal/error phrases that indicate a non-OCR response
510
- refusal_phrases = [
511
- "I can't extract", "I cannot extract",
512
- "I'm sorry", "I am sorry",
513
- "I'm unable", "I am unable",
514
- "cannot process images",
515
- "I can't help with that",
516
- "cannot view images",
517
- "no text in the image"
518
- ]
519
-
520
- while attempt < max_attempts:
521
- # Check for stop before each attempt
522
- if self._check_stop():
523
- self._log("⏹️ OCR processing stopped by user", "warning")
524
- return results
525
-
526
- try:
527
- response = self.client.send(
528
- messages=messages,
529
- temperature=self.temperature,
530
- max_tokens=max_tokens
531
- )
532
-
533
- # Extract content from response object
534
- content, finish_reason = response
535
-
536
- # Validate content
537
- has_content = bool(content and str(content).strip())
538
- refused = False
539
- if has_content:
540
- # Filter out explicit failure markers
541
- if "[" in content and "FAILED]" in content:
542
- refused = True
543
- elif any(phrase.lower() in content.lower() for phrase in refusal_phrases):
544
- refused = True
545
-
546
- # Decide success or retry
547
- if has_content and not refused:
548
- text = str(content).strip()
549
- results.append(OCRResult(
550
- text=text,
551
- bbox=(0, 0, w, h),
552
- confidence=kwargs.get('confidence', 0.85),
553
- vertices=[(0, 0), (w, 0), (w, h), (0, h)]
554
- ))
555
- self._log(f"✅ Detected: {text[:50]}...")
556
- break # success
557
- else:
558
- reason = "empty result" if not has_content else "refusal/non-OCR response"
559
- last_error = f"{reason} (finish_reason: {finish_reason})"
560
- # Check if we should retry on empty or refusal
561
- should_retry = (not has_content and self.retry_on_empty) or refused
562
- attempt += 1
563
- if attempt >= max_attempts or not should_retry:
564
- # No more retries or shouldn't retry
565
- if not has_content:
566
- self._log(f"⚠️ No text detected (finish_reason: {finish_reason})")
567
- else:
568
- self._log(f"❌ Model returned non-OCR response: {str(content)[:120]}", "warning")
569
- break
570
- # Backoff before retrying
571
- delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
572
- self._log(f"🔄 Retry {attempt}/{max_attempts - 1} after {delay:.1f}s due to {reason}...", "warning")
573
- time.sleep(delay)
574
- time.sleep(0.1) # Brief pause for stability
575
- self._log("💤 OCR retry pausing briefly for stability", "debug")
576
- continue
577
-
578
- except UnifiedClientError as ue:
579
- msg = str(ue)
580
- last_error = msg
581
- # Do not retry on explicit user cancellation
582
- if 'cancelled' in msg.lower() or 'stopped by user' in msg.lower():
583
- self._log(f"❌ OCR cancelled: {msg}", "error")
584
- break
585
- attempt += 1
586
- if attempt >= max_attempts:
587
- self._log(f"❌ OCR failed after {attempt} attempts: {msg}", "error")
588
- break
589
- delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
590
- self._log(f"🔄 API error, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {msg}", "warning")
591
- time.sleep(delay)
592
- time.sleep(0.1) # Brief pause for stability
593
- self._log("💤 OCR API error retry pausing briefly for stability", "debug")
594
- continue
595
- except Exception as e_inner:
596
- last_error = str(e_inner)
597
- attempt += 1
598
- if attempt >= max_attempts:
599
- self._log(f"❌ OCR exception after {attempt} attempts: {last_error}", "error")
600
- break
601
- delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
602
- self._log(f"🔄 Exception, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {last_error}", "warning")
603
- time.sleep(delay)
604
- time.sleep(0.1) # Brief pause for stability
605
- self._log("💤 OCR exception retry pausing briefly for stability", "debug")
606
- continue
607
-
608
- except Exception as e:
609
- self._log(f"❌ Error: {str(e)}", "error")
610
- import traceback
611
- self._log(traceback.format_exc(), "debug")
612
-
613
- return results
614
-
615
- class MangaOCRProvider(OCRProvider):
616
- """Manga OCR provider using HuggingFace model directly"""
617
-
618
- def __init__(self, log_callback=None):
619
- super().__init__(log_callback)
620
- self.processor = None
621
- self.model = None
622
- self.tokenizer = None
623
-
624
- def check_installation(self) -> bool:
625
- """Check if transformers is installed"""
626
- try:
627
- import transformers
628
- import torch
629
- self.is_installed = True
630
- return True
631
- except ImportError:
632
- return False
633
-
634
- def install(self, progress_callback=None) -> bool:
635
- """Install transformers and torch"""
636
- pass
637
-
638
- def _is_valid_local_model_dir(self, path: str) -> bool:
639
- """Check that a local HF model directory has required files."""
640
- try:
641
- if not path or not os.path.isdir(path):
642
- return False
643
- needed_any_weights = any(
644
- os.path.exists(os.path.join(path, name)) for name in (
645
- 'pytorch_model.bin',
646
- 'model.safetensors'
647
- )
648
- )
649
- has_config = os.path.exists(os.path.join(path, 'config.json'))
650
- has_processor = (
651
- os.path.exists(os.path.join(path, 'preprocessor_config.json')) or
652
- os.path.exists(os.path.join(path, 'processor_config.json'))
653
- )
654
- has_tokenizer = (
655
- os.path.exists(os.path.join(path, 'tokenizer.json')) or
656
- os.path.exists(os.path.join(path, 'tokenizer_config.json'))
657
- )
658
- return has_config and needed_any_weights and has_processor and has_tokenizer
659
- except Exception:
660
- return False
661
-
662
- def load_model(self, **kwargs) -> bool:
663
- """Load the manga-ocr model, preferring a local directory to avoid re-downloading"""
664
- print("\n>>> MangaOCRProvider.load_model() called")
665
- try:
666
- if not self.is_installed and not self.check_installation():
667
- print("ERROR: Transformers not installed")
668
- self._log("❌ Transformers not installed", "error")
669
- return False
670
-
671
- # Always disable progress bars to avoid tqdm issues in some environments
672
- import os
673
- os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
674
-
675
- from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor
676
- import torch
677
-
678
- # Prefer a local model directory if present to avoid any Hub access
679
- candidates = []
680
- env_local = os.environ.get("MANGA_OCR_LOCAL_DIR")
681
- if env_local:
682
- candidates.append(env_local)
683
-
684
- # Project root one level up from this file
685
- root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
686
- candidates.append(os.path.join(root_dir, 'models', 'manga-ocr-base'))
687
- candidates.append(os.path.join(root_dir, 'models', 'kha-white', 'manga-ocr-base'))
688
-
689
- model_source = None
690
- local_only = False
691
- # Find a valid local dir
692
- for cand in candidates:
693
- if self._is_valid_local_model_dir(cand):
694
- model_source = cand
695
- local_only = True
696
- break
697
-
698
- # If no valid local dir, use Hub
699
- if not model_source:
700
- model_source = "kha-white/manga-ocr-base"
701
- # Make sure we are not forcing offline mode
702
- if os.environ.get("HF_HUB_OFFLINE") == "1":
703
- try:
704
- del os.environ["HF_HUB_OFFLINE"]
705
- except Exception:
706
- pass
707
- self._log("🔥 Loading manga-ocr model from Hugging Face Hub")
708
- self._log(f" Repo: {model_source}")
709
- else:
710
- # Only set offline when local dir is fully valid
711
- os.environ.setdefault("HF_HUB_OFFLINE", "1")
712
- self._log("🔥 Loading manga-ocr model from local directory")
713
- self._log(f" Local path: {model_source}")
714
-
715
- # Decide target device once; we will move after full CPU load to avoid meta tensors
716
- use_cuda = torch.cuda.is_available()
717
-
718
- # Try loading components, falling back to Hub if local-only fails
719
- def _load_components(source: str, local_flag: bool):
720
- self._log(" Loading tokenizer...")
721
- tok = AutoTokenizer.from_pretrained(source, local_files_only=local_flag)
722
-
723
- self._log(" Loading image processor...")
724
- try:
725
- from transformers import AutoProcessor
726
- except Exception:
727
- AutoProcessor = None
728
- try:
729
- proc = AutoImageProcessor.from_pretrained(source, local_files_only=local_flag)
730
- except Exception as e_proc:
731
- if AutoProcessor is not None:
732
- self._log(f" ⚠️ AutoImageProcessor failed: {e_proc}. Trying AutoProcessor...", "warning")
733
- proc = AutoProcessor.from_pretrained(source, local_files_only=local_flag)
734
- else:
735
- raise
736
-
737
- self._log(" Loading model...")
738
- # Prevent meta tensors by forcing full materialization on CPU at load time
739
- os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
740
- mdl = VisionEncoderDecoderModel.from_pretrained(
741
- source,
742
- local_files_only=local_flag,
743
- low_cpu_mem_usage=False,
744
- device_map=None,
745
- torch_dtype=torch.float32 # Use torch_dtype instead of dtype
746
- )
747
- return tok, proc, mdl
748
-
749
- try:
750
- self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
751
- except Exception as e_local:
752
- if local_only:
753
- # Fallback to Hub once if local fails
754
- self._log(f" ⚠️ Local model load failed: {e_local}", "warning")
755
- try:
756
- if os.environ.get("HF_HUB_OFFLINE") == "1":
757
- del os.environ["HF_HUB_OFFLINE"]
758
- except Exception:
759
- pass
760
- model_source = "kha-white/manga-ocr-base"
761
- local_only = False
762
- self._log(" Retrying from Hugging Face Hub...")
763
- self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
764
- else:
765
- raise
766
-
767
- # Move to CUDA only after full CPU materialization
768
- target_device = 'cpu'
769
- if use_cuda:
770
- try:
771
- self.model = self.model.to('cuda')
772
- target_device = 'cuda'
773
- except Exception as move_err:
774
- self._log(f" ⚠️ Could not move model to CUDA: {move_err}", "warning")
775
- target_device = 'cpu'
776
-
777
- # Finalize eval mode
778
- self.model.eval()
779
-
780
- # Sanity-check: ensure no parameter remains on 'meta' device
781
- try:
782
- for n, p in self.model.named_parameters():
783
- dev = getattr(p, 'device', None)
784
- if dev is not None and getattr(dev, 'type', '') == 'meta':
785
- raise RuntimeError(f"Parameter {n} is on 'meta' after load")
786
- except Exception as sanity_err:
787
- self._log(f"❌ Manga-OCR model load sanity check failed: {sanity_err}", "error")
788
- return False
789
-
790
- print(f"SUCCESS: Model loaded on {target_device.upper()}")
791
- self._log(f" ✅ Model loaded on {target_device.upper()}")
792
- self.is_loaded = True
793
- self._log("✅ Manga OCR model ready")
794
- print(">>> Returning True from load_model()")
795
- return True
796
-
797
- except Exception as e:
798
- print(f"\nEXCEPTION in load_model: {e}")
799
- import traceback
800
- print(traceback.format_exc())
801
- self._log(f"❌ Failed to load manga-ocr model: {str(e)}", "error")
802
- self._log(traceback.format_exc(), "error")
803
- try:
804
- if 'local_only' in locals() and local_only:
805
- self._log("Hint: Local load failed. Ensure your models/manga-ocr-base contains required files (config.json, preprocessor_config.json, tokenizer.json or tokenizer_config.json, and model weights).", "warning")
806
- except Exception:
807
- pass
808
- return False
809
-
810
- def _run_ocr(self, pil_image):
811
- """Run OCR on a PIL image using the HuggingFace model"""
812
- import torch
813
-
814
- # Process image (keyword arg for broader compatibility across transformers versions)
815
- inputs = self.processor(images=pil_image, return_tensors="pt")
816
- pixel_values = inputs["pixel_values"]
817
-
818
- # Move to same device as model
819
- try:
820
- model_device = next(self.model.parameters()).device
821
- except StopIteration:
822
- model_device = torch.device('cpu')
823
- pixel_values = pixel_values.to(model_device)
824
-
825
- # Generate text
826
- with torch.no_grad():
827
- generated_ids = self.model.generate(pixel_values)
828
-
829
- # Decode
830
- generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
831
-
832
- return generated_text
833
-
834
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
835
- """
836
- Process the image region passed to it.
837
- This could be a bubble region or the full image.
838
- """
839
- results = []
840
-
841
- # Check for stop at start
842
- if self._check_stop():
843
- self._log("⏹️ Manga-OCR processing stopped by user", "warning")
844
- return results
845
-
846
- try:
847
- if not self.is_loaded:
848
- if not self.load_model():
849
- return results
850
-
851
- import cv2
852
- from PIL import Image
853
-
854
- # Get confidence from kwargs
855
- confidence = kwargs.get('confidence', 0.7)
856
-
857
- # Convert numpy array to PIL
858
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
859
- pil_image = Image.fromarray(image_rgb)
860
- h, w = image.shape[:2]
861
-
862
- self._log("🔍 Processing region with manga-ocr...")
863
-
864
- # Check for stop before inference
865
- if self._check_stop():
866
- self._log("⏹️ Manga-OCR inference stopped by user", "warning")
867
- return results
868
-
869
- # Run OCR on the image region
870
- text = self._run_ocr(pil_image)
871
-
872
- if text and text.strip():
873
- # Return result for this region with its actual bbox
874
- results.append(OCRResult(
875
- text=text.strip(),
876
- bbox=(0, 0, w, h), # Relative to the region passed in
877
- confidence=confidence,
878
- vertices=[(0, 0), (w, 0), (w, h), (0, h)]
879
- ))
880
- self._log(f"✅ Detected text: {text[:50]}...")
881
-
882
- except Exception as e:
883
- self._log(f"❌ Error in manga-ocr: {str(e)}", "error")
884
-
885
- return results
886
-
887
- class Qwen2VL(OCRProvider):
888
- """OCR using Qwen2-VL - Vision Language Model that can read Korean text"""
889
-
890
- def __init__(self, log_callback=None):
891
- super().__init__(log_callback)
892
- self.processor = None
893
- self.model = None
894
- self.tokenizer = None
895
-
896
- # Get OCR prompt from environment or use default
897
- self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
898
- "YOU ARE AN OCR SYSTEM. YOUR ONLY JOB IS TEXT EXTRACTION.\n\n"
899
- "CRITICAL RULES:\n"
900
- "1. DO NOT TRANSLATE ANYTHING\n"
901
- "2. DO NOT MODIFY THE TEXT\n"
902
- "3. DO NOT EXPLAIN OR COMMENT\n"
903
- "4. ONLY OUTPUT THE EXACT TEXT YOU SEE\n"
904
- "5. PRESERVE NATURAL TEXT FLOW - DO NOT ADD UNNECESSARY LINE BREAKS\n\n"
905
- "If you see Korean text, output it in Korean.\n"
906
- "If you see Japanese text, output it in Japanese.\n"
907
- "If you see Chinese text, output it in Chinese.\n"
908
- "If you see English text, output it in English.\n\n"
909
- "IMPORTANT: Only use line breaks where they naturally occur in the original text "
910
- "(e.g., between dialogue lines or paragraphs). Do not break text mid-sentence or "
911
- "between every word/character.\n\n"
912
- "For vertical text common in manga/comics, transcribe it as a continuous line unless "
913
- "there are clear visual breaks.\n\n"
914
- "NEVER translate. ONLY extract exactly what is written.\n"
915
- "Output ONLY the raw text, nothing else."
916
- )
917
-
918
- def set_ocr_prompt(self, prompt: str):
919
- """Allow setting the OCR prompt dynamically"""
920
- self.ocr_prompt = prompt
921
-
922
- def check_installation(self) -> bool:
923
- """Check if required packages are installed"""
924
- try:
925
- import transformers
926
- import torch
927
- self.is_installed = True
928
- return True
929
- except ImportError:
930
- return False
931
-
932
- def install(self, progress_callback=None) -> bool:
933
- """Install requirements for Qwen2-VL"""
934
- pass
935
-
936
- def load_model(self, model_size=None, **kwargs) -> bool:
937
- """Load Qwen2-VL model with size selection"""
938
- self._log(f"DEBUG: load_model called with model_size={model_size}")
939
-
940
- try:
941
- if not self.is_installed and not self.check_installation():
942
- self._log("❌ Not installed", "error")
943
- return False
944
-
945
- self._log("🔥 Loading Qwen2-VL for Advanced OCR...")
946
-
947
-
948
-
949
- from transformers import AutoProcessor, AutoTokenizer
950
- import torch
951
-
952
- # Model options
953
- model_options = {
954
- "1": "Qwen/Qwen2-VL-2B-Instruct",
955
- "2": "Qwen/Qwen2-VL-7B-Instruct",
956
- "3": "Qwen/Qwen2-VL-72B-Instruct",
957
- "4": "custom"
958
- }
959
- # CHANGE: Default to 7B instead of 2B
960
- # Check for saved preference first
961
- if model_size is None:
962
- # Try to get from environment or config
963
- import os
964
- model_size = os.environ.get('QWEN2VL_MODEL_SIZE', '1')
965
-
966
- # Determine which model to load
967
- if model_size and str(model_size).startswith("custom:"):
968
- # Custom model passed with ID
969
- model_id = str(model_size).replace("custom:", "")
970
- self.loaded_model_size = "Custom"
971
- self.model_id = model_id
972
- self._log(f"Loading custom model: {model_id}")
973
- elif model_size == "4":
974
- # Custom option selected but no ID - shouldn't happen
975
- self._log("❌ Custom model selected but no ID provided", "error")
976
- return False
977
- elif model_size and str(model_size) in model_options:
978
- # Standard model option
979
- option = model_options[str(model_size)]
980
- if option == "custom":
981
- self._log("❌ Custom model needs an ID", "error")
982
- return False
983
- model_id = option
984
- # Set loaded_model_size for status display
985
- if model_size == "1":
986
- self.loaded_model_size = "2B"
987
- elif model_size == "2":
988
- self.loaded_model_size = "7B"
989
- elif model_size == "3":
990
- self.loaded_model_size = "72B"
991
- else:
992
- # CHANGE: Default to 7B (option "2") instead of 2B
993
- model_id = model_options["1"] # Changed from "1" to "2"
994
- self.loaded_model_size = "2B" # Changed from "2B" to "7B"
995
- self._log("No model size specified, defaulting to 2B") # Changed message
996
-
997
- self._log(f"Loading model: {model_id}")
998
-
999
- # Load processor and tokenizer
1000
- self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
1001
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
1002
-
1003
- # Load the model - let it figure out the class dynamically
1004
- if torch.cuda.is_available():
1005
- self._log(f"GPU: {torch.cuda.get_device_name(0)}")
1006
- # Use auto model class
1007
- from transformers import AutoModelForVision2Seq
1008
- self.model = AutoModelForVision2Seq.from_pretrained(
1009
- model_id,
1010
- dtype=torch.float16,
1011
- device_map="auto",
1012
- trust_remote_code=True
1013
- )
1014
- self._log("✅ Model loaded on GPU")
1015
- else:
1016
- self._log("Loading on CPU...")
1017
- from transformers import AutoModelForVision2Seq
1018
- self.model = AutoModelForVision2Seq.from_pretrained(
1019
- model_id,
1020
- dtype=torch.float32,
1021
- trust_remote_code=True
1022
- )
1023
- self._log("✅ Model loaded on CPU")
1024
-
1025
- self.model.eval()
1026
- self.is_loaded = True
1027
- self._log("✅ Qwen2-VL ready for Advanced OCR!")
1028
- return True
1029
-
1030
- except Exception as e:
1031
- self._log(f"❌ Failed to load: {str(e)}", "error")
1032
- import traceback
1033
- self._log(traceback.format_exc(), "debug")
1034
- return False
1035
-
1036
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1037
- """Process image with Qwen2-VL for Korean text extraction"""
1038
- results = []
1039
- if hasattr(self, 'model_id'):
1040
- self._log(f"DEBUG: Using model: {self.model_id}", "debug")
1041
-
1042
- # Check if OCR prompt was passed in kwargs (for dynamic updates)
1043
- if 'ocr_prompt' in kwargs:
1044
- self.ocr_prompt = kwargs['ocr_prompt']
1045
-
1046
- try:
1047
- if not self.is_loaded:
1048
- if not self.load_model():
1049
- return results
1050
-
1051
- import cv2
1052
- from PIL import Image
1053
- import torch
1054
-
1055
- # Convert to PIL
1056
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1057
- pil_image = Image.fromarray(image_rgb)
1058
- h, w = image.shape[:2]
1059
-
1060
- self._log(f"🔍 Processing with Qwen2-VL ({w}x{h} pixels)...")
1061
-
1062
- # Use the configurable OCR prompt
1063
- messages = [
1064
- {
1065
- "role": "user",
1066
- "content": [
1067
- {
1068
- "type": "image",
1069
- "image": pil_image,
1070
- },
1071
- {
1072
- "type": "text",
1073
- "text": self.ocr_prompt # Use the configurable prompt
1074
- }
1075
- ]
1076
- }
1077
- ]
1078
-
1079
- # Alternative simpler prompt if the above still causes issues:
1080
- # "text": "OCR: Extract text as-is"
1081
-
1082
- # Process with Qwen2-VL
1083
- text = self.processor.apply_chat_template(
1084
- messages,
1085
- tokenize=False,
1086
- add_generation_prompt=True
1087
- )
1088
-
1089
- inputs = self.processor(
1090
- text=[text],
1091
- images=[pil_image],
1092
- padding=True,
1093
- return_tensors="pt"
1094
- )
1095
-
1096
- # Get the device and dtype the model is currently on
1097
- model_device = next(self.model.parameters()).device
1098
- model_dtype = next(self.model.parameters()).dtype
1099
-
1100
- # Move inputs to the same device as the model and cast float tensors to model dtype
1101
- try:
1102
- # Move first
1103
- inputs = inputs.to(model_device)
1104
- # Then align dtypes only for floating tensors (e.g., pixel_values)
1105
- for k, v in inputs.items():
1106
- if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
1107
- inputs[k] = v.to(model_dtype)
1108
- except Exception:
1109
- # Fallback: ensure at least pixel_values is correct if present
1110
- try:
1111
- if isinstance(inputs, dict) and "pixel_values" in inputs:
1112
- pv = inputs["pixel_values"].to(model_device)
1113
- if torch.is_floating_point(pv):
1114
- inputs["pixel_values"] = pv.to(model_dtype)
1115
- except Exception:
1116
- pass
1117
-
1118
- # Ensure pixel_values explicitly matches model dtype if present
1119
- try:
1120
- if isinstance(inputs, dict) and "pixel_values" in inputs:
1121
- inputs["pixel_values"] = inputs["pixel_values"].to(device=model_device, dtype=model_dtype)
1122
- except Exception:
1123
- pass
1124
-
1125
- # Generate text with stricter parameters to avoid creative responses
1126
- use_amp = (hasattr(torch, 'cuda') and model_device.type == 'cuda' and model_dtype in (torch.float16, torch.bfloat16))
1127
- autocast_dev = 'cuda' if model_device.type == 'cuda' else 'cpu'
1128
- autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
1129
-
1130
- with torch.no_grad():
1131
- if use_amp and autocast_dtype is not None:
1132
- with torch.autocast(autocast_dev, dtype=autocast_dtype):
1133
- generated_ids = self.model.generate(
1134
- **inputs,
1135
- max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1136
- do_sample=False, # Keep deterministic
1137
- temperature=0.01, # Keep your very low temperature
1138
- top_p=1.0, # Keep no nucleus sampling
1139
- repetition_penalty=1.0, # Keep no repetition penalty
1140
- num_beams=1, # Ensure greedy decoding (faster than beam search)
1141
- use_cache=True, # Enable KV cache for speed
1142
- early_stopping=True, # Stop at EOS token
1143
- pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1144
- eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1145
- )
1146
- else:
1147
- generated_ids = self.model.generate(
1148
- **inputs,
1149
- max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1150
- do_sample=False, # Keep deterministic
1151
- temperature=0.01, # Keep your very low temperature
1152
- top_p=1.0, # Keep no nucleus sampling
1153
- repetition_penalty=1.0, # Keep no repetition penalty
1154
- num_beams=1, # Ensure greedy decoding (faster than beam search)
1155
- use_cache=True, # Enable KV cache for speed
1156
- early_stopping=True, # Stop at EOS token
1157
- pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1158
- eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1159
- )
1160
-
1161
- # Decode the output
1162
- generated_ids_trimmed = [
1163
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
1164
- ]
1165
- output_text = self.processor.batch_decode(
1166
- generated_ids_trimmed,
1167
- skip_special_tokens=True,
1168
- clean_up_tokenization_spaces=False
1169
- )[0]
1170
-
1171
- if output_text and output_text.strip():
1172
- text = output_text.strip()
1173
-
1174
- # ADDED: Filter out any response that looks like an explanation or apology
1175
- # Common patterns that indicate the model is being "helpful" instead of just extracting
1176
- unwanted_patterns = [
1177
- "죄송합니다", # "I apologize"
1178
- "sorry",
1179
- "apologize",
1180
- "이미지에는", # "in this image"
1181
- "텍스트가 없습니다", # "there is no text"
1182
- "I cannot",
1183
- "I don't see",
1184
- "There is no",
1185
- "질문이 있으시면", # "if you have questions"
1186
- ]
1187
-
1188
- # Check if response contains unwanted patterns
1189
- text_lower = text.lower()
1190
- is_explanation = any(pattern.lower() in text_lower for pattern in unwanted_patterns)
1191
-
1192
- # Also check if the response is suspiciously long for a bubble
1193
- # Most manga bubbles are short, if we get 50+ chars it might be an explanation
1194
- is_too_long = len(text) > 100 and ('.' in text or ',' in text or '!' in text)
1195
-
1196
- if is_explanation or is_too_long:
1197
- self._log(f"⚠️ Model returned explanation instead of text, ignoring", "warning")
1198
- # Return empty result or just skip this region
1199
- return results
1200
-
1201
- # Check language
1202
- has_korean = any('\uAC00' <= c <= '\uD7AF' for c in text)
1203
- has_japanese = any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in text)
1204
- has_chinese = any('\u4E00' <= c <= '\u9FFF' for c in text)
1205
-
1206
- if has_korean:
1207
- self._log(f"✅ Korean detected: {text[:50]}...")
1208
- elif has_japanese:
1209
- self._log(f"✅ Japanese detected: {text[:50]}...")
1210
- elif has_chinese:
1211
- self._log(f"✅ Chinese detected: {text[:50]}...")
1212
- else:
1213
- self._log(f"✅ Text: {text[:50]}...")
1214
-
1215
- results.append(OCRResult(
1216
- text=text,
1217
- bbox=(0, 0, w, h),
1218
- confidence=0.9,
1219
- vertices=[(0, 0), (w, 0), (w, h), (0, h)]
1220
- ))
1221
- else:
1222
- self._log("⚠️ No text detected", "warning")
1223
-
1224
- except Exception as e:
1225
- self._log(f"❌ Error: {str(e)}", "error")
1226
- import traceback
1227
- self._log(traceback.format_exc(), "debug")
1228
-
1229
- return results
1230
-
1231
- class EasyOCRProvider(OCRProvider):
1232
- """EasyOCR provider for multiple languages"""
1233
-
1234
- def __init__(self, log_callback=None, languages=None):
1235
- super().__init__(log_callback)
1236
- # Default to safe language combination
1237
- self.languages = languages or ['ja', 'en'] # Safe default
1238
- self._validate_language_combination()
1239
-
1240
- def _validate_language_combination(self):
1241
- """Validate and fix EasyOCR language combinations"""
1242
- # EasyOCR language compatibility rules
1243
- incompatible_pairs = [
1244
- (['ja', 'ko'], 'Japanese and Korean cannot be used together'),
1245
- (['ja', 'zh'], 'Japanese and Chinese cannot be used together'),
1246
- (['ko', 'zh'], 'Korean and Chinese cannot be used together')
1247
- ]
1248
-
1249
- for incompatible, reason in incompatible_pairs:
1250
- if all(lang in self.languages for lang in incompatible):
1251
- self._log(f"⚠️ EasyOCR: {reason}", "warning")
1252
- # Keep first language + English
1253
- self.languages = [self.languages[0], 'en']
1254
- self._log(f"🔧 Auto-adjusted to: {self.languages}", "info")
1255
- break
1256
-
1257
- def check_installation(self) -> bool:
1258
- """Check if easyocr is installed"""
1259
- try:
1260
- import easyocr
1261
- self.is_installed = True
1262
- return True
1263
- except ImportError:
1264
- return False
1265
-
1266
- def install(self, progress_callback=None) -> bool:
1267
- """Install easyocr"""
1268
- pass
1269
-
1270
- def load_model(self, **kwargs) -> bool:
1271
- """Load easyocr model"""
1272
- try:
1273
- if not self.is_installed and not self.check_installation():
1274
- self._log("❌ easyocr not installed", "error")
1275
- return False
1276
-
1277
- self._log(f"🔥 Loading easyocr model for languages: {self.languages}...")
1278
- import easyocr
1279
-
1280
- # This will download models on first run
1281
- self.model = easyocr.Reader(self.languages, gpu=True)
1282
- self.is_loaded = True
1283
-
1284
- self._log("✅ easyocr model loaded successfully")
1285
- return True
1286
-
1287
- except Exception as e:
1288
- self._log(f"❌ Failed to load easyocr: {str(e)}", "error")
1289
- # Try CPU mode if GPU fails
1290
- try:
1291
- import easyocr
1292
- self.model = easyocr.Reader(self.languages, gpu=False)
1293
- self.is_loaded = True
1294
- self._log("✅ easyocr loaded in CPU mode")
1295
- return True
1296
- except:
1297
- return False
1298
-
1299
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1300
- """Detect text using easyocr"""
1301
- results = []
1302
-
1303
- try:
1304
- if not self.is_loaded:
1305
- if not self.load_model():
1306
- return results
1307
-
1308
- # EasyOCR can work directly with numpy arrays
1309
- ocr_results = self.model.readtext(image, detail=1)
1310
-
1311
- # Parse results
1312
- for (bbox, text, confidence) in ocr_results:
1313
- # bbox is a list of 4 points
1314
- xs = [point[0] for point in bbox]
1315
- ys = [point[1] for point in bbox]
1316
- x_min, x_max = min(xs), max(xs)
1317
- y_min, y_max = min(ys), max(ys)
1318
-
1319
- results.append(OCRResult(
1320
- text=text,
1321
- bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1322
- confidence=confidence,
1323
- vertices=[(int(p[0]), int(p[1])) for p in bbox]
1324
- ))
1325
-
1326
- self._log(f"✅ Detected {len(results)} text regions")
1327
-
1328
- except Exception as e:
1329
- self._log(f"❌ Error in easyocr detection: {str(e)}", "error")
1330
-
1331
- return results
1332
-
1333
-
1334
- class PaddleOCRProvider(OCRProvider):
1335
- """PaddleOCR provider with memory safety measures"""
1336
-
1337
- def check_installation(self) -> bool:
1338
- """Check if paddleocr is installed"""
1339
- try:
1340
- from paddleocr import PaddleOCR
1341
- self.is_installed = True
1342
- return True
1343
- except ImportError:
1344
- return False
1345
-
1346
- def install(self, progress_callback=None) -> bool:
1347
- """Install paddleocr"""
1348
- pass
1349
-
1350
- def load_model(self, **kwargs) -> bool:
1351
- """Load paddleocr model with memory-safe configurations"""
1352
- try:
1353
- if not self.is_installed and not self.check_installation():
1354
- self._log("❌ paddleocr not installed", "error")
1355
- return False
1356
-
1357
- self._log("🔥 Loading PaddleOCR model...")
1358
-
1359
- # Set memory-safe environment variables BEFORE importing
1360
- import os
1361
- os.environ['OMP_NUM_THREADS'] = '1' # Prevent OpenMP conflicts
1362
- os.environ['MKL_NUM_THREADS'] = '1' # Prevent MKL conflicts
1363
- os.environ['OPENBLAS_NUM_THREADS'] = '1' # Prevent OpenBLAS conflicts
1364
- os.environ['FLAGS_use_mkldnn'] = '0' # Disable MKL-DNN
1365
-
1366
- from paddleocr import PaddleOCR
1367
-
1368
- # Try memory-safe configurations
1369
- configs_to_try = [
1370
- # Config 1: Most memory-safe configuration
1371
- {
1372
- 'use_angle_cls': False, # Disable angle to save memory
1373
- 'lang': 'ch',
1374
- 'rec_batch_num': 1, # Process one at a time
1375
- 'max_text_length': 100, # Limit text length
1376
- 'drop_score': 0.5, # Higher threshold to reduce detections
1377
- 'cpu_threads': 1, # Single thread to avoid conflicts
1378
- },
1379
- # Config 2: Minimal memory footprint
1380
- {
1381
- 'lang': 'ch',
1382
- 'rec_batch_num': 1,
1383
- 'cpu_threads': 1,
1384
- },
1385
- # Config 3: Absolute minimal
1386
- {
1387
- 'lang': 'ch'
1388
- },
1389
- # Config 4: Empty config
1390
- {}
1391
- ]
1392
-
1393
- for i, config in enumerate(configs_to_try):
1394
- try:
1395
- self._log(f" Trying configuration {i+1}/{len(configs_to_try)}: {config}")
1396
-
1397
- # Force garbage collection before loading
1398
- import gc
1399
- gc.collect()
1400
-
1401
- self.model = PaddleOCR(**config)
1402
- self.is_loaded = True
1403
- self.current_config = config
1404
- self._log(f"✅ PaddleOCR loaded successfully with config: {config}")
1405
- return True
1406
- except Exception as e:
1407
- error_str = str(e)
1408
- self._log(f" Config {i+1} failed: {error_str}", "debug")
1409
-
1410
- # Clean up on failure
1411
- if hasattr(self, 'model'):
1412
- del self.model
1413
- gc.collect()
1414
- continue
1415
-
1416
- self._log(f"❌ PaddleOCR failed to load with any configuration", "error")
1417
- return False
1418
-
1419
- except Exception as e:
1420
- self._log(f"❌ Failed to load paddleocr: {str(e)}", "error")
1421
- import traceback
1422
- self._log(traceback.format_exc(), "debug")
1423
- return False
1424
-
1425
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1426
- """Detect text with memory safety measures"""
1427
- results = []
1428
-
1429
- try:
1430
- if not self.is_loaded:
1431
- if not self.load_model():
1432
- return results
1433
-
1434
- import cv2
1435
- import numpy as np
1436
- import gc
1437
-
1438
- # Memory safety: Ensure image isn't too large
1439
- h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
1440
-
1441
- # Limit image size to prevent memory issues
1442
- MAX_DIMENSION = 1500
1443
- if h > MAX_DIMENSION or w > MAX_DIMENSION:
1444
- scale = min(MAX_DIMENSION/h, MAX_DIMENSION/w)
1445
- new_h, new_w = int(h*scale), int(w*scale)
1446
- self._log(f"⚠️ Resizing large image from {w}x{h} to {new_w}x{new_h} for memory safety", "warning")
1447
- image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
1448
- scale_factor = 1/scale
1449
- else:
1450
- scale_factor = 1.0
1451
-
1452
- # Ensure correct format
1453
- if len(image.shape) == 2: # Grayscale
1454
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
1455
- elif len(image.shape) == 4: # Batch
1456
- image = image[0]
1457
-
1458
- # Ensure uint8 type
1459
- if image.dtype != np.uint8:
1460
- if image.max() <= 1.0:
1461
- image = (image * 255).astype(np.uint8)
1462
- else:
1463
- image = image.astype(np.uint8)
1464
-
1465
- # Make a copy to avoid memory corruption
1466
- image_copy = image.copy()
1467
-
1468
- # Force garbage collection before OCR
1469
- gc.collect()
1470
-
1471
- # Process with timeout protection
1472
- import signal
1473
- import threading
1474
-
1475
- ocr_results = None
1476
- ocr_error = None
1477
-
1478
- def run_ocr():
1479
- nonlocal ocr_results, ocr_error
1480
- try:
1481
- ocr_results = self.model.ocr(image_copy)
1482
- except Exception as e:
1483
- ocr_error = e
1484
-
1485
- # Run OCR in a separate thread with timeout
1486
- ocr_thread = threading.Thread(target=run_ocr)
1487
- ocr_thread.daemon = True
1488
- ocr_thread.start()
1489
- ocr_thread.join(timeout=30) # 30 second timeout
1490
-
1491
- if ocr_thread.is_alive():
1492
- self._log("❌ PaddleOCR timeout - taking too long", "error")
1493
- return results
1494
-
1495
- if ocr_error:
1496
- raise ocr_error
1497
-
1498
- # Parse results
1499
- results = self._parse_ocr_results(ocr_results)
1500
-
1501
- # Scale coordinates back if image was resized
1502
- if scale_factor != 1.0 and results:
1503
- for r in results:
1504
- x, y, width, height = r.bbox
1505
- r.bbox = (int(x*scale_factor), int(y*scale_factor),
1506
- int(width*scale_factor), int(height*scale_factor))
1507
- r.vertices = [(int(v[0]*scale_factor), int(v[1]*scale_factor))
1508
- for v in r.vertices]
1509
-
1510
- if results:
1511
- self._log(f"✅ Detected {len(results)} text regions", "info")
1512
- else:
1513
- self._log("No text regions found", "debug")
1514
-
1515
- # Clean up
1516
- del image_copy
1517
- gc.collect()
1518
-
1519
- except Exception as e:
1520
- error_msg = str(e) if str(e) else type(e).__name__
1521
-
1522
- if "memory" in error_msg.lower() or "0x" in error_msg:
1523
- self._log("❌ Memory access violation in PaddleOCR", "error")
1524
- self._log(" This is a known Windows issue with PaddleOCR", "info")
1525
- self._log(" Please switch to EasyOCR or manga-ocr instead", "warning")
1526
- elif "trace_order.size()" in error_msg:
1527
- self._log("❌ PaddleOCR internal error", "error")
1528
- self._log(" Please switch to EasyOCR or manga-ocr", "warning")
1529
- else:
1530
- self._log(f"❌ Error in paddleocr detection: {error_msg}", "error")
1531
-
1532
- import traceback
1533
- self._log(traceback.format_exc(), "debug")
1534
-
1535
- return results
1536
-
1537
- def _parse_ocr_results(self, ocr_results) -> List[OCRResult]:
1538
- """Parse OCR results safely"""
1539
- results = []
1540
-
1541
- if isinstance(ocr_results, bool) and ocr_results == False:
1542
- return results
1543
-
1544
- if ocr_results is None or not isinstance(ocr_results, list):
1545
- return results
1546
-
1547
- if len(ocr_results) == 0:
1548
- return results
1549
-
1550
- # Handle batch format
1551
- if isinstance(ocr_results[0], list) and len(ocr_results[0]) > 0:
1552
- first_item = ocr_results[0][0]
1553
- if isinstance(first_item, list) and len(first_item) > 0:
1554
- if isinstance(first_item[0], (list, tuple)) and len(first_item[0]) == 2:
1555
- ocr_results = ocr_results[0]
1556
-
1557
- # Parse detections
1558
- for detection in ocr_results:
1559
- if not detection or isinstance(detection, bool):
1560
- continue
1561
-
1562
- if not isinstance(detection, (list, tuple)) or len(detection) < 2:
1563
- continue
1564
-
1565
- try:
1566
- bbox_points = detection[0]
1567
- text_data = detection[1]
1568
-
1569
- if not isinstance(bbox_points, (list, tuple)) or len(bbox_points) != 4:
1570
- continue
1571
-
1572
- if not isinstance(text_data, (tuple, list)) or len(text_data) < 2:
1573
- continue
1574
-
1575
- text = str(text_data[0]).strip()
1576
- confidence = float(text_data[1])
1577
-
1578
- if not text or confidence < 0.3:
1579
- continue
1580
-
1581
- xs = [float(p[0]) for p in bbox_points]
1582
- ys = [float(p[1]) for p in bbox_points]
1583
- x_min, x_max = min(xs), max(xs)
1584
- y_min, y_max = min(ys), max(ys)
1585
-
1586
- if (x_max - x_min) < 5 or (y_max - y_min) < 5:
1587
- continue
1588
-
1589
- results.append(OCRResult(
1590
- text=text,
1591
- bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1592
- confidence=confidence,
1593
- vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1594
- ))
1595
-
1596
- except Exception:
1597
- continue
1598
-
1599
- return results
1600
-
1601
- class DocTROCRProvider(OCRProvider):
1602
- """DocTR OCR provider"""
1603
-
1604
- def check_installation(self) -> bool:
1605
- """Check if doctr is installed"""
1606
- try:
1607
- from doctr.models import ocr_predictor
1608
- self.is_installed = True
1609
- return True
1610
- except ImportError:
1611
- return False
1612
-
1613
- def install(self, progress_callback=None) -> bool:
1614
- """Install doctr"""
1615
- pass
1616
-
1617
- def load_model(self, **kwargs) -> bool:
1618
- """Load doctr model"""
1619
- try:
1620
- if not self.is_installed and not self.check_installation():
1621
- self._log("❌ doctr not installed", "error")
1622
- return False
1623
-
1624
- self._log("🔥 Loading DocTR model...")
1625
- from doctr.models import ocr_predictor
1626
-
1627
- # Load pretrained model
1628
- self.model = ocr_predictor(pretrained=True)
1629
- self.is_loaded = True
1630
-
1631
- self._log("✅ DocTR model loaded successfully")
1632
- return True
1633
-
1634
- except Exception as e:
1635
- self._log(f"❌ Failed to load doctr: {str(e)}", "error")
1636
- return False
1637
-
1638
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1639
- """Detect text using doctr"""
1640
- results = []
1641
-
1642
- try:
1643
- if not self.is_loaded:
1644
- if not self.load_model():
1645
- return results
1646
-
1647
- from doctr.io import DocumentFile
1648
-
1649
- # DocTR expects document format
1650
- # Convert numpy array to PIL and save temporarily
1651
- import tempfile
1652
- import cv2
1653
-
1654
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
1655
- cv2.imwrite(tmp.name, image)
1656
- doc = DocumentFile.from_images(tmp.name)
1657
-
1658
- # Run OCR
1659
- result = self.model(doc)
1660
-
1661
- # Parse results
1662
- h, w = image.shape[:2]
1663
- for page in result.pages:
1664
- for block in page.blocks:
1665
- for line in block.lines:
1666
- for word in line.words:
1667
- # Handle different geometry formats
1668
- geometry = word.geometry
1669
-
1670
- if len(geometry) == 4:
1671
- # Standard format: (x1, y1, x2, y2)
1672
- x1, y1, x2, y2 = geometry
1673
- elif len(geometry) == 2:
1674
- # Alternative format: ((x1, y1), (x2, y2))
1675
- (x1, y1), (x2, y2) = geometry
1676
- else:
1677
- self._log(f"Unexpected geometry format: {geometry}", "warning")
1678
- continue
1679
-
1680
- # Convert relative coordinates to absolute
1681
- x1, x2 = int(x1 * w), int(x2 * w)
1682
- y1, y2 = int(y1 * h), int(y2 * h)
1683
-
1684
- results.append(OCRResult(
1685
- text=word.value,
1686
- bbox=(x1, y1, x2 - x1, y2 - y1),
1687
- confidence=word.confidence,
1688
- vertices=[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
1689
- ))
1690
-
1691
- # Clean up temp file
1692
- try:
1693
- os.unlink(tmp.name)
1694
- except:
1695
- pass
1696
-
1697
- self._log(f"DocTR detected {len(results)} text regions")
1698
-
1699
- except Exception as e:
1700
- self._log(f"Error in doctr detection: {str(e)}", "error")
1701
- import traceback
1702
- self._log(traceback.format_exc(), "error")
1703
-
1704
- return results
1705
-
1706
-
1707
- class RapidOCRProvider(OCRProvider):
1708
- """RapidOCR provider for fast local OCR"""
1709
-
1710
- def check_installation(self) -> bool:
1711
- """Check if rapidocr is installed"""
1712
- try:
1713
- import rapidocr_onnxruntime
1714
- self.is_installed = True
1715
- return True
1716
- except ImportError:
1717
- return False
1718
-
1719
- def install(self, progress_callback=None) -> bool:
1720
- """Install rapidocr (requires manual pip install)"""
1721
- # RapidOCR requires manual installation
1722
- if progress_callback:
1723
- progress_callback("RapidOCR requires manual pip installation")
1724
- self._log("Run: pip install rapidocr-onnxruntime", "info")
1725
- return False # Always return False since we can't auto-install
1726
-
1727
- def load_model(self, **kwargs) -> bool:
1728
- """Load RapidOCR model"""
1729
- try:
1730
- if not self.is_installed and not self.check_installation():
1731
- self._log("RapidOCR not installed", "error")
1732
- return False
1733
-
1734
- self._log("Loading RapidOCR...")
1735
- from rapidocr_onnxruntime import RapidOCR
1736
-
1737
- self.model = RapidOCR()
1738
- self.is_loaded = True
1739
-
1740
- self._log("RapidOCR model loaded successfully")
1741
- return True
1742
-
1743
- except Exception as e:
1744
- self._log(f"Failed to load RapidOCR: {str(e)}", "error")
1745
- return False
1746
-
1747
- def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1748
- """Detect text using RapidOCR"""
1749
- if not self.is_loaded:
1750
- self._log("RapidOCR model not loaded", "error")
1751
- return []
1752
-
1753
- results = []
1754
-
1755
- try:
1756
- # Convert numpy array to PIL Image for RapidOCR
1757
- if len(image.shape) == 3:
1758
- # BGR to RGB
1759
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1760
- else:
1761
- image_rgb = image
1762
-
1763
- # RapidOCR expects PIL Image or numpy array
1764
- ocr_results, _ = self.model(image_rgb)
1765
-
1766
- if ocr_results:
1767
- for result in ocr_results:
1768
- # RapidOCR returns [bbox, text, confidence]
1769
- bbox_points = result[0] # 4 corner points
1770
- text = result[1]
1771
- confidence = float(result[2])
1772
-
1773
- if not text or not text.strip():
1774
- continue
1775
-
1776
- # Convert 4-point bbox to x,y,w,h format
1777
- xs = [point[0] for point in bbox_points]
1778
- ys = [point[1] for point in bbox_points]
1779
- x_min, x_max = min(xs), max(xs)
1780
- y_min, y_max = min(ys), max(ys)
1781
-
1782
- results.append(OCRResult(
1783
- text=text.strip(),
1784
- bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1785
- confidence=confidence,
1786
- vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1787
- ))
1788
-
1789
- self._log(f"Detected {len(results)} text regions")
1790
-
1791
- except Exception as e:
1792
- self._log(f"Error in RapidOCR detection: {str(e)}", "error")
1793
-
1794
- return results
1795
-
1796
- class OCRManager:
1797
- """Manager for multiple OCR providers"""
1798
-
1799
- def __init__(self, log_callback=None):
1800
- self.log_callback = log_callback
1801
- self.providers = {
1802
- 'custom-api': CustomAPIProvider(log_callback) ,
1803
- 'manga-ocr': MangaOCRProvider(log_callback),
1804
- 'easyocr': EasyOCRProvider(log_callback),
1805
- 'paddleocr': PaddleOCRProvider(log_callback),
1806
- 'doctr': DocTROCRProvider(log_callback),
1807
- 'rapidocr': RapidOCRProvider(log_callback),
1808
- 'Qwen2-VL': Qwen2VL(log_callback)
1809
- }
1810
- self.current_provider = None
1811
- self.stop_flag = None
1812
-
1813
- def get_provider(self, name: str) -> Optional[OCRProvider]:
1814
- """Get OCR provider by name"""
1815
- return self.providers.get(name)
1816
-
1817
- def set_current_provider(self, name: str):
1818
- """Set current active provider"""
1819
- if name in self.providers:
1820
- self.current_provider = name
1821
- return True
1822
- return False
1823
-
1824
- def check_provider_status(self, name: str) -> Dict[str, bool]:
1825
- """Check installation and loading status of provider"""
1826
- provider = self.providers.get(name)
1827
- if not provider:
1828
- return {'installed': False, 'loaded': False}
1829
-
1830
- result = {
1831
- 'installed': provider.check_installation(),
1832
- 'loaded': provider.is_loaded
1833
- }
1834
- if self.log_callback:
1835
- self.log_callback(f"DEBUG: check_provider_status({name}) returning loaded={result['loaded']}", "debug")
1836
- return result
1837
-
1838
- def install_provider(self, name: str, progress_callback=None) -> bool:
1839
- """Install a provider"""
1840
- provider = self.providers.get(name)
1841
- if not provider:
1842
- return False
1843
-
1844
- return provider.install(progress_callback)
1845
-
1846
- def load_provider(self, name: str, **kwargs) -> bool:
1847
- """Load a provider's model with optional parameters"""
1848
- provider = self.providers.get(name)
1849
- if not provider:
1850
- return False
1851
-
1852
- return provider.load_model(**kwargs) # <-- Passes model_size and any other kwargs
1853
-
1854
- def shutdown(self):
1855
- """Release models/processors/tokenizers for all providers and clear caches."""
1856
- try:
1857
- import gc
1858
- for name, provider in list(self.providers.items()):
1859
- try:
1860
- if hasattr(provider, 'model'):
1861
- provider.model = None
1862
- if hasattr(provider, 'processor'):
1863
- provider.processor = None
1864
- if hasattr(provider, 'tokenizer'):
1865
- provider.tokenizer = None
1866
- if hasattr(provider, 'reader'):
1867
- provider.reader = None
1868
- if hasattr(provider, 'is_loaded'):
1869
- provider.is_loaded = False
1870
- except Exception:
1871
- pass
1872
- gc.collect()
1873
- try:
1874
- import torch
1875
- torch.cuda.empty_cache()
1876
- except Exception:
1877
- pass
1878
- except Exception:
1879
- pass
1880
-
1881
- def detect_text(self, image: np.ndarray, provider_name: str = None, **kwargs) -> List[OCRResult]:
1882
- """Detect text using specified or current provider"""
1883
- provider_name = provider_name or self.current_provider
1884
- if not provider_name:
1885
- return []
1886
-
1887
- provider = self.providers.get(provider_name)
1888
- if not provider:
1889
- return []
1890
-
1891
- return provider.detect_text(image, **kwargs)
1892
-
1893
- def set_stop_flag(self, stop_flag):
1894
- """Set stop flag for all providers"""
1895
- self.stop_flag = stop_flag
1896
- for provider in self.providers.values():
1897
- if hasattr(provider, 'set_stop_flag'):
1898
- provider.set_stop_flag(stop_flag)
1899
-
1900
- def reset_stop_flags(self):
1901
- """Reset stop flags for all providers"""
1902
- for provider in self.providers.values():
1903
- if hasattr(provider, 'reset_stop_flags'):
1904
- provider.reset_stop_flags()