text_editing / core.py
yingzhac's picture
🎨 Add smart text resizing functionality
838e8f6
import cv2
import numpy as np
import easyocr
from typing import List, Tuple, Optional
class TextResizer:
def __init__(self, languages=['en', 'ch_sim'], gpu=False):
"""
初始化文字缩放器
Args:
languages: OCR支持的语言列表
gpu: 是否使用GPU
"""
self.reader = easyocr.Reader(languages, gpu=gpu)
def read_text(self, image: np.ndarray) -> List[Tuple]:
"""
从图像中识别文字
Args:
image: RGB格式的图像数组
Returns:
OCR结果列表,每个元素为(bbox, text, confidence)
"""
return self.reader.readtext(image)
def extract_text_mask_by_content(self, image: np.ndarray, results: List[Tuple], target_text: str) -> np.ndarray:
"""
根据目标文字内容提取文字mask
Args:
image: RGB格式的图像数组
results: OCR识别结果
target_text: 目标文字内容
Returns:
文字mask,白色为文字区域
"""
h, w = image.shape[:2]
mask = np.zeros((h, w), dtype=np.uint8)
for (bbox, text, _) in results:
if text.strip() != target_text:
continue
x_min = int(min([pt[0] for pt in bbox]))
x_max = int(max([pt[0] for pt in bbox]))
y_min = int(min([pt[1] for pt in bbox]))
y_max = int(max([pt[1] for pt in bbox]))
roi = image[y_min:y_max, x_min:x_max]
gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask_roi = np.zeros_like(thresh)
cv2.drawContours(mask_roi, contours, -1, 255, -1)
mask[y_min:y_max, x_min:x_max] = np.maximum(mask[y_min:y_max, x_min:x_max], mask_roi)
return mask
def inpaint_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
使用mask对图像进行修复
Args:
image: RGB格式的图像数组
mask: 需要修复的区域mask
Returns:
修复后的图像
"""
return cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA)
def find_target_bbox(self, results: List[Tuple], target_text: str) -> Optional[List]:
"""
查找目标文字的边界框
Args:
results: OCR识别结果
target_text: 目标文字内容
Returns:
目标文字的边界框,如果未找到则返回None
"""
for (bbox, text, _) in results:
if text.strip() == target_text:
return bbox
return None
def create_resized_text_patch(self, image: np.ndarray, bbox: List, scale_factor: float) -> Tuple[np.ndarray, int, int]:
"""
创建缩放后的文字补丁
Args:
image: RGB格式的图像数组
bbox: 文字边界框
scale_factor: 缩放因子
Returns:
(RGBA格式的缩放后文字补丁, 原始中心x坐标, 原始中心y坐标)
"""
# 提取ROI
x_min = int(min(pt[0] for pt in bbox))
x_max = int(max(pt[0] for pt in bbox))
y_min = int(min(pt[1] for pt in bbox))
y_max = int(max(pt[1] for pt in bbox))
roi = image[y_min:y_max, x_min:x_max]
# 创建文字mask
gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV, 11, 2)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask_roi = np.zeros_like(thresh)
cv2.drawContours(mask_roi, contours, -1, 255, -1)
# 创建RGBA补丁
rgba_patch = cv2.cvtColor(roi, cv2.COLOR_RGB2RGBA)
rgba_patch[:, :, 3] = mask_roi
# 缩放
h, w = rgba_patch.shape[:2]
new_size = (int(w * scale_factor), int(h * scale_factor))
resized_patch = cv2.resize(rgba_patch, new_size, interpolation=cv2.INTER_LINEAR)
# 计算原始中心点
cx = (x_min + x_max) // 2
cy = (y_min + y_max) // 2
return resized_patch, cx, cy
def blend_text_patch(self, canvas: np.ndarray, patch: np.ndarray, center_x: int, center_y: int) -> np.ndarray:
"""
将文字补丁混合到画布上
Args:
canvas: 目标画布(RGB格式)
patch: RGBA格式的文字补丁
center_x: 放置的中心x坐标
center_y: 放置的中心y坐标
Returns:
混合后的图像
"""
result = canvas.copy()
new_h, new_w = patch.shape[:2]
top_left_x = max(0, center_x - new_w // 2)
top_left_y = max(0, center_y - new_h // 2)
for y in range(new_h):
for x in range(new_w):
if patch[y, x, 3] > 0: # 如果alpha > 0
yy = top_left_y + y
xx = top_left_x + x
if 0 <= yy < result.shape[0] and 0 <= xx < result.shape[1]:
alpha = patch[y, x, 3] / 255.0
result[yy, xx] = (
(1 - alpha) * result[yy, xx] + alpha * patch[y, x, :3]
).astype(np.uint8)
return result
def resize_text(self, image: np.ndarray, target_text: str, scale_factor: float) -> np.ndarray:
"""
完整的文字缩放流程
Args:
image: RGB格式的图像数组
target_text: 目标文字内容
scale_factor: 缩放因子
Returns:
处理后的图像
"""
# 1. OCR识别
results = self.read_text(image)
# 2. 查找目标文字
target_bbox = self.find_target_bbox(results, target_text)
if target_bbox is None:
raise ValueError(f"未找到目标文字: {target_text}")
# 3. 提取文字mask
text_mask = self.extract_text_mask_by_content(image, results, target_text)
# 4. 图像修复
inpainted = self.inpaint_image(image, text_mask)
# 5. 创建缩放后的文字补丁
resized_patch, cx, cy = self.create_resized_text_patch(image, target_bbox, scale_factor)
# 6. 混合文字补丁
result = self.blend_text_patch(inpainted, resized_patch, cx, cy)
return result