Shirochi commited on
Commit
7769622
·
verified ·
1 Parent(s): 0cc2912

Upload 8 files

Browse files
app.py ADDED
The diff for this file is too large to render. See raw diff
 
bubble_detector.py ADDED
@@ -0,0 +1,2030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ bubble_detector.py - Modified version that works in frozen PyInstaller executables
3
+ Replace your bubble_detector.py with this version
4
+ """
5
+ import os
6
+ import sys
7
+ import json
8
+ import numpy as np
9
+ import cv2
10
+ from typing import List, Tuple, Optional, Dict, Any
11
+ import logging
12
+ import traceback
13
+ import hashlib
14
+ from pathlib import Path
15
+ import threading
16
+ import time
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Check if we're running in a frozen environment
22
+ IS_FROZEN = getattr(sys, 'frozen', False)
23
+ if IS_FROZEN:
24
+ # In frozen environment, set proper paths for ML libraries
25
+ MEIPASS = sys._MEIPASS
26
+ os.environ['TORCH_HOME'] = MEIPASS
27
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(MEIPASS, 'transformers')
28
+ os.environ['HF_HOME'] = os.path.join(MEIPASS, 'huggingface')
29
+ logger.info(f"Running in frozen environment: {MEIPASS}")
30
+
31
+ # Modified import checks for frozen environment
32
+ YOLO_AVAILABLE = False
33
+ YOLO = None
34
+ torch = None
35
+ TORCH_AVAILABLE = False
36
+ ONNX_AVAILABLE = False
37
+ TRANSFORMERS_AVAILABLE = False
38
+ RTDetrForObjectDetection = None
39
+ RTDetrImageProcessor = None
40
+ PIL_AVAILABLE = False
41
+
42
+ # Try to import YOLO dependencies with better error handling
43
+ if IS_FROZEN:
44
+ # In frozen environment, try harder to import
45
+ try:
46
+ # First try to import torch components individually
47
+ import torch
48
+ import torch.nn
49
+ import torch.cuda
50
+ TORCH_AVAILABLE = True
51
+ logger.info("✓ PyTorch loaded in frozen environment")
52
+ except Exception as e:
53
+ logger.warning(f"PyTorch not available in frozen environment: {e}")
54
+ TORCH_AVAILABLE = False
55
+ torch = None
56
+
57
+ # Try ultralytics after torch
58
+ if TORCH_AVAILABLE:
59
+ try:
60
+ from ultralytics import YOLO
61
+ YOLO_AVAILABLE = True
62
+ logger.info("✓ Ultralytics YOLO loaded in frozen environment")
63
+ except Exception as e:
64
+ logger.warning(f"Ultralytics not available in frozen environment: {e}")
65
+ YOLO_AVAILABLE = False
66
+
67
+ # Try transformers
68
+ try:
69
+ import transformers
70
+ # Try specific imports
71
+ try:
72
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
73
+ TRANSFORMERS_AVAILABLE = True
74
+ logger.info("✓ Transformers RT-DETR loaded in frozen environment")
75
+ except ImportError:
76
+ # Try alternative import
77
+ try:
78
+ from transformers import AutoModel, AutoImageProcessor
79
+ RTDetrForObjectDetection = AutoModel
80
+ RTDetrImageProcessor = AutoImageProcessor
81
+ TRANSFORMERS_AVAILABLE = True
82
+ logger.info("✓ Transformers loaded with AutoModel fallback")
83
+ except:
84
+ TRANSFORMERS_AVAILABLE = False
85
+ logger.warning("Transformers RT-DETR not available in frozen environment")
86
+ except Exception as e:
87
+ logger.warning(f"Transformers not available in frozen environment: {e}")
88
+ TRANSFORMERS_AVAILABLE = False
89
+ else:
90
+ # Normal environment - original import logic
91
+ try:
92
+ from ultralytics import YOLO
93
+ YOLO_AVAILABLE = True
94
+ except:
95
+ YOLO_AVAILABLE = False
96
+ logger.warning("Ultralytics YOLO not available")
97
+
98
+ try:
99
+ import torch
100
+ # Test if cuda attribute exists
101
+ _ = torch.cuda
102
+ TORCH_AVAILABLE = True
103
+ except (ImportError, AttributeError):
104
+ TORCH_AVAILABLE = False
105
+ torch = None
106
+ logger.warning("PyTorch not available or incomplete")
107
+
108
+ try:
109
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
110
+ try:
111
+ from transformers import RTDetrV2ForObjectDetection
112
+ RTDetrForObjectDetection = RTDetrV2ForObjectDetection
113
+ except ImportError:
114
+ pass
115
+ TRANSFORMERS_AVAILABLE = True
116
+ except:
117
+ TRANSFORMERS_AVAILABLE = False
118
+ logger.info("Transformers not available for RT-DETR")
119
+
120
+ # Configure ORT memory behavior before importing
121
+ try:
122
+ os.environ.setdefault('ORT_DISABLE_MEMORY_ARENA', '1')
123
+ except Exception:
124
+ pass
125
+ # ONNX Runtime - works well in frozen environments
126
+ try:
127
+ import onnxruntime as ort
128
+ ONNX_AVAILABLE = True
129
+ logger.info("✓ ONNX Runtime available")
130
+ except ImportError:
131
+ ONNX_AVAILABLE = False
132
+ logger.warning("ONNX Runtime not available")
133
+
134
+ # PIL
135
+ try:
136
+ from PIL import Image
137
+ PIL_AVAILABLE = True
138
+ except ImportError:
139
+ PIL_AVAILABLE = False
140
+ logger.info("PIL not available")
141
+
142
+
143
+ class BubbleDetector:
144
+ """
145
+ Combined YOLOv8 and RT-DETR speech bubble detector for comics and manga.
146
+ Supports multiple model formats and provides configurable detection.
147
+ Backward compatible with existing code while adding RT-DETR support.
148
+ """
149
+
150
+ # Process-wide shared RT-DETR to avoid concurrent meta-device loads
151
+ _rtdetr_init_lock = threading.Lock()
152
+ _rtdetr_shared_model = None
153
+ _rtdetr_shared_processor = None
154
+ _rtdetr_loaded = False
155
+ _rtdetr_repo_id = 'ogkalu/comic-text-and-bubble-detector'
156
+
157
+ # Shared RT-DETR (ONNX) across process to avoid device/context storms
158
+ _rtdetr_onnx_init_lock = threading.Lock()
159
+ _rtdetr_onnx_shared_session = None
160
+ _rtdetr_onnx_loaded = False
161
+ _rtdetr_onnx_providers = None
162
+ _rtdetr_onnx_model_path = None
163
+ # Limit concurrent runs to avoid device hangs. Defaults to 2 for better parallelism.
164
+ # Can be overridden via env DML_MAX_CONCURRENT or config rtdetr_max_concurrency
165
+ try:
166
+ _rtdetr_onnx_max_concurrent = int(os.environ.get('DML_MAX_CONCURRENT', '2'))
167
+ except Exception:
168
+ _rtdetr_onnx_max_concurrent = 2
169
+ _rtdetr_onnx_sema = threading.Semaphore(max(1, _rtdetr_onnx_max_concurrent))
170
+ _rtdetr_onnx_sema_initialized = False
171
+
172
+ def __init__(self, config_path: str = "config.json"):
173
+ """
174
+ Initialize the bubble detector.
175
+
176
+ Args:
177
+ config_path: Path to configuration file
178
+ """
179
+ # Set thread limits early if environment indicates single-threaded mode
180
+ try:
181
+ if os.environ.get('OMP_NUM_THREADS') == '1':
182
+ # Already in single-threaded mode, ensure it's applied to this process
183
+ # Check if torch is available at module level before trying to use it
184
+ if TORCH_AVAILABLE and torch is not None:
185
+ try:
186
+ torch.set_num_threads(1)
187
+ except (RuntimeError, AttributeError):
188
+ pass
189
+ try:
190
+ import cv2
191
+ cv2.setNumThreads(1)
192
+ except (ImportError, AttributeError):
193
+ pass
194
+ except Exception:
195
+ pass
196
+
197
+ self.config_path = config_path
198
+ self.config = self._load_config()
199
+
200
+ # YOLOv8 components (original)
201
+ self.model = None
202
+ self.model_loaded = False
203
+ self.model_type = None # 'yolo', 'onnx', or 'torch'
204
+ self.onnx_session = None
205
+
206
+ # RT-DETR components (new)
207
+ self.rtdetr_model = None
208
+ self.rtdetr_processor = None
209
+ self.rtdetr_loaded = False
210
+ self.rtdetr_repo = 'ogkalu/comic-text-and-bubble-detector'
211
+
212
+ # RT-DETR (ONNX) backend components
213
+ self.rtdetr_onnx_session = None
214
+ self.rtdetr_onnx_loaded = False
215
+ self.rtdetr_onnx_repo = 'ogkalu/comic-text-and-bubble-detector'
216
+
217
+ # RT-DETR class definitions
218
+ self.CLASS_BUBBLE = 0 # Empty speech bubble
219
+ self.CLASS_TEXT_BUBBLE = 1 # Bubble with text
220
+ self.CLASS_TEXT_FREE = 2 # Text without bubble
221
+
222
+ # Detection settings
223
+ self.default_confidence = 0.3
224
+ self.default_iou_threshold = 0.45
225
+ # Allow override from settings
226
+ try:
227
+ ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
228
+ self.default_max_detections = int(ocr_cfg.get('bubble_max_detections', 100))
229
+ self.max_det_yolo = int(ocr_cfg.get('bubble_max_detections_yolo', self.default_max_detections))
230
+ self.max_det_rtdetr = int(ocr_cfg.get('bubble_max_detections_rtdetr', self.default_max_detections))
231
+ except Exception:
232
+ self.default_max_detections = 100
233
+ self.max_det_yolo = 100
234
+ self.max_det_rtdetr = 100
235
+
236
+ # Cache directory for ONNX conversions
237
+ self.cache_dir = os.environ.get('BUBBLE_CACHE_DIR', 'models')
238
+ os.makedirs(self.cache_dir, exist_ok=True)
239
+
240
+ # RT-DETR concurrency setting from config
241
+ try:
242
+ rtdetr_max_conc = int(ocr_cfg.get('rtdetr_max_concurrency', 2))
243
+ # Update class-level semaphore if not yet initialized or if value changed
244
+ if not BubbleDetector._rtdetr_onnx_sema_initialized or rtdetr_max_conc != BubbleDetector._rtdetr_onnx_max_concurrent:
245
+ BubbleDetector._rtdetr_onnx_max_concurrent = max(1, rtdetr_max_conc)
246
+ BubbleDetector._rtdetr_onnx_sema = threading.Semaphore(BubbleDetector._rtdetr_onnx_max_concurrent)
247
+ BubbleDetector._rtdetr_onnx_sema_initialized = True
248
+ logger.info(f"RT-DETR concurrency set to: {BubbleDetector._rtdetr_onnx_max_concurrent}")
249
+ except Exception as e:
250
+ logger.warning(f"Failed to set RT-DETR concurrency: {e}")
251
+
252
+ # GPU availability
253
+ self.use_gpu = TORCH_AVAILABLE and torch.cuda.is_available()
254
+ self.device = 'cuda' if self.use_gpu else 'cpu'
255
+
256
+ # Quantization/precision settings
257
+ adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
258
+ ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {}
259
+ env_quant = os.environ.get('MODEL_QUANTIZE', 'false').lower() == 'true'
260
+ self.quantize_enabled = bool(env_quant or adv_cfg.get('quantize_models', False) or ocr_cfg.get('quantize_bubble_detector', False))
261
+ self.quantize_dtype = str(adv_cfg.get('torch_precision', os.environ.get('TORCH_PRECISION', 'auto'))).lower()
262
+ # Prefer advanced.onnx_quantize; fall back to env or global quantize
263
+ self.onnx_quantize_enabled = bool(adv_cfg.get('onnx_quantize', os.environ.get('ONNX_QUANTIZE', 'false').lower() == 'true' or self.quantize_enabled))
264
+
265
+ # Stop flag support
266
+ self.stop_flag = None
267
+ self._stopped = False
268
+ self.log_callback = None
269
+
270
+ logger.info(f"🗨️ BubbleDetector initialized")
271
+ logger.info(f" GPU: {'Available' if self.use_gpu else 'Not available'}")
272
+ logger.info(f" YOLO: {'Available' if YOLO_AVAILABLE else 'Not installed'}")
273
+ logger.info(f" ONNX: {'Available' if ONNX_AVAILABLE else 'Not installed'}")
274
+ logger.info(f" RT-DETR: {'Available' if TRANSFORMERS_AVAILABLE else 'Not installed'}")
275
+ logger.info(f" Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} (torch_precision={self.quantize_dtype}, onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'})" )
276
+
277
+ def _load_config(self) -> Dict[str, Any]:
278
+ """Load configuration from file."""
279
+ if os.path.exists(self.config_path):
280
+ try:
281
+ with open(self.config_path, 'r', encoding='utf-8') as f:
282
+ return json.load(f)
283
+ except Exception as e:
284
+ logger.warning(f"Failed to load config: {e}")
285
+ return {}
286
+
287
+ def _save_config(self):
288
+ """Save configuration to file."""
289
+ try:
290
+ with open(self.config_path, 'w', encoding='utf-8') as f:
291
+ json.dump(self.config, f, indent=2)
292
+ except Exception as e:
293
+ logger.error(f"Failed to save config: {e}")
294
+
295
+ def set_stop_flag(self, stop_flag):
296
+ """Set the stop flag for checking interruptions"""
297
+ self.stop_flag = stop_flag
298
+ self._stopped = False
299
+
300
+ def set_log_callback(self, log_callback):
301
+ """Set log callback for GUI integration"""
302
+ self.log_callback = log_callback
303
+
304
+ def _check_stop(self) -> bool:
305
+ """Check if stop has been requested"""
306
+ if self._stopped:
307
+ return True
308
+ if self.stop_flag and self.stop_flag.is_set():
309
+ self._stopped = True
310
+ return True
311
+ # Check global manga translator cancellation
312
+ try:
313
+ from manga_translator import MangaTranslator
314
+ if MangaTranslator.is_globally_cancelled():
315
+ self._stopped = True
316
+ return True
317
+ except Exception:
318
+ pass
319
+ return False
320
+
321
+ def _log(self, message: str, level: str = "info"):
322
+ """Log message with stop suppression"""
323
+ # Suppress logs when stopped (allow only essential stop confirmation messages)
324
+ if self._check_stop():
325
+ essential_stop_keywords = [
326
+ "⏹️ Translation stopped by user",
327
+ "⏹️ Bubble detection stopped",
328
+ "cleanup", "🧹"
329
+ ]
330
+ if not any(keyword in message for keyword in essential_stop_keywords):
331
+ return
332
+
333
+ if self.log_callback:
334
+ self.log_callback(message, level)
335
+ else:
336
+ logger.info(message) if level == 'info' else getattr(logger, level, logger.info)(message)
337
+
338
+ def reset_stop_flags(self):
339
+ """Reset stop flags when starting new processing"""
340
+ self._stopped = False
341
+
342
+ def load_model(self, model_path: str, force_reload: bool = False) -> bool:
343
+ """
344
+ Load a YOLOv8 model for bubble detection.
345
+
346
+ Args:
347
+ model_path: Path to model file (.pt, .onnx, or .torchscript)
348
+ force_reload: Force reload even if model is already loaded
349
+
350
+ Returns:
351
+ True if model loaded successfully, False otherwise
352
+ """
353
+ try:
354
+ # If given a Hugging Face repo ID (e.g., 'owner/name'), fetch detector.onnx into models/
355
+ if model_path and (('/' in model_path) and not os.path.exists(model_path)):
356
+ try:
357
+ from huggingface_hub import hf_hub_download
358
+ os.makedirs(self.cache_dir, exist_ok=True)
359
+ logger.info(f"📥 Resolving repo '{model_path}' to detector.onnx in {self.cache_dir}...")
360
+ resolved = hf_hub_download(repo_id=model_path, filename='detector.onnx', cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False)
361
+ if resolved and os.path.exists(resolved):
362
+ model_path = resolved
363
+ logger.info(f"✅ Downloaded detector.onnx to: {model_path}")
364
+ except Exception as repo_err:
365
+ logger.error(f"Failed to download from repo '{model_path}': {repo_err}")
366
+ if not os.path.exists(model_path):
367
+ logger.error(f"Model file not found: {model_path}")
368
+ return False
369
+
370
+ # Check if it's the same model already loaded
371
+ if self.model_loaded and not force_reload:
372
+ last_path = self.config.get('last_model_path', '')
373
+ if last_path == model_path:
374
+ logger.info("Model already loaded (same path)")
375
+ return True
376
+ else:
377
+ logger.info(f"Model path changed from {last_path} to {model_path}, reloading...")
378
+ force_reload = True
379
+
380
+ # Clear previous model if force reload
381
+ if force_reload:
382
+ logger.info("Force reloading model...")
383
+ self.model = None
384
+ self.onnx_session = None
385
+ self.model_loaded = False
386
+ self.model_type = None
387
+
388
+ logger.info(f"📥 Loading bubble detection model: {model_path}")
389
+
390
+ # Determine model type by extension
391
+ ext = Path(model_path).suffix.lower()
392
+
393
+ if ext in ['.pt', '.pth']:
394
+ if not YOLO_AVAILABLE:
395
+ logger.warning("Ultralytics package not available in this build")
396
+ logger.info("Bubble detection will be disabled - this is normal for lightweight builds")
397
+ # Don't return False immediately, try other fallbacks
398
+ self.model_loaded = False
399
+ return False
400
+
401
+ # Load YOLOv8 model
402
+ try:
403
+ self.model = YOLO(model_path)
404
+ self.model_type = 'yolo'
405
+
406
+ # Set to eval mode
407
+ if hasattr(self.model, 'model'):
408
+ self.model.model.eval()
409
+
410
+ # Move to GPU if available
411
+ if self.use_gpu and TORCH_AVAILABLE:
412
+ try:
413
+ self.model.to('cuda')
414
+ except Exception as gpu_error:
415
+ logger.warning(f"Could not move model to GPU: {gpu_error}")
416
+
417
+ logger.info("✅ YOLOv8 model loaded successfully")
418
+ # Apply optional FP16 precision to reduce VRAM if enabled
419
+ if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
420
+ try:
421
+ m = self.model.model if hasattr(self.model, 'model') else self.model
422
+ m.half()
423
+ logger.info("🔻 Applied FP16 precision to YOLO model (GPU)")
424
+ except Exception as _e:
425
+ logger.warning(f"Could not switch YOLO model to FP16: {_e}")
426
+
427
+ except Exception as yolo_error:
428
+ logger.error(f"Failed to load YOLO model: {yolo_error}")
429
+ return False
430
+
431
+ elif ext == '.onnx':
432
+ if not ONNX_AVAILABLE:
433
+ logger.warning("ONNX Runtime not available in this build")
434
+ logger.info("ONNX model support disabled - this is normal for lightweight builds")
435
+ return False
436
+
437
+ try:
438
+ # Load ONNX model
439
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.use_gpu else ['CPUExecutionProvider']
440
+ session_path = model_path
441
+ if self.quantize_enabled:
442
+ try:
443
+ from onnxruntime.quantization import quantize_dynamic, QuantType
444
+ quant_path = os.path.splitext(model_path)[0] + ".int8.onnx"
445
+ if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
446
+ logger.info("🔻 Quantizing ONNX model weights to INT8 (dynamic)...")
447
+ quantize_dynamic(model_input=model_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
448
+ session_path = quant_path
449
+ self.config['last_onnx_quantized_path'] = quant_path
450
+ self._save_config()
451
+ logger.info(f"✅ Using quantized ONNX model: {quant_path}")
452
+ except Exception as qe:
453
+ logger.warning(f"ONNX quantization not applied: {qe}")
454
+ # Use conservative ORT memory options to reduce RAM growth
455
+ so = ort.SessionOptions()
456
+ try:
457
+ so.enable_mem_pattern = False
458
+ so.enable_cpu_mem_arena = False
459
+ except Exception:
460
+ pass
461
+ self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=providers)
462
+ self.model_type = 'onnx'
463
+
464
+ logger.info("✅ ONNX model loaded successfully")
465
+
466
+ except Exception as onnx_error:
467
+ logger.error(f"Failed to load ONNX model: {onnx_error}")
468
+ return False
469
+
470
+ elif ext == '.torchscript':
471
+ if not TORCH_AVAILABLE:
472
+ logger.warning("PyTorch not available in this build")
473
+ logger.info("TorchScript model support disabled - this is normal for lightweight builds")
474
+ return False
475
+
476
+ try:
477
+ # Add safety check for torch being None
478
+ if torch is None:
479
+ logger.error("PyTorch module is None - cannot load TorchScript model")
480
+ return False
481
+
482
+ # Load TorchScript model
483
+ self.model = torch.jit.load(model_path, map_location='cpu')
484
+ self.model.eval()
485
+ self.model_type = 'torch'
486
+
487
+ if self.use_gpu:
488
+ try:
489
+ self.model = self.model.cuda()
490
+ except Exception as gpu_error:
491
+ logger.warning(f"Could not move TorchScript model to GPU: {gpu_error}")
492
+
493
+ logger.info("✅ TorchScript model loaded successfully")
494
+
495
+ # Optional FP16 precision on GPU
496
+ if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE:
497
+ try:
498
+ self.model = self.model.half()
499
+ logger.info("🔻 Applied FP16 precision to TorchScript model (GPU)")
500
+ except Exception as _e:
501
+ logger.warning(f"Could not switch TorchScript model to FP16: {_e}")
502
+
503
+ except Exception as torch_error:
504
+ logger.error(f"Failed to load TorchScript model: {torch_error}")
505
+ return False
506
+
507
+ else:
508
+ logger.error(f"Unsupported model format: {ext}")
509
+ logger.info("Supported formats: .pt/.pth (YOLOv8), .onnx (ONNX), .torchscript (TorchScript)")
510
+ return False
511
+
512
+ # Only set loaded if we actually succeeded
513
+ self.model_loaded = True
514
+ self.config['last_model_path'] = model_path
515
+ self.config['model_type'] = self.model_type
516
+ self._save_config()
517
+
518
+ return True
519
+
520
+ except Exception as e:
521
+ logger.error(f"Failed to load model: {e}")
522
+ logger.error(traceback.format_exc())
523
+ self.model_loaded = False
524
+
525
+ # Provide helpful context for .exe users
526
+ logger.info("Note: If running from .exe, some ML libraries may not be included")
527
+ logger.info("This is normal for lightweight builds - bubble detection will be disabled")
528
+
529
+ return False
530
+
531
+ def load_rtdetr_model(self, model_path: str = None, model_id: str = None, force_reload: bool = False) -> bool:
532
+ """
533
+ Load RT-DETR model for advanced bubble and text detection.
534
+ This implementation avoids the 'meta tensor' copy error by:
535
+ - Serializing the entire load under a class lock (no concurrent loads)
536
+ - Loading directly onto the target device (CUDA if available) via device_map='auto'
537
+ - Avoiding .to() on a potentially-meta model; no device migration post-load
538
+
539
+ Args:
540
+ model_path: Optional path to local model
541
+ model_id: Optional HuggingFace model ID (default: 'ogkalu/comic-text-and-bubble-detector')
542
+ force_reload: Force reload even if already loaded
543
+
544
+ Returns:
545
+ True if successful, False otherwise
546
+ """
547
+ if not TRANSFORMERS_AVAILABLE:
548
+ logger.error("Transformers library required for RT-DETR. Install with: pip install transformers")
549
+ return False
550
+
551
+ if not PIL_AVAILABLE:
552
+ logger.error("PIL required for RT-DETR. Install with: pip install pillow")
553
+ return False
554
+
555
+ if self.rtdetr_loaded and not force_reload:
556
+ logger.info("RT-DETR model already loaded")
557
+ return True
558
+
559
+ # Fast path: if shared already loaded and not forcing reload, attach
560
+ if BubbleDetector._rtdetr_loaded and not force_reload:
561
+ self.rtdetr_model = BubbleDetector._rtdetr_shared_model
562
+ self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
563
+ self.rtdetr_loaded = True
564
+ logger.info("RT-DETR model attached from shared cache")
565
+ return True
566
+
567
+ # Serialize the ENTIRE loading sequence to avoid concurrent init issues
568
+ with BubbleDetector._rtdetr_init_lock:
569
+ try:
570
+ # Re-check after acquiring lock
571
+ if BubbleDetector._rtdetr_loaded and not force_reload:
572
+ self.rtdetr_model = BubbleDetector._rtdetr_shared_model
573
+ self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor
574
+ self.rtdetr_loaded = True
575
+ logger.info("RT-DETR model attached from shared cache (post-lock)")
576
+ return True
577
+
578
+ # Use custom model_id if provided, otherwise use default
579
+ repo_id = model_id if model_id else self.rtdetr_repo
580
+ logger.info(f"📥 Loading RT-DETR model from {repo_id}...")
581
+
582
+ # Ensure TorchDynamo/compile doesn't interfere on some builds
583
+ try:
584
+ os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
585
+ except Exception:
586
+ pass
587
+
588
+ # Decide device strategy
589
+ gpu_available = bool(TORCH_AVAILABLE and hasattr(torch, 'cuda') and torch.cuda.is_available())
590
+ device_map = 'auto' if gpu_available else None
591
+ # Choose dtype
592
+ dtype = None
593
+ if TORCH_AVAILABLE:
594
+ try:
595
+ dtype = torch.float16 if gpu_available else torch.float32
596
+ except Exception:
597
+ dtype = None
598
+ low_cpu = True if gpu_available else False
599
+
600
+ # Load processor (once)
601
+ self.rtdetr_processor = RTDetrImageProcessor.from_pretrained(
602
+ repo_id,
603
+ size={"width": 640, "height": 640},
604
+ cache_dir=self.cache_dir if not model_path else None
605
+ )
606
+
607
+ # Prepare kwargs for from_pretrained
608
+ from_kwargs = {
609
+ 'cache_dir': self.cache_dir if not model_path else None,
610
+ 'low_cpu_mem_usage': low_cpu,
611
+ 'device_map': device_map,
612
+ }
613
+ if dtype is not None:
614
+ from_kwargs['dtype'] = dtype
615
+
616
+ # First attempt: load directly to target (CUDA if available)
617
+ try:
618
+ self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
619
+ model_path if model_path else repo_id,
620
+ **from_kwargs,
621
+ )
622
+ except Exception as primary_err:
623
+ # Fallback to a simple CPU load (no device move) if CUDA path fails
624
+ logger.warning(f"RT-DETR primary load failed ({primary_err}); retrying on CPU...")
625
+ from_kwargs_fallback = {
626
+ 'cache_dir': self.cache_dir if not model_path else None,
627
+ 'low_cpu_mem_usage': False,
628
+ 'device_map': None,
629
+ }
630
+ if TORCH_AVAILABLE:
631
+ from_kwargs_fallback['dtype'] = torch.float32
632
+ self.rtdetr_model = RTDetrForObjectDetection.from_pretrained(
633
+ model_path if model_path else repo_id,
634
+ **from_kwargs_fallback,
635
+ )
636
+
637
+ # Optional dynamic quantization for linear layers (CPU only)
638
+ if self.quantize_enabled and TORCH_AVAILABLE and (not gpu_available):
639
+ try:
640
+ try:
641
+ import torch.ao.quantization as tq
642
+ quantize_dynamic = tq.quantize_dynamic # type: ignore
643
+ except Exception:
644
+ import torch.quantization as tq # type: ignore
645
+ quantize_dynamic = tq.quantize_dynamic # type: ignore
646
+ self.rtdetr_model = quantize_dynamic(self.rtdetr_model, {torch.nn.Linear}, dtype=torch.qint8)
647
+ logger.info("🔻 Applied dynamic INT8 quantization to RT-DETR linear layers (CPU)")
648
+ except Exception as qe:
649
+ logger.warning(f"RT-DETR dynamic quantization skipped: {qe}")
650
+
651
+ # Finalize
652
+ self.rtdetr_model.eval()
653
+
654
+ # Sanity check: ensure no parameter is left on 'meta' device
655
+ try:
656
+ for n, p in self.rtdetr_model.named_parameters():
657
+ dev = getattr(p, 'device', None)
658
+ if dev is not None and getattr(dev, 'type', '') == 'meta':
659
+ raise RuntimeError(f"Parameter {n} is on 'meta' device after load")
660
+ except Exception as e:
661
+ logger.error(f"RT-DETR load sanity check failed: {e}")
662
+ self.rtdetr_loaded = False
663
+ return False
664
+
665
+ # Publish shared cache
666
+ BubbleDetector._rtdetr_shared_model = self.rtdetr_model
667
+ BubbleDetector._rtdetr_shared_processor = self.rtdetr_processor
668
+ BubbleDetector._rtdetr_loaded = True
669
+ BubbleDetector._rtdetr_repo_id = repo_id
670
+
671
+ self.rtdetr_loaded = True
672
+
673
+ # Save the model ID that was used
674
+ self.config['rtdetr_loaded'] = True
675
+ self.config['rtdetr_model_id'] = repo_id
676
+ self._save_config()
677
+
678
+ loc = 'CUDA' if gpu_available else 'CPU'
679
+ logger.info(f"✅ RT-DETR model loaded successfully ({loc})")
680
+ logger.info(" Classes: Empty bubbles, Text bubbles, Free text")
681
+
682
+ # Auto-convert to ONNX for RT-DETR only if explicitly enabled
683
+ if os.environ.get('AUTO_CONVERT_RTDETR_ONNX', 'false').lower() == 'true':
684
+ onnx_path = os.path.join(self.cache_dir, 'rtdetr_comic.onnx')
685
+ if self.convert_to_onnx('rtdetr', onnx_path):
686
+ logger.info("🚀 RT-DETR converted to ONNX for faster inference")
687
+ # Store ONNX path for later use
688
+ self.config['rtdetr_onnx_path'] = onnx_path
689
+ self._save_config()
690
+ # Optionally quantize ONNX for reduced RAM
691
+ if self.onnx_quantize_enabled:
692
+ try:
693
+ from onnxruntime.quantization import quantize_dynamic, QuantType
694
+ quant_path = os.path.splitext(onnx_path)[0] + ".int8.onnx"
695
+ if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
696
+ logger.info("🔻 Quantizing RT-DETR ONNX to INT8 (dynamic)...")
697
+ quantize_dynamic(model_input=onnx_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul'])
698
+ self.config['rtdetr_onnx_quantized_path'] = quant_path
699
+ self._save_config()
700
+ logger.info(f"✅ Quantized RT-DETR ONNX saved to: {quant_path}")
701
+ except Exception as qe:
702
+ logger.warning(f"ONNX quantization for RT-DETR skipped: {qe}")
703
+ else:
704
+ logger.info("ℹ️ Skipping RT-DETR ONNX export (converter not supported in current environment)")
705
+
706
+ return True
707
+ except Exception as e:
708
+ logger.error(f"❌ Failed to load RT-DETR: {e}")
709
+ self.rtdetr_loaded = False
710
+ return False
711
+
712
+ def check_rtdetr_available(self, model_id: str = None) -> bool:
713
+ """
714
+ Check if RT-DETR model is available (cached).
715
+
716
+ Args:
717
+ model_id: Optional HuggingFace model ID
718
+
719
+ Returns:
720
+ True if model is cached and available
721
+ """
722
+ try:
723
+ from pathlib import Path
724
+
725
+ # Use provided model_id or default
726
+ repo_id = model_id if model_id else self.rtdetr_repo
727
+
728
+ # Check HuggingFace cache
729
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
730
+ model_id_formatted = repo_id.replace("/", "--")
731
+
732
+ # Look for model folder
733
+ model_folders = list(cache_dir.glob(f"models--{model_id_formatted}*"))
734
+
735
+ if model_folders:
736
+ for folder in model_folders:
737
+ if (folder / "snapshots").exists():
738
+ snapshots = list((folder / "snapshots").iterdir())
739
+ if snapshots:
740
+ return True
741
+
742
+ return False
743
+
744
+ except Exception:
745
+ return False
746
+
747
+ def detect_bubbles(self,
748
+ image_path: str,
749
+ confidence: float = None,
750
+ iou_threshold: float = None,
751
+ max_detections: int = None,
752
+ use_rtdetr: bool = None) -> List[Tuple[int, int, int, int]]:
753
+ """
754
+ Detect speech bubbles in an image (backward compatible method).
755
+
756
+ Args:
757
+ image_path: Path to image file
758
+ confidence: Minimum confidence threshold (0-1)
759
+ iou_threshold: IOU threshold for NMS (0-1)
760
+ max_detections: Maximum number of detections to return
761
+ use_rtdetr: If True, use RT-DETR instead of YOLOv8 (if available)
762
+
763
+ Returns:
764
+ List of bubble bounding boxes as (x, y, width, height) tuples
765
+ """
766
+ # Check for stop at start
767
+ if self._check_stop():
768
+ self._log("⏹️ Bubble detection stopped by user", "warning")
769
+ return []
770
+
771
+ # Decide which model to use
772
+ if use_rtdetr is None:
773
+ # Auto-select: prefer RT-DETR if available
774
+ use_rtdetr = self.rtdetr_loaded
775
+
776
+ if use_rtdetr:
777
+ # Prefer ONNX backend if available, else PyTorch
778
+ if getattr(self, 'rtdetr_onnx_loaded', False):
779
+ results = self.detect_with_rtdetr_onnx(
780
+ image_path=image_path,
781
+ confidence=confidence,
782
+ return_all_bubbles=True
783
+ )
784
+ return results
785
+ if self.rtdetr_loaded:
786
+ results = self.detect_with_rtdetr(
787
+ image_path=image_path,
788
+ confidence=confidence,
789
+ return_all_bubbles=True
790
+ )
791
+ return results
792
+
793
+ # Original YOLOv8 detection
794
+ if not self.model_loaded:
795
+ logger.error("No model loaded. Call load_model() first.")
796
+ return []
797
+
798
+ # Use defaults if not specified
799
+ confidence = confidence or self.default_confidence
800
+ iou_threshold = iou_threshold or self.default_iou_threshold
801
+ max_detections = max_detections or self.default_max_detections
802
+
803
+ try:
804
+ # Load image
805
+ image = cv2.imread(image_path)
806
+ if image is None:
807
+ logger.error(f"Failed to load image: {image_path}")
808
+ return []
809
+
810
+ h, w = image.shape[:2]
811
+ self._log(f"🔍 Detecting bubbles in {w}x{h} image")
812
+
813
+ # Check for stop before inference
814
+ if self._check_stop():
815
+ self._log("⏹️ Bubble detection inference stopped by user", "warning")
816
+ return []
817
+
818
+ if self.model_type == 'yolo':
819
+ # YOLOv8 inference
820
+ results = self.model(
821
+ image_path,
822
+ conf=confidence,
823
+ iou=iou_threshold,
824
+ max_det=min(max_detections, getattr(self, 'max_det_yolo', max_detections)),
825
+ verbose=False
826
+ )
827
+
828
+ bubbles = []
829
+ for r in results:
830
+ if r.boxes is not None:
831
+ for box in r.boxes:
832
+ # Get box coordinates
833
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
834
+ x, y = int(x1), int(y1)
835
+ width = int(x2 - x1)
836
+ height = int(y2 - y1)
837
+
838
+ # Get confidence
839
+ conf = float(box.conf[0])
840
+
841
+ # Add to list
842
+ if len(bubbles) < max_detections:
843
+ bubbles.append((x, y, width, height))
844
+
845
+ logger.debug(f" Bubble: ({x},{y}) {width}x{height} conf={conf:.2f}")
846
+
847
+ elif self.model_type == 'onnx':
848
+ # ONNX inference
849
+ bubbles = self._detect_with_onnx(image, confidence, iou_threshold, max_detections)
850
+
851
+ elif self.model_type == 'torch':
852
+ # TorchScript inference
853
+ bubbles = self._detect_with_torchscript(image, confidence, iou_threshold, max_detections)
854
+
855
+ else:
856
+ logger.error(f"Unknown model type: {self.model_type}")
857
+ return []
858
+
859
+ logger.info(f"✅ Detected {len(bubbles)} speech bubbles")
860
+ time.sleep(0.1) # Brief pause for stability
861
+ logger.debug("💤 Bubble detection pausing briefly for stability")
862
+ return bubbles
863
+
864
+ except Exception as e:
865
+ logger.error(f"Detection failed: {e}")
866
+ logger.error(traceback.format_exc())
867
+ return []
868
+
869
+ def detect_with_rtdetr(self,
870
+ image_path: str = None,
871
+ image: np.ndarray = None,
872
+ confidence: float = None,
873
+ return_all_bubbles: bool = False) -> Any:
874
+ """
875
+ Detect using RT-DETR model with 3-class detection (PyTorch backend).
876
+
877
+ Args:
878
+ image_path: Path to image file
879
+ image: Image array (BGR format)
880
+ confidence: Confidence threshold
881
+ return_all_bubbles: If True, return list of bubble boxes (for compatibility)
882
+ If False, return dict with all classes
883
+
884
+ Returns:
885
+ List of bubbles if return_all_bubbles=True, else dict with classes
886
+ """
887
+ # Check for stop at start
888
+ if self._check_stop():
889
+ self._log("⏹️ RT-DETR detection stopped by user", "warning")
890
+ if return_all_bubbles:
891
+ return []
892
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
893
+
894
+ if not self.rtdetr_loaded:
895
+ self._log("RT-DETR not loaded. Call load_rtdetr_model() first.", "warning")
896
+ if return_all_bubbles:
897
+ return []
898
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
899
+
900
+ confidence = confidence or self.default_confidence
901
+
902
+ try:
903
+ # Load image
904
+ if image_path:
905
+ image = cv2.imread(image_path)
906
+ elif image is None:
907
+ logger.error("No image provided")
908
+ if return_all_bubbles:
909
+ return []
910
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
911
+
912
+ # Convert BGR to RGB for PIL
913
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
914
+ pil_image = Image.fromarray(image_rgb)
915
+
916
+ # Prepare image for model
917
+ inputs = self.rtdetr_processor(images=pil_image, return_tensors="pt")
918
+
919
+ # Move inputs to the same device as the model and match model dtype for floating tensors
920
+ model_device = next(self.rtdetr_model.parameters()).device if self.rtdetr_model is not None else (torch.device('cpu') if TORCH_AVAILABLE else 'cpu')
921
+ model_dtype = None
922
+ if TORCH_AVAILABLE and self.rtdetr_model is not None:
923
+ try:
924
+ model_dtype = next(self.rtdetr_model.parameters()).dtype
925
+ except Exception:
926
+ model_dtype = None
927
+
928
+ if TORCH_AVAILABLE:
929
+ new_inputs = {}
930
+ for k, v in inputs.items():
931
+ if isinstance(v, torch.Tensor):
932
+ v = v.to(model_device)
933
+ if model_dtype is not None and torch.is_floating_point(v):
934
+ v = v.to(model_dtype)
935
+ new_inputs[k] = v
936
+ inputs = new_inputs
937
+
938
+ # Run inference with autocast when model is half/bfloat16 on CUDA
939
+ use_amp = TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == 'cuda' and (model_dtype in (torch.float16, torch.bfloat16))
940
+ autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
941
+
942
+ with torch.no_grad():
943
+ if use_amp and autocast_dtype is not None:
944
+ with torch.autocast('cuda', dtype=autocast_dtype):
945
+ outputs = self.rtdetr_model(**inputs)
946
+ else:
947
+ outputs = self.rtdetr_model(**inputs)
948
+
949
+ # Brief pause for stability after inference
950
+ time.sleep(0.1)
951
+ logger.debug("💤 RT-DETR inference pausing briefly for stability")
952
+
953
+ # Post-process results
954
+ target_sizes = torch.tensor([pil_image.size[::-1]]) if TORCH_AVAILABLE else None
955
+ if TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == "cuda":
956
+ target_sizes = target_sizes.to(model_device)
957
+
958
+ results = self.rtdetr_processor.post_process_object_detection(
959
+ outputs,
960
+ target_sizes=target_sizes,
961
+ threshold=confidence
962
+ )[0]
963
+
964
+ # Apply per-detector cap if configured
965
+ cap = getattr(self, 'max_det_rtdetr', self.default_max_detections)
966
+ if cap and len(results['boxes']) > cap:
967
+ # Keep top-scoring first
968
+ scores = results['scores']
969
+ top_idx = scores.topk(k=cap).indices if hasattr(scores, 'topk') else range(cap)
970
+ results = {
971
+ 'boxes': [results['boxes'][i] for i in top_idx],
972
+ 'scores': [results['scores'][i] for i in top_idx],
973
+ 'labels': [results['labels'][i] for i in top_idx]
974
+ }
975
+
976
+ logger.info(f"📊 RT-DETR found {len(results['boxes'])} detections above {confidence:.2f} confidence")
977
+
978
+ # Apply NMS to remove duplicate detections
979
+ # Group detections by class
980
+ class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []}
981
+
982
+ for box, score, label in zip(results['boxes'], results['scores'], results['labels']):
983
+ x1, y1, x2, y2 = map(float, box.tolist())
984
+ label_id = label.item()
985
+ if label_id in class_detections:
986
+ class_detections[label_id].append((x1, y1, x2, y2, float(score.item())))
987
+
988
+ # Apply NMS per class to remove duplicates
989
+ def compute_iou(box1, box2):
990
+ """Compute IoU between two boxes (x1, y1, x2, y2)"""
991
+ x1_1, y1_1, x2_1, y2_1 = box1[:4]
992
+ x1_2, y1_2, x2_2, y2_2 = box2[:4]
993
+
994
+ # Intersection
995
+ x_left = max(x1_1, x1_2)
996
+ y_top = max(y1_1, y1_2)
997
+ x_right = min(x2_1, x2_2)
998
+ y_bottom = min(y2_1, y2_2)
999
+
1000
+ if x_right < x_left or y_bottom < y_top:
1001
+ return 0.0
1002
+
1003
+ intersection = (x_right - x_left) * (y_bottom - y_top)
1004
+
1005
+ # Union
1006
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
1007
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
1008
+ union = area1 + area2 - intersection
1009
+
1010
+ return intersection / union if union > 0 else 0.0
1011
+
1012
+ def apply_nms(boxes_with_scores, iou_threshold=0.45):
1013
+ """Apply Non-Maximum Suppression"""
1014
+ if not boxes_with_scores:
1015
+ return []
1016
+
1017
+ # Sort by score (descending)
1018
+ sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True)
1019
+ keep = []
1020
+
1021
+ while sorted_boxes:
1022
+ # Keep the box with highest score
1023
+ current = sorted_boxes.pop(0)
1024
+ keep.append(current)
1025
+
1026
+ # Remove boxes with high IoU
1027
+ sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold]
1028
+
1029
+ return keep
1030
+
1031
+ # Apply NMS and organize by class
1032
+ detections = {
1033
+ 'bubbles': [], # Empty speech bubbles
1034
+ 'text_bubbles': [], # Bubbles with text
1035
+ 'text_free': [] # Text without bubbles
1036
+ }
1037
+
1038
+ for class_id, boxes_list in class_detections.items():
1039
+ nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold)
1040
+
1041
+ for x1, y1, x2, y2, scr in nms_boxes:
1042
+ width = int(x2 - x1)
1043
+ height = int(y2 - y1)
1044
+ # Store as (x, y, width, height) to match YOLOv8 format
1045
+ bbox = (int(x1), int(y1), width, height)
1046
+
1047
+ if class_id == self.CLASS_BUBBLE:
1048
+ detections['bubbles'].append(bbox)
1049
+ elif class_id == self.CLASS_TEXT_BUBBLE:
1050
+ detections['text_bubbles'].append(bbox)
1051
+ elif class_id == self.CLASS_TEXT_FREE:
1052
+ detections['text_free'].append(bbox)
1053
+
1054
+ # Stop early if we hit the configured cap across all classes
1055
+ total_count = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1056
+ if total_count >= (self.config.get('manga_settings', {}).get('ocr', {}).get('bubble_max_detections', self.default_max_detections) if isinstance(self.config, dict) else self.default_max_detections):
1057
+ break
1058
+
1059
+ # Log results
1060
+ total = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free'])
1061
+ logger.info(f"✅ RT-DETR detected {total} objects:")
1062
+ logger.info(f" - Empty bubbles: {len(detections['bubbles'])}")
1063
+ logger.info(f" - Text bubbles: {len(detections['text_bubbles'])}")
1064
+ logger.info(f" - Free text: {len(detections['text_free'])}")
1065
+
1066
+ # Return format based on compatibility mode
1067
+ if return_all_bubbles:
1068
+ # Return all bubbles (empty + with text) for backward compatibility
1069
+ all_bubbles = detections['bubbles'] + detections['text_bubbles']
1070
+ return all_bubbles
1071
+ else:
1072
+ return detections
1073
+
1074
+ except Exception as e:
1075
+ logger.error(f"RT-DETR detection failed: {e}")
1076
+ logger.error(traceback.format_exc())
1077
+ if return_all_bubbles:
1078
+ return []
1079
+ return {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1080
+
1081
+ def detect_all_text_regions(self, image_path: str = None, image: np.ndarray = None) -> List[Tuple[int, int, int, int]]:
1082
+ """
1083
+ Detect all text regions using RT-DETR (both in bubbles and free text).
1084
+
1085
+ Returns:
1086
+ List of bounding boxes for all text regions
1087
+ """
1088
+ if not self.rtdetr_loaded:
1089
+ logger.warning("RT-DETR required for text detection")
1090
+ return []
1091
+
1092
+ detections = self.detect_with_rtdetr(image_path=image_path, image=image, return_all_bubbles=False)
1093
+
1094
+ # Combine text bubbles and free text
1095
+ all_text = detections['text_bubbles'] + detections['text_free']
1096
+
1097
+ logger.info(f"📝 Found {len(all_text)} text regions total")
1098
+ return all_text
1099
+
1100
+ def _detect_with_onnx(self, image: np.ndarray, confidence: float,
1101
+ iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1102
+ """Run detection using ONNX model."""
1103
+ # Preprocess image
1104
+ img_size = 640 # Standard YOLOv8 input size
1105
+ img_resized = cv2.resize(image, (img_size, img_size))
1106
+ img_norm = img_resized.astype(np.float32) / 255.0
1107
+ img_transposed = np.transpose(img_norm, (2, 0, 1))
1108
+ img_batch = np.expand_dims(img_transposed, axis=0)
1109
+
1110
+ # Run inference
1111
+ input_name = self.onnx_session.get_inputs()[0].name
1112
+ outputs = self.onnx_session.run(None, {input_name: img_batch})
1113
+
1114
+ # Process outputs (YOLOv8 format)
1115
+ predictions = outputs[0][0] # Remove batch dimension
1116
+
1117
+ # Filter by confidence and apply NMS
1118
+ bubbles = []
1119
+ boxes = []
1120
+ scores = []
1121
+
1122
+ for pred in predictions.T: # Transpose to get predictions per detection
1123
+ if len(pred) >= 5:
1124
+ x_center, y_center, width, height, obj_conf = pred[:5]
1125
+
1126
+ if obj_conf >= confidence:
1127
+ # Convert to corner coordinates
1128
+ x1 = x_center - width / 2
1129
+ y1 = y_center - height / 2
1130
+
1131
+ # Scale to original image size
1132
+ h, w = image.shape[:2]
1133
+ x1 = int(x1 * w / img_size)
1134
+ y1 = int(y1 * h / img_size)
1135
+ width = int(width * w / img_size)
1136
+ height = int(height * h / img_size)
1137
+
1138
+ boxes.append([x1, y1, x1 + width, y1 + height])
1139
+ scores.append(float(obj_conf))
1140
+
1141
+ # Apply NMS
1142
+ if boxes:
1143
+ indices = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold)
1144
+ if len(indices) > 0:
1145
+ indices = indices.flatten()[:max_detections]
1146
+ for i in indices:
1147
+ x1, y1, x2, y2 = boxes[i]
1148
+ bubbles.append((x1, y1, x2 - x1, y2 - y1))
1149
+
1150
+ return bubbles
1151
+
1152
+ def _detect_with_torchscript(self, image: np.ndarray, confidence: float,
1153
+ iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]:
1154
+ """Run detection using TorchScript model."""
1155
+ # Similar to ONNX but using PyTorch tensors
1156
+ img_size = 640
1157
+ img_resized = cv2.resize(image, (img_size, img_size))
1158
+ img_norm = img_resized.astype(np.float32) / 255.0
1159
+ img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0)
1160
+
1161
+ if self.use_gpu:
1162
+ img_tensor = img_tensor.cuda()
1163
+
1164
+ with torch.no_grad():
1165
+ outputs = self.model(img_tensor)
1166
+
1167
+ # Process outputs similar to ONNX
1168
+ # Implementation depends on exact model output format
1169
+ # This is a placeholder - adjust based on your model
1170
+ return []
1171
+
1172
+ def visualize_detections(self, image_path: str, bubbles: List[Tuple[int, int, int, int]] = None,
1173
+ output_path: str = None, use_rtdetr: bool = False) -> np.ndarray:
1174
+ """
1175
+ Visualize detected bubbles on the image.
1176
+
1177
+ Args:
1178
+ image_path: Path to original image
1179
+ bubbles: List of bubble bounding boxes (if None, will detect)
1180
+ output_path: Optional path to save visualization
1181
+ use_rtdetr: Use RT-DETR for visualization with class colors
1182
+
1183
+ Returns:
1184
+ Image with drawn bounding boxes
1185
+ """
1186
+ image = cv2.imread(image_path)
1187
+ if image is None:
1188
+ logger.error(f"Failed to load image: {image_path}")
1189
+ return None
1190
+
1191
+ vis_image = image.copy()
1192
+
1193
+ if use_rtdetr and self.rtdetr_loaded:
1194
+ # RT-DETR visualization with different colors per class
1195
+ detections = self.detect_with_rtdetr(image_path=image_path, return_all_bubbles=False)
1196
+
1197
+ # Colors for each class
1198
+ colors = {
1199
+ 'bubbles': (0, 255, 0), # Green for empty bubbles
1200
+ 'text_bubbles': (255, 0, 0), # Blue for text bubbles
1201
+ 'text_free': (0, 0, 255) # Red for free text
1202
+ }
1203
+
1204
+ # Draw detections
1205
+ for class_name, bboxes in detections.items():
1206
+ color = colors[class_name]
1207
+
1208
+ for i, (x, y, w, h) in enumerate(bboxes):
1209
+ # Draw rectangle
1210
+ cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2)
1211
+
1212
+ # Add label
1213
+ label = f"{class_name.replace('_', ' ').title()} {i+1}"
1214
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1215
+ cv2.rectangle(vis_image, (x, y - label_size[1] - 4),
1216
+ (x + label_size[0], y), color, -1)
1217
+ cv2.putText(vis_image, label, (x, y - 2),
1218
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1219
+ else:
1220
+ # Original YOLOv8 visualization
1221
+ if bubbles is None:
1222
+ bubbles = self.detect_bubbles(image_path)
1223
+
1224
+ # Draw bounding boxes
1225
+ for i, (x, y, w, h) in enumerate(bubbles):
1226
+ # Draw rectangle
1227
+ color = (0, 255, 0) # Green
1228
+ thickness = 2
1229
+ cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, thickness)
1230
+
1231
+ # Add label
1232
+ label = f"Bubble {i+1}"
1233
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
1234
+ cv2.rectangle(vis_image, (x, y - label_size[1] - 4), (x + label_size[0], y), color, -1)
1235
+ cv2.putText(vis_image, label, (x, y - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
1236
+
1237
+ # Save if output path provided
1238
+ if output_path:
1239
+ cv2.imwrite(output_path, vis_image)
1240
+ logger.info(f"💾 Visualization saved to: {output_path}")
1241
+
1242
+ return vis_image
1243
+
1244
+ def convert_to_onnx(self, model_path: str, output_path: str = None) -> bool:
1245
+ """
1246
+ Convert a YOLOv8 or RT-DETR model to ONNX format.
1247
+
1248
+ Args:
1249
+ model_path: Path to model file or 'rtdetr' for loaded RT-DETR
1250
+ output_path: Path for ONNX output (auto-generated if None)
1251
+
1252
+ Returns:
1253
+ True if conversion successful, False otherwise
1254
+ """
1255
+ try:
1256
+ logger.info(f"🔄 Converting {model_path} to ONNX...")
1257
+
1258
+ # Generate output path if not provided
1259
+ if output_path is None:
1260
+ if model_path == 'rtdetr' and self.rtdetr_loaded:
1261
+ base_name = 'rtdetr_comic'
1262
+ else:
1263
+ base_name = Path(model_path).stem
1264
+ output_path = os.path.join(self.cache_dir, f"{base_name}.onnx")
1265
+
1266
+ # Check if already exists
1267
+ if os.path.exists(output_path) and not os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true':
1268
+ logger.info(f"✅ ONNX model already exists: {output_path}")
1269
+ return True
1270
+
1271
+ # Handle RT-DETR conversion
1272
+ if model_path == 'rtdetr' and self.rtdetr_loaded:
1273
+ if not TORCH_AVAILABLE:
1274
+ logger.error("PyTorch required for RT-DETR ONNX conversion")
1275
+ return False
1276
+
1277
+ # RT-DETR specific conversion
1278
+ self.rtdetr_model.eval()
1279
+
1280
+ # Create dummy input (pixel values): BxCxHxW
1281
+ dummy_input = torch.randn(1, 3, 640, 640)
1282
+ if self.device == 'cuda':
1283
+ dummy_input = dummy_input.to('cuda')
1284
+
1285
+ # Wrap the model to return only tensors (logits, pred_boxes)
1286
+ class _RTDetrExportWrapper(torch.nn.Module):
1287
+ def __init__(self, mdl):
1288
+ super().__init__()
1289
+ self.mdl = mdl
1290
+ def forward(self, images):
1291
+ out = self.mdl(pixel_values=images)
1292
+ # Handle dict/ModelOutput/tuple outputs
1293
+ logits = None
1294
+ boxes = None
1295
+ try:
1296
+ if isinstance(out, dict):
1297
+ logits = out.get('logits', None)
1298
+ boxes = out.get('pred_boxes', out.get('boxes', None))
1299
+ else:
1300
+ logits = getattr(out, 'logits', None)
1301
+ boxes = getattr(out, 'pred_boxes', getattr(out, 'boxes', None))
1302
+ except Exception:
1303
+ pass
1304
+ if (logits is None or boxes is None) and isinstance(out, (tuple, list)) and len(out) >= 2:
1305
+ logits, boxes = out[0], out[1]
1306
+ return logits, boxes
1307
+
1308
+ wrapper = _RTDetrExportWrapper(self.rtdetr_model)
1309
+ if self.device == 'cuda':
1310
+ wrapper = wrapper.to('cuda')
1311
+
1312
+ # Try PyTorch 2.x dynamo_export first (more tolerant of newer aten ops)
1313
+ try:
1314
+ success = False
1315
+ try:
1316
+ from torch.onnx import dynamo_export
1317
+ try:
1318
+ exp = dynamo_export(wrapper, dummy_input)
1319
+ except TypeError:
1320
+ # Older PyTorch dynamo_export may not support this calling convention
1321
+ exp = dynamo_export(wrapper, dummy_input)
1322
+ # exp may have save(); otherwise, it may expose model_proto
1323
+ try:
1324
+ exp.save(output_path) # type: ignore
1325
+ success = True
1326
+ except Exception:
1327
+ try:
1328
+ import onnx as _onnx
1329
+ _onnx.save(exp.model_proto, output_path) # type: ignore
1330
+ success = True
1331
+ except Exception as _se:
1332
+ logger.warning(f"dynamo_export produced model but could not save: {_se}")
1333
+ except Exception as de:
1334
+ logger.warning(f"dynamo_export failed; falling back to legacy exporter: {de}")
1335
+ if success:
1336
+ logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (dynamo_export)")
1337
+ return True
1338
+ except Exception as de2:
1339
+ logger.warning(f"dynamo_export path error: {de2}")
1340
+
1341
+ # Legacy exporter with opset fallback
1342
+ last_err = None
1343
+ for opset in [19, 18, 17, 16, 15, 14, 13]:
1344
+ try:
1345
+ torch.onnx.export(
1346
+ wrapper,
1347
+ dummy_input,
1348
+ output_path,
1349
+ export_params=True,
1350
+ opset_version=opset,
1351
+ do_constant_folding=True,
1352
+ input_names=['pixel_values'],
1353
+ output_names=['logits', 'boxes'],
1354
+ dynamic_axes={
1355
+ 'pixel_values': {0: 'batch', 2: 'height', 3: 'width'},
1356
+ 'logits': {0: 'batch'},
1357
+ 'boxes': {0: 'batch'}
1358
+ }
1359
+ )
1360
+ logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (opset {opset})")
1361
+ return True
1362
+ except Exception as _e:
1363
+ last_err = _e
1364
+ try:
1365
+ msg = str(_e)
1366
+ except Exception:
1367
+ msg = ''
1368
+ logger.warning(f"RT-DETR ONNX export failed at opset {opset}: {msg}")
1369
+ continue
1370
+
1371
+ logger.error(f"All RT-DETR ONNX export attempts failed. Last error: {last_err}")
1372
+ return False
1373
+
1374
+ # Handle YOLOv8 conversion - FIXED
1375
+ elif YOLO_AVAILABLE and os.path.exists(model_path):
1376
+ logger.info(f"Loading YOLOv8 model from: {model_path}")
1377
+
1378
+ # Load model
1379
+ model = YOLO(model_path)
1380
+
1381
+ # Export to ONNX - this returns the path to the exported model
1382
+ logger.info("Exporting to ONNX format...")
1383
+ exported_path = model.export(format='onnx', imgsz=640, simplify=True)
1384
+
1385
+ # exported_path could be a string or Path object
1386
+ exported_path = str(exported_path) if exported_path else None
1387
+
1388
+ if exported_path and os.path.exists(exported_path):
1389
+ # Move to desired location if different
1390
+ if exported_path != output_path:
1391
+ import shutil
1392
+ logger.info(f"Moving ONNX from {exported_path} to {output_path}")
1393
+ shutil.move(exported_path, output_path)
1394
+
1395
+ logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1396
+ return True
1397
+ else:
1398
+ # Fallback: check if it was created with expected name
1399
+ expected_onnx = model_path.replace('.pt', '.onnx')
1400
+ if os.path.exists(expected_onnx):
1401
+ if expected_onnx != output_path:
1402
+ import shutil
1403
+ shutil.move(expected_onnx, output_path)
1404
+ logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}")
1405
+ return True
1406
+ else:
1407
+ logger.error(f"ONNX export failed - no output file found")
1408
+ return False
1409
+
1410
+ else:
1411
+ logger.error(f"Cannot convert {model_path}: Model not found or dependencies missing")
1412
+ return False
1413
+
1414
+ except Exception as e:
1415
+ logger.error(f"Conversion failed: {e}")
1416
+ # Avoid noisy full stack trace in production logs; return False gracefully
1417
+ return False
1418
+
1419
+ def batch_detect(self, image_paths: List[str], **kwargs) -> Dict[str, List[Tuple[int, int, int, int]]]:
1420
+ """
1421
+ Detect bubbles in multiple images.
1422
+
1423
+ Args:
1424
+ image_paths: List of image paths
1425
+ **kwargs: Detection parameters (confidence, iou_threshold, max_detections, use_rtdetr)
1426
+
1427
+ Returns:
1428
+ Dictionary mapping image paths to bubble lists
1429
+ """
1430
+ results = {}
1431
+
1432
+ for i, image_path in enumerate(image_paths):
1433
+ logger.info(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")
1434
+ bubbles = self.detect_bubbles(image_path, **kwargs)
1435
+ results[image_path] = bubbles
1436
+
1437
+ return results
1438
+
1439
+ def unload(self, release_shared: bool = False):
1440
+ """Release model resources held by this detector instance.
1441
+ Args:
1442
+ release_shared: If True, also clear class-level shared RT-DETR caches.
1443
+ """
1444
+ try:
1445
+ # Release instance-level models and sessions
1446
+ try:
1447
+ if getattr(self, 'onnx_session', None) is not None:
1448
+ self.onnx_session = None
1449
+ except Exception:
1450
+ pass
1451
+ try:
1452
+ if getattr(self, 'rtdetr_onnx_session', None) is not None:
1453
+ self.rtdetr_onnx_session = None
1454
+ except Exception:
1455
+ pass
1456
+ for attr in ['model', 'rtdetr_model', 'rtdetr_processor']:
1457
+ try:
1458
+ if hasattr(self, attr):
1459
+ setattr(self, attr, None)
1460
+ except Exception:
1461
+ pass
1462
+ for flag in ['model_loaded', 'rtdetr_loaded', 'rtdetr_onnx_loaded']:
1463
+ try:
1464
+ if hasattr(self, flag):
1465
+ setattr(self, flag, False)
1466
+ except Exception:
1467
+ pass
1468
+
1469
+ # Optional: release shared caches
1470
+ if release_shared:
1471
+ try:
1472
+ BubbleDetector._rtdetr_shared_model = None
1473
+ BubbleDetector._rtdetr_shared_processor = None
1474
+ BubbleDetector._rtdetr_loaded = False
1475
+ except Exception:
1476
+ pass
1477
+
1478
+ # Free CUDA cache and trigger GC
1479
+ try:
1480
+ if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available():
1481
+ torch.cuda.empty_cache()
1482
+ except Exception:
1483
+ pass
1484
+ try:
1485
+ import gc
1486
+ gc.collect()
1487
+ except Exception:
1488
+ pass
1489
+ except Exception:
1490
+ # Best-effort only
1491
+ pass
1492
+
1493
+ def get_bubble_masks(self, image_path: str, bubbles: List[Tuple[int, int, int, int]]) -> np.ndarray:
1494
+ """
1495
+ Create a mask image with bubble regions.
1496
+
1497
+ Args:
1498
+ image_path: Path to original image
1499
+ bubbles: List of bubble bounding boxes
1500
+
1501
+ Returns:
1502
+ Binary mask with bubble regions as white (255)
1503
+ """
1504
+ image = cv2.imread(image_path)
1505
+ if image is None:
1506
+ return None
1507
+
1508
+ h, w = image.shape[:2]
1509
+ mask = np.zeros((h, w), dtype=np.uint8)
1510
+
1511
+ # Fill bubble regions
1512
+ for x, y, bw, bh in bubbles:
1513
+ cv2.rectangle(mask, (x, y), (x + bw, y + bh), 255, -1)
1514
+
1515
+ return mask
1516
+
1517
+ def filter_bubbles_by_size(self, bubbles: List[Tuple[int, int, int, int]],
1518
+ min_area: int = 100,
1519
+ max_area: int = None) -> List[Tuple[int, int, int, int]]:
1520
+ """
1521
+ Filter bubbles by area.
1522
+
1523
+ Args:
1524
+ bubbles: List of bubble bounding boxes
1525
+ min_area: Minimum area in pixels
1526
+ max_area: Maximum area in pixels (None for no limit)
1527
+
1528
+ Returns:
1529
+ Filtered list of bubbles
1530
+ """
1531
+ filtered = []
1532
+
1533
+ for x, y, w, h in bubbles:
1534
+ area = w * h
1535
+ if area >= min_area and (max_area is None or area <= max_area):
1536
+ filtered.append((x, y, w, h))
1537
+
1538
+ return filtered
1539
+
1540
+ def merge_overlapping_bubbles(self, bubbles: List[Tuple[int, int, int, int]],
1541
+ overlap_threshold: float = 0.1) -> List[Tuple[int, int, int, int]]:
1542
+ """
1543
+ Merge overlapping bubble detections.
1544
+
1545
+ Args:
1546
+ bubbles: List of bubble bounding boxes
1547
+ overlap_threshold: Minimum overlap ratio to merge
1548
+
1549
+ Returns:
1550
+ Merged list of bubbles
1551
+ """
1552
+ if not bubbles:
1553
+ return []
1554
+
1555
+ # Convert to numpy array for easier manipulation
1556
+ boxes = np.array([(x, y, x+w, y+h) for x, y, w, h in bubbles])
1557
+
1558
+ merged = []
1559
+ used = set()
1560
+
1561
+ for i, box1 in enumerate(boxes):
1562
+ if i in used:
1563
+ continue
1564
+
1565
+ # Start with current box
1566
+ x1, y1, x2, y2 = box1
1567
+
1568
+ # Check for overlaps with remaining boxes
1569
+ for j in range(i + 1, len(boxes)):
1570
+ if j in used:
1571
+ continue
1572
+
1573
+ box2 = boxes[j]
1574
+
1575
+ # Calculate intersection
1576
+ ix1 = max(x1, box2[0])
1577
+ iy1 = max(y1, box2[1])
1578
+ ix2 = min(x2, box2[2])
1579
+ iy2 = min(y2, box2[3])
1580
+
1581
+ if ix1 < ix2 and iy1 < iy2:
1582
+ # Calculate overlap ratio
1583
+ intersection = (ix2 - ix1) * (iy2 - iy1)
1584
+ area1 = (x2 - x1) * (y2 - y1)
1585
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
1586
+ overlap = intersection / min(area1, area2)
1587
+
1588
+ if overlap >= overlap_threshold:
1589
+ # Merge boxes
1590
+ x1 = min(x1, box2[0])
1591
+ y1 = min(y1, box2[1])
1592
+ x2 = max(x2, box2[2])
1593
+ y2 = max(y2, box2[3])
1594
+ used.add(j)
1595
+
1596
+ merged.append((int(x1), int(y1), int(x2 - x1), int(y2 - y1)))
1597
+
1598
+ return merged
1599
+
1600
+ # ============================
1601
+ # RT-DETR (ONNX) BACKEND
1602
+ # ============================
1603
+ def load_rtdetr_onnx_model(self, model_id: str = None, force_reload: bool = False) -> bool:
1604
+ """
1605
+ Load RT-DETR ONNX model using onnxruntime. Downloads detector.onnx and config.json
1606
+ from the provided Hugging Face repo if not already cached.
1607
+ """
1608
+ if not ONNX_AVAILABLE:
1609
+ logger.error("ONNX Runtime not available for RT-DETR ONNX backend")
1610
+ return False
1611
+ try:
1612
+ # If singleton mode and already loaded, just attach shared session
1613
+ try:
1614
+ adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1615
+ singleton = bool(adv.get('use_singleton_models', True))
1616
+ except Exception:
1617
+ singleton = True
1618
+ if singleton and BubbleDetector._rtdetr_onnx_loaded and not force_reload and BubbleDetector._rtdetr_onnx_shared_session is not None:
1619
+ self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1620
+ self.rtdetr_onnx_loaded = True
1621
+ return True
1622
+
1623
+ repo = model_id or self.rtdetr_onnx_repo
1624
+ try:
1625
+ from huggingface_hub import hf_hub_download
1626
+ except Exception as e:
1627
+ logger.error(f"huggingface-hub required to fetch RT-DETR ONNX: {e}")
1628
+ return False
1629
+
1630
+ # Ensure local models dir (use configured cache_dir directly: e.g., 'models')
1631
+ cache_dir = self.cache_dir
1632
+ os.makedirs(cache_dir, exist_ok=True)
1633
+
1634
+ # Download files into models/ and avoid symlinks so the file is visible there
1635
+ try:
1636
+ _ = hf_hub_download(repo_id=repo, filename='config.json', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1637
+ except Exception:
1638
+ pass
1639
+ onnx_fp = hf_hub_download(repo_id=repo, filename='detector.onnx', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False)
1640
+ BubbleDetector._rtdetr_onnx_model_path = onnx_fp
1641
+
1642
+ # Pick providers: prefer CUDA if available; otherwise CPU. Do NOT use DML.
1643
+ providers = ['CPUExecutionProvider']
1644
+ try:
1645
+ avail = ort.get_available_providers() if ONNX_AVAILABLE else []
1646
+ if 'CUDAExecutionProvider' in avail:
1647
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1648
+ except Exception:
1649
+ pass
1650
+
1651
+ # Session options with reduced memory arena and optional thread limiting in singleton mode
1652
+ so = ort.SessionOptions()
1653
+ try:
1654
+ so.enable_mem_pattern = False
1655
+ so.enable_cpu_mem_arena = False
1656
+ except Exception:
1657
+ pass
1658
+ # If singleton models mode is enabled in config, limit ORT threading to reduce CPU spikes
1659
+ try:
1660
+ adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {}
1661
+ if bool(adv.get('use_singleton_models', True)):
1662
+ so.intra_op_num_threads = 1
1663
+ so.inter_op_num_threads = 1
1664
+ try:
1665
+ so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
1666
+ except Exception:
1667
+ pass
1668
+ try:
1669
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
1670
+ except Exception:
1671
+ pass
1672
+ except Exception:
1673
+ pass
1674
+
1675
+ # Create session (serialize creation in singleton mode to avoid device storms)
1676
+ if singleton:
1677
+ with BubbleDetector._rtdetr_onnx_init_lock:
1678
+ # Re-check after acquiring lock
1679
+ if BubbleDetector._rtdetr_onnx_loaded and BubbleDetector._rtdetr_onnx_shared_session is not None and not force_reload:
1680
+ self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session
1681
+ self.rtdetr_onnx_loaded = True
1682
+ return True
1683
+ sess = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1684
+ BubbleDetector._rtdetr_onnx_shared_session = sess
1685
+ BubbleDetector._rtdetr_onnx_loaded = True
1686
+ BubbleDetector._rtdetr_onnx_providers = providers
1687
+ self.rtdetr_onnx_session = sess
1688
+ self.rtdetr_onnx_loaded = True
1689
+ else:
1690
+ self.rtdetr_onnx_session = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so)
1691
+ self.rtdetr_onnx_loaded = True
1692
+ logger.info("✅ RT-DETR (ONNX) model ready")
1693
+ return True
1694
+ except Exception as e:
1695
+ logger.error(f"Failed to load RT-DETR ONNX: {e}")
1696
+ self.rtdetr_onnx_session = None
1697
+ self.rtdetr_onnx_loaded = False
1698
+ return False
1699
+
1700
+ def detect_with_rtdetr_onnx(self,
1701
+ image_path: str = None,
1702
+ image: np.ndarray = None,
1703
+ confidence: float = 0.3,
1704
+ return_all_bubbles: bool = False) -> Any:
1705
+ """Detect using RT-DETR ONNX backend.
1706
+ Returns bubbles list if return_all_bubbles else dict by classes similar to PyTorch path.
1707
+ """
1708
+ if not self.rtdetr_onnx_loaded or self.rtdetr_onnx_session is None:
1709
+ logger.warning("RT-DETR ONNX not loaded")
1710
+ return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1711
+ try:
1712
+ # Acquire image
1713
+ if image_path is not None:
1714
+ import cv2
1715
+ image = cv2.imread(image_path)
1716
+ if image is None:
1717
+ raise RuntimeError(f"Failed to read image: {image_path}")
1718
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1719
+ else:
1720
+ if image is None:
1721
+ raise RuntimeError("No image provided")
1722
+ # Assume image is BGR np.ndarray if from OpenCV
1723
+ try:
1724
+ import cv2
1725
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1726
+ except Exception:
1727
+ image_rgb = image
1728
+
1729
+ # To PIL then resize 640x640 as in reference
1730
+ from PIL import Image as _PILImage
1731
+ pil_image = _PILImage.fromarray(image_rgb)
1732
+ im_resized = pil_image.resize((640, 640))
1733
+ arr = np.asarray(im_resized, dtype=np.float32) / 255.0
1734
+ arr = np.transpose(arr, (2, 0, 1)) # (3,H,W)
1735
+ im_data = arr[np.newaxis, ...]
1736
+
1737
+ w, h = pil_image.size
1738
+ orig_size = np.array([[w, h]], dtype=np.int64)
1739
+
1740
+ # Run with a concurrency guard to prevent device hangs and limit memory usage
1741
+ # Apply semaphore for ALL providers (not just DML) to control concurrency
1742
+ providers = BubbleDetector._rtdetr_onnx_providers or []
1743
+ def _do_run(session):
1744
+ return session.run(None, {
1745
+ 'images': im_data,
1746
+ 'orig_target_sizes': orig_size
1747
+ })
1748
+
1749
+ # Always use semaphore to limit concurrent RT-DETR calls
1750
+ acquired = False
1751
+ try:
1752
+ BubbleDetector._rtdetr_onnx_sema.acquire()
1753
+ acquired = True
1754
+
1755
+ # Special DML error handling
1756
+ if 'DmlExecutionProvider' in providers:
1757
+ try:
1758
+ outputs = _do_run(self.rtdetr_onnx_session)
1759
+ except Exception as dml_err:
1760
+ msg = str(dml_err)
1761
+ if '887A0005' in msg or '887A0006' in msg or 'Dml' in msg:
1762
+ # Rebuild CPU session and retry once
1763
+ try:
1764
+ base_path = BubbleDetector._rtdetr_onnx_model_path
1765
+ if base_path:
1766
+ so = ort.SessionOptions()
1767
+ so.enable_mem_pattern = False
1768
+ so.enable_cpu_mem_arena = False
1769
+ cpu_providers = ['CPUExecutionProvider']
1770
+ # Serialize rebuild
1771
+ with BubbleDetector._rtdetr_onnx_init_lock:
1772
+ sess = ort.InferenceSession(base_path, providers=cpu_providers, sess_options=so)
1773
+ BubbleDetector._rtdetr_onnx_shared_session = sess
1774
+ BubbleDetector._rtdetr_onnx_providers = cpu_providers
1775
+ self.rtdetr_onnx_session = sess
1776
+ outputs = _do_run(self.rtdetr_onnx_session)
1777
+ else:
1778
+ raise
1779
+ except Exception:
1780
+ raise
1781
+ else:
1782
+ raise
1783
+ else:
1784
+ # Non-DML providers - just run directly
1785
+ outputs = _do_run(self.rtdetr_onnx_session)
1786
+ finally:
1787
+ if acquired:
1788
+ try:
1789
+ BubbleDetector._rtdetr_onnx_sema.release()
1790
+ except Exception:
1791
+ pass
1792
+
1793
+ # outputs expected: labels, boxes, scores
1794
+ labels, boxes, scores = outputs[:3]
1795
+ if labels.ndim == 2 and labels.shape[0] == 1:
1796
+ labels = labels[0]
1797
+ if scores.ndim == 2 and scores.shape[0] == 1:
1798
+ scores = scores[0]
1799
+ if boxes.ndim == 3 and boxes.shape[0] == 1:
1800
+ boxes = boxes[0]
1801
+
1802
+ # Apply NMS to remove duplicate detections
1803
+ # Group detections by class and apply NMS per class
1804
+ class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []}
1805
+
1806
+ for lab, box, scr in zip(labels, boxes, scores):
1807
+ if float(scr) < float(confidence):
1808
+ continue
1809
+ label_id = int(lab)
1810
+ if label_id in class_detections:
1811
+ x1, y1, x2, y2 = map(float, box)
1812
+ class_detections[label_id].append((x1, y1, x2, y2, float(scr)))
1813
+
1814
+ # Apply NMS per class to remove duplicates
1815
+ def compute_iou(box1, box2):
1816
+ """Compute IoU between two boxes (x1, y1, x2, y2)"""
1817
+ x1_1, y1_1, x2_1, y2_1 = box1[:4]
1818
+ x1_2, y1_2, x2_2, y2_2 = box2[:4]
1819
+
1820
+ # Intersection
1821
+ x_left = max(x1_1, x1_2)
1822
+ y_top = max(y1_1, y1_2)
1823
+ x_right = min(x2_1, x2_2)
1824
+ y_bottom = min(y2_1, y2_2)
1825
+
1826
+ if x_right < x_left or y_bottom < y_top:
1827
+ return 0.0
1828
+
1829
+ intersection = (x_right - x_left) * (y_bottom - y_top)
1830
+
1831
+ # Union
1832
+ area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
1833
+ area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
1834
+ union = area1 + area2 - intersection
1835
+
1836
+ return intersection / union if union > 0 else 0.0
1837
+
1838
+ def apply_nms(boxes_with_scores, iou_threshold=0.45):
1839
+ """Apply Non-Maximum Suppression"""
1840
+ if not boxes_with_scores:
1841
+ return []
1842
+
1843
+ # Sort by score (descending)
1844
+ sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True)
1845
+ keep = []
1846
+
1847
+ while sorted_boxes:
1848
+ # Keep the box with highest score
1849
+ current = sorted_boxes.pop(0)
1850
+ keep.append(current)
1851
+
1852
+ # Remove boxes with high IoU
1853
+ sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold]
1854
+
1855
+ return keep
1856
+
1857
+ # Apply NMS and build final detections
1858
+ detections = {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1859
+ bubbles_all = []
1860
+
1861
+ for class_id, boxes_list in class_detections.items():
1862
+ nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold)
1863
+
1864
+ for x1, y1, x2, y2, scr in nms_boxes:
1865
+ bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1))
1866
+
1867
+ if class_id == self.CLASS_BUBBLE:
1868
+ detections['bubbles'].append(bbox)
1869
+ bubbles_all.append(bbox)
1870
+ elif class_id == self.CLASS_TEXT_BUBBLE:
1871
+ detections['text_bubbles'].append(bbox)
1872
+ bubbles_all.append(bbox)
1873
+ elif class_id == self.CLASS_TEXT_FREE:
1874
+ detections['text_free'].append(bbox)
1875
+
1876
+ return bubbles_all if return_all_bubbles else detections
1877
+ except Exception as e:
1878
+ logger.error(f"RT-DETR ONNX detection failed: {e}")
1879
+ return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []}
1880
+
1881
+
1882
+ # Standalone utility functions
1883
+ def download_model_from_huggingface(repo_id: str = "ogkalu/comic-speech-bubble-detector-yolov8m",
1884
+ filename: str = "comic-speech-bubble-detector-yolov8m.pt",
1885
+ cache_dir: str = "models") -> str:
1886
+ """
1887
+ Download model from Hugging Face Hub.
1888
+
1889
+ Args:
1890
+ repo_id: Hugging Face repository ID
1891
+ filename: Model filename in the repository
1892
+ cache_dir: Local directory to cache the model
1893
+
1894
+ Returns:
1895
+ Path to downloaded model file
1896
+ """
1897
+ try:
1898
+ from huggingface_hub import hf_hub_download
1899
+
1900
+ os.makedirs(cache_dir, exist_ok=True)
1901
+
1902
+ logger.info(f"📥 Downloading {filename} from {repo_id}...")
1903
+
1904
+ model_path = hf_hub_download(
1905
+ repo_id=repo_id,
1906
+ filename=filename,
1907
+ cache_dir=cache_dir,
1908
+ local_dir=cache_dir
1909
+ )
1910
+
1911
+ logger.info(f"✅ Model downloaded to: {model_path}")
1912
+ return model_path
1913
+
1914
+ except ImportError:
1915
+ logger.error("huggingface-hub package required. Install with: pip install huggingface-hub")
1916
+ return None
1917
+ except Exception as e:
1918
+ logger.error(f"Download failed: {e}")
1919
+ return None
1920
+
1921
+
1922
+ def download_rtdetr_model(cache_dir: str = "models") -> bool:
1923
+ """
1924
+ Download RT-DETR model for advanced detection.
1925
+
1926
+ Args:
1927
+ cache_dir: Directory to cache the model
1928
+
1929
+ Returns:
1930
+ True if successful
1931
+ """
1932
+ if not TRANSFORMERS_AVAILABLE:
1933
+ logger.error("Transformers required. Install with: pip install transformers")
1934
+ return False
1935
+
1936
+ try:
1937
+ logger.info("📥 Downloading RT-DETR model...")
1938
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
1939
+
1940
+ # This will download and cache the model
1941
+ processor = RTDetrImageProcessor.from_pretrained(
1942
+ "ogkalu/comic-text-and-bubble-detector",
1943
+ cache_dir=cache_dir
1944
+ )
1945
+ model = RTDetrForObjectDetection.from_pretrained(
1946
+ "ogkalu/comic-text-and-bubble-detector",
1947
+ cache_dir=cache_dir
1948
+ )
1949
+
1950
+ logger.info("✅ RT-DETR model downloaded successfully")
1951
+ return True
1952
+
1953
+ except Exception as e:
1954
+ logger.error(f"Download failed: {e}")
1955
+ return False
1956
+
1957
+
1958
+ # Example usage and testing
1959
+ if __name__ == "__main__":
1960
+ import sys
1961
+
1962
+ # Create detector
1963
+ detector = BubbleDetector()
1964
+
1965
+ if len(sys.argv) > 1:
1966
+ if sys.argv[1] == "download":
1967
+ # Download model from Hugging Face
1968
+ model_path = download_model_from_huggingface()
1969
+ if model_path:
1970
+ print(f"YOLOv8 model downloaded to: {model_path}")
1971
+
1972
+ # Also download RT-DETR
1973
+ if download_rtdetr_model():
1974
+ print("RT-DETR model downloaded")
1975
+
1976
+ elif sys.argv[1] == "detect" and len(sys.argv) > 3:
1977
+ # Detect bubbles in an image
1978
+ model_path = sys.argv[2]
1979
+ image_path = sys.argv[3]
1980
+
1981
+ # Load appropriate model
1982
+ if 'rtdetr' in model_path.lower():
1983
+ if detector.load_rtdetr_model():
1984
+ # Use RT-DETR
1985
+ results = detector.detect_with_rtdetr(image_path)
1986
+ print(f"RT-DETR Detection:")
1987
+ print(f" Empty bubbles: {len(results['bubbles'])}")
1988
+ print(f" Text bubbles: {len(results['text_bubbles'])}")
1989
+ print(f" Free text: {len(results['text_free'])}")
1990
+ else:
1991
+ if detector.load_model(model_path):
1992
+ bubbles = detector.detect_bubbles(image_path, confidence=0.5)
1993
+ print(f"YOLOv8 detected {len(bubbles)} bubbles:")
1994
+ for i, (x, y, w, h) in enumerate(bubbles):
1995
+ print(f" Bubble {i+1}: position=({x},{y}) size=({w}x{h})")
1996
+
1997
+ # Optionally visualize
1998
+ if len(sys.argv) > 4:
1999
+ output_path = sys.argv[4]
2000
+ detector.visualize_detections(image_path, output_path=output_path,
2001
+ use_rtdetr='rtdetr' in model_path.lower())
2002
+
2003
+ elif sys.argv[1] == "test-both" and len(sys.argv) > 2:
2004
+ # Test both models
2005
+ image_path = sys.argv[2]
2006
+
2007
+ # Load YOLOv8
2008
+ yolo_path = "models/comic-speech-bubble-detector-yolov8m.pt"
2009
+ if os.path.exists(yolo_path):
2010
+ detector.load_model(yolo_path)
2011
+ yolo_bubbles = detector.detect_bubbles(image_path, use_rtdetr=False)
2012
+ print(f"YOLOv8: {len(yolo_bubbles)} bubbles")
2013
+
2014
+ # Load RT-DETR
2015
+ if detector.load_rtdetr_model():
2016
+ rtdetr_bubbles = detector.detect_bubbles(image_path, use_rtdetr=True)
2017
+ print(f"RT-DETR: {len(rtdetr_bubbles)} bubbles")
2018
+
2019
+ else:
2020
+ print("Usage:")
2021
+ print(" python bubble_detector.py download")
2022
+ print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
2023
+ print(" python bubble_detector.py test-both <image_path>")
2024
+
2025
+ else:
2026
+ print("Bubble Detector Module (YOLOv8 + RT-DETR)")
2027
+ print("Usage:")
2028
+ print(" python bubble_detector.py download")
2029
+ print(" python bubble_detector.py detect <model_path> <image_path> [output_path]")
2030
+ print(" python bubble_detector.py test-both <image_path>")
hyphen_textwrap.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified textwrap module to add hyphens whenever it breaks a long word
2
+ # https://github.com/python/cpython/blob/main/Lib/textwrap.py
3
+
4
+ """Text wrapping and filling with improved hyphenation support.
5
+
6
+ This module is adapted from comic-translate's enhanced textwrap implementation.
7
+ It provides better hyphenation behavior when breaking long words across lines.
8
+ """
9
+
10
+ # Copyright (C) 1999-2001 Gregory P. Ward.
11
+ # Copyright (C) 2002, 2003 Python Software Foundation.
12
+ # Written by Greg Ward <gward@python.net>
13
+
14
+ import re
15
+
16
+ __all__ = ['TextWrapper', 'wrap', 'fill', 'dedent', 'indent', 'shorten']
17
+
18
+ # Hardcode the recognized whitespace characters to the US-ASCII
19
+ # whitespace characters. The main reason for doing this is that
20
+ # some Unicode spaces (like \u00a0) are non-breaking whitespaces.
21
+ _whitespace = '\t\n\x0b\x0c\r '
22
+
23
+ class TextWrapper:
24
+ """
25
+ Object for wrapping/filling text. The public interface consists of
26
+ the wrap() and fill() methods; the other methods are just there for
27
+ subclasses to override in order to tweak the default behaviour.
28
+ If you want to completely replace the main wrapping algorithm,
29
+ you'll probably have to override _wrap_chunks().
30
+
31
+ Several instance attributes control various aspects of wrapping:
32
+ width (default: 70)
33
+ the maximum width of wrapped lines (unless break_long_words
34
+ is false)
35
+ initial_indent (default: "")
36
+ string that will be prepended to the first line of wrapped
37
+ output. Counts towards the line's width.
38
+ subsequent_indent (default: "")
39
+ string that will be prepended to all lines save the first
40
+ of wrapped output; also counts towards each line's width.
41
+ expand_tabs (default: true)
42
+ Expand tabs in input text to spaces before further processing.
43
+ Each tab will become 0 .. 'tabsize' spaces, depending on its position
44
+ in its line. If false, each tab is treated as a single character.
45
+ tabsize (default: 8)
46
+ Expand tabs in input text to 0 .. 'tabsize' spaces, unless
47
+ 'expand_tabs' is false.
48
+ replace_whitespace (default: true)
49
+ Replace all whitespace characters in the input text by spaces
50
+ after tab expansion. Note that if expand_tabs is false and
51
+ replace_whitespace is true, every tab will be converted to a
52
+ single space!
53
+ fix_sentence_endings (default: false)
54
+ Ensure that sentence-ending punctuation is always followed
55
+ by two spaces. Off by default because the algorithm is
56
+ (unavoidably) imperfect.
57
+ break_long_words (default: true)
58
+ Break words longer than 'width'. If false, those words will not
59
+ be broken, and some lines might be longer than 'width'.
60
+ break_on_hyphens (default: true)
61
+ Allow breaking hyphenated words. If true, wrapping will occur
62
+ preferably on whitespaces and right after hyphens part of
63
+ compound words.
64
+ drop_whitespace (default: true)
65
+ Drop leading and trailing whitespace from lines.
66
+ max_lines (default: None)
67
+ Truncate wrapped lines.
68
+ placeholder (default: ' [...]')
69
+ Append to the last line of truncated text.
70
+ hyphenate_broken_words (default: True)
71
+ Add hyphens when breaking long words across lines.
72
+ """
73
+
74
+ unicode_whitespace_trans = dict.fromkeys(map(ord, _whitespace), ord(' '))
75
+
76
+ # This funky little regex is just the trick for splitting
77
+ # text up into word-wrappable chunks. E.g.
78
+ # "Hello there -- you goof-ball, use the -b option!"
79
+ # splits into
80
+ # Hello/ /there/ /--/ /you/ /goof-/ball,/ /use/ /the/ /-b/ /option!
81
+ # (after stripping out empty strings).
82
+ word_punct = r'[\w!"\'\&.,?]'
83
+ letter = r'[^\d\W]'
84
+ whitespace = r'[%s]' % re.escape(_whitespace)
85
+ nowhitespace = '[^' + whitespace[1:]
86
+ wordsep_re = re.compile(r'''
87
+ ( # any whitespace
88
+ %(ws)s+
89
+ | # em-dash between words
90
+ (?<=%(wp)s) -{2,} (?=\w)
91
+ | # word, possibly hyphenated
92
+ %(nws)s+? (?:
93
+ # hyphenated word
94
+ -(?: (?<=%(lt)s{2}-) | (?<=%(lt)s-%(lt)s-))
95
+ (?= %(lt)s -? %(lt)s)
96
+ | # end of word
97
+ (?=%(ws)s|\Z)
98
+ | # em-dash
99
+ (?<=%(wp)s) (?=-{2,}\w)
100
+ )
101
+ )''' % {'wp': word_punct, 'lt': letter,
102
+ 'ws': whitespace, 'nws': nowhitespace},
103
+ re.VERBOSE)
104
+ del word_punct, letter, nowhitespace
105
+
106
+ # This less funky little regex just split on recognized spaces. E.g.
107
+ # "Hello there -- you goof-ball, use the -b option!"
108
+ # splits into
109
+ # Hello/ /there/ /--/ /you/ /goof-ball,/ /use/ /the/ /-b/ /option!/
110
+ wordsep_simple_re = re.compile(r'(%s+)' % whitespace)
111
+ del whitespace
112
+
113
+ # XXX this is not locale- or charset-aware -- string.lowercase
114
+ # is US-ASCII only (and therefore English-only)
115
+ sentence_end_re = re.compile(r'[a-z]' # lowercase letter
116
+ r'[\.\!\?]' # sentence-ending punct.
117
+ r'[\"\']?' # optional end-of-quote
118
+ r'\Z') # end of chunk
119
+
120
+ def __init__(self,
121
+ width=70,
122
+ initial_indent="",
123
+ subsequent_indent="",
124
+ expand_tabs=True,
125
+ replace_whitespace=True,
126
+ fix_sentence_endings=False,
127
+ break_long_words=True,
128
+ drop_whitespace=True,
129
+ break_on_hyphens=True,
130
+ hyphenate_broken_words=True,
131
+ tabsize=8,
132
+ *,
133
+ max_lines=None,
134
+ placeholder=' [...]'):
135
+ self.width = width
136
+ self.initial_indent = initial_indent
137
+ self.subsequent_indent = subsequent_indent
138
+ self.expand_tabs = expand_tabs
139
+ self.replace_whitespace = replace_whitespace
140
+ self.fix_sentence_endings = fix_sentence_endings
141
+ self.break_long_words = break_long_words
142
+ self.drop_whitespace = drop_whitespace
143
+ self.break_on_hyphens = break_on_hyphens
144
+ self.tabsize = tabsize
145
+ self.max_lines = max_lines
146
+ self.placeholder = placeholder
147
+ self.hyphenate_broken_words = hyphenate_broken_words
148
+
149
+
150
+ # -- Private methods -----------------------------------------------
151
+ # (possibly useful for subclasses to override)
152
+
153
+ def _munge_whitespace(self, text):
154
+ """_munge_whitespace(text : string) -> string
155
+
156
+ Munge whitespace in text: expand tabs and convert all other
157
+ whitespace characters to spaces. Eg. " foo\\tbar\\n\\nbaz"
158
+ becomes " foo bar baz".
159
+ """
160
+ if self.expand_tabs:
161
+ text = text.expandtabs(self.tabsize)
162
+ if self.replace_whitespace:
163
+ text = text.translate(self.unicode_whitespace_trans)
164
+ return text
165
+
166
+
167
+ def _split(self, text):
168
+ """_split(text : string) -> [string]
169
+
170
+ Split the text to wrap into indivisible chunks. Chunks are
171
+ not quite the same as words; see _wrap_chunks() for full
172
+ details. As an example, the text
173
+ Look, goof-ball -- use the -b option!
174
+ breaks into the following chunks:
175
+ 'Look,', ' ', 'goof-', 'ball', ' ', '--', ' ',
176
+ 'use', ' ', 'the', ' ', '-b', ' ', 'option!'
177
+ if break_on_hyphens is True, or in:
178
+ 'Look,', ' ', 'goof-ball', ' ', '--', ' ',
179
+ 'use', ' ', 'the', ' ', '-b', ' ', option!'
180
+ otherwise.
181
+ """
182
+ if self.break_on_hyphens is True:
183
+ chunks = self.wordsep_re.split(text)
184
+ else:
185
+ chunks = self.wordsep_simple_re.split(text)
186
+ chunks = [c for c in chunks if c]
187
+
188
+ return chunks
189
+
190
+ def _fix_sentence_endings(self, chunks):
191
+ """_fix_sentence_endings(chunks : [string])
192
+
193
+ Correct for sentence endings buried in 'chunks'. Eg. when the
194
+ original text contains "... foo.\\nBar ...", munge_whitespace()
195
+ and split() will convert that to [..., "foo.", " ", "Bar", ...]
196
+ which has one too few spaces; this method simply changes the one
197
+ space to two.
198
+ """
199
+ i = 0
200
+ patsearch = self.sentence_end_re.search
201
+ while i < len(chunks)-1:
202
+ if chunks[i+1] == " " and patsearch(chunks[i]):
203
+ chunks[i+1] = " "
204
+ i += 2
205
+ else:
206
+ i += 1
207
+
208
+ def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width):
209
+ """_handle_long_word(chunks : [string],
210
+ cur_line : [string],
211
+ cur_len : int, width : int)
212
+
213
+ Handle a chunk of text (most likely a word, not whitespace) that
214
+ is too long to fit in any line.
215
+ """
216
+ # Figure out when indent is larger than the specified width, and make
217
+ # sure at least one character is stripped off on every pass
218
+ if width < 1:
219
+ space_left = 1
220
+ else:
221
+ space_left = width - cur_len
222
+
223
+ # If we're allowed to break long words, then do so: put as much
224
+ # of the next chunk onto the current line as will fit.
225
+ if self.break_long_words:
226
+ end = space_left
227
+ chunk = reversed_chunks[-1]
228
+ if self.break_on_hyphens and len(chunk) > space_left:
229
+ # break after last hyphen, but only if there are
230
+ # non-hyphens before it
231
+ hyphen = chunk.rfind('-', 0, space_left)
232
+ if hyphen > 0 and any(c != '-' for c in chunk[:hyphen]):
233
+ end = hyphen + 1
234
+
235
+ if chunk[:end]:
236
+ cur_line.append(chunk[:end])
237
+ # Now adds a hyphen whenever a long word is split to the next line
238
+ # unless certain chracters already exists at the split
239
+ if self.hyphenate_broken_words and chunk[:end][-1] not in ['-','.',',']:
240
+ cur_line.append('-')
241
+ reversed_chunks[-1] = chunk[end:]
242
+
243
+ # Otherwise, we have to preserve the long word intact. Only add
244
+ # it to the current line if there's nothing already there --
245
+ # that minimizes how much we violate the width constraint.
246
+ elif not cur_line:
247
+ cur_line.append(reversed_chunks.pop())
248
+
249
+ # If we're not allowed to break long words, and there's already
250
+ # text on the current line, do nothing. Next time through the
251
+ # main loop of _wrap_chunks(), we'll wind up here again, but
252
+ # cur_len will be zero, so the next line will be entirely
253
+ # devoted to the long word that we can't handle right now.
254
+
255
+ def _wrap_chunks(self, chunks):
256
+ """_wrap_chunks(chunks : [string]) -> [string]
257
+
258
+ Wrap a sequence of text chunks and return a list of lines of
259
+ length 'self.width' or less. (If 'break_long_words' is false,
260
+ some lines may be longer than this.) Chunks correspond roughly
261
+ to words and the whitespace between them: each chunk is
262
+ indivisible (modulo 'break_long_words'), but a line break can
263
+ come between any two chunks. Chunks should not have internal
264
+ whitespace; ie. a chunk is either all whitespace or a "word".
265
+ Whitespace chunks will be removed from the beginning and end of
266
+ lines, but apart from that whitespace is preserved.
267
+ """
268
+ lines = []
269
+ if self.width <= 0:
270
+ raise ValueError("invalid width %r (must be > 0)" % self.width)
271
+ if self.max_lines is not None:
272
+ if self.max_lines > 1:
273
+ indent = self.subsequent_indent
274
+ else:
275
+ indent = self.initial_indent
276
+ if len(indent) + len(self.placeholder.lstrip()) > self.width:
277
+ raise ValueError("placeholder too large for max width")
278
+
279
+ # Arrange in reverse order so items can be efficiently popped
280
+ # from a stack of chucks.
281
+ chunks.reverse()
282
+
283
+ while chunks:
284
+
285
+ # Start the list of chunks that will make up the current line.
286
+ # cur_len is just the length of all the chunks in cur_line.
287
+ cur_line = []
288
+ cur_len = 0
289
+
290
+ # Figure out which static string will prefix this line.
291
+ if lines:
292
+ indent = self.subsequent_indent
293
+ else:
294
+ indent = self.initial_indent
295
+
296
+ # Maximum width for this line.
297
+ width = self.width - len(indent)
298
+
299
+ # First chunk on line is whitespace -- drop it, unless this
300
+ # is the very beginning of the text (ie. no lines started yet).
301
+ if self.drop_whitespace and chunks[-1].strip() == '' and lines:
302
+ del chunks[-1]
303
+
304
+ while chunks:
305
+ l = len(chunks[-1])
306
+
307
+ # Can at least squeeze this chunk onto the current line.
308
+ if cur_len + l <= width:
309
+ cur_line.append(chunks.pop())
310
+ cur_len += l
311
+
312
+ # Nope, this line is full.
313
+ else:
314
+ break
315
+
316
+ # The current line is full, and the next chunk is too big to
317
+ # fit on *any* line (not just this one).
318
+ if chunks and len(chunks[-1]) > width:
319
+ self._handle_long_word(chunks, cur_line, cur_len, width)
320
+ cur_len = sum(map(len, cur_line))
321
+
322
+ # If the last chunk on this line is all whitespace, drop it.
323
+ if self.drop_whitespace and cur_line and cur_line[-1].strip() == '':
324
+ cur_len -= len(cur_line[-1])
325
+ del cur_line[-1]
326
+
327
+ if cur_line:
328
+ if (self.max_lines is None or
329
+ len(lines) + 1 < self.max_lines or
330
+ (not chunks or
331
+ self.drop_whitespace and
332
+ len(chunks) == 1 and
333
+ not chunks[0].strip()) and cur_len <= width):
334
+ # Convert current line back to a string and store it in
335
+ # list of all lines (return value).
336
+ lines.append(indent + ''.join(cur_line))
337
+ else:
338
+ while cur_line:
339
+ if (cur_line[-1].strip() and
340
+ cur_len + len(self.placeholder) <= width):
341
+ cur_line.append(self.placeholder)
342
+ lines.append(indent + ''.join(cur_line))
343
+ break
344
+ cur_len -= len(cur_line[-1])
345
+ del cur_line[-1]
346
+ else:
347
+ if lines:
348
+ prev_line = lines[-1].rstrip()
349
+ if (len(prev_line) + len(self.placeholder) <=
350
+ self.width):
351
+ lines[-1] = prev_line + self.placeholder
352
+ break
353
+ lines.append(indent + self.placeholder.lstrip())
354
+ break
355
+
356
+ return lines
357
+
358
+ def _split_chunks(self, text):
359
+ text = self._munge_whitespace(text)
360
+ return self._split(text)
361
+
362
+ # -- Public interface ----------------------------------------------
363
+
364
+ def wrap(self, text):
365
+ """wrap(text : string) -> [string]
366
+
367
+ Reformat the single paragraph in 'text' so it fits in lines of
368
+ no more than 'self.width' columns, and return a list of wrapped
369
+ lines. Tabs in 'text' are expanded with string.expandtabs(),
370
+ and all other whitespace characters (including newline) are
371
+ converted to space.
372
+ """
373
+ chunks = self._split_chunks(text)
374
+ if self.fix_sentence_endings:
375
+ self._fix_sentence_endings(chunks)
376
+ return self._wrap_chunks(chunks)
377
+
378
+ def fill(self, text):
379
+ """fill(text : string) -> string
380
+
381
+ Reformat the single paragraph in 'text' to fit in lines of no
382
+ more than 'self.width' columns, and return a new string
383
+ containing the entire wrapped paragraph.
384
+ """
385
+ return "\n".join(self.wrap(text))
386
+
387
+
388
+ # -- Convenience interface ---------------------------------------------
389
+
390
+ def wrap(text, width=70, **kwargs):
391
+ """Wrap a single paragraph of text, returning a list of wrapped lines.
392
+
393
+ Reformat the single paragraph in 'text' so it fits in lines of no
394
+ more than 'width' columns, and return a list of wrapped lines. By
395
+ default, tabs in 'text' are expanded with string.expandtabs(), and
396
+ all other whitespace characters (including newline) are converted to
397
+ space. See TextWrapper class for available keyword args to customize
398
+ wrapping behaviour.
399
+ """
400
+ w = TextWrapper(width=width, **kwargs)
401
+ return w.wrap(text)
402
+
403
+ def fill(text, width=70, **kwargs):
404
+ """Fill a single paragraph of text, returning a new string.
405
+
406
+ Reformat the single paragraph in 'text' to fit in lines of no more
407
+ than 'width' columns, and return a new string containing the entire
408
+ wrapped paragraph. As with wrap(), tabs are expanded and other
409
+ whitespace characters converted to space. See TextWrapper class for
410
+ available keyword args to customize wrapping behaviour.
411
+ """
412
+ w = TextWrapper(width=width, **kwargs)
413
+ return w.fill(text)
414
+
415
+ def shorten(text, width, **kwargs):
416
+ """Collapse and truncate the given text to fit in the given width.
417
+
418
+ The text first has its whitespace collapsed. If it then fits in
419
+ the *width*, it is returned as is. Otherwise, as many words
420
+ as possible are joined and then the placeholder is appended::
421
+
422
+ >>> textwrap.shorten("Hello world!", width=12)
423
+ 'Hello world!'
424
+ >>> textwrap.shorten("Hello world!", width=11)
425
+ 'Hello [...]'
426
+ """
427
+ w = TextWrapper(width=width, max_lines=1, **kwargs)
428
+ return w.fill(' '.join(text.strip().split()))
429
+
430
+
431
+ # -- Loosely related functionality -------------------------------------
432
+
433
+ _whitespace_only_re = re.compile('^[ \t]+$', re.MULTILINE)
434
+ _leading_whitespace_re = re.compile('(^[ \t]*)(?:[^ \t\n])', re.MULTILINE)
435
+
436
+ def dedent(text):
437
+ """Remove any common leading whitespace from every line in `text`.
438
+
439
+ This can be used to make triple-quoted strings line up with the left
440
+ edge of the display, while still presenting them in the source code
441
+ in indented form.
442
+
443
+ Note that tabs and spaces are both treated as whitespace, but they
444
+ are not equal: the lines " hello" and "\\thello" are
445
+ considered to have no common leading whitespace.
446
+
447
+ Entirely blank lines are normalized to a newline character.
448
+ """
449
+ # Look for the longest leading string of spaces and tabs common to
450
+ # all lines.
451
+ margin = None
452
+ text = _whitespace_only_re.sub('', text)
453
+ indents = _leading_whitespace_re.findall(text)
454
+ for indent in indents:
455
+ if margin is None:
456
+ margin = indent
457
+
458
+ # Current line more deeply indented than previous winner:
459
+ # no change (previous winner is still on top).
460
+ elif indent.startswith(margin):
461
+ pass
462
+
463
+ # Current line consistent with and no deeper than previous winner:
464
+ # it's the new winner.
465
+ elif margin.startswith(indent):
466
+ margin = indent
467
+
468
+ # Find the largest common whitespace between current line and previous
469
+ # winner.
470
+ else:
471
+ for i, (x, y) in enumerate(zip(margin, indent)):
472
+ if x != y:
473
+ margin = margin[:i]
474
+ break
475
+
476
+ # sanity check (testing/debugging only)
477
+ if 0 and margin:
478
+ for line in text.split("\n"):
479
+ assert not line or line.startswith(margin), \
480
+ "line = %r, margin = %r" % (line, margin)
481
+
482
+ if margin:
483
+ text = re.sub(r'(?m)^' + margin, '', text)
484
+ return text
485
+
486
+
487
+ def indent(text, prefix, predicate=None):
488
+ """Adds 'prefix' to the beginning of selected lines in 'text'.
489
+
490
+ If 'predicate' is provided, 'prefix' will only be added to the lines
491
+ where 'predicate(line)' is True. If 'predicate' is not provided,
492
+ it will default to adding 'prefix' to all non-empty lines that do not
493
+ consist solely of whitespace characters.
494
+ """
495
+ if predicate is None:
496
+ # str.splitlines(True) doesn't produce empty string.
497
+ # ''.splitlines(True) => []
498
+ # 'foo\n'.splitlines(True) => ['foo\n']
499
+ # So we can use just `not s.isspace()` here.
500
+ predicate = lambda s: not s.isspace()
501
+
502
+ prefixed_lines = []
503
+ for line in text.splitlines(True):
504
+ if predicate(line):
505
+ prefixed_lines.append(prefix)
506
+ prefixed_lines.append(line)
507
+
508
+ return ''.join(prefixed_lines)
local_inpainter.py ADDED
The diff for this file is too large to render. See raw diff
 
manga_integration.py ADDED
The diff for this file is too large to render. See raw diff
 
manga_settings_dialog.py ADDED
The diff for this file is too large to render. See raw diff
 
manga_translator.py ADDED
The diff for this file is too large to render. See raw diff
 
ocr_manager.py ADDED
@@ -0,0 +1,1904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ocr_manager.py
2
+ """
3
+ OCR Manager for handling multiple OCR providers
4
+ Handles installation, model downloading, and OCR processing
5
+ Updated with HuggingFace donut model and proper bubble detection integration
6
+ """
7
+ import os
8
+ import sys
9
+ import cv2
10
+ import json
11
+ import subprocess
12
+ import threading
13
+ import traceback
14
+ from typing import List, Dict, Optional, Tuple, Any
15
+ import numpy as np
16
+ from dataclasses import dataclass
17
+ from PIL import Image
18
+ import logging
19
+ import time
20
+ import random
21
+ import base64
22
+ import io
23
+ import requests
24
+
25
+ try:
26
+ import gptqmodel
27
+ HAS_GPTQ = True
28
+ except ImportError:
29
+ try:
30
+ import auto_gptq
31
+ HAS_GPTQ = True
32
+ except ImportError:
33
+ HAS_GPTQ = False
34
+
35
+ try:
36
+ import optimum
37
+ HAS_OPTIMUM = True
38
+ except ImportError:
39
+ HAS_OPTIMUM = False
40
+
41
+ try:
42
+ import accelerate
43
+ HAS_ACCELERATE = True
44
+ except ImportError:
45
+ HAS_ACCELERATE = False
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ @dataclass
50
+ class OCRResult:
51
+ """Unified OCR result format with built-in sanitization to prevent data corruption."""
52
+ text: str
53
+ bbox: Tuple[int, int, int, int] # x, y, w, h
54
+ confidence: float
55
+ vertices: Optional[List[Tuple[int, int]]] = None
56
+
57
+ def __post_init__(self):
58
+ """
59
+ This special method is called automatically after the object is created.
60
+ It acts as a final safeguard to ensure the 'text' attribute is ALWAYS a clean string.
61
+ """
62
+ # --- THIS IS THE DEFINITIVE FIX ---
63
+ # If the text we received is a tuple, we extract the first element.
64
+ # This makes it impossible for a tuple to exist in a finished object.
65
+ if isinstance(self.text, tuple):
66
+ # Log that we are fixing a critical data error.
67
+ print(f"CRITICAL WARNING: Corrupted tuple detected in OCRResult. Sanitizing '{self.text}' to '{self.text[0]}'.")
68
+ self.text = self.text[0]
69
+
70
+ # Ensure the final result is always a stripped string.
71
+ self.text = str(self.text).strip()
72
+
73
+ class OCRProvider:
74
+ """Base class for OCR providers"""
75
+
76
+ def __init__(self, log_callback=None):
77
+ # Set thread limits early if environment indicates single-threaded mode
78
+ try:
79
+ if os.environ.get('OMP_NUM_THREADS') == '1':
80
+ # Already in single-threaded mode, ensure it's applied to this process
81
+ try:
82
+ import sys
83
+ if 'torch' in sys.modules:
84
+ import torch
85
+ torch.set_num_threads(1)
86
+ except (ImportError, RuntimeError, AttributeError):
87
+ pass
88
+ try:
89
+ import cv2
90
+ cv2.setNumThreads(1)
91
+ except (ImportError, AttributeError):
92
+ pass
93
+ except Exception:
94
+ pass
95
+
96
+ self.log_callback = log_callback
97
+ self.is_installed = False
98
+ self.is_loaded = False
99
+ self.model = None
100
+ self.stop_flag = None
101
+ self._stopped = False
102
+
103
+ def _log(self, message: str, level: str = "info"):
104
+ """Log message with stop suppression"""
105
+ # Suppress logs when stopped (allow only essential stop confirmation messages)
106
+ if self._check_stop():
107
+ essential_stop_keywords = [
108
+ "⏹️ Translation stopped by user",
109
+ "⏹️ OCR processing stopped",
110
+ "cleanup", "🧹"
111
+ ]
112
+ if not any(keyword in message for keyword in essential_stop_keywords):
113
+ return
114
+
115
+ if self.log_callback:
116
+ self.log_callback(message, level)
117
+ else:
118
+ print(f"[{level.upper()}] {message}")
119
+
120
+ def set_stop_flag(self, stop_flag):
121
+ """Set the stop flag for checking interruptions"""
122
+ self.stop_flag = stop_flag
123
+ self._stopped = False
124
+
125
+ def _check_stop(self) -> bool:
126
+ """Check if stop has been requested"""
127
+ if self._stopped:
128
+ return True
129
+ if self.stop_flag and self.stop_flag.is_set():
130
+ self._stopped = True
131
+ return True
132
+ # Check global manga translator cancellation
133
+ try:
134
+ from manga_translator import MangaTranslator
135
+ if MangaTranslator.is_globally_cancelled():
136
+ self._stopped = True
137
+ return True
138
+ except Exception:
139
+ pass
140
+ return False
141
+
142
+ def reset_stop_flags(self):
143
+ """Reset stop flags when starting new processing"""
144
+ self._stopped = False
145
+
146
+ def check_installation(self) -> bool:
147
+ """Check if provider is installed"""
148
+ raise NotImplementedError
149
+
150
+ def install(self, progress_callback=None) -> bool:
151
+ """Install the provider"""
152
+ raise NotImplementedError
153
+
154
+ def load_model(self, **kwargs) -> bool:
155
+ """Load the OCR model"""
156
+ raise NotImplementedError
157
+
158
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
159
+ """Detect text in image"""
160
+ raise NotImplementedError
161
+
162
+ class CustomAPIProvider(OCRProvider):
163
+ """Custom API OCR provider that uses existing GUI variables"""
164
+
165
+ def __init__(self, log_callback=None):
166
+ super().__init__(log_callback)
167
+
168
+ # Use EXISTING environment variables from TranslatorGUI
169
+ self.api_url = os.environ.get('OPENAI_CUSTOM_BASE_URL', '')
170
+ self.api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
171
+ self.model_name = os.environ.get('MODEL', 'gpt-4o-mini')
172
+
173
+ # OCR prompt - use system prompt or a dedicated OCR prompt variable
174
+ self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
175
+ os.environ.get('SYSTEM_PROMPT',
176
+ "YOU ARE AN OCR SYSTEM. YOUR ONLY JOB IS TEXT EXTRACTION.\n\n"
177
+ "CRITICAL RULES:\n"
178
+ "1. DO NOT TRANSLATE ANYTHING\n"
179
+ "2. DO NOT MODIFY THE TEXT\n"
180
+ "3. DO NOT EXPLAIN OR COMMENT\n"
181
+ "4. ONLY OUTPUT THE EXACT TEXT YOU SEE\n"
182
+ "5. PRESERVE NATURAL TEXT FLOW - DO NOT ADD UNNECESSARY LINE BREAKS\n\n"
183
+ "If you see Korean text, output it in Korean.\n"
184
+ "If you see Japanese text, output it in Japanese.\n"
185
+ "If you see Chinese text, output it in Chinese.\n"
186
+ "If you see English text, output it in English.\n\n"
187
+ "IMPORTANT: Only use line breaks where they naturally occur in the original text "
188
+ "(e.g., between dialogue lines or paragraphs). Do not break text mid-sentence or "
189
+ "between every word/character.\n\n"
190
+ "For vertical text common in manga/comics, transcribe it as a continuous line unless "
191
+ "there are clear visual breaks.\n\n"
192
+ "NEVER translate. ONLY extract exactly what is written.\n"
193
+ "Output ONLY the raw text, nothing else."
194
+ ))
195
+
196
+ # Use existing temperature and token settings
197
+ self.temperature = float(os.environ.get('TRANSLATION_TEMPERATURE', '0.01'))
198
+ # Don't hardcode to 8192 - get fresh value when actually used
199
+ self.max_tokens = int(os.environ.get('MAX_OUTPUT_TOKENS', '4096'))
200
+
201
+ # Image settings from existing compression variables
202
+ self.image_format = 'jpeg' if os.environ.get('IMAGE_COMPRESSION_FORMAT', 'auto') != 'png' else 'png'
203
+ self.image_quality = int(os.environ.get('JPEG_QUALITY', '100'))
204
+
205
+ # Simple defaults
206
+ self.api_format = 'openai' # Most custom endpoints are OpenAI-compatible
207
+ self.timeout = int(os.environ.get('CHUNK_TIMEOUT', '30'))
208
+ self.api_headers = {} # Additional custom headers
209
+
210
+ # Retry configuration for Custom API OCR calls
211
+ self.max_retries = int(os.environ.get('CUSTOM_OCR_MAX_RETRIES', '3'))
212
+ self.retry_initial_delay = float(os.environ.get('CUSTOM_OCR_RETRY_INITIAL_DELAY', '0.8'))
213
+ self.retry_backoff = float(os.environ.get('CUSTOM_OCR_RETRY_BACKOFF', '1.8'))
214
+ self.retry_jitter = float(os.environ.get('CUSTOM_OCR_RETRY_JITTER', '0.4'))
215
+ self.retry_on_empty = os.environ.get('CUSTOM_OCR_RETRY_ON_EMPTY', '1') == '1'
216
+
217
+ def check_installation(self) -> bool:
218
+ """Always installed - uses UnifiedClient"""
219
+ self.is_installed = True
220
+ return True
221
+
222
+ def install(self, progress_callback=None) -> bool:
223
+ """No installation needed for API-based provider"""
224
+ return self.check_installation()
225
+
226
+ def load_model(self, **kwargs) -> bool:
227
+ """Initialize UnifiedClient with current settings"""
228
+ try:
229
+ from unified_api_client import UnifiedClient
230
+
231
+ # Support passing API key from GUI if available
232
+ if 'api_key' in kwargs:
233
+ api_key = kwargs['api_key']
234
+ else:
235
+ api_key = os.environ.get('API_KEY', '') or os.environ.get('OPENAI_API_KEY', '')
236
+
237
+ if 'model' in kwargs:
238
+ model = kwargs['model']
239
+ else:
240
+ model = os.environ.get('MODEL', 'gpt-4o-mini')
241
+
242
+ if not api_key:
243
+ self._log("❌ No API key configured", "error")
244
+ return False
245
+
246
+ # Create UnifiedClient just like translations do
247
+ self.client = UnifiedClient(model=model, api_key=api_key)
248
+
249
+ #self._log(f"✅ Using {model} for OCR via UnifiedClient")
250
+ self.is_loaded = True
251
+ return True
252
+
253
+ except Exception as e:
254
+ self._log(f"❌ Failed to initialize UnifiedClient: {str(e)}", "error")
255
+ return False
256
+
257
+ def _test_connection(self) -> bool:
258
+ """Test API connection with a simple request"""
259
+ try:
260
+ # Create a small test image
261
+ test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
262
+ cv2.putText(test_image, "TEST", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
263
+
264
+ # Encode image
265
+ image_base64 = self._encode_image(test_image)
266
+
267
+ # Prepare test request based on API format
268
+ if self.api_format == 'openai':
269
+ test_payload = {
270
+ "model": self.model_name,
271
+ "messages": [
272
+ {
273
+ "role": "user",
274
+ "content": [
275
+ {"type": "text", "text": "What text do you see?"},
276
+ {"type": "image_url", "image_url": {"url": f"data:image/{self.image_format};base64,{image_base64}"}}
277
+ ]
278
+ }
279
+ ],
280
+ "max_tokens": 50
281
+ }
282
+ else:
283
+ # For other formats, just try a basic health check
284
+ return True
285
+
286
+ headers = self._prepare_headers()
287
+ response = requests.post(
288
+ self.api_url,
289
+ headers=headers,
290
+ json=test_payload,
291
+ timeout=10
292
+ )
293
+
294
+ return response.status_code == 200
295
+
296
+ except Exception:
297
+ return False
298
+
299
+ def _encode_image(self, image: np.ndarray) -> str:
300
+ """Encode numpy array to base64 string"""
301
+ # Convert BGR to RGB if needed
302
+ if len(image.shape) == 3 and image.shape[2] == 3:
303
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
304
+ else:
305
+ image_rgb = image
306
+
307
+ # Convert to PIL Image
308
+ pil_image = Image.fromarray(image_rgb)
309
+
310
+ # Save to bytes buffer
311
+ buffer = io.BytesIO()
312
+ if self.image_format.lower() == 'png':
313
+ pil_image.save(buffer, format='PNG')
314
+ else:
315
+ pil_image.save(buffer, format='JPEG', quality=self.image_quality)
316
+
317
+ # Encode to base64
318
+ buffer.seek(0)
319
+ image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
320
+
321
+ return image_base64
322
+
323
+ def _prepare_headers(self) -> dict:
324
+ """Prepare request headers"""
325
+ headers = {
326
+ "Content-Type": "application/json"
327
+ }
328
+
329
+ # Add API key if configured
330
+ if self.api_key:
331
+ if self.api_format == 'anthropic':
332
+ headers["x-api-key"] = self.api_key
333
+ else:
334
+ headers["Authorization"] = f"Bearer {self.api_key}"
335
+
336
+ # Add any custom headers
337
+ headers.update(self.api_headers)
338
+
339
+ return headers
340
+
341
+ def _prepare_request_payload(self, image_base64: str) -> dict:
342
+ """Prepare request payload based on API format"""
343
+ if self.api_format == 'openai':
344
+ return {
345
+ "model": self.model_name,
346
+ "messages": [
347
+ {
348
+ "role": "user",
349
+ "content": [
350
+ {"type": "text", "text": self.ocr_prompt},
351
+ {
352
+ "type": "image_url",
353
+ "image_url": {
354
+ "url": f"data:image/{self.image_format};base64,{image_base64}"
355
+ }
356
+ }
357
+ ]
358
+ }
359
+ ],
360
+ "max_tokens": self.max_tokens,
361
+ "temperature": self.temperature
362
+ }
363
+
364
+ elif self.api_format == 'anthropic':
365
+ return {
366
+ "model": self.model_name,
367
+ "max_tokens": self.max_tokens,
368
+ "temperature": self.temperature,
369
+ "messages": [
370
+ {
371
+ "role": "user",
372
+ "content": [
373
+ {
374
+ "type": "text",
375
+ "text": self.ocr_prompt
376
+ },
377
+ {
378
+ "type": "image",
379
+ "source": {
380
+ "type": "base64",
381
+ "media_type": f"image/{self.image_format}",
382
+ "data": image_base64
383
+ }
384
+ }
385
+ ]
386
+ }
387
+ ]
388
+ }
389
+
390
+ else:
391
+ # Custom format - use environment variable for template
392
+ template = os.environ.get('CUSTOM_OCR_REQUEST_TEMPLATE', '{}')
393
+ payload = json.loads(template)
394
+
395
+ # Replace placeholders
396
+ payload_str = json.dumps(payload)
397
+ payload_str = payload_str.replace('{{IMAGE_BASE64}}', image_base64)
398
+ payload_str = payload_str.replace('{{PROMPT}}', self.ocr_prompt)
399
+ payload_str = payload_str.replace('{{MODEL}}', self.model_name)
400
+ payload_str = payload_str.replace('{{MAX_TOKENS}}', str(self.max_tokens))
401
+ payload_str = payload_str.replace('{{TEMPERATURE}}', str(self.temperature))
402
+
403
+ return json.loads(payload_str)
404
+
405
+ def _extract_text_from_response(self, response_data: dict) -> str:
406
+ """Extract text from API response based on format"""
407
+ try:
408
+ if self.api_format == 'openai':
409
+ # OpenAI format: response.choices[0].message.content
410
+ return response_data.get('choices', [{}])[0].get('message', {}).get('content', '')
411
+
412
+ elif self.api_format == 'anthropic':
413
+ # Anthropic format: response.content[0].text
414
+ content = response_data.get('content', [])
415
+ if content and isinstance(content, list):
416
+ return content[0].get('text', '')
417
+ return ''
418
+
419
+ else:
420
+ # Custom format - use environment variable for path
421
+ response_path = os.environ.get('CUSTOM_OCR_RESPONSE_PATH', 'text')
422
+
423
+ # Navigate through the response using the path
424
+ result = response_data
425
+ for key in response_path.split('.'):
426
+ if isinstance(result, dict):
427
+ result = result.get(key, '')
428
+ elif isinstance(result, list) and key.isdigit():
429
+ idx = int(key)
430
+ result = result[idx] if idx < len(result) else ''
431
+ else:
432
+ result = ''
433
+ break
434
+
435
+ return str(result)
436
+
437
+ except Exception as e:
438
+ self._log(f"Failed to extract text from response: {e}", "error")
439
+ return ''
440
+
441
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
442
+ """Process image using UnifiedClient.send_image()"""
443
+ results = []
444
+
445
+ try:
446
+ # Get fresh max_tokens from environment - GUI will have set this
447
+ max_tokens = int(os.environ.get('MAX_OUTPUT_TOKENS', '4096'))
448
+ if not self.is_loaded:
449
+ if not self.load_model():
450
+ return results
451
+
452
+ import cv2
453
+ from PIL import Image
454
+ import base64
455
+ import io
456
+
457
+ # Convert numpy array to PIL Image
458
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
459
+ pil_image = Image.fromarray(image_rgb)
460
+ h, w = image.shape[:2]
461
+
462
+ # Convert PIL Image to base64 string
463
+ buffer = io.BytesIO()
464
+
465
+ # Use the image format from settings
466
+ if self.image_format.lower() == 'png':
467
+ pil_image.save(buffer, format='PNG')
468
+ else:
469
+ pil_image.save(buffer, format='JPEG', quality=self.image_quality)
470
+
471
+ buffer.seek(0)
472
+ image_base64 = base64.b64encode(buffer.read()).decode('utf-8')
473
+
474
+ # For OpenAI vision models, we need BOTH:
475
+ # 1. System prompt with instructions
476
+ # 2. User message that includes the image
477
+ messages = [
478
+ {
479
+ "role": "system",
480
+ "content": self.ocr_prompt # The OCR instruction as system prompt
481
+ },
482
+ {
483
+ "role": "user",
484
+ "content": [
485
+ {
486
+ "type": "text",
487
+ "text": "Image:" # Minimal text, just to have something
488
+ },
489
+ {
490
+ "type": "image_url",
491
+ "image_url": {
492
+ "url": f"data:image/jpeg;base64,{image_base64}"
493
+ }
494
+ }
495
+ ]
496
+ }
497
+ ]
498
+
499
+ # Now send this properly formatted message
500
+ # The UnifiedClient should handle this correctly
501
+ # But we're NOT using send_image, we're using regular send
502
+
503
+ # Retry-aware call
504
+ from unified_api_client import UnifiedClientError # local import to avoid hard dependency at module import time
505
+ max_attempts = max(1, self.max_retries)
506
+ attempt = 0
507
+ last_error = None
508
+
509
+ # Common refusal/error phrases that indicate a non-OCR response
510
+ refusal_phrases = [
511
+ "I can't extract", "I cannot extract",
512
+ "I'm sorry", "I am sorry",
513
+ "I'm unable", "I am unable",
514
+ "cannot process images",
515
+ "I can't help with that",
516
+ "cannot view images",
517
+ "no text in the image"
518
+ ]
519
+
520
+ while attempt < max_attempts:
521
+ # Check for stop before each attempt
522
+ if self._check_stop():
523
+ self._log("⏹️ OCR processing stopped by user", "warning")
524
+ return results
525
+
526
+ try:
527
+ response = self.client.send(
528
+ messages=messages,
529
+ temperature=self.temperature,
530
+ max_tokens=max_tokens
531
+ )
532
+
533
+ # Extract content from response object
534
+ content, finish_reason = response
535
+
536
+ # Validate content
537
+ has_content = bool(content and str(content).strip())
538
+ refused = False
539
+ if has_content:
540
+ # Filter out explicit failure markers
541
+ if "[" in content and "FAILED]" in content:
542
+ refused = True
543
+ elif any(phrase.lower() in content.lower() for phrase in refusal_phrases):
544
+ refused = True
545
+
546
+ # Decide success or retry
547
+ if has_content and not refused:
548
+ text = str(content).strip()
549
+ results.append(OCRResult(
550
+ text=text,
551
+ bbox=(0, 0, w, h),
552
+ confidence=kwargs.get('confidence', 0.85),
553
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
554
+ ))
555
+ self._log(f"✅ Detected: {text[:50]}...")
556
+ break # success
557
+ else:
558
+ reason = "empty result" if not has_content else "refusal/non-OCR response"
559
+ last_error = f"{reason} (finish_reason: {finish_reason})"
560
+ # Check if we should retry on empty or refusal
561
+ should_retry = (not has_content and self.retry_on_empty) or refused
562
+ attempt += 1
563
+ if attempt >= max_attempts or not should_retry:
564
+ # No more retries or shouldn't retry
565
+ if not has_content:
566
+ self._log(f"⚠️ No text detected (finish_reason: {finish_reason})")
567
+ else:
568
+ self._log(f"❌ Model returned non-OCR response: {str(content)[:120]}", "warning")
569
+ break
570
+ # Backoff before retrying
571
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
572
+ self._log(f"🔄 Retry {attempt}/{max_attempts - 1} after {delay:.1f}s due to {reason}...", "warning")
573
+ time.sleep(delay)
574
+ time.sleep(0.1) # Brief pause for stability
575
+ self._log("💤 OCR retry pausing briefly for stability", "debug")
576
+ continue
577
+
578
+ except UnifiedClientError as ue:
579
+ msg = str(ue)
580
+ last_error = msg
581
+ # Do not retry on explicit user cancellation
582
+ if 'cancelled' in msg.lower() or 'stopped by user' in msg.lower():
583
+ self._log(f"❌ OCR cancelled: {msg}", "error")
584
+ break
585
+ attempt += 1
586
+ if attempt >= max_attempts:
587
+ self._log(f"❌ OCR failed after {attempt} attempts: {msg}", "error")
588
+ break
589
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
590
+ self._log(f"🔄 API error, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {msg}", "warning")
591
+ time.sleep(delay)
592
+ time.sleep(0.1) # Brief pause for stability
593
+ self._log("💤 OCR API error retry pausing briefly for stability", "debug")
594
+ continue
595
+ except Exception as e_inner:
596
+ last_error = str(e_inner)
597
+ attempt += 1
598
+ if attempt >= max_attempts:
599
+ self._log(f"❌ OCR exception after {attempt} attempts: {last_error}", "error")
600
+ break
601
+ delay = self.retry_initial_delay * (self.retry_backoff ** (attempt - 1)) + random.uniform(0, self.retry_jitter)
602
+ self._log(f"🔄 Exception, retry {attempt}/{max_attempts - 1} after {delay:.1f}s: {last_error}", "warning")
603
+ time.sleep(delay)
604
+ time.sleep(0.1) # Brief pause for stability
605
+ self._log("💤 OCR exception retry pausing briefly for stability", "debug")
606
+ continue
607
+
608
+ except Exception as e:
609
+ self._log(f"❌ Error: {str(e)}", "error")
610
+ import traceback
611
+ self._log(traceback.format_exc(), "debug")
612
+
613
+ return results
614
+
615
+ class MangaOCRProvider(OCRProvider):
616
+ """Manga OCR provider using HuggingFace model directly"""
617
+
618
+ def __init__(self, log_callback=None):
619
+ super().__init__(log_callback)
620
+ self.processor = None
621
+ self.model = None
622
+ self.tokenizer = None
623
+
624
+ def check_installation(self) -> bool:
625
+ """Check if transformers is installed"""
626
+ try:
627
+ import transformers
628
+ import torch
629
+ self.is_installed = True
630
+ return True
631
+ except ImportError:
632
+ return False
633
+
634
+ def install(self, progress_callback=None) -> bool:
635
+ """Install transformers and torch"""
636
+ pass
637
+
638
+ def _is_valid_local_model_dir(self, path: str) -> bool:
639
+ """Check that a local HF model directory has required files."""
640
+ try:
641
+ if not path or not os.path.isdir(path):
642
+ return False
643
+ needed_any_weights = any(
644
+ os.path.exists(os.path.join(path, name)) for name in (
645
+ 'pytorch_model.bin',
646
+ 'model.safetensors'
647
+ )
648
+ )
649
+ has_config = os.path.exists(os.path.join(path, 'config.json'))
650
+ has_processor = (
651
+ os.path.exists(os.path.join(path, 'preprocessor_config.json')) or
652
+ os.path.exists(os.path.join(path, 'processor_config.json'))
653
+ )
654
+ has_tokenizer = (
655
+ os.path.exists(os.path.join(path, 'tokenizer.json')) or
656
+ os.path.exists(os.path.join(path, 'tokenizer_config.json'))
657
+ )
658
+ return has_config and needed_any_weights and has_processor and has_tokenizer
659
+ except Exception:
660
+ return False
661
+
662
+ def load_model(self, **kwargs) -> bool:
663
+ """Load the manga-ocr model, preferring a local directory to avoid re-downloading"""
664
+ print("\n>>> MangaOCRProvider.load_model() called")
665
+ try:
666
+ if not self.is_installed and not self.check_installation():
667
+ print("ERROR: Transformers not installed")
668
+ self._log("❌ Transformers not installed", "error")
669
+ return False
670
+
671
+ # Always disable progress bars to avoid tqdm issues in some environments
672
+ import os
673
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
674
+
675
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoImageProcessor
676
+ import torch
677
+
678
+ # Prefer a local model directory if present to avoid any Hub access
679
+ candidates = []
680
+ env_local = os.environ.get("MANGA_OCR_LOCAL_DIR")
681
+ if env_local:
682
+ candidates.append(env_local)
683
+
684
+ # Project root one level up from this file
685
+ root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
686
+ candidates.append(os.path.join(root_dir, 'models', 'manga-ocr-base'))
687
+ candidates.append(os.path.join(root_dir, 'models', 'kha-white', 'manga-ocr-base'))
688
+
689
+ model_source = None
690
+ local_only = False
691
+ # Find a valid local dir
692
+ for cand in candidates:
693
+ if self._is_valid_local_model_dir(cand):
694
+ model_source = cand
695
+ local_only = True
696
+ break
697
+
698
+ # If no valid local dir, use Hub
699
+ if not model_source:
700
+ model_source = "kha-white/manga-ocr-base"
701
+ # Make sure we are not forcing offline mode
702
+ if os.environ.get("HF_HUB_OFFLINE") == "1":
703
+ try:
704
+ del os.environ["HF_HUB_OFFLINE"]
705
+ except Exception:
706
+ pass
707
+ self._log("🔥 Loading manga-ocr model from Hugging Face Hub")
708
+ self._log(f" Repo: {model_source}")
709
+ else:
710
+ # Only set offline when local dir is fully valid
711
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
712
+ self._log("🔥 Loading manga-ocr model from local directory")
713
+ self._log(f" Local path: {model_source}")
714
+
715
+ # Decide target device once; we will move after full CPU load to avoid meta tensors
716
+ use_cuda = torch.cuda.is_available()
717
+
718
+ # Try loading components, falling back to Hub if local-only fails
719
+ def _load_components(source: str, local_flag: bool):
720
+ self._log(" Loading tokenizer...")
721
+ tok = AutoTokenizer.from_pretrained(source, local_files_only=local_flag)
722
+
723
+ self._log(" Loading image processor...")
724
+ try:
725
+ from transformers import AutoProcessor
726
+ except Exception:
727
+ AutoProcessor = None
728
+ try:
729
+ proc = AutoImageProcessor.from_pretrained(source, local_files_only=local_flag)
730
+ except Exception as e_proc:
731
+ if AutoProcessor is not None:
732
+ self._log(f" ⚠️ AutoImageProcessor failed: {e_proc}. Trying AutoProcessor...", "warning")
733
+ proc = AutoProcessor.from_pretrained(source, local_files_only=local_flag)
734
+ else:
735
+ raise
736
+
737
+ self._log(" Loading model...")
738
+ # Prevent meta tensors by forcing full materialization on CPU at load time
739
+ os.environ.setdefault('TORCHDYNAMO_DISABLE', '1')
740
+ mdl = VisionEncoderDecoderModel.from_pretrained(
741
+ source,
742
+ local_files_only=local_flag,
743
+ low_cpu_mem_usage=False,
744
+ device_map=None,
745
+ torch_dtype=torch.float32 # Use torch_dtype instead of dtype
746
+ )
747
+ return tok, proc, mdl
748
+
749
+ try:
750
+ self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
751
+ except Exception as e_local:
752
+ if local_only:
753
+ # Fallback to Hub once if local fails
754
+ self._log(f" ⚠️ Local model load failed: {e_local}", "warning")
755
+ try:
756
+ if os.environ.get("HF_HUB_OFFLINE") == "1":
757
+ del os.environ["HF_HUB_OFFLINE"]
758
+ except Exception:
759
+ pass
760
+ model_source = "kha-white/manga-ocr-base"
761
+ local_only = False
762
+ self._log(" Retrying from Hugging Face Hub...")
763
+ self.tokenizer, self.processor, self.model = _load_components(model_source, local_only)
764
+ else:
765
+ raise
766
+
767
+ # Move to CUDA only after full CPU materialization
768
+ target_device = 'cpu'
769
+ if use_cuda:
770
+ try:
771
+ self.model = self.model.to('cuda')
772
+ target_device = 'cuda'
773
+ except Exception as move_err:
774
+ self._log(f" ⚠️ Could not move model to CUDA: {move_err}", "warning")
775
+ target_device = 'cpu'
776
+
777
+ # Finalize eval mode
778
+ self.model.eval()
779
+
780
+ # Sanity-check: ensure no parameter remains on 'meta' device
781
+ try:
782
+ for n, p in self.model.named_parameters():
783
+ dev = getattr(p, 'device', None)
784
+ if dev is not None and getattr(dev, 'type', '') == 'meta':
785
+ raise RuntimeError(f"Parameter {n} is on 'meta' after load")
786
+ except Exception as sanity_err:
787
+ self._log(f"❌ Manga-OCR model load sanity check failed: {sanity_err}", "error")
788
+ return False
789
+
790
+ print(f"SUCCESS: Model loaded on {target_device.upper()}")
791
+ self._log(f" ✅ Model loaded on {target_device.upper()}")
792
+ self.is_loaded = True
793
+ self._log("✅ Manga OCR model ready")
794
+ print(">>> Returning True from load_model()")
795
+ return True
796
+
797
+ except Exception as e:
798
+ print(f"\nEXCEPTION in load_model: {e}")
799
+ import traceback
800
+ print(traceback.format_exc())
801
+ self._log(f"❌ Failed to load manga-ocr model: {str(e)}", "error")
802
+ self._log(traceback.format_exc(), "error")
803
+ try:
804
+ if 'local_only' in locals() and local_only:
805
+ self._log("Hint: Local load failed. Ensure your models/manga-ocr-base contains required files (config.json, preprocessor_config.json, tokenizer.json or tokenizer_config.json, and model weights).", "warning")
806
+ except Exception:
807
+ pass
808
+ return False
809
+
810
+ def _run_ocr(self, pil_image):
811
+ """Run OCR on a PIL image using the HuggingFace model"""
812
+ import torch
813
+
814
+ # Process image (keyword arg for broader compatibility across transformers versions)
815
+ inputs = self.processor(images=pil_image, return_tensors="pt")
816
+ pixel_values = inputs["pixel_values"]
817
+
818
+ # Move to same device as model
819
+ try:
820
+ model_device = next(self.model.parameters()).device
821
+ except StopIteration:
822
+ model_device = torch.device('cpu')
823
+ pixel_values = pixel_values.to(model_device)
824
+
825
+ # Generate text
826
+ with torch.no_grad():
827
+ generated_ids = self.model.generate(pixel_values)
828
+
829
+ # Decode
830
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
831
+
832
+ return generated_text
833
+
834
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
835
+ """
836
+ Process the image region passed to it.
837
+ This could be a bubble region or the full image.
838
+ """
839
+ results = []
840
+
841
+ # Check for stop at start
842
+ if self._check_stop():
843
+ self._log("⏹️ Manga-OCR processing stopped by user", "warning")
844
+ return results
845
+
846
+ try:
847
+ if not self.is_loaded:
848
+ if not self.load_model():
849
+ return results
850
+
851
+ import cv2
852
+ from PIL import Image
853
+
854
+ # Get confidence from kwargs
855
+ confidence = kwargs.get('confidence', 0.7)
856
+
857
+ # Convert numpy array to PIL
858
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
859
+ pil_image = Image.fromarray(image_rgb)
860
+ h, w = image.shape[:2]
861
+
862
+ self._log("🔍 Processing region with manga-ocr...")
863
+
864
+ # Check for stop before inference
865
+ if self._check_stop():
866
+ self._log("⏹️ Manga-OCR inference stopped by user", "warning")
867
+ return results
868
+
869
+ # Run OCR on the image region
870
+ text = self._run_ocr(pil_image)
871
+
872
+ if text and text.strip():
873
+ # Return result for this region with its actual bbox
874
+ results.append(OCRResult(
875
+ text=text.strip(),
876
+ bbox=(0, 0, w, h), # Relative to the region passed in
877
+ confidence=confidence,
878
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
879
+ ))
880
+ self._log(f"✅ Detected text: {text[:50]}...")
881
+
882
+ except Exception as e:
883
+ self._log(f"❌ Error in manga-ocr: {str(e)}", "error")
884
+
885
+ return results
886
+
887
+ class Qwen2VL(OCRProvider):
888
+ """OCR using Qwen2-VL - Vision Language Model that can read Korean text"""
889
+
890
+ def __init__(self, log_callback=None):
891
+ super().__init__(log_callback)
892
+ self.processor = None
893
+ self.model = None
894
+ self.tokenizer = None
895
+
896
+ # Get OCR prompt from environment or use default
897
+ self.ocr_prompt = os.environ.get('OCR_SYSTEM_PROMPT',
898
+ "YOU ARE AN OCR SYSTEM. YOUR ONLY JOB IS TEXT EXTRACTION.\n\n"
899
+ "CRITICAL RULES:\n"
900
+ "1. DO NOT TRANSLATE ANYTHING\n"
901
+ "2. DO NOT MODIFY THE TEXT\n"
902
+ "3. DO NOT EXPLAIN OR COMMENT\n"
903
+ "4. ONLY OUTPUT THE EXACT TEXT YOU SEE\n"
904
+ "5. PRESERVE NATURAL TEXT FLOW - DO NOT ADD UNNECESSARY LINE BREAKS\n\n"
905
+ "If you see Korean text, output it in Korean.\n"
906
+ "If you see Japanese text, output it in Japanese.\n"
907
+ "If you see Chinese text, output it in Chinese.\n"
908
+ "If you see English text, output it in English.\n\n"
909
+ "IMPORTANT: Only use line breaks where they naturally occur in the original text "
910
+ "(e.g., between dialogue lines or paragraphs). Do not break text mid-sentence or "
911
+ "between every word/character.\n\n"
912
+ "For vertical text common in manga/comics, transcribe it as a continuous line unless "
913
+ "there are clear visual breaks.\n\n"
914
+ "NEVER translate. ONLY extract exactly what is written.\n"
915
+ "Output ONLY the raw text, nothing else."
916
+ )
917
+
918
+ def set_ocr_prompt(self, prompt: str):
919
+ """Allow setting the OCR prompt dynamically"""
920
+ self.ocr_prompt = prompt
921
+
922
+ def check_installation(self) -> bool:
923
+ """Check if required packages are installed"""
924
+ try:
925
+ import transformers
926
+ import torch
927
+ self.is_installed = True
928
+ return True
929
+ except ImportError:
930
+ return False
931
+
932
+ def install(self, progress_callback=None) -> bool:
933
+ """Install requirements for Qwen2-VL"""
934
+ pass
935
+
936
+ def load_model(self, model_size=None, **kwargs) -> bool:
937
+ """Load Qwen2-VL model with size selection"""
938
+ self._log(f"DEBUG: load_model called with model_size={model_size}")
939
+
940
+ try:
941
+ if not self.is_installed and not self.check_installation():
942
+ self._log("❌ Not installed", "error")
943
+ return False
944
+
945
+ self._log("🔥 Loading Qwen2-VL for Advanced OCR...")
946
+
947
+
948
+
949
+ from transformers import AutoProcessor, AutoTokenizer
950
+ import torch
951
+
952
+ # Model options
953
+ model_options = {
954
+ "1": "Qwen/Qwen2-VL-2B-Instruct",
955
+ "2": "Qwen/Qwen2-VL-7B-Instruct",
956
+ "3": "Qwen/Qwen2-VL-72B-Instruct",
957
+ "4": "custom"
958
+ }
959
+ # CHANGE: Default to 7B instead of 2B
960
+ # Check for saved preference first
961
+ if model_size is None:
962
+ # Try to get from environment or config
963
+ import os
964
+ model_size = os.environ.get('QWEN2VL_MODEL_SIZE', '1')
965
+
966
+ # Determine which model to load
967
+ if model_size and str(model_size).startswith("custom:"):
968
+ # Custom model passed with ID
969
+ model_id = str(model_size).replace("custom:", "")
970
+ self.loaded_model_size = "Custom"
971
+ self.model_id = model_id
972
+ self._log(f"Loading custom model: {model_id}")
973
+ elif model_size == "4":
974
+ # Custom option selected but no ID - shouldn't happen
975
+ self._log("❌ Custom model selected but no ID provided", "error")
976
+ return False
977
+ elif model_size and str(model_size) in model_options:
978
+ # Standard model option
979
+ option = model_options[str(model_size)]
980
+ if option == "custom":
981
+ self._log("❌ Custom model needs an ID", "error")
982
+ return False
983
+ model_id = option
984
+ # Set loaded_model_size for status display
985
+ if model_size == "1":
986
+ self.loaded_model_size = "2B"
987
+ elif model_size == "2":
988
+ self.loaded_model_size = "7B"
989
+ elif model_size == "3":
990
+ self.loaded_model_size = "72B"
991
+ else:
992
+ # CHANGE: Default to 7B (option "2") instead of 2B
993
+ model_id = model_options["1"] # Changed from "1" to "2"
994
+ self.loaded_model_size = "2B" # Changed from "2B" to "7B"
995
+ self._log("No model size specified, defaulting to 2B") # Changed message
996
+
997
+ self._log(f"Loading model: {model_id}")
998
+
999
+ # Load processor and tokenizer
1000
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
1001
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
1002
+
1003
+ # Load the model - let it figure out the class dynamically
1004
+ if torch.cuda.is_available():
1005
+ self._log(f"GPU: {torch.cuda.get_device_name(0)}")
1006
+ # Use auto model class
1007
+ from transformers import AutoModelForVision2Seq
1008
+ self.model = AutoModelForVision2Seq.from_pretrained(
1009
+ model_id,
1010
+ dtype=torch.float16,
1011
+ device_map="auto",
1012
+ trust_remote_code=True
1013
+ )
1014
+ self._log("✅ Model loaded on GPU")
1015
+ else:
1016
+ self._log("Loading on CPU...")
1017
+ from transformers import AutoModelForVision2Seq
1018
+ self.model = AutoModelForVision2Seq.from_pretrained(
1019
+ model_id,
1020
+ dtype=torch.float32,
1021
+ trust_remote_code=True
1022
+ )
1023
+ self._log("✅ Model loaded on CPU")
1024
+
1025
+ self.model.eval()
1026
+ self.is_loaded = True
1027
+ self._log("✅ Qwen2-VL ready for Advanced OCR!")
1028
+ return True
1029
+
1030
+ except Exception as e:
1031
+ self._log(f"❌ Failed to load: {str(e)}", "error")
1032
+ import traceback
1033
+ self._log(traceback.format_exc(), "debug")
1034
+ return False
1035
+
1036
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1037
+ """Process image with Qwen2-VL for Korean text extraction"""
1038
+ results = []
1039
+ if hasattr(self, 'model_id'):
1040
+ self._log(f"DEBUG: Using model: {self.model_id}", "debug")
1041
+
1042
+ # Check if OCR prompt was passed in kwargs (for dynamic updates)
1043
+ if 'ocr_prompt' in kwargs:
1044
+ self.ocr_prompt = kwargs['ocr_prompt']
1045
+
1046
+ try:
1047
+ if not self.is_loaded:
1048
+ if not self.load_model():
1049
+ return results
1050
+
1051
+ import cv2
1052
+ from PIL import Image
1053
+ import torch
1054
+
1055
+ # Convert to PIL
1056
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1057
+ pil_image = Image.fromarray(image_rgb)
1058
+ h, w = image.shape[:2]
1059
+
1060
+ self._log(f"🔍 Processing with Qwen2-VL ({w}x{h} pixels)...")
1061
+
1062
+ # Use the configurable OCR prompt
1063
+ messages = [
1064
+ {
1065
+ "role": "user",
1066
+ "content": [
1067
+ {
1068
+ "type": "image",
1069
+ "image": pil_image,
1070
+ },
1071
+ {
1072
+ "type": "text",
1073
+ "text": self.ocr_prompt # Use the configurable prompt
1074
+ }
1075
+ ]
1076
+ }
1077
+ ]
1078
+
1079
+ # Alternative simpler prompt if the above still causes issues:
1080
+ # "text": "OCR: Extract text as-is"
1081
+
1082
+ # Process with Qwen2-VL
1083
+ text = self.processor.apply_chat_template(
1084
+ messages,
1085
+ tokenize=False,
1086
+ add_generation_prompt=True
1087
+ )
1088
+
1089
+ inputs = self.processor(
1090
+ text=[text],
1091
+ images=[pil_image],
1092
+ padding=True,
1093
+ return_tensors="pt"
1094
+ )
1095
+
1096
+ # Get the device and dtype the model is currently on
1097
+ model_device = next(self.model.parameters()).device
1098
+ model_dtype = next(self.model.parameters()).dtype
1099
+
1100
+ # Move inputs to the same device as the model and cast float tensors to model dtype
1101
+ try:
1102
+ # Move first
1103
+ inputs = inputs.to(model_device)
1104
+ # Then align dtypes only for floating tensors (e.g., pixel_values)
1105
+ for k, v in inputs.items():
1106
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
1107
+ inputs[k] = v.to(model_dtype)
1108
+ except Exception:
1109
+ # Fallback: ensure at least pixel_values is correct if present
1110
+ try:
1111
+ if isinstance(inputs, dict) and "pixel_values" in inputs:
1112
+ pv = inputs["pixel_values"].to(model_device)
1113
+ if torch.is_floating_point(pv):
1114
+ inputs["pixel_values"] = pv.to(model_dtype)
1115
+ except Exception:
1116
+ pass
1117
+
1118
+ # Ensure pixel_values explicitly matches model dtype if present
1119
+ try:
1120
+ if isinstance(inputs, dict) and "pixel_values" in inputs:
1121
+ inputs["pixel_values"] = inputs["pixel_values"].to(device=model_device, dtype=model_dtype)
1122
+ except Exception:
1123
+ pass
1124
+
1125
+ # Generate text with stricter parameters to avoid creative responses
1126
+ use_amp = (hasattr(torch, 'cuda') and model_device.type == 'cuda' and model_dtype in (torch.float16, torch.bfloat16))
1127
+ autocast_dev = 'cuda' if model_device.type == 'cuda' else 'cpu'
1128
+ autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None
1129
+
1130
+ with torch.no_grad():
1131
+ if use_amp and autocast_dtype is not None:
1132
+ with torch.autocast(autocast_dev, dtype=autocast_dtype):
1133
+ generated_ids = self.model.generate(
1134
+ **inputs,
1135
+ max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1136
+ do_sample=False, # Keep deterministic
1137
+ temperature=0.01, # Keep your very low temperature
1138
+ top_p=1.0, # Keep no nucleus sampling
1139
+ repetition_penalty=1.0, # Keep no repetition penalty
1140
+ num_beams=1, # Ensure greedy decoding (faster than beam search)
1141
+ use_cache=True, # Enable KV cache for speed
1142
+ early_stopping=True, # Stop at EOS token
1143
+ pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1144
+ eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1145
+ )
1146
+ else:
1147
+ generated_ids = self.model.generate(
1148
+ **inputs,
1149
+ max_new_tokens=128, # Reduced from 512 - manga bubbles are typically short
1150
+ do_sample=False, # Keep deterministic
1151
+ temperature=0.01, # Keep your very low temperature
1152
+ top_p=1.0, # Keep no nucleus sampling
1153
+ repetition_penalty=1.0, # Keep no repetition penalty
1154
+ num_beams=1, # Ensure greedy decoding (faster than beam search)
1155
+ use_cache=True, # Enable KV cache for speed
1156
+ early_stopping=True, # Stop at EOS token
1157
+ pad_token_id=self.tokenizer.pad_token_id, # Proper padding
1158
+ eos_token_id=self.tokenizer.eos_token_id, # Proper stopping
1159
+ )
1160
+
1161
+ # Decode the output
1162
+ generated_ids_trimmed = [
1163
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
1164
+ ]
1165
+ output_text = self.processor.batch_decode(
1166
+ generated_ids_trimmed,
1167
+ skip_special_tokens=True,
1168
+ clean_up_tokenization_spaces=False
1169
+ )[0]
1170
+
1171
+ if output_text and output_text.strip():
1172
+ text = output_text.strip()
1173
+
1174
+ # ADDED: Filter out any response that looks like an explanation or apology
1175
+ # Common patterns that indicate the model is being "helpful" instead of just extracting
1176
+ unwanted_patterns = [
1177
+ "죄송합니다", # "I apologize"
1178
+ "sorry",
1179
+ "apologize",
1180
+ "이미지에는", # "in this image"
1181
+ "텍스트가 없습니다", # "there is no text"
1182
+ "I cannot",
1183
+ "I don't see",
1184
+ "There is no",
1185
+ "질문이 있으시면", # "if you have questions"
1186
+ ]
1187
+
1188
+ # Check if response contains unwanted patterns
1189
+ text_lower = text.lower()
1190
+ is_explanation = any(pattern.lower() in text_lower for pattern in unwanted_patterns)
1191
+
1192
+ # Also check if the response is suspiciously long for a bubble
1193
+ # Most manga bubbles are short, if we get 50+ chars it might be an explanation
1194
+ is_too_long = len(text) > 100 and ('.' in text or ',' in text or '!' in text)
1195
+
1196
+ if is_explanation or is_too_long:
1197
+ self._log(f"⚠️ Model returned explanation instead of text, ignoring", "warning")
1198
+ # Return empty result or just skip this region
1199
+ return results
1200
+
1201
+ # Check language
1202
+ has_korean = any('\uAC00' <= c <= '\uD7AF' for c in text)
1203
+ has_japanese = any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in text)
1204
+ has_chinese = any('\u4E00' <= c <= '\u9FFF' for c in text)
1205
+
1206
+ if has_korean:
1207
+ self._log(f"✅ Korean detected: {text[:50]}...")
1208
+ elif has_japanese:
1209
+ self._log(f"✅ Japanese detected: {text[:50]}...")
1210
+ elif has_chinese:
1211
+ self._log(f"✅ Chinese detected: {text[:50]}...")
1212
+ else:
1213
+ self._log(f"✅ Text: {text[:50]}...")
1214
+
1215
+ results.append(OCRResult(
1216
+ text=text,
1217
+ bbox=(0, 0, w, h),
1218
+ confidence=0.9,
1219
+ vertices=[(0, 0), (w, 0), (w, h), (0, h)]
1220
+ ))
1221
+ else:
1222
+ self._log("⚠️ No text detected", "warning")
1223
+
1224
+ except Exception as e:
1225
+ self._log(f"❌ Error: {str(e)}", "error")
1226
+ import traceback
1227
+ self._log(traceback.format_exc(), "debug")
1228
+
1229
+ return results
1230
+
1231
+ class EasyOCRProvider(OCRProvider):
1232
+ """EasyOCR provider for multiple languages"""
1233
+
1234
+ def __init__(self, log_callback=None, languages=None):
1235
+ super().__init__(log_callback)
1236
+ # Default to safe language combination
1237
+ self.languages = languages or ['ja', 'en'] # Safe default
1238
+ self._validate_language_combination()
1239
+
1240
+ def _validate_language_combination(self):
1241
+ """Validate and fix EasyOCR language combinations"""
1242
+ # EasyOCR language compatibility rules
1243
+ incompatible_pairs = [
1244
+ (['ja', 'ko'], 'Japanese and Korean cannot be used together'),
1245
+ (['ja', 'zh'], 'Japanese and Chinese cannot be used together'),
1246
+ (['ko', 'zh'], 'Korean and Chinese cannot be used together')
1247
+ ]
1248
+
1249
+ for incompatible, reason in incompatible_pairs:
1250
+ if all(lang in self.languages for lang in incompatible):
1251
+ self._log(f"⚠️ EasyOCR: {reason}", "warning")
1252
+ # Keep first language + English
1253
+ self.languages = [self.languages[0], 'en']
1254
+ self._log(f"🔧 Auto-adjusted to: {self.languages}", "info")
1255
+ break
1256
+
1257
+ def check_installation(self) -> bool:
1258
+ """Check if easyocr is installed"""
1259
+ try:
1260
+ import easyocr
1261
+ self.is_installed = True
1262
+ return True
1263
+ except ImportError:
1264
+ return False
1265
+
1266
+ def install(self, progress_callback=None) -> bool:
1267
+ """Install easyocr"""
1268
+ pass
1269
+
1270
+ def load_model(self, **kwargs) -> bool:
1271
+ """Load easyocr model"""
1272
+ try:
1273
+ if not self.is_installed and not self.check_installation():
1274
+ self._log("❌ easyocr not installed", "error")
1275
+ return False
1276
+
1277
+ self._log(f"🔥 Loading easyocr model for languages: {self.languages}...")
1278
+ import easyocr
1279
+
1280
+ # This will download models on first run
1281
+ self.model = easyocr.Reader(self.languages, gpu=True)
1282
+ self.is_loaded = True
1283
+
1284
+ self._log("✅ easyocr model loaded successfully")
1285
+ return True
1286
+
1287
+ except Exception as e:
1288
+ self._log(f"❌ Failed to load easyocr: {str(e)}", "error")
1289
+ # Try CPU mode if GPU fails
1290
+ try:
1291
+ import easyocr
1292
+ self.model = easyocr.Reader(self.languages, gpu=False)
1293
+ self.is_loaded = True
1294
+ self._log("✅ easyocr loaded in CPU mode")
1295
+ return True
1296
+ except:
1297
+ return False
1298
+
1299
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1300
+ """Detect text using easyocr"""
1301
+ results = []
1302
+
1303
+ try:
1304
+ if not self.is_loaded:
1305
+ if not self.load_model():
1306
+ return results
1307
+
1308
+ # EasyOCR can work directly with numpy arrays
1309
+ ocr_results = self.model.readtext(image, detail=1)
1310
+
1311
+ # Parse results
1312
+ for (bbox, text, confidence) in ocr_results:
1313
+ # bbox is a list of 4 points
1314
+ xs = [point[0] for point in bbox]
1315
+ ys = [point[1] for point in bbox]
1316
+ x_min, x_max = min(xs), max(xs)
1317
+ y_min, y_max = min(ys), max(ys)
1318
+
1319
+ results.append(OCRResult(
1320
+ text=text,
1321
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1322
+ confidence=confidence,
1323
+ vertices=[(int(p[0]), int(p[1])) for p in bbox]
1324
+ ))
1325
+
1326
+ self._log(f"✅ Detected {len(results)} text regions")
1327
+
1328
+ except Exception as e:
1329
+ self._log(f"❌ Error in easyocr detection: {str(e)}", "error")
1330
+
1331
+ return results
1332
+
1333
+
1334
+ class PaddleOCRProvider(OCRProvider):
1335
+ """PaddleOCR provider with memory safety measures"""
1336
+
1337
+ def check_installation(self) -> bool:
1338
+ """Check if paddleocr is installed"""
1339
+ try:
1340
+ from paddleocr import PaddleOCR
1341
+ self.is_installed = True
1342
+ return True
1343
+ except ImportError:
1344
+ return False
1345
+
1346
+ def install(self, progress_callback=None) -> bool:
1347
+ """Install paddleocr"""
1348
+ pass
1349
+
1350
+ def load_model(self, **kwargs) -> bool:
1351
+ """Load paddleocr model with memory-safe configurations"""
1352
+ try:
1353
+ if not self.is_installed and not self.check_installation():
1354
+ self._log("❌ paddleocr not installed", "error")
1355
+ return False
1356
+
1357
+ self._log("🔥 Loading PaddleOCR model...")
1358
+
1359
+ # Set memory-safe environment variables BEFORE importing
1360
+ import os
1361
+ os.environ['OMP_NUM_THREADS'] = '1' # Prevent OpenMP conflicts
1362
+ os.environ['MKL_NUM_THREADS'] = '1' # Prevent MKL conflicts
1363
+ os.environ['OPENBLAS_NUM_THREADS'] = '1' # Prevent OpenBLAS conflicts
1364
+ os.environ['FLAGS_use_mkldnn'] = '0' # Disable MKL-DNN
1365
+
1366
+ from paddleocr import PaddleOCR
1367
+
1368
+ # Try memory-safe configurations
1369
+ configs_to_try = [
1370
+ # Config 1: Most memory-safe configuration
1371
+ {
1372
+ 'use_angle_cls': False, # Disable angle to save memory
1373
+ 'lang': 'ch',
1374
+ 'rec_batch_num': 1, # Process one at a time
1375
+ 'max_text_length': 100, # Limit text length
1376
+ 'drop_score': 0.5, # Higher threshold to reduce detections
1377
+ 'cpu_threads': 1, # Single thread to avoid conflicts
1378
+ },
1379
+ # Config 2: Minimal memory footprint
1380
+ {
1381
+ 'lang': 'ch',
1382
+ 'rec_batch_num': 1,
1383
+ 'cpu_threads': 1,
1384
+ },
1385
+ # Config 3: Absolute minimal
1386
+ {
1387
+ 'lang': 'ch'
1388
+ },
1389
+ # Config 4: Empty config
1390
+ {}
1391
+ ]
1392
+
1393
+ for i, config in enumerate(configs_to_try):
1394
+ try:
1395
+ self._log(f" Trying configuration {i+1}/{len(configs_to_try)}: {config}")
1396
+
1397
+ # Force garbage collection before loading
1398
+ import gc
1399
+ gc.collect()
1400
+
1401
+ self.model = PaddleOCR(**config)
1402
+ self.is_loaded = True
1403
+ self.current_config = config
1404
+ self._log(f"✅ PaddleOCR loaded successfully with config: {config}")
1405
+ return True
1406
+ except Exception as e:
1407
+ error_str = str(e)
1408
+ self._log(f" Config {i+1} failed: {error_str}", "debug")
1409
+
1410
+ # Clean up on failure
1411
+ if hasattr(self, 'model'):
1412
+ del self.model
1413
+ gc.collect()
1414
+ continue
1415
+
1416
+ self._log(f"❌ PaddleOCR failed to load with any configuration", "error")
1417
+ return False
1418
+
1419
+ except Exception as e:
1420
+ self._log(f"❌ Failed to load paddleocr: {str(e)}", "error")
1421
+ import traceback
1422
+ self._log(traceback.format_exc(), "debug")
1423
+ return False
1424
+
1425
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1426
+ """Detect text with memory safety measures"""
1427
+ results = []
1428
+
1429
+ try:
1430
+ if not self.is_loaded:
1431
+ if not self.load_model():
1432
+ return results
1433
+
1434
+ import cv2
1435
+ import numpy as np
1436
+ import gc
1437
+
1438
+ # Memory safety: Ensure image isn't too large
1439
+ h, w = image.shape[:2] if len(image.shape) >= 2 else (0, 0)
1440
+
1441
+ # Limit image size to prevent memory issues
1442
+ MAX_DIMENSION = 1500
1443
+ if h > MAX_DIMENSION or w > MAX_DIMENSION:
1444
+ scale = min(MAX_DIMENSION/h, MAX_DIMENSION/w)
1445
+ new_h, new_w = int(h*scale), int(w*scale)
1446
+ self._log(f"⚠️ Resizing large image from {w}x{h} to {new_w}x{new_h} for memory safety", "warning")
1447
+ image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
1448
+ scale_factor = 1/scale
1449
+ else:
1450
+ scale_factor = 1.0
1451
+
1452
+ # Ensure correct format
1453
+ if len(image.shape) == 2: # Grayscale
1454
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
1455
+ elif len(image.shape) == 4: # Batch
1456
+ image = image[0]
1457
+
1458
+ # Ensure uint8 type
1459
+ if image.dtype != np.uint8:
1460
+ if image.max() <= 1.0:
1461
+ image = (image * 255).astype(np.uint8)
1462
+ else:
1463
+ image = image.astype(np.uint8)
1464
+
1465
+ # Make a copy to avoid memory corruption
1466
+ image_copy = image.copy()
1467
+
1468
+ # Force garbage collection before OCR
1469
+ gc.collect()
1470
+
1471
+ # Process with timeout protection
1472
+ import signal
1473
+ import threading
1474
+
1475
+ ocr_results = None
1476
+ ocr_error = None
1477
+
1478
+ def run_ocr():
1479
+ nonlocal ocr_results, ocr_error
1480
+ try:
1481
+ ocr_results = self.model.ocr(image_copy)
1482
+ except Exception as e:
1483
+ ocr_error = e
1484
+
1485
+ # Run OCR in a separate thread with timeout
1486
+ ocr_thread = threading.Thread(target=run_ocr)
1487
+ ocr_thread.daemon = True
1488
+ ocr_thread.start()
1489
+ ocr_thread.join(timeout=30) # 30 second timeout
1490
+
1491
+ if ocr_thread.is_alive():
1492
+ self._log("❌ PaddleOCR timeout - taking too long", "error")
1493
+ return results
1494
+
1495
+ if ocr_error:
1496
+ raise ocr_error
1497
+
1498
+ # Parse results
1499
+ results = self._parse_ocr_results(ocr_results)
1500
+
1501
+ # Scale coordinates back if image was resized
1502
+ if scale_factor != 1.0 and results:
1503
+ for r in results:
1504
+ x, y, width, height = r.bbox
1505
+ r.bbox = (int(x*scale_factor), int(y*scale_factor),
1506
+ int(width*scale_factor), int(height*scale_factor))
1507
+ r.vertices = [(int(v[0]*scale_factor), int(v[1]*scale_factor))
1508
+ for v in r.vertices]
1509
+
1510
+ if results:
1511
+ self._log(f"✅ Detected {len(results)} text regions", "info")
1512
+ else:
1513
+ self._log("No text regions found", "debug")
1514
+
1515
+ # Clean up
1516
+ del image_copy
1517
+ gc.collect()
1518
+
1519
+ except Exception as e:
1520
+ error_msg = str(e) if str(e) else type(e).__name__
1521
+
1522
+ if "memory" in error_msg.lower() or "0x" in error_msg:
1523
+ self._log("❌ Memory access violation in PaddleOCR", "error")
1524
+ self._log(" This is a known Windows issue with PaddleOCR", "info")
1525
+ self._log(" Please switch to EasyOCR or manga-ocr instead", "warning")
1526
+ elif "trace_order.size()" in error_msg:
1527
+ self._log("❌ PaddleOCR internal error", "error")
1528
+ self._log(" Please switch to EasyOCR or manga-ocr", "warning")
1529
+ else:
1530
+ self._log(f"❌ Error in paddleocr detection: {error_msg}", "error")
1531
+
1532
+ import traceback
1533
+ self._log(traceback.format_exc(), "debug")
1534
+
1535
+ return results
1536
+
1537
+ def _parse_ocr_results(self, ocr_results) -> List[OCRResult]:
1538
+ """Parse OCR results safely"""
1539
+ results = []
1540
+
1541
+ if isinstance(ocr_results, bool) and ocr_results == False:
1542
+ return results
1543
+
1544
+ if ocr_results is None or not isinstance(ocr_results, list):
1545
+ return results
1546
+
1547
+ if len(ocr_results) == 0:
1548
+ return results
1549
+
1550
+ # Handle batch format
1551
+ if isinstance(ocr_results[0], list) and len(ocr_results[0]) > 0:
1552
+ first_item = ocr_results[0][0]
1553
+ if isinstance(first_item, list) and len(first_item) > 0:
1554
+ if isinstance(first_item[0], (list, tuple)) and len(first_item[0]) == 2:
1555
+ ocr_results = ocr_results[0]
1556
+
1557
+ # Parse detections
1558
+ for detection in ocr_results:
1559
+ if not detection or isinstance(detection, bool):
1560
+ continue
1561
+
1562
+ if not isinstance(detection, (list, tuple)) or len(detection) < 2:
1563
+ continue
1564
+
1565
+ try:
1566
+ bbox_points = detection[0]
1567
+ text_data = detection[1]
1568
+
1569
+ if not isinstance(bbox_points, (list, tuple)) or len(bbox_points) != 4:
1570
+ continue
1571
+
1572
+ if not isinstance(text_data, (tuple, list)) or len(text_data) < 2:
1573
+ continue
1574
+
1575
+ text = str(text_data[0]).strip()
1576
+ confidence = float(text_data[1])
1577
+
1578
+ if not text or confidence < 0.3:
1579
+ continue
1580
+
1581
+ xs = [float(p[0]) for p in bbox_points]
1582
+ ys = [float(p[1]) for p in bbox_points]
1583
+ x_min, x_max = min(xs), max(xs)
1584
+ y_min, y_max = min(ys), max(ys)
1585
+
1586
+ if (x_max - x_min) < 5 or (y_max - y_min) < 5:
1587
+ continue
1588
+
1589
+ results.append(OCRResult(
1590
+ text=text,
1591
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1592
+ confidence=confidence,
1593
+ vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1594
+ ))
1595
+
1596
+ except Exception:
1597
+ continue
1598
+
1599
+ return results
1600
+
1601
+ class DocTROCRProvider(OCRProvider):
1602
+ """DocTR OCR provider"""
1603
+
1604
+ def check_installation(self) -> bool:
1605
+ """Check if doctr is installed"""
1606
+ try:
1607
+ from doctr.models import ocr_predictor
1608
+ self.is_installed = True
1609
+ return True
1610
+ except ImportError:
1611
+ return False
1612
+
1613
+ def install(self, progress_callback=None) -> bool:
1614
+ """Install doctr"""
1615
+ pass
1616
+
1617
+ def load_model(self, **kwargs) -> bool:
1618
+ """Load doctr model"""
1619
+ try:
1620
+ if not self.is_installed and not self.check_installation():
1621
+ self._log("❌ doctr not installed", "error")
1622
+ return False
1623
+
1624
+ self._log("🔥 Loading DocTR model...")
1625
+ from doctr.models import ocr_predictor
1626
+
1627
+ # Load pretrained model
1628
+ self.model = ocr_predictor(pretrained=True)
1629
+ self.is_loaded = True
1630
+
1631
+ self._log("✅ DocTR model loaded successfully")
1632
+ return True
1633
+
1634
+ except Exception as e:
1635
+ self._log(f"❌ Failed to load doctr: {str(e)}", "error")
1636
+ return False
1637
+
1638
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1639
+ """Detect text using doctr"""
1640
+ results = []
1641
+
1642
+ try:
1643
+ if not self.is_loaded:
1644
+ if not self.load_model():
1645
+ return results
1646
+
1647
+ from doctr.io import DocumentFile
1648
+
1649
+ # DocTR expects document format
1650
+ # Convert numpy array to PIL and save temporarily
1651
+ import tempfile
1652
+ import cv2
1653
+
1654
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
1655
+ cv2.imwrite(tmp.name, image)
1656
+ doc = DocumentFile.from_images(tmp.name)
1657
+
1658
+ # Run OCR
1659
+ result = self.model(doc)
1660
+
1661
+ # Parse results
1662
+ h, w = image.shape[:2]
1663
+ for page in result.pages:
1664
+ for block in page.blocks:
1665
+ for line in block.lines:
1666
+ for word in line.words:
1667
+ # Handle different geometry formats
1668
+ geometry = word.geometry
1669
+
1670
+ if len(geometry) == 4:
1671
+ # Standard format: (x1, y1, x2, y2)
1672
+ x1, y1, x2, y2 = geometry
1673
+ elif len(geometry) == 2:
1674
+ # Alternative format: ((x1, y1), (x2, y2))
1675
+ (x1, y1), (x2, y2) = geometry
1676
+ else:
1677
+ self._log(f"Unexpected geometry format: {geometry}", "warning")
1678
+ continue
1679
+
1680
+ # Convert relative coordinates to absolute
1681
+ x1, x2 = int(x1 * w), int(x2 * w)
1682
+ y1, y2 = int(y1 * h), int(y2 * h)
1683
+
1684
+ results.append(OCRResult(
1685
+ text=word.value,
1686
+ bbox=(x1, y1, x2 - x1, y2 - y1),
1687
+ confidence=word.confidence,
1688
+ vertices=[(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
1689
+ ))
1690
+
1691
+ # Clean up temp file
1692
+ try:
1693
+ os.unlink(tmp.name)
1694
+ except:
1695
+ pass
1696
+
1697
+ self._log(f"DocTR detected {len(results)} text regions")
1698
+
1699
+ except Exception as e:
1700
+ self._log(f"Error in doctr detection: {str(e)}", "error")
1701
+ import traceback
1702
+ self._log(traceback.format_exc(), "error")
1703
+
1704
+ return results
1705
+
1706
+
1707
+ class RapidOCRProvider(OCRProvider):
1708
+ """RapidOCR provider for fast local OCR"""
1709
+
1710
+ def check_installation(self) -> bool:
1711
+ """Check if rapidocr is installed"""
1712
+ try:
1713
+ import rapidocr_onnxruntime
1714
+ self.is_installed = True
1715
+ return True
1716
+ except ImportError:
1717
+ return False
1718
+
1719
+ def install(self, progress_callback=None) -> bool:
1720
+ """Install rapidocr (requires manual pip install)"""
1721
+ # RapidOCR requires manual installation
1722
+ if progress_callback:
1723
+ progress_callback("RapidOCR requires manual pip installation")
1724
+ self._log("Run: pip install rapidocr-onnxruntime", "info")
1725
+ return False # Always return False since we can't auto-install
1726
+
1727
+ def load_model(self, **kwargs) -> bool:
1728
+ """Load RapidOCR model"""
1729
+ try:
1730
+ if not self.is_installed and not self.check_installation():
1731
+ self._log("RapidOCR not installed", "error")
1732
+ return False
1733
+
1734
+ self._log("Loading RapidOCR...")
1735
+ from rapidocr_onnxruntime import RapidOCR
1736
+
1737
+ self.model = RapidOCR()
1738
+ self.is_loaded = True
1739
+
1740
+ self._log("RapidOCR model loaded successfully")
1741
+ return True
1742
+
1743
+ except Exception as e:
1744
+ self._log(f"Failed to load RapidOCR: {str(e)}", "error")
1745
+ return False
1746
+
1747
+ def detect_text(self, image: np.ndarray, **kwargs) -> List[OCRResult]:
1748
+ """Detect text using RapidOCR"""
1749
+ if not self.is_loaded:
1750
+ self._log("RapidOCR model not loaded", "error")
1751
+ return []
1752
+
1753
+ results = []
1754
+
1755
+ try:
1756
+ # Convert numpy array to PIL Image for RapidOCR
1757
+ if len(image.shape) == 3:
1758
+ # BGR to RGB
1759
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1760
+ else:
1761
+ image_rgb = image
1762
+
1763
+ # RapidOCR expects PIL Image or numpy array
1764
+ ocr_results, _ = self.model(image_rgb)
1765
+
1766
+ if ocr_results:
1767
+ for result in ocr_results:
1768
+ # RapidOCR returns [bbox, text, confidence]
1769
+ bbox_points = result[0] # 4 corner points
1770
+ text = result[1]
1771
+ confidence = float(result[2])
1772
+
1773
+ if not text or not text.strip():
1774
+ continue
1775
+
1776
+ # Convert 4-point bbox to x,y,w,h format
1777
+ xs = [point[0] for point in bbox_points]
1778
+ ys = [point[1] for point in bbox_points]
1779
+ x_min, x_max = min(xs), max(xs)
1780
+ y_min, y_max = min(ys), max(ys)
1781
+
1782
+ results.append(OCRResult(
1783
+ text=text.strip(),
1784
+ bbox=(int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min)),
1785
+ confidence=confidence,
1786
+ vertices=[(int(p[0]), int(p[1])) for p in bbox_points]
1787
+ ))
1788
+
1789
+ self._log(f"Detected {len(results)} text regions")
1790
+
1791
+ except Exception as e:
1792
+ self._log(f"Error in RapidOCR detection: {str(e)}", "error")
1793
+
1794
+ return results
1795
+
1796
+ class OCRManager:
1797
+ """Manager for multiple OCR providers"""
1798
+
1799
+ def __init__(self, log_callback=None):
1800
+ self.log_callback = log_callback
1801
+ self.providers = {
1802
+ 'custom-api': CustomAPIProvider(log_callback) ,
1803
+ 'manga-ocr': MangaOCRProvider(log_callback),
1804
+ 'easyocr': EasyOCRProvider(log_callback),
1805
+ 'paddleocr': PaddleOCRProvider(log_callback),
1806
+ 'doctr': DocTROCRProvider(log_callback),
1807
+ 'rapidocr': RapidOCRProvider(log_callback),
1808
+ 'Qwen2-VL': Qwen2VL(log_callback)
1809
+ }
1810
+ self.current_provider = None
1811
+ self.stop_flag = None
1812
+
1813
+ def get_provider(self, name: str) -> Optional[OCRProvider]:
1814
+ """Get OCR provider by name"""
1815
+ return self.providers.get(name)
1816
+
1817
+ def set_current_provider(self, name: str):
1818
+ """Set current active provider"""
1819
+ if name in self.providers:
1820
+ self.current_provider = name
1821
+ return True
1822
+ return False
1823
+
1824
+ def check_provider_status(self, name: str) -> Dict[str, bool]:
1825
+ """Check installation and loading status of provider"""
1826
+ provider = self.providers.get(name)
1827
+ if not provider:
1828
+ return {'installed': False, 'loaded': False}
1829
+
1830
+ result = {
1831
+ 'installed': provider.check_installation(),
1832
+ 'loaded': provider.is_loaded
1833
+ }
1834
+ if self.log_callback:
1835
+ self.log_callback(f"DEBUG: check_provider_status({name}) returning loaded={result['loaded']}", "debug")
1836
+ return result
1837
+
1838
+ def install_provider(self, name: str, progress_callback=None) -> bool:
1839
+ """Install a provider"""
1840
+ provider = self.providers.get(name)
1841
+ if not provider:
1842
+ return False
1843
+
1844
+ return provider.install(progress_callback)
1845
+
1846
+ def load_provider(self, name: str, **kwargs) -> bool:
1847
+ """Load a provider's model with optional parameters"""
1848
+ provider = self.providers.get(name)
1849
+ if not provider:
1850
+ return False
1851
+
1852
+ return provider.load_model(**kwargs) # <-- Passes model_size and any other kwargs
1853
+
1854
+ def shutdown(self):
1855
+ """Release models/processors/tokenizers for all providers and clear caches."""
1856
+ try:
1857
+ import gc
1858
+ for name, provider in list(self.providers.items()):
1859
+ try:
1860
+ if hasattr(provider, 'model'):
1861
+ provider.model = None
1862
+ if hasattr(provider, 'processor'):
1863
+ provider.processor = None
1864
+ if hasattr(provider, 'tokenizer'):
1865
+ provider.tokenizer = None
1866
+ if hasattr(provider, 'reader'):
1867
+ provider.reader = None
1868
+ if hasattr(provider, 'is_loaded'):
1869
+ provider.is_loaded = False
1870
+ except Exception:
1871
+ pass
1872
+ gc.collect()
1873
+ try:
1874
+ import torch
1875
+ torch.cuda.empty_cache()
1876
+ except Exception:
1877
+ pass
1878
+ except Exception:
1879
+ pass
1880
+
1881
+ def detect_text(self, image: np.ndarray, provider_name: str = None, **kwargs) -> List[OCRResult]:
1882
+ """Detect text using specified or current provider"""
1883
+ provider_name = provider_name or self.current_provider
1884
+ if not provider_name:
1885
+ return []
1886
+
1887
+ provider = self.providers.get(provider_name)
1888
+ if not provider:
1889
+ return []
1890
+
1891
+ return provider.detect_text(image, **kwargs)
1892
+
1893
+ def set_stop_flag(self, stop_flag):
1894
+ """Set stop flag for all providers"""
1895
+ self.stop_flag = stop_flag
1896
+ for provider in self.providers.values():
1897
+ if hasattr(provider, 'set_stop_flag'):
1898
+ provider.set_stop_flag(stop_flag)
1899
+
1900
+ def reset_stop_flags(self):
1901
+ """Reset stop flags for all providers"""
1902
+ for provider in self.providers.values():
1903
+ if hasattr(provider, 'reset_stop_flags'):
1904
+ provider.reset_stop_flags()