Spaces:
Running
Running
| """ | |
| bubble_detector.py - Modified version that works in frozen PyInstaller executables | |
| Replace your bubble_detector.py with this version | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import numpy as np | |
| import cv2 | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import logging | |
| import traceback | |
| import hashlib | |
| from pathlib import Path | |
| import threading | |
| import time | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check if we're running in a frozen environment | |
| IS_FROZEN = getattr(sys, 'frozen', False) | |
| if IS_FROZEN: | |
| # In frozen environment, set proper paths for ML libraries | |
| MEIPASS = sys._MEIPASS | |
| os.environ['TORCH_HOME'] = MEIPASS | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.join(MEIPASS, 'transformers') | |
| os.environ['HF_HOME'] = os.path.join(MEIPASS, 'huggingface') | |
| logger.info(f"Running in frozen environment: {MEIPASS}") | |
| # Modified import checks for frozen environment | |
| YOLO_AVAILABLE = False | |
| YOLO = None | |
| torch = None | |
| TORCH_AVAILABLE = False | |
| ONNX_AVAILABLE = False | |
| TRANSFORMERS_AVAILABLE = False | |
| RTDetrForObjectDetection = None | |
| RTDetrImageProcessor = None | |
| PIL_AVAILABLE = False | |
| # Try to import YOLO dependencies with better error handling | |
| if IS_FROZEN: | |
| # In frozen environment, try harder to import | |
| try: | |
| # First try to import torch components individually | |
| import torch | |
| import torch.nn | |
| import torch.cuda | |
| TORCH_AVAILABLE = True | |
| logger.info("✓ PyTorch loaded in frozen environment") | |
| except Exception as e: | |
| logger.warning(f"PyTorch not available in frozen environment: {e}") | |
| TORCH_AVAILABLE = False | |
| torch = None | |
| # Try ultralytics after torch | |
| if TORCH_AVAILABLE: | |
| try: | |
| from ultralytics import YOLO | |
| YOLO_AVAILABLE = True | |
| logger.info("✓ Ultralytics YOLO loaded in frozen environment") | |
| except Exception as e: | |
| logger.warning(f"Ultralytics not available in frozen environment: {e}") | |
| YOLO_AVAILABLE = False | |
| # Try transformers | |
| try: | |
| import transformers | |
| # Try specific imports | |
| try: | |
| from transformers import RTDetrForObjectDetection, RTDetrImageProcessor | |
| TRANSFORMERS_AVAILABLE = True | |
| logger.info("✓ Transformers RT-DETR loaded in frozen environment") | |
| except ImportError: | |
| # Try alternative import | |
| try: | |
| from transformers import AutoModel, AutoImageProcessor | |
| RTDetrForObjectDetection = AutoModel | |
| RTDetrImageProcessor = AutoImageProcessor | |
| TRANSFORMERS_AVAILABLE = True | |
| logger.info("✓ Transformers loaded with AutoModel fallback") | |
| except: | |
| TRANSFORMERS_AVAILABLE = False | |
| logger.warning("Transformers RT-DETR not available in frozen environment") | |
| except Exception as e: | |
| logger.warning(f"Transformers not available in frozen environment: {e}") | |
| TRANSFORMERS_AVAILABLE = False | |
| else: | |
| # Normal environment - original import logic | |
| try: | |
| from ultralytics import YOLO | |
| YOLO_AVAILABLE = True | |
| except: | |
| YOLO_AVAILABLE = False | |
| logger.warning("Ultralytics YOLO not available") | |
| try: | |
| import torch | |
| # Test if cuda attribute exists | |
| _ = torch.cuda | |
| TORCH_AVAILABLE = True | |
| except (ImportError, AttributeError): | |
| TORCH_AVAILABLE = False | |
| torch = None | |
| logger.warning("PyTorch not available or incomplete") | |
| try: | |
| from transformers import RTDetrForObjectDetection, RTDetrImageProcessor | |
| try: | |
| from transformers import RTDetrV2ForObjectDetection | |
| RTDetrForObjectDetection = RTDetrV2ForObjectDetection | |
| except ImportError: | |
| pass | |
| TRANSFORMERS_AVAILABLE = True | |
| except: | |
| TRANSFORMERS_AVAILABLE = False | |
| logger.info("Transformers not available for RT-DETR") | |
| # Configure ORT memory behavior before importing | |
| try: | |
| os.environ.setdefault('ORT_DISABLE_MEMORY_ARENA', '1') | |
| except Exception: | |
| pass | |
| # ONNX Runtime - works well in frozen environments | |
| try: | |
| import onnxruntime as ort | |
| ONNX_AVAILABLE = True | |
| logger.info("✓ ONNX Runtime available") | |
| except ImportError: | |
| ONNX_AVAILABLE = False | |
| logger.warning("ONNX Runtime not available") | |
| # PIL | |
| try: | |
| from PIL import Image | |
| PIL_AVAILABLE = True | |
| except ImportError: | |
| PIL_AVAILABLE = False | |
| logger.info("PIL not available") | |
| class BubbleDetector: | |
| """ | |
| Combined YOLOv8 and RT-DETR speech bubble detector for comics and manga. | |
| Supports multiple model formats and provides configurable detection. | |
| Backward compatible with existing code while adding RT-DETR support. | |
| """ | |
| # Process-wide shared RT-DETR to avoid concurrent meta-device loads | |
| _rtdetr_init_lock = threading.Lock() | |
| _rtdetr_shared_model = None | |
| _rtdetr_shared_processor = None | |
| _rtdetr_loaded = False | |
| _rtdetr_repo_id = 'ogkalu/comic-text-and-bubble-detector' | |
| # Shared RT-DETR (ONNX) across process to avoid device/context storms | |
| _rtdetr_onnx_init_lock = threading.Lock() | |
| _rtdetr_onnx_shared_session = None | |
| _rtdetr_onnx_loaded = False | |
| _rtdetr_onnx_providers = None | |
| _rtdetr_onnx_model_path = None | |
| # Limit concurrent runs to avoid device hangs. Defaults to 2 for better parallelism. | |
| # Can be overridden via env DML_MAX_CONCURRENT or config rtdetr_max_concurrency | |
| try: | |
| _rtdetr_onnx_max_concurrent = int(os.environ.get('DML_MAX_CONCURRENT', '2')) | |
| except Exception: | |
| _rtdetr_onnx_max_concurrent = 2 | |
| _rtdetr_onnx_sema = threading.Semaphore(max(1, _rtdetr_onnx_max_concurrent)) | |
| _rtdetr_onnx_sema_initialized = False | |
| def __init__(self, config_path: str = "config.json"): | |
| """ | |
| Initialize the bubble detector. | |
| Args: | |
| config_path: Path to configuration file | |
| """ | |
| # Set thread limits early if environment indicates single-threaded mode | |
| try: | |
| if os.environ.get('OMP_NUM_THREADS') == '1': | |
| # Already in single-threaded mode, ensure it's applied to this process | |
| # Check if torch is available at module level before trying to use it | |
| if TORCH_AVAILABLE and torch is not None: | |
| try: | |
| torch.set_num_threads(1) | |
| except (RuntimeError, AttributeError): | |
| pass | |
| try: | |
| import cv2 | |
| cv2.setNumThreads(1) | |
| except (ImportError, AttributeError): | |
| pass | |
| except Exception: | |
| pass | |
| self.config_path = config_path | |
| self.config = self._load_config() | |
| # YOLOv8 components (original) | |
| self.model = None | |
| self.model_loaded = False | |
| self.model_type = None # 'yolo', 'onnx', or 'torch' | |
| self.onnx_session = None | |
| # RT-DETR components (new) | |
| self.rtdetr_model = None | |
| self.rtdetr_processor = None | |
| self.rtdetr_loaded = False | |
| self.rtdetr_repo = 'ogkalu/comic-text-and-bubble-detector' | |
| # RT-DETR (ONNX) backend components | |
| self.rtdetr_onnx_session = None | |
| self.rtdetr_onnx_loaded = False | |
| self.rtdetr_onnx_repo = 'ogkalu/comic-text-and-bubble-detector' | |
| # RT-DETR class definitions | |
| self.CLASS_BUBBLE = 0 # Empty speech bubble | |
| self.CLASS_TEXT_BUBBLE = 1 # Bubble with text | |
| self.CLASS_TEXT_FREE = 2 # Text without bubble | |
| # Detection settings | |
| self.default_confidence = 0.3 | |
| self.default_iou_threshold = 0.45 | |
| # Allow override from settings | |
| try: | |
| ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {} | |
| self.default_max_detections = int(ocr_cfg.get('bubble_max_detections', 100)) | |
| self.max_det_yolo = int(ocr_cfg.get('bubble_max_detections_yolo', self.default_max_detections)) | |
| self.max_det_rtdetr = int(ocr_cfg.get('bubble_max_detections_rtdetr', self.default_max_detections)) | |
| except Exception: | |
| self.default_max_detections = 100 | |
| self.max_det_yolo = 100 | |
| self.max_det_rtdetr = 100 | |
| # Cache directory for ONNX conversions | |
| self.cache_dir = os.environ.get('BUBBLE_CACHE_DIR', 'models') | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| # RT-DETR concurrency setting from config | |
| try: | |
| rtdetr_max_conc = int(ocr_cfg.get('rtdetr_max_concurrency', 2)) | |
| # Update class-level semaphore if not yet initialized or if value changed | |
| if not BubbleDetector._rtdetr_onnx_sema_initialized or rtdetr_max_conc != BubbleDetector._rtdetr_onnx_max_concurrent: | |
| BubbleDetector._rtdetr_onnx_max_concurrent = max(1, rtdetr_max_conc) | |
| BubbleDetector._rtdetr_onnx_sema = threading.Semaphore(BubbleDetector._rtdetr_onnx_max_concurrent) | |
| BubbleDetector._rtdetr_onnx_sema_initialized = True | |
| logger.info(f"RT-DETR concurrency set to: {BubbleDetector._rtdetr_onnx_max_concurrent}") | |
| except Exception as e: | |
| logger.warning(f"Failed to set RT-DETR concurrency: {e}") | |
| # GPU availability | |
| self.use_gpu = TORCH_AVAILABLE and torch.cuda.is_available() | |
| self.device = 'cuda' if self.use_gpu else 'cpu' | |
| # Quantization/precision settings | |
| adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {} | |
| ocr_cfg = self.config.get('manga_settings', {}).get('ocr', {}) if isinstance(self.config, dict) else {} | |
| env_quant = os.environ.get('MODEL_QUANTIZE', 'false').lower() == 'true' | |
| self.quantize_enabled = bool(env_quant or adv_cfg.get('quantize_models', False) or ocr_cfg.get('quantize_bubble_detector', False)) | |
| self.quantize_dtype = str(adv_cfg.get('torch_precision', os.environ.get('TORCH_PRECISION', 'auto'))).lower() | |
| # Prefer advanced.onnx_quantize; fall back to env or global quantize | |
| self.onnx_quantize_enabled = bool(adv_cfg.get('onnx_quantize', os.environ.get('ONNX_QUANTIZE', 'false').lower() == 'true' or self.quantize_enabled)) | |
| # Stop flag support | |
| self.stop_flag = None | |
| self._stopped = False | |
| self.log_callback = None | |
| logger.info(f"🗨️ BubbleDetector initialized") | |
| logger.info(f" GPU: {'Available' if self.use_gpu else 'Not available'}") | |
| logger.info(f" YOLO: {'Available' if YOLO_AVAILABLE else 'Not installed'}") | |
| logger.info(f" ONNX: {'Available' if ONNX_AVAILABLE else 'Not installed'}") | |
| logger.info(f" RT-DETR: {'Available' if TRANSFORMERS_AVAILABLE else 'Not installed'}") | |
| logger.info(f" Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} (torch_precision={self.quantize_dtype}, onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'})" ) | |
| def _load_config(self) -> Dict[str, Any]: | |
| """Load configuration from file.""" | |
| if os.path.exists(self.config_path): | |
| try: | |
| with open(self.config_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.warning(f"Failed to load config: {e}") | |
| return {} | |
| def _save_config(self): | |
| """Save configuration to file.""" | |
| try: | |
| with open(self.config_path, 'w', encoding='utf-8') as f: | |
| json.dump(self.config, f, indent=2) | |
| except Exception as e: | |
| logger.error(f"Failed to save config: {e}") | |
| def set_stop_flag(self, stop_flag): | |
| """Set the stop flag for checking interruptions""" | |
| self.stop_flag = stop_flag | |
| self._stopped = False | |
| def set_log_callback(self, log_callback): | |
| """Set log callback for GUI integration""" | |
| self.log_callback = log_callback | |
| def _check_stop(self) -> bool: | |
| """Check if stop has been requested""" | |
| if self._stopped: | |
| return True | |
| if self.stop_flag and self.stop_flag.is_set(): | |
| self._stopped = True | |
| return True | |
| # Check global manga translator cancellation | |
| try: | |
| from manga_translator import MangaTranslator | |
| if MangaTranslator.is_globally_cancelled(): | |
| self._stopped = True | |
| return True | |
| except Exception: | |
| pass | |
| return False | |
| def _log(self, message: str, level: str = "info"): | |
| """Log message with stop suppression""" | |
| # Suppress logs when stopped (allow only essential stop confirmation messages) | |
| if self._check_stop(): | |
| essential_stop_keywords = [ | |
| "⏹️ Translation stopped by user", | |
| "⏹️ Bubble detection stopped", | |
| "cleanup", "🧹" | |
| ] | |
| if not any(keyword in message for keyword in essential_stop_keywords): | |
| return | |
| if self.log_callback: | |
| self.log_callback(message, level) | |
| else: | |
| logger.info(message) if level == 'info' else getattr(logger, level, logger.info)(message) | |
| def reset_stop_flags(self): | |
| """Reset stop flags when starting new processing""" | |
| self._stopped = False | |
| def load_model(self, model_path: str, force_reload: bool = False) -> bool: | |
| """ | |
| Load a YOLOv8 model for bubble detection. | |
| Args: | |
| model_path: Path to model file (.pt, .onnx, or .torchscript) | |
| force_reload: Force reload even if model is already loaded | |
| Returns: | |
| True if model loaded successfully, False otherwise | |
| """ | |
| try: | |
| # If given a Hugging Face repo ID (e.g., 'owner/name'), fetch detector.onnx into models/ | |
| if model_path and (('/' in model_path) and not os.path.exists(model_path)): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| logger.info(f"📥 Resolving repo '{model_path}' to detector.onnx in {self.cache_dir}...") | |
| resolved = hf_hub_download(repo_id=model_path, filename='detector.onnx', cache_dir=self.cache_dir, local_dir=self.cache_dir, local_dir_use_symlinks=False) | |
| if resolved and os.path.exists(resolved): | |
| model_path = resolved | |
| logger.info(f"✅ Downloaded detector.onnx to: {model_path}") | |
| except Exception as repo_err: | |
| logger.error(f"Failed to download from repo '{model_path}': {repo_err}") | |
| if not os.path.exists(model_path): | |
| logger.error(f"Model file not found: {model_path}") | |
| return False | |
| # Check if it's the same model already loaded | |
| if self.model_loaded and not force_reload: | |
| last_path = self.config.get('last_model_path', '') | |
| if last_path == model_path: | |
| logger.info("Model already loaded (same path)") | |
| return True | |
| else: | |
| logger.info(f"Model path changed from {last_path} to {model_path}, reloading...") | |
| force_reload = True | |
| # Clear previous model if force reload | |
| if force_reload: | |
| logger.info("Force reloading model...") | |
| self.model = None | |
| self.onnx_session = None | |
| self.model_loaded = False | |
| self.model_type = None | |
| logger.info(f"📥 Loading bubble detection model: {model_path}") | |
| # Determine model type by extension | |
| ext = Path(model_path).suffix.lower() | |
| if ext in ['.pt', '.pth']: | |
| if not YOLO_AVAILABLE: | |
| logger.warning("Ultralytics package not available in this build") | |
| logger.info("Bubble detection will be disabled - this is normal for lightweight builds") | |
| # Don't return False immediately, try other fallbacks | |
| self.model_loaded = False | |
| return False | |
| # Load YOLOv8 model | |
| try: | |
| self.model = YOLO(model_path) | |
| self.model_type = 'yolo' | |
| # Set to eval mode | |
| if hasattr(self.model, 'model'): | |
| self.model.model.eval() | |
| # Move to GPU if available | |
| if self.use_gpu and TORCH_AVAILABLE: | |
| try: | |
| self.model.to('cuda') | |
| except Exception as gpu_error: | |
| logger.warning(f"Could not move model to GPU: {gpu_error}") | |
| logger.info("✅ YOLOv8 model loaded successfully") | |
| # Apply optional FP16 precision to reduce VRAM if enabled | |
| if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE: | |
| try: | |
| m = self.model.model if hasattr(self.model, 'model') else self.model | |
| m.half() | |
| logger.info("🔻 Applied FP16 precision to YOLO model (GPU)") | |
| except Exception as _e: | |
| logger.warning(f"Could not switch YOLO model to FP16: {_e}") | |
| except Exception as yolo_error: | |
| logger.error(f"Failed to load YOLO model: {yolo_error}") | |
| return False | |
| elif ext == '.onnx': | |
| if not ONNX_AVAILABLE: | |
| logger.warning("ONNX Runtime not available in this build") | |
| logger.info("ONNX model support disabled - this is normal for lightweight builds") | |
| return False | |
| try: | |
| # Load ONNX model | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.use_gpu else ['CPUExecutionProvider'] | |
| session_path = model_path | |
| if self.quantize_enabled: | |
| try: | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quant_path = os.path.splitext(model_path)[0] + ".int8.onnx" | |
| if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true': | |
| logger.info("🔻 Quantizing ONNX model weights to INT8 (dynamic)...") | |
| quantize_dynamic(model_input=model_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul']) | |
| session_path = quant_path | |
| self.config['last_onnx_quantized_path'] = quant_path | |
| self._save_config() | |
| logger.info(f"✅ Using quantized ONNX model: {quant_path}") | |
| except Exception as qe: | |
| logger.warning(f"ONNX quantization not applied: {qe}") | |
| # Use conservative ORT memory options to reduce RAM growth | |
| so = ort.SessionOptions() | |
| try: | |
| so.enable_mem_pattern = False | |
| so.enable_cpu_mem_arena = False | |
| except Exception: | |
| pass | |
| self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=providers) | |
| self.model_type = 'onnx' | |
| logger.info("✅ ONNX model loaded successfully") | |
| except Exception as onnx_error: | |
| logger.error(f"Failed to load ONNX model: {onnx_error}") | |
| return False | |
| elif ext == '.torchscript': | |
| if not TORCH_AVAILABLE: | |
| logger.warning("PyTorch not available in this build") | |
| logger.info("TorchScript model support disabled - this is normal for lightweight builds") | |
| return False | |
| try: | |
| # Add safety check for torch being None | |
| if torch is None: | |
| logger.error("PyTorch module is None - cannot load TorchScript model") | |
| return False | |
| # Load TorchScript model | |
| self.model = torch.jit.load(model_path, map_location='cpu') | |
| self.model.eval() | |
| self.model_type = 'torch' | |
| if self.use_gpu: | |
| try: | |
| self.model = self.model.cuda() | |
| except Exception as gpu_error: | |
| logger.warning(f"Could not move TorchScript model to GPU: {gpu_error}") | |
| logger.info("✅ TorchScript model loaded successfully") | |
| # Optional FP16 precision on GPU | |
| if self.quantize_enabled and self.use_gpu and TORCH_AVAILABLE: | |
| try: | |
| self.model = self.model.half() | |
| logger.info("🔻 Applied FP16 precision to TorchScript model (GPU)") | |
| except Exception as _e: | |
| logger.warning(f"Could not switch TorchScript model to FP16: {_e}") | |
| except Exception as torch_error: | |
| logger.error(f"Failed to load TorchScript model: {torch_error}") | |
| return False | |
| else: | |
| logger.error(f"Unsupported model format: {ext}") | |
| logger.info("Supported formats: .pt/.pth (YOLOv8), .onnx (ONNX), .torchscript (TorchScript)") | |
| return False | |
| # Only set loaded if we actually succeeded | |
| self.model_loaded = True | |
| self.config['last_model_path'] = model_path | |
| self.config['model_type'] = self.model_type | |
| self._save_config() | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| logger.error(traceback.format_exc()) | |
| self.model_loaded = False | |
| # Provide helpful context for .exe users | |
| logger.info("Note: If running from .exe, some ML libraries may not be included") | |
| logger.info("This is normal for lightweight builds - bubble detection will be disabled") | |
| return False | |
| def load_rtdetr_model(self, model_path: str = None, model_id: str = None, force_reload: bool = False) -> bool: | |
| """ | |
| Load RT-DETR model for advanced bubble and text detection. | |
| This implementation avoids the 'meta tensor' copy error by: | |
| - Serializing the entire load under a class lock (no concurrent loads) | |
| - Loading directly onto the target device (CUDA if available) via device_map='auto' | |
| - Avoiding .to() on a potentially-meta model; no device migration post-load | |
| Args: | |
| model_path: Optional path to local model | |
| model_id: Optional HuggingFace model ID (default: 'ogkalu/comic-text-and-bubble-detector') | |
| force_reload: Force reload even if already loaded | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| if not TRANSFORMERS_AVAILABLE: | |
| logger.error("Transformers library required for RT-DETR. Install with: pip install transformers") | |
| return False | |
| if not PIL_AVAILABLE: | |
| logger.error("PIL required for RT-DETR. Install with: pip install pillow") | |
| return False | |
| if self.rtdetr_loaded and not force_reload: | |
| logger.info("RT-DETR model already loaded") | |
| return True | |
| # Fast path: if shared already loaded and not forcing reload, attach | |
| if BubbleDetector._rtdetr_loaded and not force_reload: | |
| self.rtdetr_model = BubbleDetector._rtdetr_shared_model | |
| self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor | |
| self.rtdetr_loaded = True | |
| logger.info("RT-DETR model attached from shared cache") | |
| return True | |
| # Serialize the ENTIRE loading sequence to avoid concurrent init issues | |
| with BubbleDetector._rtdetr_init_lock: | |
| try: | |
| # Re-check after acquiring lock | |
| if BubbleDetector._rtdetr_loaded and not force_reload: | |
| self.rtdetr_model = BubbleDetector._rtdetr_shared_model | |
| self.rtdetr_processor = BubbleDetector._rtdetr_shared_processor | |
| self.rtdetr_loaded = True | |
| logger.info("RT-DETR model attached from shared cache (post-lock)") | |
| return True | |
| # Use custom model_id if provided, otherwise use default | |
| repo_id = model_id if model_id else self.rtdetr_repo | |
| logger.info(f"📥 Loading RT-DETR model from {repo_id}...") | |
| # Ensure TorchDynamo/compile doesn't interfere on some builds | |
| try: | |
| os.environ.setdefault('TORCHDYNAMO_DISABLE', '1') | |
| except Exception: | |
| pass | |
| # Decide device strategy | |
| gpu_available = bool(TORCH_AVAILABLE and hasattr(torch, 'cuda') and torch.cuda.is_available()) | |
| device_map = 'auto' if gpu_available else None | |
| # Choose dtype | |
| dtype = None | |
| if TORCH_AVAILABLE: | |
| try: | |
| dtype = torch.float16 if gpu_available else torch.float32 | |
| except Exception: | |
| dtype = None | |
| low_cpu = True if gpu_available else False | |
| # Load processor (once) | |
| self.rtdetr_processor = RTDetrImageProcessor.from_pretrained( | |
| repo_id, | |
| size={"width": 640, "height": 640}, | |
| cache_dir=self.cache_dir if not model_path else None | |
| ) | |
| # Prepare kwargs for from_pretrained | |
| from_kwargs = { | |
| 'cache_dir': self.cache_dir if not model_path else None, | |
| 'low_cpu_mem_usage': low_cpu, | |
| 'device_map': device_map, | |
| } | |
| # Note: dtype is handled via torch_dtype parameter in newer transformers | |
| if dtype is not None: | |
| from_kwargs['torch_dtype'] = dtype | |
| # First attempt: load directly to target (CUDA if available) | |
| try: | |
| self.rtdetr_model = RTDetrForObjectDetection.from_pretrained( | |
| model_path if model_path else repo_id, | |
| **from_kwargs, | |
| ) | |
| except Exception as primary_err: | |
| # Fallback to a simple CPU load (no device move) if CUDA path fails | |
| logger.warning(f"RT-DETR primary load failed ({primary_err}); retrying on CPU...") | |
| from_kwargs_fallback = { | |
| 'cache_dir': self.cache_dir if not model_path else None, | |
| 'low_cpu_mem_usage': False, | |
| 'device_map': None, | |
| } | |
| if TORCH_AVAILABLE: | |
| from_kwargs_fallback['torch_dtype'] = torch.float32 | |
| self.rtdetr_model = RTDetrForObjectDetection.from_pretrained( | |
| model_path if model_path else repo_id, | |
| **from_kwargs_fallback, | |
| ) | |
| # Optional dynamic quantization for linear layers (CPU only) | |
| if self.quantize_enabled and TORCH_AVAILABLE and (not gpu_available): | |
| try: | |
| try: | |
| import torch.ao.quantization as tq | |
| quantize_dynamic = tq.quantize_dynamic # type: ignore | |
| except Exception: | |
| import torch.quantization as tq # type: ignore | |
| quantize_dynamic = tq.quantize_dynamic # type: ignore | |
| self.rtdetr_model = quantize_dynamic(self.rtdetr_model, {torch.nn.Linear}, dtype=torch.qint8) | |
| logger.info("🔻 Applied dynamic INT8 quantization to RT-DETR linear layers (CPU)") | |
| except Exception as qe: | |
| logger.warning(f"RT-DETR dynamic quantization skipped: {qe}") | |
| # Finalize | |
| self.rtdetr_model.eval() | |
| # Sanity check: ensure no parameter is left on 'meta' device | |
| try: | |
| for n, p in self.rtdetr_model.named_parameters(): | |
| dev = getattr(p, 'device', None) | |
| if dev is not None and getattr(dev, 'type', '') == 'meta': | |
| raise RuntimeError(f"Parameter {n} is on 'meta' device after load") | |
| except Exception as e: | |
| logger.error(f"RT-DETR load sanity check failed: {e}") | |
| self.rtdetr_loaded = False | |
| return False | |
| # Publish shared cache | |
| BubbleDetector._rtdetr_shared_model = self.rtdetr_model | |
| BubbleDetector._rtdetr_shared_processor = self.rtdetr_processor | |
| BubbleDetector._rtdetr_loaded = True | |
| BubbleDetector._rtdetr_repo_id = repo_id | |
| self.rtdetr_loaded = True | |
| # Save the model ID that was used | |
| self.config['rtdetr_loaded'] = True | |
| self.config['rtdetr_model_id'] = repo_id | |
| self._save_config() | |
| loc = 'CUDA' if gpu_available else 'CPU' | |
| logger.info(f"✅ RT-DETR model loaded successfully ({loc})") | |
| logger.info(" Classes: Empty bubbles, Text bubbles, Free text") | |
| # Auto-convert to ONNX for RT-DETR only if explicitly enabled | |
| if os.environ.get('AUTO_CONVERT_RTDETR_ONNX', 'false').lower() == 'true': | |
| onnx_path = os.path.join(self.cache_dir, 'rtdetr_comic.onnx') | |
| if self.convert_to_onnx('rtdetr', onnx_path): | |
| logger.info("🚀 RT-DETR converted to ONNX for faster inference") | |
| # Store ONNX path for later use | |
| self.config['rtdetr_onnx_path'] = onnx_path | |
| self._save_config() | |
| # Optionally quantize ONNX for reduced RAM | |
| if self.onnx_quantize_enabled: | |
| try: | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quant_path = os.path.splitext(onnx_path)[0] + ".int8.onnx" | |
| if not os.path.exists(quant_path) or os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true': | |
| logger.info("🔻 Quantizing RT-DETR ONNX to INT8 (dynamic)...") | |
| quantize_dynamic(model_input=onnx_path, model_output=quant_path, weight_type=QuantType.QInt8, op_types_to_quantize=['Conv', 'MatMul']) | |
| self.config['rtdetr_onnx_quantized_path'] = quant_path | |
| self._save_config() | |
| logger.info(f"✅ Quantized RT-DETR ONNX saved to: {quant_path}") | |
| except Exception as qe: | |
| logger.warning(f"ONNX quantization for RT-DETR skipped: {qe}") | |
| else: | |
| logger.info("ℹ️ Skipping RT-DETR ONNX export (converter not supported in current environment)") | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load RT-DETR: {e}") | |
| self.rtdetr_loaded = False | |
| return False | |
| def check_rtdetr_available(self, model_id: str = None) -> bool: | |
| """ | |
| Check if RT-DETR model is available (cached). | |
| Args: | |
| model_id: Optional HuggingFace model ID | |
| Returns: | |
| True if model is cached and available | |
| """ | |
| try: | |
| from pathlib import Path | |
| # Use provided model_id or default | |
| repo_id = model_id if model_id else self.rtdetr_repo | |
| # Check HuggingFace cache | |
| cache_dir = Path.home() / ".cache" / "huggingface" / "hub" | |
| model_id_formatted = repo_id.replace("/", "--") | |
| # Look for model folder | |
| model_folders = list(cache_dir.glob(f"models--{model_id_formatted}*")) | |
| if model_folders: | |
| for folder in model_folders: | |
| if (folder / "snapshots").exists(): | |
| snapshots = list((folder / "snapshots").iterdir()) | |
| if snapshots: | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def detect_bubbles(self, | |
| image_path: str, | |
| confidence: float = None, | |
| iou_threshold: float = None, | |
| max_detections: int = None, | |
| use_rtdetr: bool = None) -> List[Tuple[int, int, int, int]]: | |
| """ | |
| Detect speech bubbles in an image (backward compatible method). | |
| Args: | |
| image_path: Path to image file | |
| confidence: Minimum confidence threshold (0-1) | |
| iou_threshold: IOU threshold for NMS (0-1) | |
| max_detections: Maximum number of detections to return | |
| use_rtdetr: If True, use RT-DETR instead of YOLOv8 (if available) | |
| Returns: | |
| List of bubble bounding boxes as (x, y, width, height) tuples | |
| """ | |
| # Check for stop at start | |
| if self._check_stop(): | |
| self._log("⏹️ Bubble detection stopped by user", "warning") | |
| return [] | |
| # Decide which model to use | |
| if use_rtdetr is None: | |
| # Auto-select: prefer RT-DETR if available | |
| use_rtdetr = self.rtdetr_loaded | |
| if use_rtdetr: | |
| # Prefer ONNX backend if available, else PyTorch | |
| if getattr(self, 'rtdetr_onnx_loaded', False): | |
| results = self.detect_with_rtdetr_onnx( | |
| image_path=image_path, | |
| confidence=confidence, | |
| return_all_bubbles=True | |
| ) | |
| return results | |
| if self.rtdetr_loaded: | |
| results = self.detect_with_rtdetr( | |
| image_path=image_path, | |
| confidence=confidence, | |
| return_all_bubbles=True | |
| ) | |
| return results | |
| # Original YOLOv8 detection | |
| if not self.model_loaded: | |
| logger.error("No model loaded. Call load_model() first.") | |
| return [] | |
| # Use defaults if not specified | |
| confidence = confidence or self.default_confidence | |
| iou_threshold = iou_threshold or self.default_iou_threshold | |
| max_detections = max_detections or self.default_max_detections | |
| try: | |
| # Load image | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| logger.error(f"Failed to load image: {image_path}") | |
| return [] | |
| h, w = image.shape[:2] | |
| self._log(f"🔍 Detecting bubbles in {w}x{h} image") | |
| # Check for stop before inference | |
| if self._check_stop(): | |
| self._log("⏹️ Bubble detection inference stopped by user", "warning") | |
| return [] | |
| if self.model_type == 'yolo': | |
| # YOLOv8 inference | |
| results = self.model( | |
| image_path, | |
| conf=confidence, | |
| iou=iou_threshold, | |
| max_det=min(max_detections, getattr(self, 'max_det_yolo', max_detections)), | |
| verbose=False | |
| ) | |
| bubbles = [] | |
| for r in results: | |
| if r.boxes is not None: | |
| for box in r.boxes: | |
| # Get box coordinates | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() | |
| x, y = int(x1), int(y1) | |
| width = int(x2 - x1) | |
| height = int(y2 - y1) | |
| # Get confidence | |
| conf = float(box.conf[0]) | |
| # Add to list | |
| if len(bubbles) < max_detections: | |
| bubbles.append((x, y, width, height)) | |
| logger.debug(f" Bubble: ({x},{y}) {width}x{height} conf={conf:.2f}") | |
| elif self.model_type == 'onnx': | |
| # ONNX inference | |
| bubbles = self._detect_with_onnx(image, confidence, iou_threshold, max_detections) | |
| elif self.model_type == 'torch': | |
| # TorchScript inference | |
| bubbles = self._detect_with_torchscript(image, confidence, iou_threshold, max_detections) | |
| else: | |
| logger.error(f"Unknown model type: {self.model_type}") | |
| return [] | |
| logger.info(f"✅ Detected {len(bubbles)} speech bubbles") | |
| time.sleep(0.1) # Brief pause for stability | |
| logger.debug("💤 Bubble detection pausing briefly for stability") | |
| return bubbles | |
| except Exception as e: | |
| logger.error(f"Detection failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| return [] | |
| def detect_with_rtdetr(self, | |
| image_path: str = None, | |
| image: np.ndarray = None, | |
| confidence: float = None, | |
| return_all_bubbles: bool = False) -> Any: | |
| """ | |
| Detect using RT-DETR model with 3-class detection (PyTorch backend). | |
| Args: | |
| image_path: Path to image file | |
| image: Image array (BGR format) | |
| confidence: Confidence threshold | |
| return_all_bubbles: If True, return list of bubble boxes (for compatibility) | |
| If False, return dict with all classes | |
| Returns: | |
| List of bubbles if return_all_bubbles=True, else dict with classes | |
| """ | |
| # Check for stop at start | |
| if self._check_stop(): | |
| self._log("⏹️ RT-DETR detection stopped by user", "warning") | |
| if return_all_bubbles: | |
| return [] | |
| return {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| if not self.rtdetr_loaded: | |
| self._log("RT-DETR not loaded. Call load_rtdetr_model() first.", "warning") | |
| if return_all_bubbles: | |
| return [] | |
| return {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| confidence = confidence or self.default_confidence | |
| try: | |
| # Load image | |
| if image_path: | |
| image = cv2.imread(image_path) | |
| elif image is None: | |
| logger.error("No image provided") | |
| if return_all_bubbles: | |
| return [] | |
| return {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| # Convert BGR to RGB for PIL | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(image_rgb) | |
| # Prepare image for model | |
| inputs = self.rtdetr_processor(images=pil_image, return_tensors="pt") | |
| # Move inputs to the same device as the model and match model dtype for floating tensors | |
| model_device = next(self.rtdetr_model.parameters()).device if self.rtdetr_model is not None else (torch.device('cpu') if TORCH_AVAILABLE else 'cpu') | |
| model_dtype = None | |
| if TORCH_AVAILABLE and self.rtdetr_model is not None: | |
| try: | |
| model_dtype = next(self.rtdetr_model.parameters()).dtype | |
| except Exception: | |
| model_dtype = None | |
| if TORCH_AVAILABLE: | |
| new_inputs = {} | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.to(model_device) | |
| if model_dtype is not None and torch.is_floating_point(v): | |
| v = v.to(model_dtype) | |
| new_inputs[k] = v | |
| inputs = new_inputs | |
| # Run inference with autocast when model is half/bfloat16 on CUDA | |
| use_amp = TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == 'cuda' and (model_dtype in (torch.float16, torch.bfloat16)) | |
| autocast_dtype = model_dtype if model_dtype in (torch.float16, torch.bfloat16) else None | |
| with torch.no_grad(): | |
| if use_amp and autocast_dtype is not None: | |
| with torch.autocast('cuda', dtype=autocast_dtype): | |
| outputs = self.rtdetr_model(**inputs) | |
| else: | |
| outputs = self.rtdetr_model(**inputs) | |
| # Brief pause for stability after inference | |
| time.sleep(0.1) | |
| logger.debug("💤 RT-DETR inference pausing briefly for stability") | |
| # Post-process results | |
| target_sizes = torch.tensor([pil_image.size[::-1]]) if TORCH_AVAILABLE else None | |
| if TORCH_AVAILABLE and hasattr(model_device, 'type') and model_device.type == "cuda": | |
| target_sizes = target_sizes.to(model_device) | |
| results = self.rtdetr_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=target_sizes, | |
| threshold=confidence | |
| )[0] | |
| # Apply per-detector cap if configured | |
| cap = getattr(self, 'max_det_rtdetr', self.default_max_detections) | |
| if cap and len(results['boxes']) > cap: | |
| # Keep top-scoring first | |
| scores = results['scores'] | |
| top_idx = scores.topk(k=cap).indices if hasattr(scores, 'topk') else range(cap) | |
| results = { | |
| 'boxes': [results['boxes'][i] for i in top_idx], | |
| 'scores': [results['scores'][i] for i in top_idx], | |
| 'labels': [results['labels'][i] for i in top_idx] | |
| } | |
| logger.info(f"📊 RT-DETR found {len(results['boxes'])} detections above {confidence:.2f} confidence") | |
| # Apply NMS to remove duplicate detections | |
| # Group detections by class | |
| class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []} | |
| for box, score, label in zip(results['boxes'], results['scores'], results['labels']): | |
| x1, y1, x2, y2 = map(float, box.tolist()) | |
| label_id = label.item() | |
| if label_id in class_detections: | |
| class_detections[label_id].append((x1, y1, x2, y2, float(score.item()))) | |
| # Apply NMS per class to remove duplicates | |
| def compute_iou(box1, box2): | |
| """Compute IoU between two boxes (x1, y1, x2, y2)""" | |
| x1_1, y1_1, x2_1, y2_1 = box1[:4] | |
| x1_2, y1_2, x2_2, y2_2 = box2[:4] | |
| # Intersection | |
| x_left = max(x1_1, x1_2) | |
| y_top = max(y1_1, y1_2) | |
| x_right = min(x2_1, x2_2) | |
| y_bottom = min(y2_1, y2_2) | |
| if x_right < x_left or y_bottom < y_top: | |
| return 0.0 | |
| intersection = (x_right - x_left) * (y_bottom - y_top) | |
| # Union | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def apply_nms(boxes_with_scores, iou_threshold=0.45): | |
| """Apply Non-Maximum Suppression""" | |
| if not boxes_with_scores: | |
| return [] | |
| # Sort by score (descending) | |
| sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True) | |
| keep = [] | |
| while sorted_boxes: | |
| # Keep the box with highest score | |
| current = sorted_boxes.pop(0) | |
| keep.append(current) | |
| # Remove boxes with high IoU | |
| sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold] | |
| return keep | |
| # Apply NMS and organize by class | |
| detections = { | |
| 'bubbles': [], # Empty speech bubbles | |
| 'text_bubbles': [], # Bubbles with text | |
| 'text_free': [] # Text without bubbles | |
| } | |
| for class_id, boxes_list in class_detections.items(): | |
| nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold) | |
| for x1, y1, x2, y2, scr in nms_boxes: | |
| width = int(x2 - x1) | |
| height = int(y2 - y1) | |
| # Store as (x, y, width, height) to match YOLOv8 format | |
| bbox = (int(x1), int(y1), width, height) | |
| if class_id == self.CLASS_BUBBLE: | |
| detections['bubbles'].append(bbox) | |
| elif class_id == self.CLASS_TEXT_BUBBLE: | |
| detections['text_bubbles'].append(bbox) | |
| elif class_id == self.CLASS_TEXT_FREE: | |
| detections['text_free'].append(bbox) | |
| # Stop early if we hit the configured cap across all classes | |
| total_count = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free']) | |
| if total_count >= (self.config.get('manga_settings', {}).get('ocr', {}).get('bubble_max_detections', self.default_max_detections) if isinstance(self.config, dict) else self.default_max_detections): | |
| break | |
| # Log results | |
| total = len(detections['bubbles']) + len(detections['text_bubbles']) + len(detections['text_free']) | |
| logger.info(f"✅ RT-DETR detected {total} objects:") | |
| logger.info(f" - Empty bubbles: {len(detections['bubbles'])}") | |
| logger.info(f" - Text bubbles: {len(detections['text_bubbles'])}") | |
| logger.info(f" - Free text: {len(detections['text_free'])}") | |
| # Return format based on compatibility mode | |
| if return_all_bubbles: | |
| # Return all bubbles (empty + with text) for backward compatibility | |
| all_bubbles = detections['bubbles'] + detections['text_bubbles'] | |
| return all_bubbles | |
| else: | |
| return detections | |
| except Exception as e: | |
| logger.error(f"RT-DETR detection failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| if return_all_bubbles: | |
| return [] | |
| return {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| def detect_all_text_regions(self, image_path: str = None, image: np.ndarray = None) -> List[Tuple[int, int, int, int]]: | |
| """ | |
| Detect all text regions using RT-DETR (both in bubbles and free text). | |
| Returns: | |
| List of bounding boxes for all text regions | |
| """ | |
| if not self.rtdetr_loaded: | |
| logger.warning("RT-DETR required for text detection") | |
| return [] | |
| detections = self.detect_with_rtdetr(image_path=image_path, image=image, return_all_bubbles=False) | |
| # Combine text bubbles and free text | |
| all_text = detections['text_bubbles'] + detections['text_free'] | |
| logger.info(f"📝 Found {len(all_text)} text regions total") | |
| return all_text | |
| def _detect_with_onnx(self, image: np.ndarray, confidence: float, | |
| iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]: | |
| """Run detection using ONNX model.""" | |
| # Preprocess image | |
| img_size = 640 # Standard YOLOv8 input size | |
| img_resized = cv2.resize(image, (img_size, img_size)) | |
| img_norm = img_resized.astype(np.float32) / 255.0 | |
| img_transposed = np.transpose(img_norm, (2, 0, 1)) | |
| img_batch = np.expand_dims(img_transposed, axis=0) | |
| # Run inference | |
| input_name = self.onnx_session.get_inputs()[0].name | |
| outputs = self.onnx_session.run(None, {input_name: img_batch}) | |
| # Process outputs (YOLOv8 format) | |
| predictions = outputs[0][0] # Remove batch dimension | |
| # Filter by confidence and apply NMS | |
| bubbles = [] | |
| boxes = [] | |
| scores = [] | |
| for pred in predictions.T: # Transpose to get predictions per detection | |
| if len(pred) >= 5: | |
| x_center, y_center, width, height, obj_conf = pred[:5] | |
| if obj_conf >= confidence: | |
| # Convert to corner coordinates | |
| x1 = x_center - width / 2 | |
| y1 = y_center - height / 2 | |
| # Scale to original image size | |
| h, w = image.shape[:2] | |
| x1 = int(x1 * w / img_size) | |
| y1 = int(y1 * h / img_size) | |
| width = int(width * w / img_size) | |
| height = int(height * h / img_size) | |
| boxes.append([x1, y1, x1 + width, y1 + height]) | |
| scores.append(float(obj_conf)) | |
| # Apply NMS | |
| if boxes: | |
| indices = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold) | |
| if len(indices) > 0: | |
| indices = indices.flatten()[:max_detections] | |
| for i in indices: | |
| x1, y1, x2, y2 = boxes[i] | |
| bubbles.append((x1, y1, x2 - x1, y2 - y1)) | |
| return bubbles | |
| def _detect_with_torchscript(self, image: np.ndarray, confidence: float, | |
| iou_threshold: float, max_detections: int) -> List[Tuple[int, int, int, int]]: | |
| """Run detection using TorchScript model.""" | |
| # Similar to ONNX but using PyTorch tensors | |
| img_size = 640 | |
| img_resized = cv2.resize(image, (img_size, img_size)) | |
| img_norm = img_resized.astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0) | |
| if self.use_gpu: | |
| img_tensor = img_tensor.cuda() | |
| with torch.no_grad(): | |
| outputs = self.model(img_tensor) | |
| # Process outputs similar to ONNX | |
| # Implementation depends on exact model output format | |
| # This is a placeholder - adjust based on your model | |
| return [] | |
| def visualize_detections(self, image_path: str, bubbles: List[Tuple[int, int, int, int]] = None, | |
| output_path: str = None, use_rtdetr: bool = False) -> np.ndarray: | |
| """ | |
| Visualize detected bubbles on the image. | |
| Args: | |
| image_path: Path to original image | |
| bubbles: List of bubble bounding boxes (if None, will detect) | |
| output_path: Optional path to save visualization | |
| use_rtdetr: Use RT-DETR for visualization with class colors | |
| Returns: | |
| Image with drawn bounding boxes | |
| """ | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| logger.error(f"Failed to load image: {image_path}") | |
| return None | |
| vis_image = image.copy() | |
| if use_rtdetr and self.rtdetr_loaded: | |
| # RT-DETR visualization with different colors per class | |
| detections = self.detect_with_rtdetr(image_path=image_path, return_all_bubbles=False) | |
| # Colors for each class | |
| colors = { | |
| 'bubbles': (0, 255, 0), # Green for empty bubbles | |
| 'text_bubbles': (255, 0, 0), # Blue for text bubbles | |
| 'text_free': (0, 0, 255) # Red for free text | |
| } | |
| # Draw detections | |
| for class_name, bboxes in detections.items(): | |
| color = colors[class_name] | |
| for i, (x, y, w, h) in enumerate(bboxes): | |
| # Draw rectangle | |
| cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, 2) | |
| # Add label | |
| label = f"{class_name.replace('_', ' ').title()} {i+1}" | |
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(vis_image, (x, y - label_size[1] - 4), | |
| (x + label_size[0], y), color, -1) | |
| cv2.putText(vis_image, label, (x, y - 2), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| else: | |
| # Original YOLOv8 visualization | |
| if bubbles is None: | |
| bubbles = self.detect_bubbles(image_path) | |
| # Draw bounding boxes | |
| for i, (x, y, w, h) in enumerate(bubbles): | |
| # Draw rectangle | |
| color = (0, 255, 0) # Green | |
| thickness = 2 | |
| cv2.rectangle(vis_image, (x, y), (x + w, y + h), color, thickness) | |
| # Add label | |
| label = f"Bubble {i+1}" | |
| label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(vis_image, (x, y - label_size[1] - 4), (x + label_size[0], y), color, -1) | |
| cv2.putText(vis_image, label, (x, y - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| # Save if output path provided | |
| if output_path: | |
| cv2.imwrite(output_path, vis_image) | |
| logger.info(f"💾 Visualization saved to: {output_path}") | |
| return vis_image | |
| def convert_to_onnx(self, model_path: str, output_path: str = None) -> bool: | |
| """ | |
| Convert a YOLOv8 or RT-DETR model to ONNX format. | |
| Args: | |
| model_path: Path to model file or 'rtdetr' for loaded RT-DETR | |
| output_path: Path for ONNX output (auto-generated if None) | |
| Returns: | |
| True if conversion successful, False otherwise | |
| """ | |
| try: | |
| logger.info(f"🔄 Converting {model_path} to ONNX...") | |
| # Generate output path if not provided | |
| if output_path is None: | |
| if model_path == 'rtdetr' and self.rtdetr_loaded: | |
| base_name = 'rtdetr_comic' | |
| else: | |
| base_name = Path(model_path).stem | |
| output_path = os.path.join(self.cache_dir, f"{base_name}.onnx") | |
| # Check if already exists | |
| if os.path.exists(output_path) and not os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true': | |
| logger.info(f"✅ ONNX model already exists: {output_path}") | |
| return True | |
| # Handle RT-DETR conversion | |
| if model_path == 'rtdetr' and self.rtdetr_loaded: | |
| if not TORCH_AVAILABLE: | |
| logger.error("PyTorch required for RT-DETR ONNX conversion") | |
| return False | |
| # RT-DETR specific conversion | |
| self.rtdetr_model.eval() | |
| # Create dummy input (pixel values): BxCxHxW | |
| dummy_input = torch.randn(1, 3, 640, 640) | |
| if self.device == 'cuda': | |
| dummy_input = dummy_input.to('cuda') | |
| # Wrap the model to return only tensors (logits, pred_boxes) | |
| class _RTDetrExportWrapper(torch.nn.Module): | |
| def __init__(self, mdl): | |
| super().__init__() | |
| self.mdl = mdl | |
| def forward(self, images): | |
| out = self.mdl(pixel_values=images) | |
| # Handle dict/ModelOutput/tuple outputs | |
| logits = None | |
| boxes = None | |
| try: | |
| if isinstance(out, dict): | |
| logits = out.get('logits', None) | |
| boxes = out.get('pred_boxes', out.get('boxes', None)) | |
| else: | |
| logits = getattr(out, 'logits', None) | |
| boxes = getattr(out, 'pred_boxes', getattr(out, 'boxes', None)) | |
| except Exception: | |
| pass | |
| if (logits is None or boxes is None) and isinstance(out, (tuple, list)) and len(out) >= 2: | |
| logits, boxes = out[0], out[1] | |
| return logits, boxes | |
| wrapper = _RTDetrExportWrapper(self.rtdetr_model) | |
| if self.device == 'cuda': | |
| wrapper = wrapper.to('cuda') | |
| # Try PyTorch 2.x dynamo_export first (more tolerant of newer aten ops) | |
| try: | |
| success = False | |
| try: | |
| from torch.onnx import dynamo_export | |
| try: | |
| exp = dynamo_export(wrapper, dummy_input) | |
| except TypeError: | |
| # Older PyTorch dynamo_export may not support this calling convention | |
| exp = dynamo_export(wrapper, dummy_input) | |
| # exp may have save(); otherwise, it may expose model_proto | |
| try: | |
| exp.save(output_path) # type: ignore | |
| success = True | |
| except Exception: | |
| try: | |
| import onnx as _onnx | |
| _onnx.save(exp.model_proto, output_path) # type: ignore | |
| success = True | |
| except Exception as _se: | |
| logger.warning(f"dynamo_export produced model but could not save: {_se}") | |
| except Exception as de: | |
| logger.warning(f"dynamo_export failed; falling back to legacy exporter: {de}") | |
| if success: | |
| logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (dynamo_export)") | |
| return True | |
| except Exception as de2: | |
| logger.warning(f"dynamo_export path error: {de2}") | |
| # Legacy exporter with opset fallback | |
| last_err = None | |
| for opset in [19, 18, 17, 16, 15, 14, 13]: | |
| try: | |
| torch.onnx.export( | |
| wrapper, | |
| dummy_input, | |
| output_path, | |
| export_params=True, | |
| opset_version=opset, | |
| do_constant_folding=True, | |
| input_names=['pixel_values'], | |
| output_names=['logits', 'boxes'], | |
| dynamic_axes={ | |
| 'pixel_values': {0: 'batch', 2: 'height', 3: 'width'}, | |
| 'logits': {0: 'batch'}, | |
| 'boxes': {0: 'batch'} | |
| } | |
| ) | |
| logger.info(f"✅ RT-DETR ONNX saved to: {output_path} (opset {opset})") | |
| return True | |
| except Exception as _e: | |
| last_err = _e | |
| try: | |
| msg = str(_e) | |
| except Exception: | |
| msg = '' | |
| logger.warning(f"RT-DETR ONNX export failed at opset {opset}: {msg}") | |
| continue | |
| logger.error(f"All RT-DETR ONNX export attempts failed. Last error: {last_err}") | |
| return False | |
| # Handle YOLOv8 conversion - FIXED | |
| elif YOLO_AVAILABLE and os.path.exists(model_path): | |
| logger.info(f"Loading YOLOv8 model from: {model_path}") | |
| # Load model | |
| model = YOLO(model_path) | |
| # Export to ONNX - this returns the path to the exported model | |
| logger.info("Exporting to ONNX format...") | |
| exported_path = model.export(format='onnx', imgsz=640, simplify=True) | |
| # exported_path could be a string or Path object | |
| exported_path = str(exported_path) if exported_path else None | |
| if exported_path and os.path.exists(exported_path): | |
| # Move to desired location if different | |
| if exported_path != output_path: | |
| import shutil | |
| logger.info(f"Moving ONNX from {exported_path} to {output_path}") | |
| shutil.move(exported_path, output_path) | |
| logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}") | |
| return True | |
| else: | |
| # Fallback: check if it was created with expected name | |
| expected_onnx = model_path.replace('.pt', '.onnx') | |
| if os.path.exists(expected_onnx): | |
| if expected_onnx != output_path: | |
| import shutil | |
| shutil.move(expected_onnx, output_path) | |
| logger.info(f"✅ YOLOv8 ONNX saved to: {output_path}") | |
| return True | |
| else: | |
| logger.error(f"ONNX export failed - no output file found") | |
| return False | |
| else: | |
| logger.error(f"Cannot convert {model_path}: Model not found or dependencies missing") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Conversion failed: {e}") | |
| # Avoid noisy full stack trace in production logs; return False gracefully | |
| return False | |
| def batch_detect(self, image_paths: List[str], **kwargs) -> Dict[str, List[Tuple[int, int, int, int]]]: | |
| """ | |
| Detect bubbles in multiple images. | |
| Args: | |
| image_paths: List of image paths | |
| **kwargs: Detection parameters (confidence, iou_threshold, max_detections, use_rtdetr) | |
| Returns: | |
| Dictionary mapping image paths to bubble lists | |
| """ | |
| results = {} | |
| for i, image_path in enumerate(image_paths): | |
| logger.info(f"Processing image {i+1}/{len(image_paths)}: {os.path.basename(image_path)}") | |
| bubbles = self.detect_bubbles(image_path, **kwargs) | |
| results[image_path] = bubbles | |
| return results | |
| def unload(self, release_shared: bool = False): | |
| """Release model resources held by this detector instance. | |
| Args: | |
| release_shared: If True, also clear class-level shared RT-DETR caches. | |
| """ | |
| try: | |
| # Release instance-level models and sessions | |
| try: | |
| if getattr(self, 'onnx_session', None) is not None: | |
| self.onnx_session = None | |
| except Exception: | |
| pass | |
| try: | |
| if getattr(self, 'rtdetr_onnx_session', None) is not None: | |
| self.rtdetr_onnx_session = None | |
| except Exception: | |
| pass | |
| for attr in ['model', 'rtdetr_model', 'rtdetr_processor']: | |
| try: | |
| if hasattr(self, attr): | |
| setattr(self, attr, None) | |
| except Exception: | |
| pass | |
| for flag in ['model_loaded', 'rtdetr_loaded', 'rtdetr_onnx_loaded']: | |
| try: | |
| if hasattr(self, flag): | |
| setattr(self, flag, False) | |
| except Exception: | |
| pass | |
| # Optional: release shared caches | |
| if release_shared: | |
| try: | |
| BubbleDetector._rtdetr_shared_model = None | |
| BubbleDetector._rtdetr_shared_processor = None | |
| BubbleDetector._rtdetr_loaded = False | |
| except Exception: | |
| pass | |
| # Free CUDA cache and trigger GC | |
| try: | |
| if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| try: | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| except Exception: | |
| # Best-effort only | |
| pass | |
| def get_bubble_masks(self, image_path: str, bubbles: List[Tuple[int, int, int, int]]) -> np.ndarray: | |
| """ | |
| Create a mask image with bubble regions. | |
| Args: | |
| image_path: Path to original image | |
| bubbles: List of bubble bounding boxes | |
| Returns: | |
| Binary mask with bubble regions as white (255) | |
| """ | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| return None | |
| h, w = image.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| # Fill bubble regions | |
| for x, y, bw, bh in bubbles: | |
| cv2.rectangle(mask, (x, y), (x + bw, y + bh), 255, -1) | |
| return mask | |
| def filter_bubbles_by_size(self, bubbles: List[Tuple[int, int, int, int]], | |
| min_area: int = 100, | |
| max_area: int = None) -> List[Tuple[int, int, int, int]]: | |
| """ | |
| Filter bubbles by area. | |
| Args: | |
| bubbles: List of bubble bounding boxes | |
| min_area: Minimum area in pixels | |
| max_area: Maximum area in pixels (None for no limit) | |
| Returns: | |
| Filtered list of bubbles | |
| """ | |
| filtered = [] | |
| for x, y, w, h in bubbles: | |
| area = w * h | |
| if area >= min_area and (max_area is None or area <= max_area): | |
| filtered.append((x, y, w, h)) | |
| return filtered | |
| def merge_overlapping_bubbles(self, bubbles: List[Tuple[int, int, int, int]], | |
| overlap_threshold: float = 0.1) -> List[Tuple[int, int, int, int]]: | |
| """ | |
| Merge overlapping bubble detections. | |
| Args: | |
| bubbles: List of bubble bounding boxes | |
| overlap_threshold: Minimum overlap ratio to merge | |
| Returns: | |
| Merged list of bubbles | |
| """ | |
| if not bubbles: | |
| return [] | |
| # Convert to numpy array for easier manipulation | |
| boxes = np.array([(x, y, x+w, y+h) for x, y, w, h in bubbles]) | |
| merged = [] | |
| used = set() | |
| for i, box1 in enumerate(boxes): | |
| if i in used: | |
| continue | |
| # Start with current box | |
| x1, y1, x2, y2 = box1 | |
| # Check for overlaps with remaining boxes | |
| for j in range(i + 1, len(boxes)): | |
| if j in used: | |
| continue | |
| box2 = boxes[j] | |
| # Calculate intersection | |
| ix1 = max(x1, box2[0]) | |
| iy1 = max(y1, box2[1]) | |
| ix2 = min(x2, box2[2]) | |
| iy2 = min(y2, box2[3]) | |
| if ix1 < ix2 and iy1 < iy2: | |
| # Calculate overlap ratio | |
| intersection = (ix2 - ix1) * (iy2 - iy1) | |
| area1 = (x2 - x1) * (y2 - y1) | |
| area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| overlap = intersection / min(area1, area2) | |
| if overlap >= overlap_threshold: | |
| # Merge boxes | |
| x1 = min(x1, box2[0]) | |
| y1 = min(y1, box2[1]) | |
| x2 = max(x2, box2[2]) | |
| y2 = max(y2, box2[3]) | |
| used.add(j) | |
| merged.append((int(x1), int(y1), int(x2 - x1), int(y2 - y1))) | |
| return merged | |
| # ============================ | |
| # RT-DETR (ONNX) BACKEND | |
| # ============================ | |
| def load_rtdetr_onnx_model(self, model_id: str = None, force_reload: bool = False) -> bool: | |
| """ | |
| Load RT-DETR ONNX model using onnxruntime. Downloads detector.onnx and config.json | |
| from the provided Hugging Face repo if not already cached. | |
| """ | |
| if not ONNX_AVAILABLE: | |
| logger.error("ONNX Runtime not available for RT-DETR ONNX backend") | |
| return False | |
| try: | |
| # If singleton mode and already loaded, just attach shared session | |
| try: | |
| adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {} | |
| singleton = bool(adv.get('use_singleton_models', True)) | |
| except Exception: | |
| singleton = True | |
| if singleton and BubbleDetector._rtdetr_onnx_loaded and not force_reload and BubbleDetector._rtdetr_onnx_shared_session is not None: | |
| self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session | |
| self.rtdetr_onnx_loaded = True | |
| return True | |
| repo = model_id or self.rtdetr_onnx_repo | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except Exception as e: | |
| logger.error(f"huggingface-hub required to fetch RT-DETR ONNX: {e}") | |
| return False | |
| # Ensure local models dir (use configured cache_dir directly: e.g., 'models') | |
| cache_dir = self.cache_dir | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Download files into models/ and avoid symlinks so the file is visible there | |
| try: | |
| _ = hf_hub_download(repo_id=repo, filename='config.json', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False) | |
| except Exception: | |
| pass | |
| onnx_fp = hf_hub_download(repo_id=repo, filename='detector.onnx', cache_dir=cache_dir, local_dir=cache_dir, local_dir_use_symlinks=False) | |
| BubbleDetector._rtdetr_onnx_model_path = onnx_fp | |
| # Pick providers: prefer CUDA if available; otherwise CPU. Do NOT use DML. | |
| providers = ['CPUExecutionProvider'] | |
| try: | |
| avail = ort.get_available_providers() if ONNX_AVAILABLE else [] | |
| if 'CUDAExecutionProvider' in avail: | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| except Exception: | |
| pass | |
| # Session options with reduced memory arena and optional thread limiting in singleton mode | |
| so = ort.SessionOptions() | |
| try: | |
| so.enable_mem_pattern = False | |
| so.enable_cpu_mem_arena = False | |
| except Exception: | |
| pass | |
| # If singleton models mode is enabled in config, limit ORT threading to reduce CPU spikes | |
| try: | |
| adv = (self.config or {}).get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {} | |
| if bool(adv.get('use_singleton_models', True)): | |
| so.intra_op_num_threads = 1 | |
| so.inter_op_num_threads = 1 | |
| try: | |
| so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| except Exception: | |
| pass | |
| try: | |
| so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # Create session (serialize creation in singleton mode to avoid device storms) | |
| if singleton: | |
| with BubbleDetector._rtdetr_onnx_init_lock: | |
| # Re-check after acquiring lock | |
| if BubbleDetector._rtdetr_onnx_loaded and BubbleDetector._rtdetr_onnx_shared_session is not None and not force_reload: | |
| self.rtdetr_onnx_session = BubbleDetector._rtdetr_onnx_shared_session | |
| self.rtdetr_onnx_loaded = True | |
| return True | |
| sess = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so) | |
| BubbleDetector._rtdetr_onnx_shared_session = sess | |
| BubbleDetector._rtdetr_onnx_loaded = True | |
| BubbleDetector._rtdetr_onnx_providers = providers | |
| self.rtdetr_onnx_session = sess | |
| self.rtdetr_onnx_loaded = True | |
| else: | |
| self.rtdetr_onnx_session = ort.InferenceSession(onnx_fp, providers=providers, sess_options=so) | |
| self.rtdetr_onnx_loaded = True | |
| logger.info("✅ RT-DETR (ONNX) model ready") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load RT-DETR ONNX: {e}") | |
| self.rtdetr_onnx_session = None | |
| self.rtdetr_onnx_loaded = False | |
| return False | |
| def detect_with_rtdetr_onnx(self, | |
| image_path: str = None, | |
| image: np.ndarray = None, | |
| confidence: float = 0.3, | |
| return_all_bubbles: bool = False) -> Any: | |
| """Detect using RT-DETR ONNX backend. | |
| Returns bubbles list if return_all_bubbles else dict by classes similar to PyTorch path. | |
| """ | |
| if not self.rtdetr_onnx_loaded or self.rtdetr_onnx_session is None: | |
| logger.warning("RT-DETR ONNX not loaded") | |
| return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| try: | |
| # Acquire image | |
| if image_path is not None: | |
| import cv2 | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| raise RuntimeError(f"Failed to read image: {image_path}") | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| else: | |
| if image is None: | |
| raise RuntimeError("No image provided") | |
| # Assume image is BGR np.ndarray if from OpenCV | |
| try: | |
| import cv2 | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| except Exception: | |
| image_rgb = image | |
| # To PIL then resize 640x640 as in reference | |
| from PIL import Image as _PILImage | |
| pil_image = _PILImage.fromarray(image_rgb) | |
| im_resized = pil_image.resize((640, 640)) | |
| arr = np.asarray(im_resized, dtype=np.float32) / 255.0 | |
| arr = np.transpose(arr, (2, 0, 1)) # (3,H,W) | |
| im_data = arr[np.newaxis, ...] | |
| w, h = pil_image.size | |
| orig_size = np.array([[w, h]], dtype=np.int64) | |
| # Run with a concurrency guard to prevent device hangs and limit memory usage | |
| # Apply semaphore for ALL providers (not just DML) to control concurrency | |
| providers = BubbleDetector._rtdetr_onnx_providers or [] | |
| def _do_run(session): | |
| return session.run(None, { | |
| 'images': im_data, | |
| 'orig_target_sizes': orig_size | |
| }) | |
| # Always use semaphore to limit concurrent RT-DETR calls | |
| acquired = False | |
| try: | |
| BubbleDetector._rtdetr_onnx_sema.acquire() | |
| acquired = True | |
| # Special DML error handling | |
| if 'DmlExecutionProvider' in providers: | |
| try: | |
| outputs = _do_run(self.rtdetr_onnx_session) | |
| except Exception as dml_err: | |
| msg = str(dml_err) | |
| if '887A0005' in msg or '887A0006' in msg or 'Dml' in msg: | |
| # Rebuild CPU session and retry once | |
| try: | |
| base_path = BubbleDetector._rtdetr_onnx_model_path | |
| if base_path: | |
| so = ort.SessionOptions() | |
| so.enable_mem_pattern = False | |
| so.enable_cpu_mem_arena = False | |
| cpu_providers = ['CPUExecutionProvider'] | |
| # Serialize rebuild | |
| with BubbleDetector._rtdetr_onnx_init_lock: | |
| sess = ort.InferenceSession(base_path, providers=cpu_providers, sess_options=so) | |
| BubbleDetector._rtdetr_onnx_shared_session = sess | |
| BubbleDetector._rtdetr_onnx_providers = cpu_providers | |
| self.rtdetr_onnx_session = sess | |
| outputs = _do_run(self.rtdetr_onnx_session) | |
| else: | |
| raise | |
| except Exception: | |
| raise | |
| else: | |
| raise | |
| else: | |
| # Non-DML providers - just run directly | |
| outputs = _do_run(self.rtdetr_onnx_session) | |
| finally: | |
| if acquired: | |
| try: | |
| BubbleDetector._rtdetr_onnx_sema.release() | |
| except Exception: | |
| pass | |
| # outputs expected: labels, boxes, scores | |
| labels, boxes, scores = outputs[:3] | |
| if labels.ndim == 2 and labels.shape[0] == 1: | |
| labels = labels[0] | |
| if scores.ndim == 2 and scores.shape[0] == 1: | |
| scores = scores[0] | |
| if boxes.ndim == 3 and boxes.shape[0] == 1: | |
| boxes = boxes[0] | |
| # Apply NMS to remove duplicate detections | |
| # Group detections by class and apply NMS per class | |
| class_detections = {self.CLASS_BUBBLE: [], self.CLASS_TEXT_BUBBLE: [], self.CLASS_TEXT_FREE: []} | |
| for lab, box, scr in zip(labels, boxes, scores): | |
| if float(scr) < float(confidence): | |
| continue | |
| label_id = int(lab) | |
| if label_id in class_detections: | |
| x1, y1, x2, y2 = map(float, box) | |
| class_detections[label_id].append((x1, y1, x2, y2, float(scr))) | |
| # Apply NMS per class to remove duplicates | |
| def compute_iou(box1, box2): | |
| """Compute IoU between two boxes (x1, y1, x2, y2)""" | |
| x1_1, y1_1, x2_1, y2_1 = box1[:4] | |
| x1_2, y1_2, x2_2, y2_2 = box2[:4] | |
| # Intersection | |
| x_left = max(x1_1, x1_2) | |
| y_top = max(y1_1, y1_2) | |
| x_right = min(x2_1, x2_2) | |
| y_bottom = min(y2_1, y2_2) | |
| if x_right < x_left or y_bottom < y_top: | |
| return 0.0 | |
| intersection = (x_right - x_left) * (y_bottom - y_top) | |
| # Union | |
| area1 = (x2_1 - x1_1) * (y2_1 - y1_1) | |
| area2 = (x2_2 - x1_2) * (y2_2 - y1_2) | |
| union = area1 + area2 - intersection | |
| return intersection / union if union > 0 else 0.0 | |
| def apply_nms(boxes_with_scores, iou_threshold=0.45): | |
| """Apply Non-Maximum Suppression""" | |
| if not boxes_with_scores: | |
| return [] | |
| # Sort by score (descending) | |
| sorted_boxes = sorted(boxes_with_scores, key=lambda x: x[4], reverse=True) | |
| keep = [] | |
| while sorted_boxes: | |
| # Keep the box with highest score | |
| current = sorted_boxes.pop(0) | |
| keep.append(current) | |
| # Remove boxes with high IoU | |
| sorted_boxes = [box for box in sorted_boxes if compute_iou(current, box) < iou_threshold] | |
| return keep | |
| # Apply NMS and build final detections | |
| detections = {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| bubbles_all = [] | |
| for class_id, boxes_list in class_detections.items(): | |
| nms_boxes = apply_nms(boxes_list, iou_threshold=self.default_iou_threshold) | |
| for x1, y1, x2, y2, scr in nms_boxes: | |
| bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1)) | |
| if class_id == self.CLASS_BUBBLE: | |
| detections['bubbles'].append(bbox) | |
| bubbles_all.append(bbox) | |
| elif class_id == self.CLASS_TEXT_BUBBLE: | |
| detections['text_bubbles'].append(bbox) | |
| bubbles_all.append(bbox) | |
| elif class_id == self.CLASS_TEXT_FREE: | |
| detections['text_free'].append(bbox) | |
| return bubbles_all if return_all_bubbles else detections | |
| except Exception as e: | |
| logger.error(f"RT-DETR ONNX detection failed: {e}") | |
| return [] if return_all_bubbles else {'bubbles': [], 'text_bubbles': [], 'text_free': []} | |
| # Standalone utility functions | |
| def download_model_from_huggingface(repo_id: str = "ogkalu/comic-speech-bubble-detector-yolov8m", | |
| filename: str = "comic-speech-bubble-detector-yolov8m.pt", | |
| cache_dir: str = "models") -> str: | |
| """ | |
| Download model from Hugging Face Hub. | |
| Args: | |
| repo_id: Hugging Face repository ID | |
| filename: Model filename in the repository | |
| cache_dir: Local directory to cache the model | |
| Returns: | |
| Path to downloaded model file | |
| """ | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info(f"📥 Downloading {filename} from {repo_id}...") | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| local_dir=cache_dir | |
| ) | |
| logger.info(f"✅ Model downloaded to: {model_path}") | |
| return model_path | |
| except ImportError: | |
| logger.error("huggingface-hub package required. Install with: pip install huggingface-hub") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Download failed: {e}") | |
| return None | |
| def download_rtdetr_model(cache_dir: str = "models") -> bool: | |
| """ | |
| Download RT-DETR model for advanced detection. | |
| Args: | |
| cache_dir: Directory to cache the model | |
| Returns: | |
| True if successful | |
| """ | |
| if not TRANSFORMERS_AVAILABLE: | |
| logger.error("Transformers required. Install with: pip install transformers") | |
| return False | |
| try: | |
| logger.info("📥 Downloading RT-DETR model...") | |
| from transformers import RTDetrForObjectDetection, RTDetrImageProcessor | |
| # This will download and cache the model | |
| processor = RTDetrImageProcessor.from_pretrained( | |
| "ogkalu/comic-text-and-bubble-detector", | |
| cache_dir=cache_dir | |
| ) | |
| model = RTDetrForObjectDetection.from_pretrained( | |
| "ogkalu/comic-text-and-bubble-detector", | |
| cache_dir=cache_dir | |
| ) | |
| logger.info("✅ RT-DETR model downloaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Download failed: {e}") | |
| return False | |
| # Example usage and testing | |
| if __name__ == "__main__": | |
| import sys | |
| # Create detector | |
| detector = BubbleDetector() | |
| if len(sys.argv) > 1: | |
| if sys.argv[1] == "download": | |
| # Download model from Hugging Face | |
| model_path = download_model_from_huggingface() | |
| if model_path: | |
| print(f"YOLOv8 model downloaded to: {model_path}") | |
| # Also download RT-DETR | |
| if download_rtdetr_model(): | |
| print("RT-DETR model downloaded") | |
| elif sys.argv[1] == "detect" and len(sys.argv) > 3: | |
| # Detect bubbles in an image | |
| model_path = sys.argv[2] | |
| image_path = sys.argv[3] | |
| # Load appropriate model | |
| if 'rtdetr' in model_path.lower(): | |
| if detector.load_rtdetr_model(): | |
| # Use RT-DETR | |
| results = detector.detect_with_rtdetr(image_path) | |
| print(f"RT-DETR Detection:") | |
| print(f" Empty bubbles: {len(results['bubbles'])}") | |
| print(f" Text bubbles: {len(results['text_bubbles'])}") | |
| print(f" Free text: {len(results['text_free'])}") | |
| else: | |
| if detector.load_model(model_path): | |
| bubbles = detector.detect_bubbles(image_path, confidence=0.5) | |
| print(f"YOLOv8 detected {len(bubbles)} bubbles:") | |
| for i, (x, y, w, h) in enumerate(bubbles): | |
| print(f" Bubble {i+1}: position=({x},{y}) size=({w}x{h})") | |
| # Optionally visualize | |
| if len(sys.argv) > 4: | |
| output_path = sys.argv[4] | |
| detector.visualize_detections(image_path, output_path=output_path, | |
| use_rtdetr='rtdetr' in model_path.lower()) | |
| elif sys.argv[1] == "test-both" and len(sys.argv) > 2: | |
| # Test both models | |
| image_path = sys.argv[2] | |
| # Load YOLOv8 | |
| yolo_path = "models/comic-speech-bubble-detector-yolov8m.pt" | |
| if os.path.exists(yolo_path): | |
| detector.load_model(yolo_path) | |
| yolo_bubbles = detector.detect_bubbles(image_path, use_rtdetr=False) | |
| print(f"YOLOv8: {len(yolo_bubbles)} bubbles") | |
| # Load RT-DETR | |
| if detector.load_rtdetr_model(): | |
| rtdetr_bubbles = detector.detect_bubbles(image_path, use_rtdetr=True) | |
| print(f"RT-DETR: {len(rtdetr_bubbles)} bubbles") | |
| else: | |
| print("Usage:") | |
| print(" python bubble_detector.py download") | |
| print(" python bubble_detector.py detect <model_path> <image_path> [output_path]") | |
| print(" python bubble_detector.py test-both <image_path>") | |
| else: | |
| print("Bubble Detector Module (YOLOv8 + RT-DETR)") | |
| print("Usage:") | |
| print(" python bubble_detector.py download") | |
| print(" python bubble_detector.py detect <model_path> <image_path> [output_path]") | |
| print(" python bubble_detector.py test-both <image_path>") | |