yingzhac commited on
Commit
838e8f6
·
1 Parent(s): 6092598

🎨 Add smart text resizing functionality

Browse files

- Integrate text resizing core functionality (OCR + AI parsing)
- Add beautiful Gradio interface with dual-column layout
- Support natural language commands (e.g., 'enlarge Hello by 50%')
- Add OpenAI API integration for intelligent prompt parsing
- Support both AI parsing and fallback percentage parsing
- Add comprehensive styling with gradient backgrounds
- Remove example images to reduce repo size

Files changed (6) hide show
  1. .gitignore +91 -0
  2. app.py +305 -133
  3. core.py +200 -0
  4. prompt_handler.py +149 -0
  5. requirements.txt +8 -6
  6. utils.py +160 -0
.gitignore ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # PyInstaller
25
+ *.manifest
26
+ *.spec
27
+
28
+ # Gradio
29
+ .gradio/
30
+ gradio_cached_examples/
31
+ flagged/
32
+
33
+ # Jupyter Notebook
34
+ .ipynb_checkpoints
35
+
36
+ # IPython
37
+ profile_default/
38
+ ipython_config.py
39
+
40
+ # pyenv
41
+ .python-version
42
+
43
+ # pipenv
44
+ Pipfile.lock
45
+
46
+ # PEP 582
47
+ __pypackages__/
48
+
49
+ # Celery stuff
50
+ celerybeat-schedule
51
+ celerybeat.pid
52
+
53
+ # SageMath parsed files
54
+ *.sage.py
55
+
56
+ # Environments
57
+ .env
58
+ .venv
59
+ env/
60
+ venv/
61
+ ENV/
62
+ env.bak/
63
+ venv.bak/
64
+
65
+ # Spyder project settings
66
+ .spyderproject
67
+ .spyproject
68
+
69
+ # Rope project settings
70
+ .ropeproject
71
+
72
+ # mkdocs documentation
73
+ /site
74
+
75
+ # mypy
76
+ .mypy_cache/
77
+ .dmypy.json
78
+ dmypy.json
79
+
80
+ # Pyre type checker
81
+ .pyre/
82
+
83
+ # macOS
84
+ .DS_Store
85
+ .AppleDouble
86
+ .LSOverride
87
+
88
+ # Temporary files
89
+ *.tmp
90
+ *.temp
91
+ *~
app.py CHANGED
@@ -1,154 +1,326 @@
1
  import gradio as gr
2
  import numpy as np
3
- import random
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
 
60
  css = """
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 640px;
64
  }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
 
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
 
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import cv2
6
 
7
+ # Import our custom modules
8
+ from core import TextResizer
9
+ from prompt_handler import PromptHandler
10
+ from utils import (
11
+ load_image,
12
+ save_image,
13
+ validate_scale_factor,
14
+ parse_percentage_to_scale_factor,
15
+ create_output_filename
16
+ )
17
 
18
+ # Initialize the text resizer
19
+ text_resizer = TextResizer(languages=['en', 'ch_sim'], gpu=False)
20
 
21
+ def process_image(input_image, user_prompt, use_ai_parsing=True, api_key=None):
22
+ """
23
+ Process image with text resizing based on user prompt
24
+ """
25
+ try:
26
+ if input_image is None:
27
+ return None, "❌ 错误: 请上传一张图片"
28
+
29
+ # Convert PIL to RGB numpy array
30
+ image_rgb = np.array(input_image.convert('RGB'))
31
+
32
+ # Perform OCR
33
+ ocr_results = text_resizer.read_text(image_rgb)
34
+
35
+ if not ocr_results:
36
+ return None, "❌ 错误: 未在图像中识别到任何文字"
37
+
38
+ # Parse user prompt
39
+ try:
40
+ if use_ai_parsing and api_key:
41
+ # Use OpenAI API parsing
42
+ prompt_handler = PromptHandler(api_key=api_key)
43
+ parsed_result = prompt_handler.parse_user_request(ocr_results, user_prompt)
44
+
45
+ if not prompt_handler.validate_parsed_result(parsed_result, ocr_results):
46
+ raise Exception("AI解析结果验证失败")
47
+
48
+ target_text = parsed_result["target_text"]
49
+ scale_factor = validate_scale_factor(parsed_result["scale_factor"])
50
+ status_msg = f"✅ AI解析成功: 目标文字='{target_text}', 缩放因子={scale_factor}"
51
+
52
+ else:
53
+ # Use fallback parsing
54
+ scale_factor = parse_percentage_to_scale_factor(user_prompt)
55
+ if scale_factor == 1.0:
56
+ return None, "❌ 错误: 无法从用户指令中解析出缩放信息"
57
+
58
+ # Use the first detected text as target
59
+ target_text = ocr_results[0][1].strip()
60
+ status_msg = f"✅ 备用解析: 目标文字='{target_text}', 缩放因子={scale_factor}"
61
+
62
+ except Exception as e:
63
+ return None, f"❌ 错误: 指令解析失败: {str(e)}"
64
+
65
+ # Process the image
66
+ try:
67
+ result_image = text_resizer.resize_text(image_rgb, target_text, scale_factor)
68
+
69
+ # Convert back to PIL Image
70
+ result_pil = Image.fromarray(result_image)
71
+
72
+ return result_pil, status_msg
73
+
74
+ except ValueError as e:
75
+ # Show available texts
76
+ available_texts = [text.strip() for _, text, _ in ocr_results]
77
+ error_msg = f"❌ 错误: {str(e)}\n\n📝 可用的文字: {available_texts}"
78
+ return None, error_msg
79
+
80
+ except Exception as e:
81
+ return None, f"❌ 处理过程中出现错误: {str(e)}"
82
 
83
+ def get_ocr_info(input_image):
84
+ """
85
+ Get OCR information from the image
86
+ """
87
+ if input_image is None:
88
+ return "请先上传图片"
89
+
90
+ try:
91
+ # Convert PIL to RGB numpy array
92
+ image_rgb = np.array(input_image.convert('RGB'))
93
+
94
+ # Perform OCR
95
+ ocr_results = text_resizer.read_text(image_rgb)
96
+
97
+ if not ocr_results:
98
+ return "未识别到任何文字"
99
+
100
+ # Format results
101
+ info = f"📝 识别到 {len(ocr_results)} 个文字区域:\n"
102
+ info += "=" * 50 + "\n"
103
+ for i, (bbox, text, conf) in enumerate(ocr_results):
104
+ info += f"{i+1:2d}. '{text}' (置信度: {conf:.2f})\n"
105
+ info += "=" * 50
106
+
107
+ return info
108
+
109
+ except Exception as e:
110
+ return f"❌ OCR识别失败: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # Define CSS for styling
113
  css = """
114
  #col-container {
115
  margin: 0 auto;
116
+ max-width: 1000px;
117
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ #input-section {
120
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
121
+ padding: 20px;
122
+ border-radius: 15px;
123
+ margin-bottom: 20px;
124
+ }
 
 
 
 
 
125
 
126
+ #output-section {
127
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
128
+ padding: 20px;
129
+ border-radius: 15px;
130
+ }
 
 
131
 
132
+ .gradio-container {
133
+ background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%);
134
+ }
135
 
136
+ #title {
137
+ text-align: center;
138
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
139
+ -webkit-background-clip: text;
140
+ -webkit-text-fill-color: transparent;
141
+ font-size: 2.5em;
142
+ font-weight: bold;
143
+ margin-bottom: 20px;
144
+ }
145
 
146
+ #description {
147
+ text-align: center;
148
+ color: #666;
149
+ font-size: 1.1em;
150
+ margin-bottom: 30px;
151
+ line-height: 1.6;
152
+ }
153
 
154
+ .process-button {
155
+ background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%);
156
+ color: white;
157
+ border: none;
158
+ padding: 15px 30px;
159
+ font-size: 16px;
160
+ border-radius: 10px;
161
+ cursor: pointer;
162
+ transition: all 0.3s ease;
163
+ }
164
 
165
+ .process-button:hover {
166
+ transform: translateY(-2px);
167
+ box-shadow: 0 5px 15px rgba(0,0,0,0.2);
168
+ }
169
+ """
 
 
170
 
171
+ # Create the Gradio interface
172
+ with gr.Blocks(css=css, title="智能文字缩放工具") as demo:
173
+ with gr.Column(elem_id="col-container"):
174
+ gr.Markdown("# 🎨 智能文字缩放工具", elem_id="title")
175
+ gr.Markdown(
176
+ """
177
+ 🚀 **使用AI技术智能调整图片中的文字大小**
178
+
179
+ 📝 支持自然语言指令,如:
180
+ - `enlarge 'Hello' by 50%` - 将'Hello'放大50%
181
+ - `make the title bigger` - 让标题变大
182
+ - `shrink the footer text` - 缩小页脚文字
183
+
184
+ 🎯 **使用方法**:
185
+ 1. 上传包含文字的图片
186
+ 2. 输入文字调整指令
187
+ 3. 点击处理按钮
188
+ 4. 查看处理结果
189
+ """,
190
+ elem_id="description"
191
+ )
192
+
193
+ with gr.Row():
194
+ with gr.Column(scale=1):
195
+ with gr.Group(elem_id="input-section"):
196
+ gr.Markdown("### 📤 输入设置")
197
+
198
+ # Image input
199
+ input_image = gr.Image(
200
+ label="上传图片",
201
+ type="pil",
202
+ height=300,
203
+ sources=["upload", "clipboard", "webcam"]
204
+ )
205
+
206
+ # Prompt input
207
+ user_prompt = gr.Textbox(
208
+ label="文字调整指令",
209
+ placeholder="例如: enlarge 'Hello' by 50%",
210
+ lines=2,
211
+ info="支持自然语言描述,如 'make XX bigger' 或 'enlarge XX by 50%'"
212
+ )
213
+
214
+ # OCR info button
215
+ ocr_button = gr.Button(
216
+ "🔍 查看图片中的文字",
217
+ variant="secondary",
218
+ size="sm"
219
+ )
220
+
221
+ # Advanced settings
222
+ with gr.Accordion("⚙️ 高级设置", open=False):
223
+ use_ai_parsing = gr.Checkbox(
224
+ label="🤖 使用AI解析 (需要OpenAI API密钥)",
225
+ value=False,
226
+ info="使用GPT模型理解自然语言指令"
227
+ )
228
+
229
+ api_key = gr.Textbox(
230
+ label="🔑 OpenAI API密钥 (可选)",
231
+ placeholder="sk-...",
232
+ type="password",
233
+ info="仅在使用AI解析时需要"
234
+ )
235
+
236
+ # Process button
237
+ process_button = gr.Button(
238
+ "🎯 开始处理",
239
+ variant="primary",
240
+ size="lg",
241
+ elem_classes="process-button"
242
+ )
243
+
244
+ with gr.Column(scale=1):
245
+ with gr.Group(elem_id="output-section"):
246
+ gr.Markdown("### 📤 处理结果")
247
+
248
+ # Output image
249
+ output_image = gr.Image(
250
+ label="处理后的图片",
251
+ height=300,
252
+ show_download_button=True
253
+ )
254
+
255
+ # Status message
256
+ status_message = gr.Textbox(
257
+ label="💬 状态信息",
258
+ lines=4,
259
+ max_lines=8,
260
+ interactive=False
261
+ )
262
+
263
+ # OCR info display
264
+ ocr_info = gr.Textbox(
265
+ label="📝 OCR识别结果",
266
+ lines=6,
267
+ max_lines=10,
268
+ interactive=False
269
+ )
270
+
271
+ # Examples section
272
+ gr.Markdown("### 📚 示例用法")
273
+ gr.Markdown(
274
+ """
275
+ **示例指令格式:**
276
+
277
+ 🔍 **指定文字 + 具体比例:**
278
+ - `enlarge 'Hello' by 50%` - 将'Hello'放大50%
279
+ - `shrink 'Title' by 30%` - 将'Title'缩小30%
280
+
281
+ 🎯 **自然语言描述:**
282
+ - `make the title bigger` - 让标题变大
283
+ - `make the text smaller` - 让文字变小
284
+ - `enlarge the heading` - 放大标题
285
+
286
+ 💡 **使用提示:**
287
+ 1. 上传包含文字的图片
288
+ 2. 先点击"查看图片中的文字"了解可用文字
289
+ 3. 输入调整指令
290
+ 4. 点击"开始处理"
291
+ """
292
+ )
293
+
294
+ # Event handlers
295
+ process_button.click(
296
+ fn=process_image,
297
+ inputs=[input_image, user_prompt, use_ai_parsing, api_key],
298
+ outputs=[output_image, status_message]
299
+ )
300
+
301
+ ocr_button.click(
302
+ fn=get_ocr_info,
303
+ inputs=[input_image],
304
+ outputs=[ocr_info]
305
+ )
306
+
307
+ # Auto-run OCR when image is uploaded
308
+ input_image.change(
309
+ fn=get_ocr_info,
310
+ inputs=[input_image],
311
+ outputs=[ocr_info]
312
+ )
313
+
314
+ # Footer
315
+ gr.Markdown(
316
+ """
317
+ ---
318
+
319
+ 🎨 **智能文字缩放工具** | 基于OCR和AI技术的智能图像文字处理
320
+
321
+ 📧 如有问题或建议,请联系开发者
322
+ """
323
+ )
324
 
325
  if __name__ == "__main__":
326
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
core.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import easyocr
4
+ from typing import List, Tuple, Optional
5
+
6
+
7
+ class TextResizer:
8
+ def __init__(self, languages=['en', 'ch_sim'], gpu=False):
9
+ """
10
+ 初始化文字缩放器
11
+
12
+ Args:
13
+ languages: OCR支持的语言列表
14
+ gpu: 是否使用GPU
15
+ """
16
+ self.reader = easyocr.Reader(languages, gpu=gpu)
17
+
18
+ def read_text(self, image: np.ndarray) -> List[Tuple]:
19
+ """
20
+ 从图像中识别文字
21
+
22
+ Args:
23
+ image: RGB格式的图像数组
24
+
25
+ Returns:
26
+ OCR结果列表,每个元素为(bbox, text, confidence)
27
+ """
28
+ return self.reader.readtext(image)
29
+
30
+ def extract_text_mask_by_content(self, image: np.ndarray, results: List[Tuple], target_text: str) -> np.ndarray:
31
+ """
32
+ 根据目标文字内容提取文字mask
33
+
34
+ Args:
35
+ image: RGB格式的图像数组
36
+ results: OCR识别结果
37
+ target_text: 目标文字内容
38
+
39
+ Returns:
40
+ 文字mask,白色为文字区域
41
+ """
42
+ h, w = image.shape[:2]
43
+ mask = np.zeros((h, w), dtype=np.uint8)
44
+
45
+ for (bbox, text, _) in results:
46
+ if text.strip() != target_text:
47
+ continue
48
+
49
+ x_min = int(min([pt[0] for pt in bbox]))
50
+ x_max = int(max([pt[0] for pt in bbox]))
51
+ y_min = int(min([pt[1] for pt in bbox]))
52
+ y_max = int(max([pt[1] for pt in bbox]))
53
+
54
+ roi = image[y_min:y_max, x_min:x_max]
55
+ gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
56
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
57
+ cv2.THRESH_BINARY_INV, 11, 2)
58
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
59
+ mask_roi = np.zeros_like(thresh)
60
+ cv2.drawContours(mask_roi, contours, -1, 255, -1)
61
+ mask[y_min:y_max, x_min:x_max] = np.maximum(mask[y_min:y_max, x_min:x_max], mask_roi)
62
+
63
+ return mask
64
+
65
+ def inpaint_image(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
66
+ """
67
+ 使用mask对图像进行修复
68
+
69
+ Args:
70
+ image: RGB格式的图像数组
71
+ mask: 需要修复的区域mask
72
+
73
+ Returns:
74
+ 修复后的图像
75
+ """
76
+ return cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA)
77
+
78
+ def find_target_bbox(self, results: List[Tuple], target_text: str) -> Optional[List]:
79
+ """
80
+ 查找目标文字的边界框
81
+
82
+ Args:
83
+ results: OCR识别结果
84
+ target_text: 目标文字内容
85
+
86
+ Returns:
87
+ 目标文字的边界框,如果未找到则返回None
88
+ """
89
+ for (bbox, text, _) in results:
90
+ if text.strip() == target_text:
91
+ return bbox
92
+ return None
93
+
94
+ def create_resized_text_patch(self, image: np.ndarray, bbox: List, scale_factor: float) -> Tuple[np.ndarray, int, int]:
95
+ """
96
+ 创建缩放后的文字补丁
97
+
98
+ Args:
99
+ image: RGB格式的图像数组
100
+ bbox: 文字边界框
101
+ scale_factor: 缩放因子
102
+
103
+ Returns:
104
+ (RGBA格式的缩放后文字补丁, 原始中心x坐标, 原始中心y坐标)
105
+ """
106
+ # 提取ROI
107
+ x_min = int(min(pt[0] for pt in bbox))
108
+ x_max = int(max(pt[0] for pt in bbox))
109
+ y_min = int(min(pt[1] for pt in bbox))
110
+ y_max = int(max(pt[1] for pt in bbox))
111
+
112
+ roi = image[y_min:y_max, x_min:x_max]
113
+
114
+ # 创建文字mask
115
+ gray = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
116
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
117
+ cv2.THRESH_BINARY_INV, 11, 2)
118
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
119
+ mask_roi = np.zeros_like(thresh)
120
+ cv2.drawContours(mask_roi, contours, -1, 255, -1)
121
+
122
+ # 创建RGBA补丁
123
+ rgba_patch = cv2.cvtColor(roi, cv2.COLOR_RGB2RGBA)
124
+ rgba_patch[:, :, 3] = mask_roi
125
+
126
+ # 缩放
127
+ h, w = rgba_patch.shape[:2]
128
+ new_size = (int(w * scale_factor), int(h * scale_factor))
129
+ resized_patch = cv2.resize(rgba_patch, new_size, interpolation=cv2.INTER_LINEAR)
130
+
131
+ # 计算原始中心点
132
+ cx = (x_min + x_max) // 2
133
+ cy = (y_min + y_max) // 2
134
+
135
+ return resized_patch, cx, cy
136
+
137
+ def blend_text_patch(self, canvas: np.ndarray, patch: np.ndarray, center_x: int, center_y: int) -> np.ndarray:
138
+ """
139
+ 将文字补丁混合到画布上
140
+
141
+ Args:
142
+ canvas: 目标画布(RGB格式)
143
+ patch: RGBA格式的文字补丁
144
+ center_x: 放置的中心x坐标
145
+ center_y: 放置的中心y坐标
146
+
147
+ Returns:
148
+ 混合后的图像
149
+ """
150
+ result = canvas.copy()
151
+ new_h, new_w = patch.shape[:2]
152
+ top_left_x = max(0, center_x - new_w // 2)
153
+ top_left_y = max(0, center_y - new_h // 2)
154
+
155
+ for y in range(new_h):
156
+ for x in range(new_w):
157
+ if patch[y, x, 3] > 0: # 如果alpha > 0
158
+ yy = top_left_y + y
159
+ xx = top_left_x + x
160
+ if 0 <= yy < result.shape[0] and 0 <= xx < result.shape[1]:
161
+ alpha = patch[y, x, 3] / 255.0
162
+ result[yy, xx] = (
163
+ (1 - alpha) * result[yy, xx] + alpha * patch[y, x, :3]
164
+ ).astype(np.uint8)
165
+
166
+ return result
167
+
168
+ def resize_text(self, image: np.ndarray, target_text: str, scale_factor: float) -> np.ndarray:
169
+ """
170
+ 完整的文字缩放流程
171
+
172
+ Args:
173
+ image: RGB格式的图像数组
174
+ target_text: 目标文字内容
175
+ scale_factor: 缩放因子
176
+
177
+ Returns:
178
+ 处理后的图像
179
+ """
180
+ # 1. OCR识别
181
+ results = self.read_text(image)
182
+
183
+ # 2. 查找目标文字
184
+ target_bbox = self.find_target_bbox(results, target_text)
185
+ if target_bbox is None:
186
+ raise ValueError(f"未找到目标文字: {target_text}")
187
+
188
+ # 3. 提取文字mask
189
+ text_mask = self.extract_text_mask_by_content(image, results, target_text)
190
+
191
+ # 4. 图像修复
192
+ inpainted = self.inpaint_image(image, text_mask)
193
+
194
+ # 5. 创建缩放后的文字补丁
195
+ resized_patch, cx, cy = self.create_resized_text_patch(image, target_bbox, scale_factor)
196
+
197
+ # 6. 混合文字补丁
198
+ result = self.blend_text_patch(inpainted, resized_patch, cx, cy)
199
+
200
+ return result
prompt_handler.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+ from typing import List, Tuple, Dict, Any
4
+ from utils import format_ocr_results_for_prompt, robust_parse_reply
5
+
6
+
7
+ class PromptHandler:
8
+ def __init__(self, api_key: str = None, model: str = "gpt-4o-mini"):
9
+ """
10
+ 初始化Prompt处理器
11
+
12
+ Args:
13
+ api_key: OpenAI API密钥,如果不提供则从环境变量获取
14
+ model: 使用的模型名称
15
+ """
16
+ if api_key:
17
+ os.environ["OPENAI_API_KEY"] = api_key
18
+
19
+ self.client = OpenAI()
20
+ self.model = model
21
+
22
+ def create_system_prompt(self) -> str:
23
+ """
24
+ 创建系统提示词
25
+
26
+ Returns:
27
+ 系统提示词字符串
28
+ """
29
+ return (
30
+ "You are a helpful assistant. "
31
+ "You are given a list of OCR results in the form [(bbox, text, score)], "
32
+ "and a user prompt that describes what text to enlarge and how much to scale it. "
33
+ "Your job is to:\n"
34
+ "1. Match the user input text to the actual text in OCR results as best as possible, even if it's fuzzy or missing punctuation.\n"
35
+ "2. Estimate a scale_factor (float > 0) based on qualitative user intent like 'a bit', 'a lot', 'shrink slightly', etc.\n"
36
+ "3. Output only two fields:\n"
37
+ " target_text: the exact string from OCR result you chose\n"
38
+ " scale_factor: a float number\n\n"
39
+ "Your output must be strictly in JSON format like:\n"
40
+ "{\n \"target_text\": \"Tools\",\n \"scale_factor\": 1.2\n}"
41
+ )
42
+
43
+ def create_user_prompt(self, ocr_results: List[Tuple], user_request: str) -> str:
44
+ """
45
+ 创建用户提示词
46
+
47
+ Args:
48
+ ocr_results: OCR识别结果列表
49
+ user_request: 用户的原始请求
50
+
51
+ Returns:
52
+ 用户提示词字符串
53
+ """
54
+ formatted_results = format_ocr_results_for_prompt(ocr_results)
55
+
56
+ return f"""
57
+ Here are the OCR results:
58
+ {formatted_results}
59
+
60
+ User prompt:
61
+ "{user_request}"
62
+ """
63
+
64
+ def parse_user_request(self, ocr_results: List[Tuple], user_request: str) -> Dict[str, Any]:
65
+ """
66
+ 使用LLM解析用户请求
67
+
68
+ Args:
69
+ ocr_results: OCR识别结果列表
70
+ user_request: 用户的原始请求
71
+
72
+ Returns:
73
+ 包含target_text和scale_factor的字典
74
+
75
+ Raises:
76
+ Exception: 当API调用失败或解析失败时
77
+ """
78
+ # 构造消息
79
+ messages = [
80
+ {"role": "system", "content": self.create_system_prompt()},
81
+ {"role": "user", "content": self.create_user_prompt(ocr_results, user_request)}
82
+ ]
83
+
84
+ try:
85
+ # 调用OpenAI API
86
+ response = self.client.chat.completions.create(
87
+ model=self.model,
88
+ messages=messages,
89
+ temperature=0.3,
90
+ max_tokens=300,
91
+ )
92
+
93
+ # 获取回复
94
+ reply = response.choices[0].message.content
95
+
96
+ # 解析回复
97
+ parsed_result = robust_parse_reply(reply)
98
+
99
+ return parsed_result
100
+
101
+ except Exception as e:
102
+ raise Exception(f"LLM解析失败: {str(e)}")
103
+
104
+ def validate_parsed_result(self, parsed_result: Dict[str, Any], ocr_results: List[Tuple]) -> bool:
105
+ """
106
+ 验证解析结果的有效性
107
+
108
+ Args:
109
+ parsed_result: 解析后的结果字典
110
+ ocr_results: OCR识别结果列表
111
+
112
+ Returns:
113
+ 验证是否通过
114
+ """
115
+ target_text = parsed_result.get("target_text", "")
116
+ scale_factor = parsed_result.get("scale_factor", 0)
117
+
118
+ # 检查目标文字是否在OCR结果中
119
+ ocr_texts = [text.strip() for _, text, _ in ocr_results]
120
+ if target_text not in ocr_texts:
121
+ print(f"警告: 目标文字 '{target_text}' 未在OCR结果中找到")
122
+ print(f"可用的文字: {ocr_texts}")
123
+ return False
124
+
125
+ # 检查缩放因子是否合理
126
+ if not isinstance(scale_factor, (int, float)) or scale_factor <= 0:
127
+ print(f"错误: 缩放因子 {scale_factor} 不合法")
128
+ return False
129
+
130
+ return True
131
+
132
+
133
+ def get_api_key_from_env() -> str:
134
+ """
135
+ 从环境变量获取OpenAI API密钥
136
+
137
+ Returns:
138
+ API密钥字符串
139
+
140
+ Raises:
141
+ ValueError: 当找不到API密钥时
142
+ """
143
+ api_key = os.getenv("OPENAI_API_KEY")
144
+ if not api_key:
145
+ raise ValueError(
146
+ "未找到OpenAI API密钥。请设置环境变量OPENAI_API_KEY,"
147
+ "或在创建PromptHandler时提供api_key参数。"
148
+ )
149
+ return api_key
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
1
+ opencv-python>=4.5.0
2
+ numpy>=1.21.0
3
+ easyocr>=1.6.0
4
+ openai>=1.0.0
5
+ matplotlib>=3.3.0
6
+ scikit-image>=0.18.0
7
+ Pillow>=8.0.0
8
+ gradio>=4.0.0
utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import cv2
4
+ import numpy as np
5
+ from typing import Dict, Any
6
+
7
+
8
+ def robust_parse_reply(reply: str) -> Dict[str, Any]:
9
+ """
10
+ 从LLM返回的字符串中提取JSON格式的回复
11
+
12
+ Args:
13
+ reply: LLM返回的原始字符串
14
+
15
+ Returns:
16
+ 解析后的字典,包含target_text和scale_factor
17
+
18
+ Raises:
19
+ ValueError: 当无法解析JSON或缺少必要字段时
20
+ """
21
+ # 尝试去除 Markdown 代码块标记(如 ```json 或 ```)
22
+ cleaned = re.sub(r"```(?:json)?", "", reply, flags=re.IGNORECASE).strip("` \n")
23
+
24
+ # 尝试提取最可能的 JSON 段(形如 {...})
25
+ match = re.search(r"\{.*?\}", cleaned, flags=re.DOTALL)
26
+ if not match:
27
+ raise ValueError("未找到 JSON 对象")
28
+
29
+ json_str = match.group(0)
30
+
31
+ try:
32
+ parsed = json.loads(json_str)
33
+ except json.JSONDecodeError as e:
34
+ raise ValueError(f"JSON 解析失败: {e}")
35
+
36
+ # 校验字段完整性
37
+ if "target_text" not in parsed or "scale_factor" not in parsed:
38
+ raise ValueError("JSON 中缺少必要字段 target_text 或 scale_factor")
39
+
40
+ return parsed
41
+
42
+
43
+ def load_image(image_path: str) -> np.ndarray:
44
+ """
45
+ 加载图像并转换为RGB格式
46
+
47
+ Args:
48
+ image_path: 图像文件路径
49
+
50
+ Returns:
51
+ RGB格式的图像数组
52
+
53
+ Raises:
54
+ ValueError: 当图像加载失败时
55
+ """
56
+ image = cv2.imread(image_path)
57
+ if image is None:
58
+ raise ValueError(f"无法加载图像: {image_path}")
59
+
60
+ # 转换为RGB格式
61
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
62
+ return image_rgb
63
+
64
+
65
+ def save_image(image: np.ndarray, output_path: str) -> None:
66
+ """
67
+ 保存RGB格式的图像
68
+
69
+ Args:
70
+ image: RGB格式的图像数组
71
+ output_path: 输出文件路径
72
+ """
73
+ # 转换为BGR格式以便OpenCV保存
74
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
75
+ cv2.imwrite(output_path, image_bgr)
76
+
77
+
78
+ def validate_scale_factor(scale_factor: float) -> float:
79
+ """
80
+ 验证并标准化缩放因子
81
+
82
+ Args:
83
+ scale_factor: 原始缩放因子
84
+
85
+ Returns:
86
+ 验证后的缩放因子
87
+
88
+ Raises:
89
+ ValueError: 当缩放因子不合法时
90
+ """
91
+ if not isinstance(scale_factor, (int, float)):
92
+ raise ValueError("缩放因子必须是数字")
93
+
94
+ if scale_factor <= 0:
95
+ raise ValueError("缩放因子必须大于0")
96
+
97
+ if scale_factor > 10:
98
+ print(f"警告: 缩放因子 {scale_factor} 过大,可能导致处理时间过长")
99
+
100
+ return float(scale_factor)
101
+
102
+
103
+ def format_ocr_results_for_prompt(results: list) -> str:
104
+ """
105
+ 格式化OCR结果以用于LLM prompt
106
+
107
+ Args:
108
+ results: OCR识别结果列表
109
+
110
+ Returns:
111
+ 格式化后的文字列表字符串
112
+ """
113
+ text_list = [text for _, text, _ in results]
114
+ return str(text_list)
115
+
116
+
117
+ def parse_percentage_to_scale_factor(text: str) -> float:
118
+ """
119
+ 将百分比表示转换为缩放因子
120
+
121
+ Args:
122
+ text: 包含百分比的文本,如 "enlarge by 50%" 或 "shrink by 25%"
123
+
124
+ Returns:
125
+ 对应的缩放因子
126
+ """
127
+ # 查找百分比数字
128
+ percentage_match = re.search(r'(\d+(?:\.\d+)?)%', text.lower())
129
+ if not percentage_match:
130
+ return 1.0 # 默认不缩放
131
+
132
+ percentage = float(percentage_match.group(1))
133
+
134
+ # 判断是放大还是缩小
135
+ if 'enlarge' in text.lower() or 'increase' in text.lower() or 'bigger' in text.lower():
136
+ return 1 + (percentage / 100)
137
+ elif 'shrink' in text.lower() or 'reduce' in text.lower() or 'smaller' in text.lower():
138
+ return 1 - (percentage / 100)
139
+ else:
140
+ # 默认当作放大处理
141
+ return 1 + (percentage / 100)
142
+
143
+
144
+ def create_output_filename(input_path: str, suffix: str = "_resized") -> str:
145
+ """
146
+ 根据输入文件路径创建输出文件名
147
+
148
+ Args:
149
+ input_path: 输入文件路径
150
+ suffix: 添加的后缀
151
+
152
+ Returns:
153
+ 输出文件路径
154
+ """
155
+ import os
156
+
157
+ base_name = os.path.splitext(input_path)[0]
158
+ extension = os.path.splitext(input_path)[1]
159
+
160
+ return f"{base_name}{suffix}{extension}"