Spaces:
Sleeping
Sleeping
| import os | |
| from openai import OpenAI | |
| from typing import List, Tuple, Dict, Any | |
| from utils import format_ocr_results_for_prompt, robust_parse_reply | |
| class PromptHandler: | |
| def __init__(self, api_key: str = None, model: str = "gpt-4.1-nano"): | |
| """ | |
| 初始化Prompt处理器 | |
| Args: | |
| api_key: OpenAI API密钥,如果不提供则从环境变量获取 | |
| model: 使用的模型名称 | |
| """ | |
| if api_key: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| self.client = OpenAI() | |
| self.model = model | |
| def create_system_prompt(self) -> str: | |
| """ | |
| 创建系统提示词 | |
| Returns: | |
| 系统提示词字符串 | |
| """ | |
| return ( | |
| "You are a helpful assistant. " | |
| "You are given a list of OCR results in the form [(bbox, text, score)], " | |
| "and a user prompt that describes what text to enlarge and how much to scale it. " | |
| "Your job is to:\n" | |
| "1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n" | |
| "2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n" | |
| "3. Output only two fields:\n" | |
| " target_text: the exact string from OCR result you chose\n" | |
| " scale_factor: a float number\n\n" | |
| "Your output must be strictly in JSON format like:\n" | |
| "{\n \"target_text\": \"Tools\",\n \"scale_factor\": 1.2\n}" | |
| ) | |
| def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str: | |
| """ | |
| 创建用户提示词 | |
| Args: | |
| ocr_results: OCR识别结果列表 | |
| user_request: 用户的原始请求 | |
| Returns: | |
| 用户提示词字符串 | |
| """ | |
| formatted_results = format_ocr_results_for_prompt(ocr_results) | |
| return f""" | |
| Here are the OCR results: | |
| {formatted_results} | |
| User prompt: | |
| "{user_request}" | |
| """ | |
| def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]: | |
| """ | |
| 使用LLM解析用户请求 | |
| Args: | |
| ocr_results: OCR识别结果列表 | |
| user_request: 用户的原始请求 | |
| Returns: | |
| 包含target_text和scale_factor的字典 | |
| Raises: | |
| Exception: 当API调用失败或解析失败时 | |
| """ | |
| # 构造消息 | |
| messages = [ | |
| {"role": "system", "content": self.create_system_prompt()}, | |
| {"role": "user", "content": self.create_user_prompt(ocr_results, user_request)} | |
| ] | |
| try: | |
| # 调用OpenAI API | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=messages, | |
| temperature=0.3, | |
| max_tokens=300, | |
| ) | |
| # 获取回复 | |
| reply = response.choices[0].message.content | |
| # 解析回复 | |
| parsed_result = robust_parse_reply(reply) | |
| return parsed_result | |
| except Exception as e: | |
| raise Exception(f"LLM解析失败: {str(e)}") | |
| def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool: | |
| """ | |
| 验证解析结果的有效性 | |
| Args: | |
| parsed_result: 解析后的结果字典 | |
| ocr_results: OCR识别结果列表 | |
| Returns: | |
| 验证是否通过 | |
| """ | |
| target_text = parsed_result.get("target_text", "") | |
| scale_factor = parsed_result.get("scale_factor", 0) | |
| # 检查目标文字是否在OCR结果中 | |
| ocr_texts = [text.strip() for _, text, _ in ocr_results] | |
| if target_text not in ocr_texts: | |
| print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到") | |
| print(f"可用的文字: {ocr_texts}") | |
| return False | |
| # 检查缩放因子是否合理 | |
| if not isinstance(scale_factor, (int, float)) or scale_factor <= 0: | |
| print(f"错误: 缩放因子 {scale_factor} 不合法") | |
| return False | |
| return True | |
| def get_api_key_from_env() -> str: | |
| """ | |
| 从环境变量获取OpenAI API密钥 | |
| Returns: | |
| API密钥字符串 | |
| Raises: | |
| ValueError: 当找不到API密钥时 | |
| """ | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY," | |
| "或在创建PromptHandler时提供api_key参数。" | |
| ) | |
| return api_key |