text_editing / prompt_handler.py
yingzhac's picture
🔧 Fix major issues in text processing
e219479
import os
from openai import OpenAI
from typing import List, Tuple, Dict, Any
from utils import format_ocr_results_for_prompt, robust_parse_reply
class PromptHandler:
def __init__(self, api_key: str = None, model: str = "gpt-4.1-nano"):
"""
初始化Prompt处理器
Args:
api_key: OpenAI API密钥,如果不提供则从环境变量获取
model: 使用的模型名称
"""
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
self.client = OpenAI()
self.model = model
def create_system_prompt(self) -> str:
"""
创建系统提示词
Returns:
系统提示词字符串
"""
return (
"You are a helpful assistant. "
"You are given a list of OCR results in the form [(bbox, text, score)], "
"and a user prompt that describes what text to enlarge and how much to scale it. "
"Your job is to:\n"
"1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n"
"2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n"
"3. Output only two fields:\n"
" target_text: the exact string from OCR result you chose\n"
" scale_factor: a float number\n\n"
"Your output must be strictly in JSON format like:\n"
"{\n \"target_text\": \"Tools\",\n \"scale_factor\": 1.2\n}"
)
def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str:
"""
创建用户提示词
Args:
ocr_results: OCR识别结果列表
user_request: 用户的原始请求
Returns:
用户提示词字符串
"""
formatted_results = format_ocr_results_for_prompt(ocr_results)
return f"""
Here are the OCR results:
{formatted_results}
User prompt:
"{user_request}"
"""
def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]:
"""
使用LLM解析用户请求
Args:
ocr_results: OCR识别结果列表
user_request: 用户的原始请求
Returns:
包含target_text和scale_factor的字典
Raises:
Exception: 当API调用失败或解析失败时
"""
# 构造消息
messages = [
{"role": "system", "content": self.create_system_prompt()},
{"role": "user", "content": self.create_user_prompt(ocr_results, user_request)}
]
try:
# 调用OpenAI API
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.3,
max_tokens=300,
)
# 获取回复
reply = response.choices[0].message.content
# 解析回复
parsed_result = robust_parse_reply(reply)
return parsed_result
except Exception as e:
raise Exception(f"LLM解析失败: {str(e)}")
def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool:
"""
验证解析结果的有效性
Args:
parsed_result: 解析后的结果字典
ocr_results: OCR识别结果列表
Returns:
验证是否通过
"""
target_text = parsed_result.get("target_text", "")
scale_factor = parsed_result.get("scale_factor", 0)
# 检查目标文字是否在OCR结果中
ocr_texts = [text.strip() for _, text, _ in ocr_results]
if target_text not in ocr_texts:
print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到")
print(f"可用的文字: {ocr_texts}")
return False
# 检查缩放因子是否合理
if not isinstance(scale_factor, (int, float)) or scale_factor <= 0:
print(f"错误: 缩放因子 {scale_factor} 不合法")
return False
return True
def get_api_key_from_env() -> str:
"""
从环境变量获取OpenAI API密钥
Returns:
API密钥字符串
Raises:
ValueError: 当找不到API密钥时
"""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY,"
"或在创建PromptHandler时提供api_key参数。"
)
return api_key