Spaces:
Sleeping
Sleeping
File size: 4,261 Bytes
838e8f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import json
import re
import cv2
import numpy as np
from typing import Dict, Any
def robust_parse_reply(reply: str) -> Dict[str, Any]:
"""
从LLM返回的字符串中提取JSON格式的回复
Args:
reply: LLM返回的原始字符串
Returns:
解析后的字典,包含target_text和scale_factor
Raises:
ValueError: 当无法解析JSON或缺少必要字段时
"""
# 尝试去除 Markdown 代码块标记(如 ```json 或 ```)
cleaned = re.sub(r"```(?:json)?", "", reply, flags=re.IGNORECASE).strip("` \n")
# 尝试提取最可能的 JSON 段(形如 {...})
match = re.search(r"\{.*?\}", cleaned, flags=re.DOTALL)
if not match:
raise ValueError("未找到 JSON 对象")
json_str = match.group(0)
try:
parsed = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"JSON 解析失败: {e}")
# 校验字段完整性
if "target_text" not in parsed or "scale_factor" not in parsed:
raise ValueError("JSON 中缺少必要字段 target_text 或 scale_factor")
return parsed
def load_image(image_path: str) -> np.ndarray:
"""
加载图像并转换为RGB格式
Args:
image_path: 图像文件路径
Returns:
RGB格式的图像数组
Raises:
ValueError: 当图像加载失败时
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法加载图像: {image_path}")
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image_rgb
def save_image(image: np.ndarray, output_path: str) -> None:
"""
保存RGB格式的图像
Args:
image: RGB格式的图像数组
output_path: 输出文件路径
"""
# 转换为BGR格式以便OpenCV保存
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_path, image_bgr)
def validate_scale_factor(scale_factor: float) -> float:
"""
验证并标准化缩放因子
Args:
scale_factor: 原始缩放因子
Returns:
验证后的缩放因子
Raises:
ValueError: 当缩放因子不合法时
"""
if not isinstance(scale_factor, (int, float)):
raise ValueError("缩放因子必须是数字")
if scale_factor <= 0:
raise ValueError("缩放因子必须大于0")
if scale_factor > 10:
print(f"警告: 缩放因子 {scale_factor} 过大,可能导致处理时间过长")
return float(scale_factor)
def format_ocr_results_for_prompt(results: list) -> str:
"""
格式化OCR结果以用于LLM prompt
Args:
results: OCR识别结果列表
Returns:
格式化后的文字列表字符串
"""
text_list = [text for _, text, _ in results]
return str(text_list)
def parse_percentage_to_scale_factor(text: str) -> float:
"""
将百分比表示转换为缩放因子
Args:
text: 包含百分比的文本,如 "enlarge by 50%" 或 "shrink by 25%"
Returns:
对应的缩放因子
"""
# 查找百分比数字
percentage_match = re.search(r'(\d+(?:\.\d+)?)%', text.lower())
if not percentage_match:
return 1.0 # 默认不缩放
percentage = float(percentage_match.group(1))
# 判断是放大还是缩小
if 'enlarge' in text.lower() or 'increase' in text.lower() or 'bigger' in text.lower():
return 1 + (percentage / 100)
elif 'shrink' in text.lower() or 'reduce' in text.lower() or 'smaller' in text.lower():
return 1 - (percentage / 100)
else:
# 默认当作放大处理
return 1 + (percentage / 100)
def create_output_filename(input_path: str, suffix: str = "_resized") -> str:
"""
根据输入文件路径创建输出文件名
Args:
input_path: 输入文件路径
suffix: 添加的后缀
Returns:
输出文件路径
"""
import os
base_name = os.path.splitext(input_path)[0]
extension = os.path.splitext(input_path)[1]
return f"{base_name}{suffix}{extension}" |