Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import cv2 | |
| import numpy as np | |
| from typing import Dict, Any | |
| def robust_parse_reply(reply: str) -> Dict[str, Any]: | |
| """ | |
| 从LLM返回的字符串中提取JSON格式的回复 | |
| Args: | |
| reply: LLM返回的原始字符串 | |
| Returns: | |
| 解析后的字典,包含target_text和scale_factor | |
| Raises: | |
| ValueError: 当无法解析JSON或缺少必要字段时 | |
| """ | |
| # 尝试去除 Markdown 代码块标记(如 ```json 或 ```) | |
| cleaned = re.sub(r"```(?:json)?", "", reply, flags=re.IGNORECASE).strip("` \n") | |
| # 尝试提取最可能的 JSON 段(形如 {...}) | |
| match = re.search(r"\{.*?\}", cleaned, flags=re.DOTALL) | |
| if not match: | |
| raise ValueError("未找到 JSON 对象") | |
| json_str = match.group(0) | |
| try: | |
| parsed = json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| raise ValueError(f"JSON 解析失败: {e}") | |
| # 校验字段完整性 | |
| if "target_text" not in parsed or "scale_factor" not in parsed: | |
| raise ValueError("JSON 中缺少必要字段 target_text 或 scale_factor") | |
| return parsed | |
| def load_image(image_path: str) -> np.ndarray: | |
| """ | |
| 加载图像并转换为RGB格式 | |
| Args: | |
| image_path: 图像文件路径 | |
| Returns: | |
| RGB格式的图像数组 | |
| Raises: | |
| ValueError: 当图像加载失败时 | |
| """ | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| raise ValueError(f"无法加载图像: {image_path}") | |
| # 转换为RGB格式 | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| return image_rgb | |
| def save_image(image: np.ndarray, output_path: str) -> None: | |
| """ | |
| 保存RGB格式的图像 | |
| Args: | |
| image: RGB格式的图像数组 | |
| output_path: 输出文件路径 | |
| """ | |
| # 转换为BGR格式以便OpenCV保存 | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(output_path, image_bgr) | |
| def validate_scale_factor(scale_factor: float) -> float: | |
| """ | |
| 验证并标准化缩放因子 | |
| Args: | |
| scale_factor: 原始缩放因子 | |
| Returns: | |
| 验证后的缩放因子 | |
| Raises: | |
| ValueError: 当缩放因子不合法时 | |
| """ | |
| if not isinstance(scale_factor, (int, float)): | |
| raise ValueError("缩放因子必须是数字") | |
| if scale_factor <= 0: | |
| raise ValueError("缩放因子必须大于0") | |
| if scale_factor > 10: | |
| print(f"警告: 缩放因子 {scale_factor} 过大,可能导致处理时间过长") | |
| return float(scale_factor) | |
| def format_ocr_results_for_prompt(results: list) -> str: | |
| """ | |
| 格式化OCR结果以用于LLM prompt | |
| Args: | |
| results: OCR识别结果列表 | |
| Returns: | |
| 格式化后的文字列表字符串 | |
| """ | |
| text_list = [text for _, text, _ in results] | |
| return str(text_list) | |
| def parse_percentage_to_scale_factor(text: str) -> float: | |
| """ | |
| 将百分比表示转换为缩放因子 | |
| Args: | |
| text: 包含百分比的文本,如 "enlarge by 50%" 或 "shrink by 25%" | |
| Returns: | |
| 对应的缩放因子 | |
| """ | |
| # 查找百分比数字 | |
| percentage_match = re.search(r'(\d+(?:\.\d+)?)%', text.lower()) | |
| if not percentage_match: | |
| return 1.0 # 默认不缩放 | |
| percentage = float(percentage_match.group(1)) | |
| # 判断是放大还是缩小 | |
| if 'enlarge' in text.lower() or 'increase' in text.lower() or 'bigger' in text.lower(): | |
| return 1 + (percentage / 100) | |
| elif 'shrink' in text.lower() or 'reduce' in text.lower() or 'smaller' in text.lower(): | |
| return 1 - (percentage / 100) | |
| else: | |
| # 默认当作放大处理 | |
| return 1 + (percentage / 100) | |
| def create_output_filename(input_path: str, suffix: str = "_resized") -> str: | |
| """ | |
| 根据输入文件路径创建输出文件名 | |
| Args: | |
| input_path: 输入文件路径 | |
| suffix: 添加的后缀 | |
| Returns: | |
| 输出文件路径 | |
| """ | |
| import os | |
| base_name = os.path.splitext(input_path)[0] | |
| extension = os.path.splitext(input_path)[1] | |
| return f"{base_name}{suffix}{extension}" |