File size: 4,914 Bytes
838e8f6
 
 
 
 
 
 
e219479
838e8f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
from openai import OpenAI
from typing import List, Tuple, Dict, Any
from utils import format_ocr_results_for_prompt, robust_parse_reply


class PromptHandler:
    def __init__(self, api_key: str = None, model: str = "gpt-4.1-nano"):
        """
        初始化Prompt处理器
        
        Args:
            api_key: OpenAI API密钥,如果不提供则从环境变量获取
            model: 使用的模型名称
        """
        if api_key:
            os.environ["OPENAI_API_KEY"] = api_key
        
        self.client = OpenAI()
        self.model = model
    
    def create_system_prompt(self) -> str:
        """
        创建系统提示词
        
        Returns:
            系统提示词字符串
        """
        return (
            "You are a helpful assistant. "
            "You are given a list of OCR results in the form [(bbox, text, score)], "
            "and a user prompt that describes what text to enlarge and how much to scale it. "
            "Your job is to:\n"
            "1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n"
            "2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n"
            "3. Output only two fields:\n"
            "   target_text: the exact string from OCR result you chose\n"
            "   scale_factor: a float number\n\n"
            "Your output must be strictly in JSON format like:\n"
            "{\n  \"target_text\": \"Tools\",\n  \"scale_factor\": 1.2\n}"
        )
    
    def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str:
        """
        创建用户提示词
        
        Args:
            ocr_results: OCR识别结果列表
            user_request: 用户的原始请求
            
        Returns:
            用户提示词字符串
        """
        formatted_results = format_ocr_results_for_prompt(ocr_results)
        
        return f"""
Here are the OCR results:
{formatted_results}

User prompt:
"{user_request}"
"""
    
    def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]:
        """
        使用LLM解析用户请求
        
        Args:
            ocr_results: OCR识别结果列表
            user_request: 用户的原始请求
            
        Returns:
            包含target_text和scale_factor的字典
            
        Raises:
            Exception: 当API调用失败或解析失败时
        """
        # 构造消息
        messages = [
            {"role": "system", "content": self.create_system_prompt()},
            {"role": "user", "content": self.create_user_prompt(ocr_results, user_request)}
        ]
        
        try:
            # 调用OpenAI API
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=0.3,
                max_tokens=300,
            )
            
            # 获取回复
            reply = response.choices[0].message.content
            
            # 解析回复
            parsed_result = robust_parse_reply(reply)
            
            return parsed_result
            
        except Exception as e:
            raise Exception(f"LLM解析失败: {str(e)}")
    
    def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool:
        """
        验证解析结果的有效性
        
        Args:
            parsed_result: 解析后的结果字典
            ocr_results: OCR识别结果列表
            
        Returns:
            验证是否通过
        """
        target_text = parsed_result.get("target_text", "")
        scale_factor = parsed_result.get("scale_factor", 0)
        
        # 检查目标文字是否在OCR结果中
        ocr_texts = [text.strip() for _, text, _ in ocr_results]
        if target_text not in ocr_texts:
            print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到")
            print(f"可用的文字: {ocr_texts}")
            return False
        
        # 检查缩放因子是否合理
        if not isinstance(scale_factor, (int, float)) or scale_factor <= 0:
            print(f"错误: 缩放因子 {scale_factor} 不合法")
            return False
        
        return True


def get_api_key_from_env() -> str:
    """
    从环境变量获取OpenAI API密钥
    
    Returns:
        API密钥字符串
        
    Raises:
        ValueError: 当找不到API密钥时
    """
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY,"
            "或在创建PromptHandler时提供api_key参数。"
        )
    return api_key