import os from openai import OpenAI from typing import List, Tuple, Dict, Any from utils import format_ocr_results_for_prompt, robust_parse_reply class PromptHandler: def __init__(self, api_key: str = None, model: str = "gpt-4.1-nano"): """ 初始化Prompt处理器 Args: api_key: OpenAI API密钥,如果不提供则从环境变量获取 model: 使用的模型名称 """ if api_key: os.environ["OPENAI_API_KEY"] = api_key self.client = OpenAI() self.model = model def create_system_prompt(self) -> str: """ 创建系统提示词 Returns: 系统提示词字符串 """ return ( "You are a helpful assistant. " "You are given a list of OCR results in the form [(bbox, text, score)], " "and a user prompt that describes what text to enlarge and how much to scale it. " "Your job is to:\n" "1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n" "2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n" "3. Output only two fields:\n" " target_text: the exact string from OCR result you chose\n" " scale_factor: a float number\n\n" "Your output must be strictly in JSON format like:\n" "{\n \"target_text\": \"Tools\",\n \"scale_factor\": 1.2\n}" ) def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str: """ 创建用户提示词 Args: ocr_results: OCR识别结果列表 user_request: 用户的原始请求 Returns: 用户提示词字符串 """ formatted_results = format_ocr_results_for_prompt(ocr_results) return f""" Here are the OCR results: {formatted_results} User prompt: "{user_request}" """ def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]: """ 使用LLM解析用户请求 Args: ocr_results: OCR识别结果列表 user_request: 用户的原始请求 Returns: 包含target_text和scale_factor的字典 Raises: Exception: 当API调用失败或解析失败时 """ # 构造消息 messages = [ {"role": "system", "content": self.create_system_prompt()}, {"role": "user", "content": self.create_user_prompt(ocr_results, user_request)} ] try: # 调用OpenAI API response = self.client.chat.completions.create( model=self.model, messages=messages, temperature=0.3, max_tokens=300, ) # 获取回复 reply = response.choices[0].message.content # 解析回复 parsed_result = robust_parse_reply(reply) return parsed_result except Exception as e: raise Exception(f"LLM解析失败: {str(e)}") def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool: """ 验证解析结果的有效性 Args: parsed_result: 解析后的结果字典 ocr_results: OCR识别结果列表 Returns: 验证是否通过 """ target_text = parsed_result.get("target_text", "") scale_factor = parsed_result.get("scale_factor", 0) # 检查目标文字是否在OCR结果中 ocr_texts = [text.strip() for _, text, _ in ocr_results] if target_text not in ocr_texts: print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到") print(f"可用的文字: {ocr_texts}") return False # 检查缩放因子是否合理 if not isinstance(scale_factor, (int, float)) or scale_factor <= 0: print(f"错误: 缩放因子 {scale_factor} 不合法") return False return True def get_api_key_from_env() -> str: """ 从环境变量获取OpenAI API密钥 Returns: API密钥字符串 Raises: ValueError: 当找不到API密钥时 """ api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError( "未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY," "或在创建PromptHandler时提供api_key参数。" ) return api_key