Spaces:
Running
Running
| """ | |
| Local inpainting implementation - COMPATIBLE VERSION WITH JIT SUPPORT | |
| Maintains full backward compatibility while adding proper JIT model support | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import numpy as np | |
| import cv2 | |
| from typing import Optional, List, Tuple, Dict, Any | |
| import logging | |
| import traceback | |
| import re | |
| import hashlib | |
| import urllib.request | |
| from pathlib import Path | |
| import threading | |
| import time | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check if we're running in a frozen environment | |
| IS_FROZEN = getattr(sys, 'frozen', False) | |
| if IS_FROZEN: | |
| MEIPASS = sys._MEIPASS | |
| os.environ['TORCH_HOME'] = MEIPASS | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.join(MEIPASS, 'transformers') | |
| os.environ['HF_HOME'] = os.path.join(MEIPASS, 'huggingface') | |
| logger.info(f"Running in frozen environment: {MEIPASS}") | |
| # Environment variables for ONNX | |
| ONNX_CACHE_DIR = os.environ.get('ONNX_CACHE_DIR', 'models') | |
| AUTO_CONVERT_TO_ONNX = os.environ.get('AUTO_CONVERT_TO_ONNX', 'false').lower() == 'true' | |
| SKIP_ONNX_FOR_CKPT = os.environ.get('SKIP_ONNX_FOR_CKPT', 'true').lower() == 'true' | |
| FORCE_ONNX_REBUILD = os.environ.get('FORCE_ONNX_REBUILD', 'false').lower() == 'true' | |
| CACHE_DIR = os.environ.get('MODEL_CACHE_DIR', os.path.expanduser('~/.cache/inpainting')) | |
| # Modified import handling for frozen environment | |
| TORCH_AVAILABLE = False | |
| torch = None | |
| nn = None | |
| F = None | |
| BaseModel = object | |
| try: | |
| import onnxruntime_extensions | |
| ONNX_EXTENSIONS_AVAILABLE = True | |
| except ImportError: | |
| ONNX_EXTENSIONS_AVAILABLE = False | |
| logger.info("ONNX Runtime Extensions not available - FFT models won't work in ONNX") | |
| if IS_FROZEN: | |
| # In frozen environment, try harder to import | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| TORCH_AVAILABLE = True | |
| BaseModel = nn.Module | |
| logger.info("✓ PyTorch loaded in frozen environment") | |
| except Exception as e: | |
| logger.error(f"PyTorch not available in frozen environment: {e}") | |
| logger.error("❌ Inpainting disabled - PyTorch is required") | |
| else: | |
| # Normal environment | |
| try: | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| TORCH_AVAILABLE = True | |
| BaseModel = nn.Module | |
| except ImportError: | |
| TORCH_AVAILABLE = False | |
| logger.error("PyTorch not available - inpainting disabled") | |
| # Configure ORT memory behavior before importing | |
| try: | |
| os.environ.setdefault('ORT_DISABLE_MEMORY_ARENA', '1') | |
| except Exception: | |
| pass | |
| # ONNX Runtime - usually works well in frozen environments | |
| ONNX_AVAILABLE = False | |
| try: | |
| import onnx | |
| import onnxruntime as ort | |
| ONNX_AVAILABLE = True | |
| logger.info("✓ ONNX Runtime available") | |
| except ImportError: | |
| ONNX_AVAILABLE = False | |
| logger.warning("ONNX Runtime not available") | |
| # Bubble detector - optional | |
| BUBBLE_DETECTOR_AVAILABLE = False | |
| try: | |
| from bubble_detector import BubbleDetector | |
| BUBBLE_DETECTOR_AVAILABLE = True | |
| logger.info("✓ Bubble detector available") | |
| except ImportError: | |
| logger.info("Bubble detector not available - basic inpainting will be used") | |
| # JIT Model URLs (for automatic download) | |
| LAMA_JIT_MODELS = { | |
| 'lama': { | |
| 'url': 'https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt', | |
| 'md5': 'e3aa4aaa15225a33ec84f9f4bc47e500', | |
| 'name': 'BigLama' | |
| }, | |
| 'anime': { | |
| 'url': 'https://github.com/Sanster/models/releases/download/AnimeMangaInpainting/anime-manga-big-lama.pt', | |
| 'md5': '29f284f36a0a510bcacf39ecf4c4d54f', | |
| 'name': 'Anime-Manga BigLama' | |
| }, | |
| 'lama_official': { | |
| 'url': 'https://github.com/Sanster/models/releases/download/lama/lama.pt', | |
| 'md5': '4b1a1de53b7a74e0ff9dd622834e8e1e', | |
| 'name': 'LaMa Official' | |
| }, | |
| 'aot': { | |
| 'url': 'https://huggingface.co/ogkalu/aot-inpainting-jit/resolve/main/aot_traced.pt', | |
| 'md5': '5ecdac562c1d56267468fc4fbf80db27', | |
| 'name': 'AOT GAN' | |
| }, | |
| 'aot_onnx': { | |
| 'url': 'https://huggingface.co/ogkalu/aot-inpainting/resolve/main/aot.onnx', | |
| 'md5': 'ffd39ed8e2a275869d3b49180d030f0d8b8b9c2c20ed0e099ecd207201f0eada', | |
| 'name': 'AOT ONNX (Fast)', | |
| 'is_onnx': True | |
| }, | |
| 'lama_onnx': { | |
| 'url': 'https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx', | |
| 'md5': None, # Add MD5 if you want to verify | |
| 'name': 'LaMa ONNX (Carve)', | |
| 'is_onnx': True # Flag to indicate this is ONNX, not JIT | |
| }, | |
| 'anime_onnx': { | |
| 'url': 'https://huggingface.co/ogkalu/lama-manga-onnx-dynamic/resolve/main/lama-manga-dynamic.onnx', | |
| 'md5': 'de31ffa5ba26916b8ea35319f6c12151ff9654d4261bccf0583a69bb095315f9', | |
| 'name': 'Anime/Manga ONNX (Dynamic)', | |
| 'is_onnx': True # Flag to indicate this is ONNX | |
| } | |
| } | |
| def norm_img(img: np.ndarray) -> np.ndarray: | |
| """Normalize image to [0, 1] range""" | |
| if img.dtype == np.uint8: | |
| return img.astype(np.float32) / 255.0 | |
| return img | |
| def get_cache_path_by_url(url: str) -> str: | |
| """Get cache path for a model URL""" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| filename = os.path.basename(url) | |
| return os.path.join(CACHE_DIR, filename) | |
| def download_model(url: str, md5: str = None) -> str: | |
| """Download model if not cached""" | |
| cache_path = get_cache_path_by_url(url) | |
| if os.path.exists(cache_path): | |
| logger.info(f"✅ Model already cached: {cache_path}") | |
| return cache_path | |
| logger.info(f"📥 Downloading model from {url}") | |
| try: | |
| urllib.request.urlretrieve(url, cache_path) | |
| logger.info(f"✅ Model downloaded to: {cache_path}") | |
| return cache_path | |
| except Exception as e: | |
| logger.error(f"❌ Download failed: {e}") | |
| if os.path.exists(cache_path): | |
| os.remove(cache_path) | |
| raise | |
| class FFCInpaintModel(BaseModel): # Use BaseModel instead of nn.Module | |
| """FFC model for LaMa inpainting - for checkpoint compatibility""" | |
| def __init__(self): | |
| if not TORCH_AVAILABLE: | |
| # Initialize as a simple object when PyTorch is not available | |
| super().__init__() | |
| logger.warning("PyTorch not available - FFCInpaintModel initialized as placeholder") | |
| self._pytorch_available = False | |
| return | |
| # Additional safety check for nn being None | |
| if nn is None: | |
| super().__init__() | |
| logger.error("Neural network modules not available - FFCInpaintModel disabled") | |
| self._pytorch_available = False | |
| return | |
| super().__init__() | |
| self._pytorch_available = True | |
| try: | |
| # Encoder | |
| self.model_1_ffc_convl2l = nn.Conv2d(4, 64, 7, padding=3) | |
| self.model_1_bn_l = nn.BatchNorm2d(64) | |
| self.model_2_ffc_convl2l = nn.Conv2d(64, 128, 3, padding=1) | |
| self.model_2_bn_l = nn.BatchNorm2d(128) | |
| self.model_3_ffc_convl2l = nn.Conv2d(128, 256, 3, padding=1) | |
| self.model_3_bn_l = nn.BatchNorm2d(256) | |
| self.model_4_ffc_convl2l = nn.Conv2d(256, 128, 3, padding=1) | |
| self.model_4_ffc_convl2g = nn.Conv2d(256, 384, 3, padding=1) | |
| self.model_4_bn_l = nn.BatchNorm2d(128) | |
| self.model_4_bn_g = nn.BatchNorm2d(384) | |
| # FFC blocks | |
| for i in range(5, 23): | |
| for conv_type in ['conv1', 'conv2']: | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convl2l', nn.Conv2d(128, 128, 3, padding=1)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convl2g', nn.Conv2d(128, 384, 3, padding=1)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2l', nn.Conv2d(384, 128, 3, padding=1)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2g_conv1_0', nn.Conv2d(384, 192, 1)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2g_conv1_1', nn.BatchNorm2d(192)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2g_fu_conv_layer', nn.Conv2d(384, 384, 1)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2g_fu_bn', nn.BatchNorm2d(384)) | |
| setattr(self, f'model_{i}_{conv_type}_ffc_convg2g_conv2', nn.Conv2d(192, 384, 1)) | |
| setattr(self, f'model_{i}_{conv_type}_bn_l', nn.BatchNorm2d(128)) | |
| setattr(self, f'model_{i}_{conv_type}_bn_g', nn.BatchNorm2d(384)) | |
| # Decoder | |
| self.model_24 = nn.Conv2d(512, 256, 3, padding=1) | |
| self.model_25 = nn.BatchNorm2d(256) | |
| self.model_27 = nn.Conv2d(256, 128, 3, padding=1) | |
| self.model_28 = nn.BatchNorm2d(128) | |
| self.model_30 = nn.Conv2d(128, 64, 3, padding=1) | |
| self.model_31 = nn.BatchNorm2d(64) | |
| self.model_34 = nn.Conv2d(64, 3, 7, padding=3) | |
| # Activation functions | |
| self.relu = nn.ReLU(inplace=True) | |
| self.tanh = nn.Tanh() | |
| logger.info("FFCInpaintModel initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize FFCInpaintModel: {e}") | |
| self._pytorch_available = False | |
| raise | |
| def forward(self, image, mask): | |
| if not self._pytorch_available: | |
| logger.error("PyTorch not available for forward pass") | |
| raise RuntimeError("PyTorch not available for forward pass") | |
| if not TORCH_AVAILABLE or torch is None: | |
| logger.error("PyTorch not available for forward pass") | |
| raise RuntimeError("PyTorch not available for forward pass") | |
| try: | |
| x = torch.cat([image, mask], dim=1) | |
| x = self.relu(self.model_1_bn_l(self.model_1_ffc_convl2l(x))) | |
| x = self.relu(self.model_2_bn_l(self.model_2_ffc_convl2l(x))) | |
| x = self.relu(self.model_3_bn_l(self.model_3_ffc_convl2l(x))) | |
| x_l = self.relu(self.model_4_bn_l(self.model_4_ffc_convl2l(x))) | |
| x_g = self.relu(self.model_4_bn_g(self.model_4_ffc_convl2g(x))) | |
| for i in range(5, 23): | |
| identity_l, identity_g = x_l, x_g | |
| x_l, x_g = self._ffc_block(x_l, x_g, i, 'conv1') | |
| x_l, x_g = self._ffc_block(x_l, x_g, i, 'conv2') | |
| x_l = x_l + identity_l | |
| x_g = x_g + identity_g | |
| x = torch.cat([x_l, x_g], dim=1) | |
| x = self.relu(self.model_25(self.model_24(x))) | |
| x = self.relu(self.model_28(self.model_27(x))) | |
| x = self.relu(self.model_31(self.model_30(x))) | |
| x = self.tanh(self.model_34(x)) | |
| mask_3ch = mask.repeat(1, 3, 1, 1) | |
| return x * mask_3ch + image * (1 - mask_3ch) | |
| except Exception as e: | |
| logger.error(f"Forward pass failed: {e}") | |
| raise RuntimeError(f"Forward pass failed: {e}") | |
| def _ffc_block(self, x_l, x_g, idx, conv_type): | |
| if not self._pytorch_available: | |
| raise RuntimeError("PyTorch not available for FFC block") | |
| if not TORCH_AVAILABLE: | |
| raise RuntimeError("PyTorch not available for FFC block") | |
| try: | |
| convl2l = getattr(self, f'model_{idx}_{conv_type}_ffc_convl2l') | |
| convl2g = getattr(self, f'model_{idx}_{conv_type}_ffc_convl2g') | |
| convg2l = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2l') | |
| convg2g_conv1 = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2g_conv1_0') | |
| convg2g_bn1 = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2g_conv1_1') | |
| fu_conv = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2g_fu_conv_layer') | |
| fu_bn = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2g_fu_bn') | |
| convg2g_conv2 = getattr(self, f'model_{idx}_{conv_type}_ffc_convg2g_conv2') | |
| bn_l = getattr(self, f'model_{idx}_{conv_type}_bn_l') | |
| bn_g = getattr(self, f'model_{idx}_{conv_type}_bn_g') | |
| out_xl = convl2l(x_l) + convg2l(x_g) | |
| out_xg = convl2g(x_l) + convg2g_conv2(self.relu(convg2g_bn1(convg2g_conv1(x_g)))) + self.relu(fu_bn(fu_conv(x_g))) | |
| return self.relu(bn_l(out_xl)), self.relu(bn_g(out_xg)) | |
| except Exception as e: | |
| logger.error(f"FFC block failed: {e}") | |
| raise RuntimeError(f"FFC block failed: {e}") | |
| class LocalInpainter: | |
| """Local inpainter with full backward compatibility""" | |
| # MAINTAIN ORIGINAL SUPPORTED_METHODS for compatibility | |
| SUPPORTED_METHODS = { | |
| 'lama': ('LaMa Inpainting', FFCInpaintModel), | |
| 'mat': ('MAT Inpainting', FFCInpaintModel), | |
| 'aot': ('AOT GAN Inpainting', FFCInpaintModel), | |
| 'aot_onnx': ('AOT ONNX (Fast)', FFCInpaintModel), | |
| 'sd': ('Stable Diffusion Inpainting', FFCInpaintModel), | |
| 'anime': ('Anime/Manga Inpainting', FFCInpaintModel), | |
| 'anime_onnx': ('Anime ONNX (Fast)', FFCInpaintModel), | |
| 'lama_official': ('Official LaMa', FFCInpaintModel), | |
| } | |
| def __init__(self, config_path="config.json"): | |
| # Set thread limits early if environment indicates single-threaded mode | |
| try: | |
| if os.environ.get('OMP_NUM_THREADS') == '1': | |
| # Already in single-threaded mode, ensure it's applied to this process | |
| # Check if torch is available at module level before trying to use it | |
| if TORCH_AVAILABLE and torch is not None: | |
| try: | |
| torch.set_num_threads(1) | |
| except (RuntimeError, AttributeError): | |
| pass | |
| try: | |
| import cv2 | |
| cv2.setNumThreads(1) | |
| except (ImportError, AttributeError): | |
| pass | |
| except Exception: | |
| pass | |
| self.config_path = config_path | |
| self.config = self._load_config() | |
| self.model = None | |
| self.model_loaded = False | |
| self.current_method = None | |
| self.use_opencv_fallback = False # FORCED DISABLED - No OpenCV fallback allowed | |
| self.onnx_session = None | |
| self.use_onnx = False | |
| self.is_jit_model = False | |
| self.pad_mod = 8 | |
| # Default tiling settings - OFF by default for most models | |
| self.tiling_enabled = False | |
| self.tile_size = 512 | |
| self.tile_overlap = 64 | |
| # ONNX-specific settings | |
| self.onnx_model_loaded = False | |
| self.onnx_input_size = None # Will be detected from model | |
| # Quantization diagnostics flags | |
| self.onnx_quantize_applied = False | |
| self.torch_quantize_applied = False | |
| # Bubble detection | |
| self.bubble_detector = None | |
| self.bubble_model_loaded = False | |
| # Create directories | |
| os.makedirs(ONNX_CACHE_DIR, exist_ok=True) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| logger.info(f"📁 ONNX cache directory: {ONNX_CACHE_DIR}") | |
| logger.info(f" Contents: {os.listdir(ONNX_CACHE_DIR) if os.path.exists(ONNX_CACHE_DIR) else 'Directory does not exist'}") | |
| # Check GPU availability safely | |
| self.use_gpu = False | |
| self.device = None | |
| if TORCH_AVAILABLE and torch is not None: | |
| try: | |
| self.use_gpu = torch.cuda.is_available() | |
| self.device = torch.device('cuda' if self.use_gpu else 'cpu') | |
| if self.use_gpu: | |
| logger.info(f"🚀 GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| logger.info("💻 Using CPU") | |
| except AttributeError: | |
| # torch module exists but doesn't have cuda attribute | |
| self.use_gpu = False | |
| self.device = None | |
| logger.info("⚠️ PyTorch incomplete - inpainting disabled") | |
| else: | |
| logger.info("⚠️ PyTorch not available - inpainting disabled") | |
| # Quantization/precision toggle (off by default) | |
| try: | |
| adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {} | |
| # Track singleton mode from settings for thread limiting (deprecated - kept for compatibility) | |
| self.singleton_mode = bool(adv_cfg.get('use_singleton_models', True)) | |
| env_quant = os.environ.get('MODEL_QUANTIZE', 'false').lower() == 'true' | |
| self.quantize_enabled = bool(env_quant or adv_cfg.get('quantize_models', False)) | |
| # ONNX quantization is now strictly opt-in (config or env), decoupled from general quantize_models | |
| self.onnx_quantize_enabled = bool( | |
| adv_cfg.get('onnx_quantize', os.environ.get('ONNX_QUANTIZE', 'false').lower() == 'true') | |
| ) | |
| self.torch_precision = str(adv_cfg.get('torch_precision', os.environ.get('TORCH_PRECISION', 'auto'))).lower() | |
| logger.info(f"Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} for Local Inpainter; onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'}; torch_precision={self.torch_precision}") | |
| self.int8_enabled = bool( | |
| adv_cfg.get('int8_quantize', False) | |
| or adv_cfg.get('quantize_int8', False) | |
| or os.environ.get('TORCH_INT8', 'false').lower() == 'true' | |
| or self.torch_precision in ('int8', 'int8_dynamic') | |
| ) | |
| logger.info( | |
| f"Quantization: {'ENABLED' if self.quantize_enabled else 'disabled'} for Local Inpainter; " | |
| f"onnx_quantize={'on' if self.onnx_quantize_enabled else 'off'}; " | |
| f"torch_precision={self.torch_precision}; int8={'on' if self.int8_enabled else 'off'}" | |
| ) | |
| except Exception: | |
| self.quantize_enabled = False | |
| self.onnx_quantize_enabled = False | |
| self.torch_precision = 'auto' | |
| self.int8_enabled = False | |
| # HD strategy defaults (mirror of comic-translate behavior) | |
| try: | |
| adv_cfg = self.config.get('manga_settings', {}).get('advanced', {}) if isinstance(self.config, dict) else {} | |
| except Exception: | |
| adv_cfg = {} | |
| try: | |
| self.hd_strategy = str(os.environ.get('HD_STRATEGY', adv_cfg.get('hd_strategy', 'resize'))).lower() | |
| except Exception: | |
| self.hd_strategy = 'resize' | |
| try: | |
| self.hd_resize_limit = int(os.environ.get('HD_RESIZE_LIMIT', adv_cfg.get('hd_strategy_resize_limit', 1536))) | |
| except Exception: | |
| self.hd_resize_limit = 1536 | |
| try: | |
| self.hd_crop_margin = int(os.environ.get('HD_CROP_MARGIN', adv_cfg.get('hd_strategy_crop_margin', 16))) | |
| except Exception: | |
| self.hd_crop_margin = 16 | |
| try: | |
| self.hd_crop_trigger_size = int(os.environ.get('HD_CROP_TRIGGER', adv_cfg.get('hd_strategy_crop_trigger_size', 1024))) | |
| except Exception: | |
| self.hd_crop_trigger_size = 1024 | |
| logger.info(f"HD strategy: {self.hd_strategy} (resize_limit={self.hd_resize_limit}, crop_margin={self.hd_crop_margin}, crop_trigger={self.hd_crop_trigger_size})") | |
| # Stop flag support | |
| self.stop_flag = None | |
| self._stopped = False | |
| self.log_callback = None | |
| # Initialize bubble detector if available | |
| if BUBBLE_DETECTOR_AVAILABLE: | |
| try: | |
| self.bubble_detector = BubbleDetector() | |
| logger.info("🗨️ Bubble detection available") | |
| except: | |
| self.bubble_detector = None | |
| logger.info("🗨️ Bubble detection not available") | |
| def _load_config(self): | |
| try: | |
| if self.config_path and os.path.exists(self.config_path): | |
| with open(self.config_path, 'r', encoding='utf-8') as f: | |
| content = f.read().strip() | |
| if not content: | |
| return {} | |
| try: | |
| return json.loads(content) | |
| except json.JSONDecodeError: | |
| # Likely a concurrent write; retry once after a short delay | |
| try: | |
| import time | |
| time.sleep(0.05) | |
| with open(self.config_path, 'r', encoding='utf-8') as f2: | |
| return json.load(f2) | |
| except Exception: | |
| return {} | |
| except Exception: | |
| return {} | |
| return {} | |
| def _save_config(self): | |
| # Don't save if config is empty (prevents purging) | |
| if not getattr(self, 'config', None): | |
| return | |
| try: | |
| # Load existing (best-effort) | |
| full_config = {} | |
| if self.config_path and os.path.exists(self.config_path): | |
| try: | |
| with open(self.config_path, 'r', encoding='utf-8') as f: | |
| full_config = json.load(f) | |
| except Exception as read_err: | |
| logger.debug(f"Config read during save failed (non-critical): {read_err}") | |
| full_config = {} | |
| # Update | |
| full_config.update(self.config) | |
| # Atomic write: write to temp then replace | |
| tmp_path = (self.config_path or 'config.json') + '.tmp' | |
| with open(tmp_path, 'w', encoding='utf-8') as f: | |
| json.dump(full_config, f, indent=2, ensure_ascii=False) | |
| try: | |
| os.replace(tmp_path, self.config_path or 'config.json') | |
| except Exception as replace_err: | |
| logger.debug(f"Config atomic replace failed, trying direct write: {replace_err}") | |
| # Fallback to direct write | |
| with open(self.config_path or 'config.json', 'w', encoding='utf-8') as f: | |
| json.dump(full_config, f, indent=2, ensure_ascii=False) | |
| except Exception as save_err: | |
| # Never crash on config save, but log for debugging | |
| logger.debug(f"Config save failed (non-critical): {save_err}") | |
| pass | |
| def set_stop_flag(self, stop_flag): | |
| """Set the stop flag for checking interruptions""" | |
| self.stop_flag = stop_flag | |
| self._stopped = False | |
| def set_log_callback(self, log_callback): | |
| """Set log callback for GUI integration""" | |
| self.log_callback = log_callback | |
| def _check_stop(self) -> bool: | |
| """Check if stop has been requested""" | |
| if self._stopped: | |
| return True | |
| if self.stop_flag and self.stop_flag.is_set(): | |
| self._stopped = True | |
| return True | |
| # Check global manga translator cancellation | |
| try: | |
| from manga_translator import MangaTranslator | |
| if MangaTranslator.is_globally_cancelled(): | |
| self._stopped = True | |
| return True | |
| except Exception: | |
| pass | |
| return False | |
| def _log(self, message: str, level: str = "info"): | |
| """Log message with stop suppression""" | |
| # Suppress logs when stopped (allow only essential stop confirmation messages) | |
| if self._check_stop(): | |
| essential_stop_keywords = [ | |
| "⏹️ Translation stopped by user", | |
| "⏹️ Inpainting stopped", | |
| "cleanup", "🧹" | |
| ] | |
| if not any(keyword in message for keyword in essential_stop_keywords): | |
| return | |
| if self.log_callback: | |
| self.log_callback(message, level) | |
| else: | |
| logger.info(message) if level == 'info' else getattr(logger, level, logger.info)(message) | |
| def reset_stop_flags(self): | |
| """Reset stop flags when starting new processing""" | |
| self._stopped = False | |
| def convert_to_onnx(self, model_path: str, method: str) -> Optional[str]: | |
| """Convert a PyTorch model to ONNX format with FFT handling via custom operators""" | |
| if not ONNX_AVAILABLE: | |
| logger.warning("ONNX not available, skipping conversion") | |
| return None | |
| try: | |
| # Generate ONNX path | |
| model_name = os.path.basename(model_path).replace('.pt', '') | |
| onnx_path = os.path.join(ONNX_CACHE_DIR, f"{model_name}_{method}.onnx") | |
| # Check if ONNX already exists | |
| if os.path.exists(onnx_path) and not FORCE_ONNX_REBUILD: | |
| logger.info(f"✅ ONNX model already exists: {onnx_path}") | |
| return onnx_path | |
| logger.info(f"🔄 Converting {method} model to ONNX...") | |
| # The model should already be loaded at this point | |
| if not self.model_loaded or self.current_method != method: | |
| logger.error("Model not loaded for ONNX conversion") | |
| return None | |
| # Create dummy inputs | |
| dummy_image = torch.randn(1, 3, 512, 512).to(self.device) | |
| dummy_mask = torch.randn(1, 1, 512, 512).to(self.device) | |
| # For FFT models, we can't convert directly | |
| fft_models = ['lama', 'anime', 'lama_official'] | |
| if method in fft_models: | |
| logger.warning(f"⚠️ {method.upper()} uses FFT operations that cannot be exported") | |
| return None # Just return None, don't suggest Carve | |
| # Standard export for non-FFT models | |
| try: | |
| torch.onnx.export( | |
| self.model, | |
| (dummy_image, dummy_mask), | |
| onnx_path, | |
| export_params=True, | |
| opset_version=13, | |
| do_constant_folding=True, | |
| input_names=['image', 'mask'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'image': {0: 'batch', 2: 'height', 3: 'width'}, | |
| 'mask': {0: 'batch', 2: 'height', 3: 'width'}, | |
| 'output': {0: 'batch', 2: 'height', 3: 'width'} | |
| } | |
| ) | |
| logger.info(f"✅ ONNX model saved to: {onnx_path}") | |
| return onnx_path | |
| except torch.onnx.errors.UnsupportedOperatorError as e: | |
| logger.error(f"❌ Unsupported operator: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"❌ ONNX conversion failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def load_onnx_model(self, onnx_path: str) -> bool: | |
| """Load an ONNX model with custom operator support""" | |
| if not ONNX_AVAILABLE: | |
| logger.error("ONNX Runtime not available") | |
| return False | |
| # Check if this exact ONNX model is already loaded | |
| if (self.onnx_session is not None and | |
| hasattr(self, 'current_onnx_path') and | |
| self.current_onnx_path == onnx_path): | |
| logger.debug(f"✅ ONNX model already loaded: {onnx_path}") | |
| return True | |
| try: | |
| # Don't log here if we already logged in load_model | |
| logger.debug(f"📦 ONNX Runtime loading: {onnx_path}") | |
| # Store the path for later checking | |
| self.current_onnx_path = onnx_path | |
| # Check if this is a Carve model (fixed 512x512) | |
| is_carve_model = "lama_fp32" in onnx_path or "carve" in onnx_path.lower() | |
| if is_carve_model: | |
| logger.info("📦 Detected Carve ONNX model (fixed 512x512 input)") | |
| self.onnx_fixed_size = (512, 512) | |
| else: | |
| self.onnx_fixed_size = None | |
| # Standard ONNX loading: prefer CUDA if available; otherwise CPU. Do NOT use DML. | |
| try: | |
| avail = ort.get_available_providers() if ONNX_AVAILABLE else [] | |
| except Exception: | |
| avail = [] | |
| if 'CUDAExecutionProvider' in avail: | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| else: | |
| providers = ['CPUExecutionProvider'] | |
| session_path = onnx_path | |
| try: | |
| fname_lower = os.path.basename(onnx_path).lower() | |
| except Exception: | |
| fname_lower = str(onnx_path).lower() | |
| # Device-aware policy for LaMa-type ONNX (Carve or contains 'lama') | |
| is_lama_model = is_carve_model or ('lama' in fname_lower) | |
| if is_lama_model: | |
| base = os.path.splitext(onnx_path)[0] | |
| if self.use_gpu: | |
| # Prefer FP16 on CUDA | |
| fp16_path = base + '.fp16.onnx' | |
| if (not os.path.exists(fp16_path)) or FORCE_ONNX_REBUILD: | |
| try: | |
| import onnx as _onnx | |
| try: | |
| from onnxruntime_tools.transformers.float16 import convert_float_to_float16 as _to_fp16 | |
| except Exception: | |
| try: | |
| from onnxconverter_common import float16 | |
| def _to_fp16(m, keep_io_types=True): | |
| return float16.convert_float_to_float16(m, keep_io_types=keep_io_types) | |
| except Exception: | |
| _to_fp16 = None | |
| if _to_fp16 is not None: | |
| m = _onnx.load(onnx_path) | |
| m_fp16 = _to_fp16(m, keep_io_types=True) | |
| _onnx.save(m_fp16, fp16_path) | |
| logger.info(f"✅ Generated FP16 ONNX for LaMa: {fp16_path}") | |
| except Exception as e: | |
| logger.warning(f"FP16 conversion for LaMa failed: {e}") | |
| if os.path.exists(fp16_path): | |
| session_path = fp16_path | |
| else: | |
| # CPU path for LaMa: quantize only if enabled, and MatMul-only to avoid artifacts | |
| if self.onnx_quantize_enabled: | |
| try: | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quant_path = base + '.matmul.int8.onnx' | |
| if (not os.path.exists(quant_path)) or FORCE_ONNX_REBUILD: | |
| logger.info("🔻 LaMa: Quantizing ONNX weights to INT8 (dynamic, ops=['MatMul'])...") | |
| quantize_dynamic( | |
| model_input=onnx_path, | |
| model_output=quant_path, | |
| weight_type=QuantType.QInt8, | |
| op_types_to_quantize=['MatMul'] | |
| ) | |
| self.onnx_quantize_applied = True | |
| # Validate dynamic quant result | |
| try: | |
| import onnx as _onnx | |
| _m_q = _onnx.load(quant_path) | |
| _onnx.checker.check_model(_m_q) | |
| except Exception as _qchk: | |
| logger.warning(f"LaMa dynamic quant model invalid; deleting and falling back: {_qchk}") | |
| try: | |
| os.remove(quant_path) | |
| except Exception: | |
| pass | |
| quant_path = None | |
| except Exception as dy_err: | |
| logger.warning(f"LaMa dynamic quantization failed: {dy_err}") | |
| quant_path = None | |
| # Fallback: static QDQ MatMul-only with zero data reader | |
| if quant_path is None: | |
| try: | |
| import onnx as _onnx | |
| from onnxruntime.quantization import ( | |
| CalibrationDataReader, quantize_static, | |
| QuantFormat, QuantType, CalibrationMethod | |
| ) | |
| m = _onnx.load(onnx_path) | |
| shapes = {} | |
| for inp in m.graph.input: | |
| dims = [] | |
| for d in inp.type.tensor_type.shape.dim: | |
| dims.append(d.dim_value if d.dim_value > 0 else 1) | |
| shapes[inp.name] = dims | |
| class _ZeroReader(CalibrationDataReader): | |
| def __init__(self, shapes): | |
| self.shapes = shapes | |
| self.done = False | |
| def get_next(self): | |
| if self.done: | |
| return None | |
| feed = {} | |
| for name, s in self.shapes.items(): | |
| ss = list(s) | |
| if len(ss) == 4: | |
| if ss[2] <= 1: ss[2] = 512 | |
| if ss[3] <= 1: ss[3] = 512 | |
| if ss[1] <= 1 and 'mask' not in name.lower(): | |
| ss[1] = 3 | |
| feed[name] = np.zeros(ss, dtype=np.float32) | |
| self.done = True | |
| return feed | |
| dr = _ZeroReader(shapes) | |
| quant_path = base + '.matmul.int8.onnx' | |
| quantize_static( | |
| model_input=onnx_path, | |
| model_output=quant_path, | |
| calibration_data_reader=dr, | |
| quant_format=QuantFormat.QDQ, | |
| activation_type=QuantType.QUInt8, | |
| weight_type=QuantType.QInt8, | |
| per_channel=False, | |
| calibrate_method=CalibrationMethod.MinMax, | |
| op_types_to_quantize=['MatMul'] | |
| ) | |
| # Validate | |
| try: | |
| _m_q = _onnx.load(quant_path) | |
| _onnx.checker.check_model(_m_q) | |
| except Exception as _qchk2: | |
| logger.warning(f"LaMa static MatMul-only quant model invalid; deleting: {_qchk2}") | |
| try: | |
| os.remove(quant_path) | |
| except Exception: | |
| pass | |
| quant_path = None | |
| else: | |
| logger.info(f"✅ Generated MatMul-only INT8 ONNX for LaMa: {quant_path}") | |
| self.onnx_quantize_applied = True | |
| except Exception as st_err: | |
| logger.warning(f"LaMa static MatMul-only quantization failed: {st_err}") | |
| quant_path = None | |
| # Use the quantized model if valid | |
| if quant_path and os.path.exists(quant_path): | |
| session_path = quant_path | |
| logger.info(f"✅ Using LaMa quantized ONNX model: {quant_path}") | |
| # If quantization not enabled or failed, session_path remains onnx_path (FP32) | |
| # Optional dynamic/static quantization for other models (opt-in) | |
| if (not is_lama_model) and self.onnx_quantize_enabled: | |
| base = os.path.splitext(onnx_path)[0] | |
| fname = os.path.basename(onnx_path).lower() | |
| is_aot = 'aot' in fname | |
| # For AOT: ignore any MatMul-only file and prefer Conv+MatMul | |
| if is_aot: | |
| try: | |
| ignored_matmul = base + ".matmul.int8.onnx" | |
| if os.path.exists(ignored_matmul): | |
| logger.info(f"⏭️ Ignoring MatMul-only quantized file for AOT: {ignored_matmul}") | |
| except Exception: | |
| pass | |
| # Choose target quant file and ops | |
| if is_aot: | |
| quant_path = base + ".int8.onnx" | |
| ops_to_quant = ['MatMul'] | |
| # Use MatMul-only for safer quantization across models | |
| ops_for_static = ['MatMul'] | |
| # Try to simplify AOT graph prior to quantization | |
| quant_input_path = onnx_path | |
| try: | |
| import onnx as _onnx | |
| try: | |
| from onnxsim import simplify as _onnx_simplify | |
| _model = _onnx.load(onnx_path) | |
| _sim_model, _check = _onnx_simplify(_model) | |
| if _check: | |
| sim_path = base + ".sim.onnx" | |
| _onnx.save(_sim_model, sim_path) | |
| quant_input_path = sim_path | |
| logger.info(f"🧰 Simplified AOT ONNX before quantization: {sim_path}") | |
| except Exception as _sim_err: | |
| logger.info(f"AOT simplification skipped: {_sim_err}") | |
| # No ONNX shape inference; keep original graph structure | |
| # Ensure opset >= 13 for QDQ (axis attribute on DequantizeLinear) | |
| try: | |
| _m_tmp = _onnx.load(quant_input_path) | |
| _opset = max([op.version for op in _m_tmp.opset_import]) if _m_tmp.opset_import else 11 | |
| if _opset < 13: | |
| from onnx import version_converter as _vc | |
| _m13 = _vc.convert_version(_m_tmp, 13) | |
| up_path = base + ".op13.onnx" | |
| _onnx.save(_m13, up_path) | |
| quant_input_path = up_path | |
| logger.info(f"🧰 Upgraded ONNX opset to 13 before QDQ quantization: {up_path}") | |
| except Exception as _operr: | |
| logger.info(f"Opset upgrade skipped: {_operr}") | |
| except Exception: | |
| quant_input_path = onnx_path | |
| else: | |
| quant_path = base + ".matmul.int8.onnx" | |
| ops_to_quant = ['MatMul'] | |
| ops_for_static = ops_to_quant | |
| quant_input_path = onnx_path | |
| # Perform quantization if needed | |
| if not os.path.exists(quant_path) or FORCE_ONNX_REBUILD: | |
| if is_aot: | |
| # Directly perform static QDQ quantization for MatMul only (avoid Conv activations) | |
| try: | |
| import onnx as _onnx | |
| from onnxruntime.quantization import CalibrationDataReader, quantize_static, QuantFormat, QuantType, CalibrationMethod | |
| _model = _onnx.load(quant_input_path) | |
| # Build input shapes from the model graph | |
| input_shapes = {} | |
| for inp in _model.graph.input: | |
| dims = [] | |
| for d in inp.type.tensor_type.shape.dim: | |
| if d.dim_value > 0: | |
| dims.append(d.dim_value) | |
| else: | |
| # default fallback dimension | |
| dims.append(1) | |
| input_shapes[inp.name] = dims | |
| class _ZeroDataReader(CalibrationDataReader): | |
| def __init__(self, input_shapes): | |
| self._shapes = input_shapes | |
| self._provided = False | |
| def get_next(self): | |
| if self._provided: | |
| return None | |
| feed = {} | |
| for name, shape in self._shapes.items(): | |
| # Ensure reasonable default spatial size | |
| s = list(shape) | |
| if len(s) == 4: | |
| if s[2] <= 1: | |
| s[2] = 512 | |
| if s[3] <= 1: | |
| s[3] = 512 | |
| # channel fallback | |
| if s[1] <= 1 and 'mask' not in name.lower(): | |
| s[1] = 3 | |
| feed[name] = (np.zeros(s, dtype=np.float32)) | |
| self._provided = True | |
| return feed | |
| dr = _ZeroDataReader(input_shapes) | |
| quantize_static( | |
| model_input=quant_input_path, | |
| model_output=quant_path, | |
| calibration_data_reader=dr, | |
| quant_format=QuantFormat.QDQ, | |
| activation_type=QuantType.QUInt8, | |
| weight_type=QuantType.QInt8, | |
| per_channel=True, | |
| calibrate_method=CalibrationMethod.MinMax, | |
| op_types_to_quantize=ops_for_static | |
| ) | |
| # Validate quantized model to catch structural errors early | |
| try: | |
| _m_q = _onnx.load(quant_path) | |
| _onnx.checker.check_model(_m_q) | |
| except Exception as _qchk: | |
| logger.warning(f"Quantized AOT model validation failed: {_qchk}") | |
| # Remove broken quantized file to force fallback | |
| try: | |
| os.remove(quant_path) | |
| except Exception: | |
| pass | |
| else: | |
| logger.info(f"✅ Static INT8 quantization produced: {quant_path}") | |
| except Exception as st_err: | |
| logger.warning(f"Static ONNX quantization failed: {st_err}") | |
| else: | |
| # First attempt: dynamic quantization (MatMul) | |
| try: | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| logger.info("🔻 Quantizing ONNX inpainting model weights to INT8 (dynamic, ops=['MatMul'])...") | |
| quantize_dynamic( | |
| model_input=quant_input_path, | |
| model_output=quant_path, | |
| weight_type=QuantType.QInt8, | |
| op_types_to_quantize=['MatMul'] | |
| ) | |
| except Exception as dy_err: | |
| logger.warning(f"Dynamic ONNX quantization failed: {dy_err}; attempting static quantization...") | |
| # Fallback: static quantization with a zero data reader | |
| try: | |
| import onnx as _onnx | |
| from onnxruntime.quantization import CalibrationDataReader, quantize_static, QuantFormat, QuantType, CalibrationMethod | |
| _model = _onnx.load(quant_input_path) | |
| # Build input shapes from the model graph | |
| input_shapes = {} | |
| for inp in _model.graph.input: | |
| dims = [] | |
| for d in inp.type.tensor_type.shape.dim: | |
| if d.dim_value > 0: | |
| dims.append(d.dim_value) | |
| else: | |
| # default fallback dimension | |
| dims.append(1) | |
| input_shapes[inp.name] = dims | |
| class _ZeroDataReader(CalibrationDataReader): | |
| def __init__(self, input_shapes): | |
| self._shapes = input_shapes | |
| self._provided = False | |
| def get_next(self): | |
| if self._provided: | |
| return None | |
| feed = {} | |
| for name, shape in self._shapes.items(): | |
| # Ensure reasonable default spatial size | |
| s = list(shape) | |
| if len(s) == 4: | |
| if s[2] <= 1: | |
| s[2] = 512 | |
| if s[3] <= 1: | |
| s[3] = 512 | |
| # channel fallback | |
| if s[1] <= 1 and 'mask' not in name.lower(): | |
| s[1] = 3 | |
| feed[name] = (np.zeros(s, dtype=np.float32)) | |
| self._provided = True | |
| return feed | |
| dr = _ZeroDataReader(input_shapes) | |
| quantize_static( | |
| model_input=quant_input_path, | |
| model_output=quant_path, | |
| calibration_data_reader=dr, | |
| quant_format=QuantFormat.QDQ, | |
| activation_type=QuantType.QUInt8, | |
| weight_type=QuantType.QInt8, | |
| per_channel=True, | |
| calibrate_method=CalibrationMethod.MinMax, | |
| op_types_to_quantize=ops_for_static | |
| ) | |
| # Validate quantized model to catch structural errors early | |
| try: | |
| _m_q = _onnx.load(quant_path) | |
| _onnx.checker.check_model(_m_q) | |
| except Exception as _qchk: | |
| logger.warning(f"Quantized AOT model validation failed: {_qchk}") | |
| # Remove broken quantized file to force fallback | |
| try: | |
| os.remove(quant_path) | |
| except Exception: | |
| pass | |
| else: | |
| logger.info(f"✅ Static INT8 quantization produced: {quant_path}") | |
| except Exception as st_err: | |
| logger.warning(f"Static ONNX quantization failed: {st_err}") | |
| # Prefer the quantized file if it now exists | |
| if os.path.exists(quant_path): | |
| # Validate existing quantized model before using it | |
| try: | |
| import onnx as _onnx | |
| _m_q = _onnx.load(quant_path) | |
| _onnx.checker.check_model(_m_q) | |
| except Exception as _qchk: | |
| logger.warning(f"Existing quantized ONNX invalid; deleting and falling back: {_qchk}") | |
| try: | |
| os.remove(quant_path) | |
| except Exception: | |
| pass | |
| else: | |
| session_path = quant_path | |
| logger.info(f"✅ Using quantized ONNX model: {quant_path}") | |
| else: | |
| logger.warning("ONNX quantization not applied: quantized file not created") | |
| # Use conservative ORT memory options to reduce RAM growth | |
| so = ort.SessionOptions() | |
| try: | |
| so.enable_mem_pattern = False | |
| so.enable_cpu_mem_arena = False | |
| except Exception: | |
| pass | |
| # Enable optimal performance settings (let ONNX use all CPU cores) | |
| try: | |
| # Use all available CPU threads for best performance | |
| # ONNX Runtime will automatically use optimal thread count | |
| so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | |
| except Exception: | |
| pass | |
| # Try to create an inference session, with graceful fallbacks | |
| try: | |
| self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=providers) | |
| except Exception as e: | |
| err = str(e) | |
| logger.warning(f"ONNX session creation failed for {session_path}: {err}") | |
| # If quantized path failed due to unsupported ops or invalid graph, remove it and retry unquantized | |
| if session_path != onnx_path and ('ConvInteger' in err or 'NOT_IMPLEMENTED' in err or 'INVALID_ARGUMENT' in err): | |
| try: | |
| if os.path.exists(session_path): | |
| os.remove(session_path) | |
| logger.info(f"🧹 Deleted invalid quantized model: {session_path}") | |
| except Exception: | |
| pass | |
| try: | |
| logger.info("Retrying with unquantized ONNX model...") | |
| self.onnx_session = ort.InferenceSession(onnx_path, sess_options=so, providers=providers) | |
| session_path = onnx_path | |
| except Exception as e2: | |
| logger.warning(f"Unquantized ONNX session failed with current providers: {e2}") | |
| # As a last resort, try CPU-only | |
| try: | |
| logger.info("Retrying ONNX on CPUExecutionProvider only...") | |
| self.onnx_session = ort.InferenceSession(onnx_path, sess_options=so, providers=['CPUExecutionProvider']) | |
| session_path = onnx_path | |
| providers = ['CPUExecutionProvider'] | |
| except Exception as e3: | |
| logger.error(f"Failed to create ONNX session on CPU: {e3}") | |
| raise | |
| else: | |
| # If we weren't quantized but failed on CUDA, try CPU-only once | |
| if self.use_gpu and 'NOT_IMPLEMENTED' in err: | |
| try: | |
| logger.info("Retrying ONNX on CPUExecutionProvider only...") | |
| self.onnx_session = ort.InferenceSession(session_path, sess_options=so, providers=['CPUExecutionProvider']) | |
| providers = ['CPUExecutionProvider'] | |
| except Exception as e4: | |
| logger.error(f"Failed to create ONNX session on CPU: {e4}") | |
| raise | |
| # Get input/output names | |
| if self.onnx_session is None: | |
| raise RuntimeError("ONNX session was not created") | |
| self.onnx_input_names = [i.name for i in self.onnx_session.get_inputs()] | |
| self.onnx_output_names = [o.name for o in self.onnx_session.get_outputs()] | |
| # Check input shapes to detect fixed-size models | |
| input_shape = self.onnx_session.get_inputs()[0].shape | |
| if len(input_shape) == 4 and input_shape[2] == 512 and input_shape[3] == 512: | |
| self.onnx_fixed_size = (512, 512) | |
| logger.info(f" Model expects fixed size: 512x512") | |
| # Log success with I/O info in a single line | |
| logger.debug(f"✅ ONNX session created - Inputs: {self.onnx_input_names}, Outputs: {self.onnx_output_names}") | |
| self.use_onnx = True | |
| # CRITICAL: Set model_loaded flag when ONNX session is successfully created | |
| # This ensures preloaded spares are recognized as valid loaded instances | |
| self.model_loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load ONNX: {e}") | |
| import traceback | |
| logger.debug(f"ONNX load traceback: {traceback.format_exc()}") | |
| self.use_onnx = False | |
| self.model_loaded = False | |
| return False | |
| def _convert_checkpoint_key(self, key): | |
| """Convert checkpoint key format to model format""" | |
| # model.24.weight -> model_24.weight | |
| if re.match(r'^model\.(\d+)\.(weight|bias|running_mean|running_var)$', key): | |
| return re.sub(r'model\.(\d+)\.', r'model_\1.', key) | |
| # model.5.conv1.ffc.weight -> model_5_conv1_ffc.weight | |
| if key.startswith('model.'): | |
| parts = key.split('.') | |
| if parts[-1] in ['weight', 'bias', 'running_mean', 'running_var']: | |
| return '_'.join(parts[:-1]).replace('model_', 'model_') + '.' + parts[-1] | |
| return key.replace('.', '_') | |
| def _load_weights_with_mapping(self, model, state_dict): | |
| """Load weights with proper mapping""" | |
| model_dict = model.state_dict() | |
| logger.info(f"📊 Model expects {len(model_dict)} weights") | |
| logger.info(f"📊 Checkpoint has {len(state_dict)} weights") | |
| # Filter out num_batches_tracked | |
| actual_weights = {k: v for k, v in state_dict.items() if 'num_batches_tracked' not in k} | |
| logger.info(f" Actual weights: {len(actual_weights)}") | |
| mapped = {} | |
| unmapped_ckpt = [] | |
| unmapped_model = list(model_dict.keys()) | |
| # Map checkpoint weights | |
| for ckpt_key, ckpt_val in actual_weights.items(): | |
| success = False | |
| converted_key = self._convert_checkpoint_key(ckpt_key) | |
| if converted_key in model_dict: | |
| target_shape = model_dict[converted_key].shape | |
| if target_shape == ckpt_val.shape: | |
| mapped[converted_key] = ckpt_val | |
| success = True | |
| elif len(ckpt_val.shape) == 4 and len(target_shape) == 4: | |
| # 4D permute for decoder convs | |
| permuted = ckpt_val.permute(1, 0, 2, 3) | |
| if target_shape == permuted.shape: | |
| mapped[converted_key] = permuted | |
| logger.info(f" ✅ Permuted: {ckpt_key}") | |
| success = True | |
| elif len(ckpt_val.shape) == 2 and len(target_shape) == 2: | |
| # 2D transpose | |
| transposed = ckpt_val.transpose(0, 1) | |
| if target_shape == transposed.shape: | |
| mapped[converted_key] = transposed | |
| success = True | |
| if success and converted_key in unmapped_model: | |
| unmapped_model.remove(converted_key) | |
| if not success: | |
| unmapped_ckpt.append(ckpt_key) | |
| # Try fallback mapping for unmapped | |
| if unmapped_ckpt: | |
| logger.info(f" 🔧 Fallback mapping for {len(unmapped_ckpt)} weights...") | |
| for ckpt_key in unmapped_ckpt[:]: | |
| ckpt_val = actual_weights[ckpt_key] | |
| for model_key in unmapped_model[:]: | |
| if model_dict[model_key].shape == ckpt_val.shape: | |
| if ('weight' in ckpt_key and 'weight' in model_key) or \ | |
| ('bias' in ckpt_key and 'bias' in model_key): | |
| mapped[model_key] = ckpt_val | |
| unmapped_model.remove(model_key) | |
| unmapped_ckpt.remove(ckpt_key) | |
| logger.info(f" ✅ Mapped: {ckpt_key} -> {model_key}") | |
| break | |
| # Initialize missing weights | |
| complete_dict = model_dict.copy() | |
| complete_dict.update(mapped) | |
| for key in unmapped_model: | |
| param = complete_dict[key] | |
| if 'weight' in key: | |
| if 'conv' in key.lower(): | |
| nn.init.kaiming_normal_(param, mode='fan_out', nonlinearity='relu') | |
| else: | |
| nn.init.xavier_uniform_(param) | |
| elif 'bias' in key: | |
| nn.init.zeros_(param) | |
| elif 'running_mean' in key: | |
| nn.init.zeros_(param) | |
| elif 'running_var' in key: | |
| nn.init.ones_(param) | |
| # Report | |
| logger.info(f"✅ Mapped {len(actual_weights) - len(unmapped_ckpt)}/{len(actual_weights)} checkpoint weights") | |
| logger.info(f" Filled {len(mapped)}/{len(model_dict)} model positions") | |
| if unmapped_model: | |
| pct = (len(unmapped_model) / len(model_dict)) * 100 | |
| logger.info(f" ⚠️ Initialized {len(unmapped_model)} missing weights ({pct:.1f}%)") | |
| if pct > 20: | |
| logger.warning(" ⚠️ May produce artifacts - checkpoint is incomplete") | |
| logger.warning(" 💡 Consider downloading JIT model for better quality:") | |
| logger.warning(f" inpainter.download_jit_model('{self.current_method or 'lama'}')") | |
| model.load_state_dict(complete_dict, strict=True) | |
| return True | |
| def download_jit_model(self, method: str) -> str: | |
| """Download JIT model for a method""" | |
| if method in LAMA_JIT_MODELS: | |
| model_info = LAMA_JIT_MODELS[method] | |
| logger.info(f"📥 Downloading {model_info['name']}...") | |
| try: | |
| model_path = download_model(model_info['url'], model_info['md5']) | |
| return model_path | |
| except Exception as e: | |
| logger.error(f"Failed to download {method}: {e}") | |
| else: | |
| logger.warning(f"No JIT model available for {method}") | |
| return None | |
| def load_model(self, method, model_path, force_reload=False): | |
| """Load model - supports both JIT and checkpoint files with ONNX conversion""" | |
| try: | |
| if not TORCH_AVAILABLE: | |
| logger.warning("PyTorch not available in this build") | |
| logger.info("Inpainting features will be disabled - this is normal for lightweight builds") | |
| logger.info("The application will continue to work without local inpainting") | |
| self.model_loaded = False | |
| return False | |
| # Additional safety check for torch being None | |
| if torch is None or nn is None: | |
| logger.warning("PyTorch modules not properly loaded") | |
| logger.info("Inpainting features will be disabled - this is normal for lightweight builds") | |
| self.model_loaded = False | |
| return False | |
| # Check if model path changed - but only if we had a previous path saved | |
| current_saved_path = self.config.get(f'{method}_model_path', '') | |
| if current_saved_path and current_saved_path != model_path: | |
| logger.info(f"📍 Model path changed for {method}") | |
| logger.info(f" Old: {current_saved_path}") | |
| logger.info(f" New: {model_path}") | |
| force_reload = True | |
| if not os.path.exists(model_path): | |
| # Try to auto-download JIT model if path doesn't exist | |
| logger.warning(f"Model not found: {model_path}") | |
| logger.info("Attempting to download JIT model...") | |
| try: | |
| jit_path = self.download_jit_model(method) | |
| if jit_path and os.path.exists(jit_path): | |
| model_path = jit_path | |
| logger.info(f"Using downloaded JIT model: {jit_path}") | |
| else: | |
| logger.error(f"Model not found and download failed: {model_path}") | |
| logger.info("Inpainting will be unavailable for this session") | |
| return False | |
| except Exception as download_error: | |
| logger.error(f"Download failed: {download_error}") | |
| logger.info("Inpainting will be unavailable for this session") | |
| return False | |
| # Check if already loaded in THIS instance | |
| if self.model_loaded and self.current_method == method and not force_reload: | |
| # Additional check for ONNX - make sure the session exists | |
| if self.use_onnx and self.onnx_session is not None: | |
| logger.debug(f"✅ {method.upper()} ONNX already loaded (skipping reload)") | |
| return True | |
| elif not self.use_onnx and self.model is not None: | |
| logger.debug(f"✅ {method.upper()} already loaded (skipping reload)") | |
| return True | |
| else: | |
| # Model claims to be loaded but objects are missing - force reload | |
| logger.warning(f"⚠️ Model claims loaded but session/model object is None - forcing reload") | |
| force_reload = True | |
| self.model_loaded = False | |
| # Clear previous model if force reload | |
| if force_reload: | |
| logger.info(f"🔄 Force reloading {method} model...") | |
| self.model = None | |
| self.onnx_session = None | |
| self.model_loaded = False | |
| self.is_jit_model = False | |
| # Only log loading message when actually loading | |
| logger.info(f"📥 Loading {method} from {model_path}") | |
| elif self.model_loaded and self.current_method != method: | |
| # If we have a model loaded but it's a different method, clear it | |
| logger.info(f"🔄 Switching from {self.current_method} to {method}") | |
| self.model = None | |
| self.onnx_session = None | |
| self.model_loaded = False | |
| self.is_jit_model = False | |
| # Only log loading message when actually loading | |
| logger.info(f"📥 Loading {method} from {model_path}") | |
| elif not self.model_loaded: | |
| # Only log when we're actually going to load | |
| logger.info(f"📥 Loading {method} from {model_path}") | |
| # else: model is loaded and current, no logging needed | |
| # Normalize path and enforce expected extension for certain methods | |
| try: | |
| _ext = os.path.splitext(model_path)[1].lower() | |
| _method_lower = str(method).lower() | |
| # For explicit ONNX methods, ensure we use a .onnx path | |
| if _method_lower in ("lama_onnx", "anime_onnx", "aot_onnx") and _ext != ".onnx": | |
| # If the file exists, try to detect if it's actually an ONNX model and correct the extension | |
| if os.path.exists(model_path) and ONNX_AVAILABLE: | |
| try: | |
| import onnx as _onnx | |
| _ = _onnx.load(model_path) # will raise if not ONNX | |
| # Build a corrected path under the ONNX cache dir | |
| base_name = os.path.splitext(os.path.basename(model_path))[0] | |
| if base_name.endswith('.pt'): | |
| base_name = base_name[:-3] | |
| corrected_path = os.path.join(ONNX_CACHE_DIR, base_name + ".onnx") | |
| # Avoid overwriting a valid file with an invalid one | |
| if model_path != corrected_path: | |
| try: | |
| import shutil as _shutil | |
| _shutil.copy2(model_path, corrected_path) | |
| model_path = corrected_path | |
| logger.info(f"🔧 Corrected ONNX model extension/path: {model_path}") | |
| except Exception as _cp_e: | |
| # As a fallback, try in-place rename to .onnx | |
| try: | |
| in_place = os.path.splitext(model_path)[0] + ".onnx" | |
| os.replace(model_path, in_place) | |
| model_path = in_place | |
| logger.info(f"🔧 Renamed ONNX model to: {model_path}") | |
| except Exception: | |
| logger.warning(f"Could not correct ONNX extension automatically: {_cp_e}") | |
| except Exception: | |
| # Not an ONNX file; leave as-is | |
| pass | |
| # If the path doesn't exist or still wrong, prefer the known ONNX download for this method | |
| if (not os.path.exists(model_path)) or (os.path.splitext(model_path)[1].lower() != ".onnx"): | |
| try: | |
| # Download the appropriate ONNX model based on the method | |
| if _method_lower == "anime_onnx": | |
| _dl = self.download_jit_model("anime_onnx") | |
| elif _method_lower == "aot_onnx": | |
| _dl = self.download_jit_model("aot_onnx") | |
| else: | |
| _dl = self.download_jit_model("lama_onnx") | |
| if _dl and os.path.exists(_dl): | |
| model_path = _dl | |
| logger.info(f"🔧 Using downloaded {_method_lower.upper()} model: {model_path}") | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # Check file signature to detect ONNX files (even with wrong extension) | |
| # or check file extension | |
| ext = model_path.lower().split('.')[-1] | |
| is_onnx = False | |
| # Check by file signature | |
| try: | |
| with open(model_path, 'rb') as f: | |
| file_header = f.read(8) | |
| if file_header.startswith(b'\x08'): | |
| is_onnx = True | |
| logger.debug("📦 Detected ONNX file signature") | |
| except Exception: | |
| pass | |
| # Check by extension | |
| if ext == 'onnx': | |
| is_onnx = True | |
| # Handle ONNX files | |
| if is_onnx: | |
| # Note: load_onnx_model will handle its own logging | |
| try: | |
| onnx_load_result = self.load_onnx_model(model_path) | |
| if onnx_load_result: | |
| # CRITICAL: Set model_loaded flag FIRST before any other operations | |
| # This ensures concurrent threads see the correct state immediately | |
| self.model_loaded = True | |
| self.use_onnx = True | |
| self.is_jit_model = False | |
| # Ensure aot_onnx is properly set as current method | |
| if 'aot' in method.lower(): | |
| self.current_method = 'aot_onnx' | |
| else: | |
| self.current_method = method | |
| # Save with BOTH key formats for compatibility (non-critical - do last) | |
| try: | |
| self.config[f'{method}_model_path'] = model_path | |
| self.config[f'manga_{method}_model_path'] = model_path | |
| self._save_config() | |
| except Exception as cfg_err: | |
| logger.debug(f"Config save after ONNX load failed (non-critical): {cfg_err}") | |
| logger.info(f"✅ {method.upper()} ONNX loaded with method: {self.current_method}") | |
| # Double-check model_loaded flag is still set | |
| if not self.model_loaded: | |
| logger.error("❌ CRITICAL: model_loaded flag was unset after successful ONNX load!") | |
| self.model_loaded = True | |
| return True | |
| else: | |
| logger.error("Failed to load ONNX model - load_onnx_model returned False") | |
| self.model_loaded = False | |
| return False | |
| except Exception as onnx_err: | |
| logger.error(f"Exception during ONNX model loading: {onnx_err}") | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| self.model_loaded = False | |
| return False | |
| # Check if it's a JIT model (.pt) or checkpoint (.ckpt/.pth) | |
| if model_path.endswith('.pt'): | |
| try: | |
| # Try loading as JIT/TorchScript | |
| logger.info("📦 Attempting to load as JIT model...") | |
| self.model = torch.jit.load(model_path, map_location=self.device or 'cpu') | |
| self.model.eval() | |
| if self.use_gpu and self.device: | |
| try: | |
| self.model = self.model.to(self.device) | |
| except Exception as gpu_error: | |
| logger.warning(f"Could not move model to GPU: {gpu_error}") | |
| logger.info("Using CPU instead") | |
| self.is_jit_model = True | |
| self.model_loaded = True | |
| self.current_method = method | |
| logger.info("✅ JIT model loaded successfully!") | |
| time.sleep(0.1) # Brief pause for stability | |
| logger.debug("💤 JIT model loading pausing briefly for stability") | |
| # Optional FP16 precision on GPU to reduce VRAM | |
| if self.quantize_enabled and self.use_gpu: | |
| try: | |
| if self.torch_precision in ('fp16', 'auto'): | |
| self.model = self.model.half() | |
| logger.info("🔻 Applied FP16 precision to inpainting model (GPU)") | |
| else: | |
| logger.info("Torch precision set to fp32; skipping half()") | |
| except Exception as _e: | |
| logger.warning(f"Could not switch inpainting model precision: {_e}") | |
| # Optional INT8 dynamic quantization for CPU TorchScript (best-effort) | |
| if (self.int8_enabled or (self.quantize_enabled and not self.use_gpu and self.torch_precision in ('auto', 'int8'))) and not self.use_gpu: | |
| try: | |
| applied = False | |
| # Try TorchScript dynamic quantization API (older PyTorch) | |
| try: | |
| from torch.quantization import quantize_dynamic_jit # type: ignore | |
| self.model = quantize_dynamic_jit(self.model, {"aten::linear"}, dtype=torch.qint8) # type: ignore | |
| applied = True | |
| except Exception: | |
| pass | |
| # Try eager-style dynamic quantization on the scripted module (may no-op) | |
| if not applied: | |
| try: | |
| import torch.ao.quantization as tq # type: ignore | |
| self.model = tq.quantize_dynamic(self.model, {nn.Linear}, dtype=torch.qint8) # type: ignore | |
| applied = True | |
| except Exception: | |
| pass | |
| # Always try to optimize TorchScript for inference | |
| try: | |
| self.model = torch.jit.optimize_for_inference(self.model) # type: ignore | |
| except Exception: | |
| pass | |
| if applied: | |
| logger.info("🔻 Applied INT8 dynamic quantization to JIT inpainting model (CPU)") | |
| self.torch_quantize_applied = True | |
| else: | |
| logger.info("ℹ️ INT8 dynamic quantization not applied (unsupported for this JIT graph); using FP32 CPU") | |
| except Exception as _qe: | |
| logger.warning(f"INT8 quantization skipped: {_qe}") | |
| # Save with BOTH key formats for compatibility | |
| self.config[f'{method}_model_path'] = model_path | |
| self.config[f'manga_{method}_model_path'] = model_path | |
| self._save_config() | |
| # ONNX CONVERSION (optionally in background) | |
| if AUTO_CONVERT_TO_ONNX and self.model_loaded: | |
| def _convert_and_switch(): | |
| try: | |
| onnx_path = self.convert_to_onnx(model_path, method) | |
| if onnx_path and self.load_onnx_model(onnx_path): | |
| logger.info("🚀 Using ONNX model for inference") | |
| else: | |
| logger.info("📦 Using PyTorch JIT model for inference") | |
| except Exception as onnx_error: | |
| logger.warning(f"ONNX conversion failed: {onnx_error}") | |
| logger.info("📦 Using PyTorch JIT model for inference") | |
| if os.environ.get('AUTO_CONVERT_TO_ONNX_BACKGROUND', 'true').lower() == 'true': | |
| threading.Thread(target=_convert_and_switch, daemon=True).start() | |
| else: | |
| _convert_and_switch() | |
| return True | |
| except Exception as jit_error: | |
| logger.info(f" Not a JIT model, trying as regular checkpoint... ({jit_error})") | |
| try: | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| self.is_jit_model = False | |
| except Exception as load_error: | |
| logger.error(f"Failed to load checkpoint: {load_error}") | |
| return False | |
| else: | |
| # Load as regular checkpoint | |
| try: | |
| checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) | |
| self.is_jit_model = False | |
| except Exception as load_error: | |
| logger.error(f"Failed to load checkpoint: {load_error}") | |
| logger.info("This may happen if PyTorch is not fully available in the .exe build") | |
| return False | |
| # If we get here, it's not JIT, so load as checkpoint | |
| if not self.is_jit_model: | |
| try: | |
| # Try to create the model - this might fail if nn.Module is None | |
| self.model = FFCInpaintModel() | |
| if isinstance(checkpoint, dict): | |
| if 'gen_state_dict' in checkpoint: | |
| state_dict = checkpoint['gen_state_dict'] | |
| logger.info("📦 Found gen_state_dict") | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| elif 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| else: | |
| state_dict = checkpoint | |
| self._load_weights_with_mapping(self.model, state_dict) | |
| self.model.eval() | |
| if self.use_gpu and self.device: | |
| try: | |
| self.model = self.model.to(self.device) | |
| except Exception as gpu_error: | |
| logger.warning(f"Could not move model to GPU: {gpu_error}") | |
| logger.info("Using CPU instead") | |
| # Optional INT8 dynamic quantization for CPU eager model | |
| if (self.int8_enabled or (self.quantize_enabled and not self.use_gpu and self.torch_precision in ('auto', 'int8'))) and not self.use_gpu: | |
| try: | |
| import torch.ao.quantization as tq # type: ignore | |
| self.model = tq.quantize_dynamic(self.model, {nn.Linear}, dtype=torch.qint8) # type: ignore | |
| logger.info("🔻 Applied dynamic INT8 quantization to inpainting model (CPU)") | |
| self.torch_quantize_applied = True | |
| except Exception as qe: | |
| logger.warning(f"INT8 dynamic quantization not applied: {qe}") | |
| except Exception as model_error: | |
| logger.error(f"Failed to create or initialize model: {model_error}") | |
| logger.info("This may happen if PyTorch neural network modules are not available in the .exe build") | |
| return False | |
| self.model_loaded = True | |
| self.current_method = method | |
| self.config[f'{method}_model_path'] = model_path | |
| self._save_config() | |
| logger.info(f"✅ {method.upper()} loaded!") | |
| # ONNX CONVERSION (optionally in background) | |
| if AUTO_CONVERT_TO_ONNX and model_path.endswith('.pt') and self.model_loaded: | |
| def _convert_and_switch(): | |
| try: | |
| onnx_path = self.convert_to_onnx(model_path, method) | |
| if onnx_path and self.load_onnx_model(onnx_path): | |
| logger.info("🚀 Using ONNX model for inference") | |
| except Exception as onnx_error: | |
| logger.warning(f"ONNX conversion failed: {onnx_error}") | |
| logger.info("📦 Continuing with PyTorch model") | |
| if os.environ.get('AUTO_CONVERT_TO_ONNX_BACKGROUND', 'true').lower() == 'true': | |
| threading.Thread(target=_convert_and_switch, daemon=True).start() | |
| else: | |
| _convert_and_switch() | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load model: {e}") | |
| logger.error(traceback.format_exc()) | |
| logger.info("Note: If running from .exe, some ML libraries may not be included") | |
| logger.info("This is normal for lightweight builds - inpainting will be disabled") | |
| self.model_loaded = False | |
| return False | |
| def load_model_with_retry(self, method, model_path, force_reload=False, retries: int = 2, retry_delay: float = 0.5) -> bool: | |
| """Attempt to load a model with retries. | |
| Returns True if loaded; False if all attempts fail. On failure, the inpainter will safely no-op. | |
| """ | |
| try: | |
| attempts = max(0, int(retries)) + 1 | |
| except Exception: | |
| attempts = 1 | |
| for attempt in range(attempts): | |
| try: | |
| ok = self.load_model(method, model_path, force_reload=force_reload) | |
| if ok: | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Load attempt {attempt+1} failed with exception: {e}") | |
| # brief delay before next try | |
| if attempt < attempts - 1: | |
| try: | |
| time.sleep(max(0.0, float(retry_delay))) | |
| except Exception: | |
| pass | |
| # If we reach here, loading failed. Leave model unloaded so inpaint() no-ops and returns original image. | |
| logger.warning("All load attempts failed; local inpainting will fall back to returning original images (no-op)") | |
| self.model_loaded = False | |
| # Keep current_method for logging/context if provided | |
| try: | |
| self.current_method = method | |
| except Exception: | |
| pass | |
| return False | |
| def unload(self): | |
| """Release all heavy resources held by this inpainter instance.""" | |
| try: | |
| # Release ONNX session and metadata | |
| try: | |
| if self.onnx_session is not None: | |
| self.onnx_session = None | |
| except Exception: | |
| pass | |
| for attr in ['onnx_input_names', 'onnx_output_names', 'current_onnx_path', 'onnx_fixed_size']: | |
| try: | |
| if hasattr(self, attr): | |
| setattr(self, attr, None) | |
| except Exception: | |
| pass | |
| # Release PyTorch model | |
| try: | |
| if self.model is not None: | |
| if TORCH_AVAILABLE and torch is not None: | |
| try: | |
| # Move to CPU then drop reference | |
| self.model = self.model.to('cpu') if hasattr(self.model, 'to') else None | |
| except Exception: | |
| pass | |
| self.model = None | |
| except Exception: | |
| pass | |
| # Drop bubble detector reference (not the global cache) | |
| try: | |
| self.bubble_detector = None | |
| except Exception: | |
| pass | |
| # Update flags | |
| self.model_loaded = False | |
| self.use_onnx = False | |
| self.is_jit_model = False | |
| # Free CUDA cache and trigger GC | |
| try: | |
| if TORCH_AVAILABLE and torch is not None and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| try: | |
| import gc | |
| gc.collect() | |
| except Exception: | |
| pass | |
| except Exception: | |
| # Never raise from unload | |
| pass | |
| def pad_img_to_modulo(self, img: np.ndarray, mod: int) -> Tuple[np.ndarray, Tuple[int, int, int, int]]: | |
| """Pad image to be divisible by mod""" | |
| if len(img.shape) == 2: | |
| height, width = img.shape | |
| else: | |
| height, width = img.shape[:2] | |
| pad_height = (mod - height % mod) % mod | |
| pad_width = (mod - width % mod) % mod | |
| pad_top = pad_height // 2 | |
| pad_bottom = pad_height - pad_top | |
| pad_left = pad_width // 2 | |
| pad_right = pad_width - pad_left | |
| if len(img.shape) == 2: | |
| padded = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode='reflect') | |
| else: | |
| padded = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='reflect') | |
| return padded, (pad_top, pad_bottom, pad_left, pad_right) | |
| def remove_padding(self, img: np.ndarray, padding: Tuple[int, int, int, int]) -> np.ndarray: | |
| """Remove padding from image""" | |
| pad_top, pad_bottom, pad_left, pad_right = padding | |
| if len(img.shape) == 2: | |
| return img[pad_top:img.shape[0]-pad_bottom, pad_left:img.shape[1]-pad_right] | |
| else: | |
| return img[pad_top:img.shape[0]-pad_bottom, pad_left:img.shape[1]-pad_right, :] | |
| def _inpaint_tiled(self, image, mask, tile_size, overlap, refinement='normal'): | |
| """Process image in tiles""" | |
| orig_h, orig_w = image.shape[:2] | |
| result = image.copy() | |
| # Calculate tile positions | |
| for y in range(0, orig_h, tile_size - overlap): | |
| for x in range(0, orig_w, tile_size - overlap): | |
| # Calculate tile boundaries | |
| x_end = min(x + tile_size, orig_w) | |
| y_end = min(y + tile_size, orig_h) | |
| # Adjust start to ensure full tile size if possible | |
| if x_end - x < tile_size and x > 0: | |
| x = max(0, x_end - tile_size) | |
| if y_end - y < tile_size and y > 0: | |
| y = max(0, y_end - tile_size) | |
| # Extract tile | |
| tile_img = image[y:y_end, x:x_end] | |
| tile_mask = mask[y:y_end, x:x_end] | |
| # Skip if no inpainting needed | |
| if np.sum(tile_mask) == 0: | |
| continue | |
| # Process this tile with the actual model | |
| processed_tile = self._process_single_tile(tile_img, tile_mask, tile_size, refinement) | |
| # Auto-retry for tile if no visible change | |
| try: | |
| if self._is_noop(tile_img, processed_tile, tile_mask): | |
| kernel = np.ones((3, 3), np.uint8) | |
| expanded = cv2.dilate(tile_mask, kernel, iterations=1) | |
| processed_retry = self._process_single_tile(tile_img, expanded, tile_size, 'fast') | |
| if self._is_noop(tile_img, processed_retry, expanded): | |
| logger.warning("Tile remained unchanged after retry; proceeding without further fallback") | |
| processed_tile = processed_retry | |
| else: | |
| processed_tile = processed_retry | |
| except Exception as e: | |
| logger.debug(f"Tiled no-op detection error: {e}") | |
| # Blend tile back into result | |
| if overlap > 0 and (x > 0 or y > 0): | |
| result[y:y_end, x:x_end] = self._blend_tile( | |
| result[y:y_end, x:x_end], | |
| processed_tile, | |
| x > 0, | |
| y > 0, | |
| overlap | |
| ) | |
| else: | |
| result[y:y_end, x:x_end] = processed_tile | |
| logger.info(f"✅ Tiled inpainting complete ({orig_w}x{orig_h} in {tile_size}x{tile_size} tiles)") | |
| time.sleep(0.1) # Brief pause for stability | |
| logger.debug("💤 Tiled inpainting completion pausing briefly for stability") | |
| return result | |
| def _process_single_tile(self, tile_img, tile_mask, tile_size, refinement): | |
| """Process a single tile without tiling""" | |
| # Temporarily disable tiling | |
| old_tiling = self.tiling_enabled | |
| self.tiling_enabled = False | |
| result = self.inpaint(tile_img, tile_mask, refinement, _skip_hd=True) | |
| self.tiling_enabled = old_tiling | |
| return result | |
| def _blend_tile(self, existing, new_tile, blend_x, blend_y, overlap): | |
| """Blend a tile with existing result""" | |
| if not blend_x and not blend_y: | |
| # No blending needed for first tile | |
| return new_tile | |
| h, w = new_tile.shape[:2] | |
| result = new_tile.copy() | |
| # Create blend weights | |
| if blend_x and overlap > 0 and w > overlap: | |
| # Horizontal blend on left edge | |
| for i in range(overlap): | |
| alpha = i / overlap | |
| result[:, i] = existing[:, i] * (1 - alpha) + new_tile[:, i] * alpha | |
| if blend_y and overlap > 0 and h > overlap: | |
| # Vertical blend on top edge | |
| for i in range(overlap): | |
| alpha = i / overlap | |
| result[i, :] = existing[i, :] * (1 - alpha) + new_tile[i, :] * alpha | |
| return result | |
| def _is_noop(self, original: np.ndarray, result: np.ndarray, mask: np.ndarray, threshold: float = 0.75) -> bool: | |
| """Return True if inpainting produced negligible change within the masked area.""" | |
| try: | |
| if original is None or result is None: | |
| return True | |
| if original.shape != result.shape: | |
| return False | |
| # Normalize mask to single channel boolean | |
| if mask is None: | |
| return False | |
| if len(mask.shape) == 3: | |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| else: | |
| mask_gray = mask | |
| m = mask_gray > 0 | |
| if not np.any(m): | |
| return False | |
| # Fast path | |
| if np.array_equal(original, result): | |
| return True | |
| diff = cv2.absdiff(result, original) | |
| if len(diff.shape) == 3: | |
| diff_gray = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) | |
| else: | |
| diff_gray = diff | |
| mean_diff = float(np.mean(diff_gray[m])) | |
| return mean_diff < threshold | |
| except Exception as e: | |
| logger.debug(f"No-op detection failed: {e}") | |
| return False | |
| def _is_white_paste(self, result: np.ndarray, mask: np.ndarray, white_threshold: int = 245, ratio: float = 0.90) -> bool: | |
| """Detect 'white paste' failure: masked area mostly saturated near white.""" | |
| try: | |
| if result is None or mask is None: | |
| return False | |
| if len(mask.shape) == 3: | |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| else: | |
| mask_gray = mask | |
| m = mask_gray > 0 | |
| if not np.any(m): | |
| return False | |
| if len(result.shape) == 3: | |
| white = (result[..., 0] >= white_threshold) & (result[..., 1] >= white_threshold) & (result[..., 2] >= white_threshold) | |
| else: | |
| white = result >= white_threshold | |
| count_mask = int(np.count_nonzero(m)) | |
| count_white = int(np.count_nonzero(white & m)) | |
| if count_mask == 0: | |
| return False | |
| frac = count_white / float(count_mask) | |
| return frac >= ratio | |
| except Exception as e: | |
| logger.debug(f"White paste detection failed: {e}") | |
| return False | |
| def _log_inpaint_diag(self, path: str, result: np.ndarray, mask: np.ndarray): | |
| try: | |
| h, w = result.shape[:2] | |
| if len(result.shape) == 3: | |
| stats = (float(result.min()), float(result.max()), float(result.mean())) | |
| else: | |
| stats = (float(result.min()), float(result.max()), float(result.mean())) | |
| logger.info(f"[Diag] Path={path} onnx_quant={self.onnx_quantize_applied} torch_quant={self.torch_quantize_applied} size={w}x{h} stats(min,max,mean)={stats}") | |
| if self._is_white_paste(result, mask): | |
| logger.warning(f"[Diag] White-paste detected (mask>0 mostly white)") | |
| except Exception as e: | |
| logger.debug(f"Diag log failed: {e}") | |
| def inpaint(self, image, mask, refinement='normal', _retry_attempt: int = 0, _skip_hd: bool = False, _skip_tiling: bool = False): | |
| """Inpaint - compatible with JIT, checkpoint, and ONNX models | |
| Implements HD strategy (Resize/Crop) similar to comic-translate to speed up large images. | |
| """ | |
| # Check for stop at start | |
| if self._check_stop(): | |
| self._log("⏹️ Inpainting stopped by user", "warning") | |
| return image | |
| if not self.model_loaded: | |
| self._log("No model loaded", "error") | |
| return image | |
| try: | |
| # Store original dimensions | |
| orig_h, orig_w = image.shape[:2] | |
| # HD strategy (mirror of comic-translate): optional RESIZE or CROP before core inpainting | |
| if not _skip_hd: | |
| try: | |
| strategy = getattr(self, 'hd_strategy', 'resize') or 'resize' | |
| except Exception: | |
| strategy = 'resize' | |
| H, W = orig_h, orig_w | |
| if strategy == 'resize' and max(H, W) > max(16, int(getattr(self, 'hd_resize_limit', 1536))): | |
| limit = max(16, int(getattr(self, 'hd_resize_limit', 1536))) | |
| ratio = float(limit) / float(max(H, W)) | |
| new_w = max(1, int(W * ratio + 0.5)) | |
| new_h = max(1, int(H * ratio + 0.5)) | |
| image_small = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4) | |
| mask_small = mask if len(mask.shape) == 2 else cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| mask_small = cv2.resize(mask_small, (new_w, new_h), interpolation=cv2.INTER_NEAREST) | |
| result_small = self.inpaint(image_small, mask_small, refinement, 0, _skip_hd=True, _skip_tiling=True) | |
| result_full = cv2.resize(result_small, (W, H), interpolation=cv2.INTER_LANCZOS4) | |
| # Paste only masked area | |
| mask_gray = mask_small # already gray but at small size | |
| mask_gray = cv2.resize(mask_gray, (W, H), interpolation=cv2.INTER_NEAREST) | |
| m = mask_gray > 0 | |
| out = image.copy() | |
| out[m] = result_full[m] | |
| return out | |
| elif strategy == 'crop' and max(H, W) > max(16, int(getattr(self, 'hd_crop_trigger_size', 1024))): | |
| mask_gray0 = mask if len(mask.shape) == 2 else cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| _, thresh = cv2.threshold(mask_gray0, 127, 255, cv2.THRESH_BINARY) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if contours: | |
| out = image.copy() | |
| margin = max(0, int(getattr(self, 'hd_crop_margin', 16))) | |
| for cnt in contours: | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| l = max(0, x - margin); t = max(0, y - margin) | |
| r = min(W, x + w + margin); b = min(H, y + h + margin) | |
| if r <= l or b <= t: | |
| continue | |
| crop_img = image[t:b, l:r] | |
| crop_mask = mask_gray0[t:b, l:r] | |
| patch = self.inpaint(crop_img, crop_mask, refinement, 0, _skip_hd=True, _skip_tiling=True) | |
| out[t:b, l:r] = patch | |
| return out | |
| if len(mask.shape) == 3: | |
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
| # Apply dilation for anime method | |
| if self.current_method == 'anime': | |
| kernel = np.ones((7, 7), np.uint8) | |
| mask = cv2.dilate(mask, kernel, iterations=1) | |
| # Use instance tiling settings for ALL models | |
| logger.info(f"🔍 Tiling check: enabled={self.tiling_enabled}, tile_size={self.tile_size}, image_size={orig_h}x{orig_w}") | |
| # If tiling is enabled and image is larger than tile size | |
| if (not _skip_tiling) and self.tiling_enabled and (orig_h > self.tile_size or orig_w > self.tile_size): | |
| logger.info(f"🔲 Using tiled inpainting: {self.tile_size}x{self.tile_size} tiles with {self.tile_overlap}px overlap") | |
| return self._inpaint_tiled(image, mask, self.tile_size, self.tile_overlap, refinement) | |
| # ONNX inference path | |
| if self.use_onnx and self.onnx_session: | |
| logger.debug("Using ONNX inference") | |
| # CRITICAL: Convert BGR (OpenCV default) to RGB (ML model expected) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Check if this is a Carve model | |
| is_carve_model = False | |
| if hasattr(self, 'current_onnx_path'): | |
| is_carve_model = "lama_fp32" in self.current_onnx_path or "carve" in self.current_onnx_path.lower() | |
| # Handle fixed-size models (resize instead of padding) | |
| if hasattr(self, 'onnx_fixed_size') and self.onnx_fixed_size: | |
| fixed_h, fixed_w = self.onnx_fixed_size | |
| # Resize to fixed size | |
| image_resized = cv2.resize(image_rgb, (fixed_w, fixed_h), interpolation=cv2.INTER_LANCZOS4) | |
| mask_resized = cv2.resize(mask, (fixed_w, fixed_h), interpolation=cv2.INTER_NEAREST) | |
| # Prepare inputs based on model type | |
| if is_carve_model: | |
| # Carve model expects normalized input [0, 1] range | |
| logger.debug("Using Carve model normalization [0, 1]") | |
| img_np = image_resized.astype(np.float32) / 255.0 | |
| mask_np = mask_resized.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 # Binary mask | |
| elif self.current_method == 'aot' or 'aot' in str(self.current_method).lower(): | |
| # AOT normalization: [-1, 1] range for image | |
| logger.debug("Using AOT model normalization [-1, 1] for image, [0, 1] for mask") | |
| img_np = (image_resized.astype(np.float32) / 127.5) - 1.0 | |
| mask_np = mask_resized.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 # Binary mask | |
| img_np = img_np * (1 - mask_np[:, :, np.newaxis]) # Mask out regions | |
| elif 'anime' in str(self.current_method).lower(): | |
| # Anime/Manga LaMa normalization: [0, 1] range with optional input masking for stability | |
| logger.debug("Using Anime/Manga LaMa normalization [0, 1] with input masking") | |
| img_np = image_resized.astype(np.float32) / 255.0 | |
| mask_np = mask_resized.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 # Binary mask | |
| # CRITICAL: Mask out input regions for better text region stability | |
| # This helps the model focus on generating content rather than being influenced by text artifacts | |
| img_np = img_np * (1 - mask_np[:, :, np.newaxis]) | |
| else: | |
| # Standard LaMa normalization: [0, 1] range | |
| logger.debug("Using standard LaMa normalization [0, 1]") | |
| img_np = image_resized.astype(np.float32) / 255.0 | |
| mask_np = mask_resized.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0) * 1.0 | |
| # Convert to NCHW format | |
| img_np = img_np.transpose(2, 0, 1)[np.newaxis, ...] | |
| mask_np = mask_np[np.newaxis, np.newaxis, ...] | |
| # Run ONNX inference | |
| ort_inputs = { | |
| self.onnx_input_names[0]: img_np.astype(np.float32), | |
| self.onnx_input_names[1]: mask_np.astype(np.float32) | |
| } | |
| ort_outputs = self.onnx_session.run(self.onnx_output_names, ort_inputs) | |
| output = ort_outputs[0] | |
| # Post-process output based on model type | |
| if is_carve_model: | |
| # CRITICAL: Carve model outputs values ALREADY in [0, 255] range! | |
| # DO NOT multiply by 255 or apply any scaling | |
| logger.debug("Carve model output is already in [0, 255] range") | |
| raw_output = output[0].transpose(1, 2, 0) | |
| logger.debug(f"Carve output stats: min={raw_output.min():.3f}, max={raw_output.max():.3f}, mean={raw_output.mean():.3f}") | |
| result = raw_output # Just transpose, no scaling | |
| elif self.current_method == 'aot' or 'aot' in str(self.current_method).lower(): | |
| # AOT: [-1, 1] to [0, 255] | |
| result = ((output[0].transpose(1, 2, 0) + 1.0) * 127.5) | |
| else: | |
| # Standard: [0, 1] to [0, 255] | |
| result = output[0].transpose(1, 2, 0) * 255 | |
| result = np.clip(np.round(result), 0, 255).astype(np.uint8) | |
| # CRITICAL: Convert RGB (model output) back to BGR (OpenCV expected) | |
| result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| # Resize back to original size | |
| result = cv2.resize(result, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| self._log_inpaint_diag('onnx-fixed', result, mask) | |
| else: | |
| # Variable-size models (use padding) | |
| image_padded, padding = self.pad_img_to_modulo(image_rgb, self.pad_mod) | |
| mask_padded, _ = self.pad_img_to_modulo(mask, self.pad_mod) | |
| # Prepare inputs based on model type | |
| if is_carve_model: | |
| # Carve model normalization [0, 1] | |
| logger.debug("Using Carve model normalization [0, 1]") | |
| img_np = image_padded.astype(np.float32) / 255.0 | |
| mask_np = mask_padded.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 | |
| elif self.current_method == 'aot' or 'aot' in str(self.current_method).lower(): | |
| # AOT normalization: [-1, 1] for image | |
| logger.debug("Using AOT model normalization [-1, 1] for image, [0, 1] for mask") | |
| img_np = (image_padded.astype(np.float32) / 127.5) - 1.0 | |
| mask_np = mask_padded.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 | |
| img_np = img_np * (1 - mask_np[:, :, np.newaxis]) # Mask out regions | |
| elif 'anime' in str(self.current_method).lower(): | |
| # Anime/Manga LaMa normalization: [0, 1] range with optional input masking for stability | |
| logger.debug("Using Anime/Manga LaMa normalization [0, 1] with input masking") | |
| img_np = image_padded.astype(np.float32) / 255.0 | |
| mask_np = mask_padded.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5) * 1.0 # Binary mask | |
| # CRITICAL: Mask out input regions for better text region stability | |
| # This helps the model focus on generating content rather than being influenced by text artifacts | |
| img_np = img_np * (1 - mask_np[:, :, np.newaxis]) | |
| else: | |
| # Standard LaMa normalization: [0, 1] | |
| logger.debug("Using standard LaMa normalization [0, 1]") | |
| img_np = image_padded.astype(np.float32) / 255.0 | |
| mask_np = mask_padded.astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0) * 1.0 | |
| # Convert to NCHW format | |
| img_np = img_np.transpose(2, 0, 1)[np.newaxis, ...] | |
| mask_np = mask_np[np.newaxis, np.newaxis, ...] | |
| # Check for stop before inference | |
| if self._check_stop(): | |
| self._log("⏹️ ONNX inference stopped by user", "warning") | |
| return image | |
| # Run ONNX inference | |
| ort_inputs = { | |
| self.onnx_input_names[0]: img_np.astype(np.float32), | |
| self.onnx_input_names[1]: mask_np.astype(np.float32) | |
| } | |
| ort_outputs = self.onnx_session.run(self.onnx_output_names, ort_inputs) | |
| output = ort_outputs[0] | |
| # Post-process output | |
| if is_carve_model: | |
| # CRITICAL: Carve model outputs values ALREADY in [0, 255] range! | |
| logger.debug("Carve model output is already in [0, 255] range") | |
| raw_output = output[0].transpose(1, 2, 0) | |
| logger.debug(f"Carve output stats: min={raw_output.min():.3f}, max={raw_output.max():.3f}, mean={raw_output.mean():.3f}") | |
| result = raw_output # Just transpose, no scaling | |
| elif self.current_method == 'aot' or 'aot' in str(self.current_method).lower(): | |
| result = ((output[0].transpose(1, 2, 0) + 1.0) * 127.5) | |
| else: | |
| result = output[0].transpose(1, 2, 0) * 255 | |
| result = np.clip(np.round(result), 0, 255).astype(np.uint8) | |
| # CRITICAL: Convert RGB (model output) back to BGR (OpenCV expected) | |
| result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| # Remove padding | |
| result = self.remove_padding(result, padding) | |
| self._log_inpaint_diag('onnx-padded', result, mask) | |
| elif self.is_jit_model: | |
| # JIT model processing | |
| if self.current_method == 'aot': | |
| # Special handling for AOT model | |
| logger.debug("Using AOT-specific preprocessing") | |
| # CRITICAL: Convert BGR (OpenCV) to RGB (AOT model expected) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Pad images to be divisible by mod | |
| image_padded, padding = self.pad_img_to_modulo(image_rgb, self.pad_mod) | |
| mask_padded, _ = self.pad_img_to_modulo(mask, self.pad_mod) | |
| # AOT normalization: [-1, 1] range | |
| img_torch = torch.from_numpy(image_padded).permute(2, 0, 1).unsqueeze_(0).float() / 127.5 - 1.0 | |
| mask_torch = torch.from_numpy(mask_padded).unsqueeze_(0).unsqueeze_(0).float() / 255.0 | |
| # Binarize mask for AOT | |
| mask_torch[mask_torch < 0.5] = 0 | |
| mask_torch[mask_torch >= 0.5] = 1 | |
| # Move to device | |
| img_torch = img_torch.to(self.device) | |
| mask_torch = mask_torch.to(self.device) | |
| # Optional FP16 on GPU for lower VRAM | |
| if self.quantize_enabled and self.use_gpu: | |
| try: | |
| if self.torch_precision == 'fp16' or self.torch_precision == 'auto': | |
| img_torch = img_torch.half() | |
| mask_torch = mask_torch.half() | |
| except Exception: | |
| pass | |
| # CRITICAL FOR AOT: Apply mask to input image | |
| img_torch = img_torch * (1 - mask_torch) | |
| logger.debug(f"AOT Image shape: {img_torch.shape}, Mask shape: {mask_torch.shape}") | |
| # Run inference | |
| with torch.no_grad(): | |
| inpainted = self.model(img_torch, mask_torch) | |
| # Post-process AOT output: denormalize from [-1, 1] to [0, 255] | |
| result = ((inpainted.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5) | |
| result = np.clip(np.round(result), 0, 255).astype(np.uint8) | |
| # CRITICAL: Convert RGB (model output) back to BGR (OpenCV expected) | |
| result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| # Remove padding | |
| result = self.remove_padding(result, padding) | |
| self._log_inpaint_diag('jit-aot', result, mask) | |
| else: | |
| # LaMa/Anime model processing | |
| logger.debug(f"Using standard processing for {self.current_method}") | |
| # CRITICAL: Convert BGR (OpenCV) to RGB (LaMa/JIT models expected) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Pad images to be divisible by mod | |
| image_padded, padding = self.pad_img_to_modulo(image_rgb, self.pad_mod) | |
| mask_padded, _ = self.pad_img_to_modulo(mask, self.pad_mod) | |
| # CRITICAL: Normalize to [0, 1] range for LaMa models | |
| image_norm = image_padded.astype(np.float32) / 255.0 | |
| mask_norm = mask_padded.astype(np.float32) / 255.0 | |
| # Binary mask (values > 0 become 1) | |
| mask_binary = (mask_norm > 0) * 1.0 | |
| # For anime models: mask out input regions for better text stability | |
| if 'anime' in str(self.current_method).lower(): | |
| logger.debug("Applying input masking for anime model (text region stability)") | |
| image_norm = image_norm * (1 - mask_binary[:, :, np.newaxis]) | |
| # Convert to PyTorch tensors with correct shape | |
| # Image should be [B, C, H, W] | |
| image_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).float() | |
| mask_tensor = torch.from_numpy(mask_binary).unsqueeze(0).unsqueeze(0).float() | |
| # Move to device | |
| image_tensor = image_tensor.to(self.device) | |
| mask_tensor = mask_tensor.to(self.device) | |
| # Optional FP16 on GPU for lower VRAM | |
| if self.quantize_enabled and self.use_gpu: | |
| try: | |
| if self.torch_precision == 'fp16' or self.torch_precision == 'auto': | |
| image_tensor = image_tensor.half() | |
| mask_tensor = mask_tensor.half() | |
| except Exception: | |
| pass | |
| # Debug shapes | |
| logger.debug(f"Image tensor shape: {image_tensor.shape}") # Should be [1, 3, H, W] | |
| logger.debug(f"Mask tensor shape: {mask_tensor.shape}") # Should be [1, 1, H, W] | |
| # Ensure spatial dimensions match | |
| if image_tensor.shape[2:] != mask_tensor.shape[2:]: | |
| logger.warning(f"Spatial dimension mismatch: image {image_tensor.shape[2:]}, mask {mask_tensor.shape[2:]}") | |
| # Resize mask to match image | |
| mask_tensor = F.interpolate(mask_tensor, size=image_tensor.shape[2:], mode='nearest') | |
| # Run inference with proper error handling | |
| with torch.no_grad(): | |
| try: | |
| # Standard LaMa JIT models expect (image, mask) | |
| inpainted = self.model(image_tensor, mask_tensor) | |
| except RuntimeError as e: | |
| error_str = str(e) | |
| logger.error(f"Model inference failed: {error_str}") | |
| # If tensor size mismatch, log detailed info | |
| if "size of tensor" in error_str.lower(): | |
| logger.error(f"Image shape: {image_tensor.shape}") | |
| logger.error(f"Mask shape: {mask_tensor.shape}") | |
| # Try transposing if needed | |
| if "dimension 3" in error_str and "880" in error_str: | |
| # This suggests the tensors might be in wrong format | |
| # Try different permutation | |
| logger.info("Attempting to fix tensor format...") | |
| # Ensure image is [B, C, H, W] not [B, H, W, C] | |
| if image_tensor.shape[1] > 3: | |
| image_tensor = image_tensor.permute(0, 3, 1, 2) | |
| logger.info(f"Permuted image to: {image_tensor.shape}") | |
| # Try again | |
| inpainted = self.model(image_tensor, mask_tensor) | |
| else: | |
| # As last resort, try swapped arguments | |
| logger.info("Trying swapped arguments (mask, image)...") | |
| inpainted = self.model(mask_tensor, image_tensor) | |
| else: | |
| raise e | |
| # Process output | |
| # Output should be [B, C, H, W] | |
| if len(inpainted.shape) == 4: | |
| # Remove batch dimension and permute to [H, W, C] | |
| result = inpainted[0].permute(1, 2, 0).detach().cpu().numpy() | |
| else: | |
| # Handle unexpected output shape | |
| result = inpainted.detach().cpu().numpy() | |
| if len(result.shape) == 3 and result.shape[0] == 3: | |
| result = result.transpose(1, 2, 0) | |
| # Denormalize to 0-255 range | |
| result = np.clip(result * 255, 0, 255).astype(np.uint8) | |
| # CRITICAL: Convert RGB (model output) back to BGR (OpenCV expected) | |
| result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| # Remove padding | |
| result = self.remove_padding(result, padding) | |
| self._log_inpaint_diag('jit-lama', result, mask) | |
| else: | |
| # Original checkpoint model processing (keep as is) | |
| h, w = image.shape[:2] | |
| size = 768 if self.current_method == 'anime' else 512 | |
| img_resized = cv2.resize(image, (size, size), interpolation=cv2.INTER_LANCZOS4) | |
| mask_resized = cv2.resize(mask, (size, size), interpolation=cv2.INTER_NEAREST) | |
| img_norm = img_resized.astype(np.float32) / 127.5 - 1 | |
| mask_norm = mask_resized.astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).float() | |
| mask_tensor = torch.from_numpy(mask_norm).unsqueeze(0).unsqueeze(0).float() | |
| if self.use_gpu and self.device: | |
| img_tensor = img_tensor.to(self.device) | |
| mask_tensor = mask_tensor.to(self.device) | |
| with torch.no_grad(): | |
| output = self.model(img_tensor, mask_tensor) | |
| result = output.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
| result = ((result + 1) * 127.5).clip(0, 255).astype(np.uint8) | |
| result = cv2.resize(result, (w, h), interpolation=cv2.INTER_LANCZOS4) | |
| self._log_inpaint_diag('ckpt', result, mask) | |
| # Ensure result matches original size exactly | |
| if result.shape[:2] != (orig_h, orig_w): | |
| result = cv2.resize(result, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| # Apply refinement blending if requested | |
| if refinement != 'fast': | |
| # Ensure mask is same size as result | |
| if mask.shape[:2] != (orig_h, orig_w): | |
| mask = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
| mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) / 255.0 | |
| kernel = cv2.getGaussianKernel(21, 5) | |
| kernel = kernel @ kernel.T | |
| mask_blur = cv2.filter2D(mask_3ch, -1, kernel) | |
| result = (result * mask_blur + image * (1 - mask_blur)).astype(np.uint8) | |
| # No-op detection and auto-retry | |
| try: | |
| if self._is_noop(image, result, mask): | |
| if _retry_attempt == 0: | |
| logger.warning("⚠️ Inpainting produced no visible change; retrying with slight mask dilation and fast refinement") | |
| kernel = np.ones((3, 3), np.uint8) | |
| expanded_mask = cv2.dilate(mask, kernel, iterations=1) | |
| return self.inpaint(image, expanded_mask, refinement='fast', _retry_attempt=1) | |
| elif _retry_attempt == 1: | |
| logger.warning("⚠️ Still no visible change after retry; attempting a second dilation and fast refinement") | |
| kernel = np.ones((5, 5), np.uint8) | |
| expanded_mask2 = cv2.dilate(mask, kernel, iterations=1) | |
| return self.inpaint(image, expanded_mask2, refinement='fast', _retry_attempt=2) | |
| else: | |
| logger.warning("⚠️ No further retries; returning last result without fallback") | |
| except Exception as e: | |
| logger.debug(f"No-op detection step failed: {e}") | |
| logger.info("✅ Inpainted successfully!") | |
| # Force garbage collection to reduce memory spikes | |
| try: | |
| import gc | |
| gc.collect() | |
| # Clear CUDA cache if using GPU | |
| if torch is not None and hasattr(torch, 'cuda') and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| time.sleep(0.1) # Brief pause for stability | |
| logger.debug("💤 Inpainting completion pausing briefly for stability") | |
| return result | |
| except Exception as e: | |
| logger.error(f"❌ Inpainting failed: {e}") | |
| logger.error(traceback.format_exc()) | |
| # Return original image on failure | |
| logger.warning("Returning original image due to error") | |
| return image | |
| def inpaint_with_prompt(self, image, mask, prompt=None): | |
| """Compatibility method""" | |
| return self.inpaint(image, mask) | |
| def batch_inpaint(self, images, masks): | |
| """Batch inpainting""" | |
| return [self.inpaint(img, mask) for img, mask in zip(images, masks)] | |
| def load_bubble_model(self, model_path: str) -> bool: | |
| """Load bubble detection model""" | |
| if not BUBBLE_DETECTOR_AVAILABLE: | |
| logger.warning("Bubble detector not available") | |
| return False | |
| if self.bubble_detector is None: | |
| self.bubble_detector = BubbleDetector() | |
| if self.bubble_detector.load_model(model_path): | |
| self.bubble_model_loaded = True | |
| self.config['bubble_model_path'] = model_path | |
| self._save_config() | |
| logger.info("✅ Bubble detection model loaded") | |
| return True | |
| return False | |
| def detect_bubbles(self, image_path: str, confidence: float = 0.5) -> List[Tuple[int, int, int, int]]: | |
| """Detect speech bubbles in image""" | |
| if not self.bubble_model_loaded or self.bubble_detector is None: | |
| logger.warning("No bubble model loaded") | |
| return [] | |
| return self.bubble_detector.detect_bubbles(image_path, confidence=confidence) | |
| def create_bubble_mask(self, image: np.ndarray, bubbles: List[Tuple[int, int, int, int]], | |
| expand_pixels: int = 5) -> np.ndarray: | |
| """Create mask from detected bubbles""" | |
| h, w = image.shape[:2] | |
| mask = np.zeros((h, w), dtype=np.uint8) | |
| for x, y, bw, bh in bubbles: | |
| x1 = max(0, x - expand_pixels) | |
| y1 = max(0, y - expand_pixels) | |
| x2 = min(w, x + bw + expand_pixels) | |
| y2 = min(h, y + bh + expand_pixels) | |
| cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1) | |
| return mask | |
| def inpaint_with_bubble_detection(self, image_path: str, confidence: float = 0.5, | |
| expand_pixels: int = 5, refinement: str = 'normal') -> np.ndarray: | |
| """Inpaint using automatic bubble detection""" | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| logger.error(f"Failed to load image: {image_path}") | |
| return None | |
| bubbles = self.detect_bubbles(image_path, confidence) | |
| if not bubbles: | |
| logger.warning("No bubbles detected") | |
| return image | |
| logger.info(f"Detected {len(bubbles)} bubbles") | |
| mask = self.create_bubble_mask(image, bubbles, expand_pixels) | |
| result = self.inpaint(image, mask, refinement) | |
| return result | |
| def batch_inpaint_with_bubbles(self, image_paths: List[str], **kwargs) -> List[np.ndarray]: | |
| """Batch inpaint multiple images with bubble detection""" | |
| results = [] | |
| for i, image_path in enumerate(image_paths): | |
| logger.info(f"Processing image {i+1}/{len(image_paths)}") | |
| result = self.inpaint_with_bubble_detection(image_path, **kwargs) | |
| results.append(result) | |
| return results | |
| # Compatibility classes - MAINTAIN ALL ORIGINAL CLASSES | |
| class LaMaModel(FFCInpaintModel): | |
| pass | |
| class MATModel(FFCInpaintModel): | |
| pass | |
| class AOTModel(FFCInpaintModel): | |
| pass | |
| class SDInpaintModel(FFCInpaintModel): | |
| pass | |
| class AnimeMangaInpaintModel(FFCInpaintModel): | |
| pass | |
| class LaMaOfficialModel(FFCInpaintModel): | |
| pass | |
| class HybridInpainter: | |
| """Hybrid inpainter for compatibility""" | |
| def __init__(self): | |
| self.inpainters = {} | |
| def add_method(self, name, method, model_path): | |
| """Add a method - maintains compatibility""" | |
| try: | |
| inpainter = LocalInpainter() | |
| if inpainter.load_model(method, model_path): | |
| self.inpainters[name] = inpainter | |
| return True | |
| except: | |
| pass | |
| return False | |
| def inpaint_ensemble(self, image: np.ndarray, mask: np.ndarray, | |
| weights: Dict[str, float] = None) -> np.ndarray: | |
| """Ensemble inpainting""" | |
| if not self.inpainters: | |
| logger.error("No inpainters loaded") | |
| return image | |
| if weights is None: | |
| weights = {name: 1.0 / len(self.inpainters) for name in self.inpainters} | |
| results = [] | |
| for name, inpainter in self.inpainters.items(): | |
| result = inpainter.inpaint(image, mask) | |
| weight = weights.get(name, 1.0 / len(self.inpainters)) | |
| results.append(result * weight) | |
| ensemble = np.sum(results, axis=0).astype(np.uint8) | |
| return ensemble | |
| # Helper function for quick setup | |
| def setup_inpainter_for_manga(auto_download=True): | |
| """Quick setup for manga inpainting""" | |
| inpainter = LocalInpainter() | |
| if auto_download: | |
| # Try to download anime JIT model | |
| jit_path = inpainter.download_jit_model('anime') | |
| if jit_path: | |
| inpainter.load_model('anime', jit_path) | |
| logger.info("✅ Manga inpainter ready with JIT model") | |
| return inpainter | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) > 1: | |
| if sys.argv[1] == "download_jit": | |
| # Download JIT models | |
| inpainter = LocalInpainter() | |
| for method in ['lama', 'anime', 'lama_official']: | |
| print(f"\nDownloading {method}...") | |
| path = inpainter.download_jit_model(method) | |
| if path: | |
| print(f" ✅ Downloaded to: {path}") | |
| elif len(sys.argv) > 2: | |
| # Test with model | |
| inpainter = LocalInpainter() | |
| inpainter.load_model('lama', sys.argv[1]) | |
| print("Model loaded - check logs for details") | |
| else: | |
| print("\nLocal Inpainter - Compatible Version") | |
| print("=====================================") | |
| print("\nSupports both:") | |
| print(" - JIT models (.pt) - RECOMMENDED") | |
| print(" - Checkpoint files (.ckpt) - With warnings") | |
| print("\nTo download JIT models:") | |
| print(" python local_inpainter.py download_jit") | |
| print("\nTo test:") | |
| print(" python local_inpainter.py <model_path>") |