Spaces:
Sleeping
Sleeping
File size: 4,914 Bytes
838e8f6 e219479 838e8f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
from openai import OpenAI
from typing import List, Tuple, Dict, Any
from utils import format_ocr_results_for_prompt, robust_parse_reply
class PromptHandler:
def __init__(self, api_key: str = None, model: str = "gpt-4.1-nano"):
"""
初始化Prompt处理器
Args:
api_key: OpenAI API密钥,如果不提供则从环境变量获取
model: 使用的模型名称
"""
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
self.client = OpenAI()
self.model = model
def create_system_prompt(self) -> str:
"""
创建系统提示词
Returns:
系统提示词字符串
"""
return (
"You are a helpful assistant. "
"You are given a list of OCR results in the form [(bbox, text, score)], "
"and a user prompt that describes what text to enlarge and how much to scale it. "
"Your job is to:\n"
"1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n"
"2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n"
"3. Output only two fields:\n"
" target_text: the exact string from OCR result you chose\n"
" scale_factor: a float number\n\n"
"Your output must be strictly in JSON format like:\n"
"{\n \"target_text\": \"Tools\",\n \"scale_factor\": 1.2\n}"
)
def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str:
"""
创建用户提示词
Args:
ocr_results: OCR识别结果列表
user_request: 用户的原始请求
Returns:
用户提示词字符串
"""
formatted_results = format_ocr_results_for_prompt(ocr_results)
return f"""
Here are the OCR results:
{formatted_results}
User prompt:
"{user_request}"
"""
def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]:
"""
使用LLM解析用户请求
Args:
ocr_results: OCR识别结果列表
user_request: 用户的原始请求
Returns:
包含target_text和scale_factor的字典
Raises:
Exception: 当API调用失败或解析失败时
"""
# 构造消息
messages = [
{"role": "system", "content": self.create_system_prompt()},
{"role": "user", "content": self.create_user_prompt(ocr_results, user_request)}
]
try:
# 调用OpenAI API
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.3,
max_tokens=300,
)
# 获取回复
reply = response.choices[0].message.content
# 解析回复
parsed_result = robust_parse_reply(reply)
return parsed_result
except Exception as e:
raise Exception(f"LLM解析失败: {str(e)}")
def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool:
"""
验证解析结果的有效性
Args:
parsed_result: 解析后的结果字典
ocr_results: OCR识别结果列表
Returns:
验证是否通过
"""
target_text = parsed_result.get("target_text", "")
scale_factor = parsed_result.get("scale_factor", 0)
# 检查目标文字是否在OCR结果中
ocr_texts = [text.strip() for _, text, _ in ocr_results]
if target_text not in ocr_texts:
print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到")
print(f"可用的文字: {ocr_texts}")
return False
# 检查缩放因子是否合理
if not isinstance(scale_factor, (int, float)) or scale_factor <= 0:
print(f"错误: 缩放因子 {scale_factor} 不合法")
return False
return True
def get_api_key_from_env() -> str:
"""
从环境变量获取OpenAI API密钥
Returns:
API密钥字符串
Raises:
ValueError: 当找不到API密钥时
"""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY,"
"或在创建PromptHandler时提供api_key参数。"
)
return api_key |