|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
from typing import Dict, List, Tuple, Any, Optional |
|
|
|
|
|
import requests |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_API_URL = os.environ.get("API_URL") |
|
|
LOGO_IMAGE_PATH = './assets/logo.jpg' |
|
|
GOOGLE_FONTS_URL = "<link href='https://fonts.googleapis.com/css2?family=Noto+Sans+SC:wght@400;700&display=swap' rel='stylesheet'>" |
|
|
LATEX_DELIMS = [ |
|
|
{"left": "$$", "right": "$$", "display": True}, |
|
|
{"left": "$", "right": "$", "display": False}, |
|
|
{"left": "\\(", "right": "\\)", "display": False}, |
|
|
{"left": "\\[", "right": "\\]", "display": True}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_to_base64_data_url(filepath: str) -> str: |
|
|
"""Reads a local image file and encodes it into a Base64 Data URL.""" |
|
|
try: |
|
|
ext = os.path.splitext(filepath)[1].lower() |
|
|
mime_types = {'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.gif': 'image/gif'} |
|
|
mime_type = mime_types.get(ext, 'image/jpeg') |
|
|
with open(filepath, "rb") as image_file: |
|
|
encoded_string = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
return f"data:{mime_type};base64,{encoded_string}" |
|
|
except Exception as e: |
|
|
print(f"Error encoding image to Base64: {e}") |
|
|
return "" |
|
|
|
|
|
def _get_examples_from_dir(dir_path: str) -> List[List[str]]: |
|
|
supported_exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"} |
|
|
examples = [] |
|
|
if not os.path.exists(dir_path): return [] |
|
|
for filename in sorted(os.listdir(dir_path)): |
|
|
if os.path.splitext(filename)[1].lower() in supported_exts: |
|
|
examples.append([os.path.join(dir_path, filename)]) |
|
|
return examples |
|
|
|
|
|
TARGETED_EXAMPLES_DIR = "examples/targeted" |
|
|
COMPLEX_EXAMPLES_DIR = "examples/complex" |
|
|
targeted_recognition_examples = _get_examples_from_dir(TARGETED_EXAMPLES_DIR) |
|
|
complex_document_examples = _get_examples_from_dir(COMPLEX_EXAMPLES_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_uploaded_image_div(file_path: str) -> str: |
|
|
data_url = image_to_base64_data_url(file_path) |
|
|
return f""" |
|
|
<div class="uploaded-image"> |
|
|
<img src="{data_url}" alt="Uploaded image" style="width:100%;height:100%;object-fit:contain;"/> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
def update_preview_visibility(file_path: Optional[str]) -> Dict: |
|
|
if file_path: |
|
|
html_content = render_uploaded_image_div(file_path) |
|
|
return gr.update(value=html_content, visible=True) |
|
|
else: |
|
|
return gr.update(value="", visible=False) |
|
|
|
|
|
def _on_gallery_select(example_paths: List[str], evt: gr.SelectData): |
|
|
try: |
|
|
idx = evt.index |
|
|
return example_paths[idx] |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _file_to_b64_image_only(file_path: str) -> Tuple[str, int]: |
|
|
if not file_path: raise ValueError("Please upload an image first.") |
|
|
ext = os.path.splitext(file_path)[1].lower() |
|
|
if ext not in {".png", ".jpg", ".jpeg", ".bmp", ".webp"}: raise ValueError("Only image files are supported.") |
|
|
with open(file_path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode("utf-8"), 1 |
|
|
|
|
|
def _call_api(api_url: str, file_path: str, use_layout_detection: bool, prompt_label: Optional[str], use_chart_recognition: bool = False) -> Dict[str, Any]: |
|
|
b64, file_type = _file_to_b64_image_only(file_path) |
|
|
payload = {"file": b64, "useLayoutDetection": bool(use_layout_detection), "fileType": file_type, "layoutMergeBboxesMode": "union"} |
|
|
|
|
|
if not use_layout_detection: |
|
|
if not prompt_label: raise ValueError("Please select a recognition type.") |
|
|
payload["promptLabel"] = prompt_label.strip().lower() |
|
|
|
|
|
|
|
|
if use_layout_detection and use_chart_recognition: |
|
|
payload["use_chart_recognition"] = True |
|
|
|
|
|
try: |
|
|
resp = requests.post(api_url, json=payload, timeout=120) |
|
|
resp.raise_for_status() |
|
|
data = resp.json() |
|
|
except requests.exceptions.RequestException as e: |
|
|
raise gr.Error(f"API request failed: {e}") |
|
|
except json.JSONDecodeError: |
|
|
raise gr.Error(f"Invalid JSON response from server:\n{getattr(resp, 'text', '')}") |
|
|
if data.get("errorCode", -1) != 0: |
|
|
raise gr.Error(f"API returned an error: errorCode={data.get('errorCode')} errorMsg={data.get('errorMsg')}") |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def url_to_base64_data_url(url: str) -> str: |
|
|
"""Downloads an image from a URL and formats it as a Base64 Data URL for Markdown.""" |
|
|
try: |
|
|
response = requests.get(url, timeout=30) |
|
|
response.raise_for_status() |
|
|
mime_type = response.headers.get('Content-Type', 'image/jpeg') |
|
|
if not mime_type.startswith('image/'): |
|
|
print(f"Warning: URL did not return an image content type. Got: {mime_type}") |
|
|
mime_type = 'image/jpeg' |
|
|
image_bytes = response.content |
|
|
encoded_string = base64.b64encode(image_bytes).decode('utf-8') |
|
|
return f"data:{mime_type};base64,{encoded_string}" |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error fetching markdown image from URL {url}: {e}") |
|
|
return url |
|
|
except Exception as e: |
|
|
print(f"An unexpected error occurred while processing markdown URL {url}: {e}") |
|
|
return url |
|
|
|
|
|
def replace_image_urls_with_data_urls(md_text: str, md_images_map: Dict[str, str]) -> str: |
|
|
"""Replaces image placeholder paths in Markdown with Base64 Data URLs fetched from external URLs.""" |
|
|
if not md_images_map: |
|
|
return md_text |
|
|
for placeholder_path, image_url in md_images_map.items(): |
|
|
print(f"Processing markdown image for '{placeholder_path}' from URL: {image_url}") |
|
|
data_url = url_to_base64_data_url(image_url) |
|
|
md_text = md_text.replace(f'src="{placeholder_path}"', f'src="{data_url}"') \ |
|
|
.replace(f']({placeholder_path})', f']({data_url})') |
|
|
return md_text |
|
|
|
|
|
def url_to_pil_image(url: str) -> Optional[Image.Image]: |
|
|
"""Downloads an image from a URL and returns it as a PIL Image object for the Gradio Image component.""" |
|
|
if not url or not url.startswith(('http://', 'https://')): |
|
|
print(f"Warning: Invalid URL provided for visualization image: {url}") |
|
|
return None |
|
|
try: |
|
|
response = requests.get(url, timeout=30) |
|
|
response.raise_for_status() |
|
|
image_bytes = response.content |
|
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
return pil_image |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error fetching visualization image from URL {url}: {e}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error processing visualization image from URL {url}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_api_response_page(result: Dict[str, Any]) -> Tuple[str, Optional[Image.Image], str]: |
|
|
""" |
|
|
Processes the API response which contains URLs for images. |
|
|
1. Converts markdown image URLs to inline Base64 Data URLs. |
|
|
2. Downloads the visualization image URL into a PIL Image object. |
|
|
""" |
|
|
layout_results = (result or {}).get("layoutParsingResults", []) |
|
|
if not layout_results: |
|
|
return "No content was recognized.", None, "" |
|
|
|
|
|
page0 = layout_results[0] or {} |
|
|
|
|
|
|
|
|
md_data = page0.get("markdown") or {} |
|
|
md_text = md_data.get("text", "") or "" |
|
|
md_images_map = md_data.get("images", {}) |
|
|
if md_images_map: |
|
|
md_text = replace_image_urls_with_data_urls(md_text, md_images_map) |
|
|
|
|
|
|
|
|
vis_images: List[Image.Image] = [] |
|
|
out_imgs = page0.get("outputImages") or {} |
|
|
for _, img_url in sorted(out_imgs.items()): |
|
|
pil_image = url_to_pil_image(img_url) |
|
|
if pil_image: |
|
|
vis_images.append(pil_image) |
|
|
else: |
|
|
print(f"Warning: Failed to load visualization image from URL: {img_url}") |
|
|
|
|
|
|
|
|
output_image: Optional[Image.Image] = None |
|
|
if len(vis_images) >= 2: |
|
|
output_image = vis_images[1] |
|
|
elif vis_images: |
|
|
output_image = vis_images[0] |
|
|
|
|
|
return md_text or "(Empty result)", output_image, md_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_complex_doc(file_path: str, use_chart_recognition: bool) -> Tuple[str, Optional[Image.Image], str]: |
|
|
if not file_path: raise gr.Error("Please upload an image first.") |
|
|
data = _call_api(DEFAULT_API_URL, file_path, use_layout_detection=True, prompt_label=None, use_chart_recognition=use_chart_recognition) |
|
|
result = data.get("result", {}) |
|
|
return _process_api_response_page(result) |
|
|
|
|
|
def handle_targeted_recognition(file_path: str, prompt_choice: str) -> Tuple[str, str]: |
|
|
if not file_path: raise gr.Error("Please upload an image first.") |
|
|
mapping = {"Text Recognition": "ocr", "Formula Recognition": "formula", "Table Recognition": "table", "Chart Recognition": "chart"} |
|
|
label = mapping.get(prompt_choice, "ocr") |
|
|
data = _call_api(DEFAULT_API_URL, file_path, use_layout_detection=False, prompt_label=label) |
|
|
result = data.get("result", {}) |
|
|
md_preview, _, md_raw = _process_api_response_page(result) |
|
|
return md_preview, md_raw |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = ''' |
|
|
body, button, input, textarea, select, p, label { font-family: "Microsoft YaHei","微软雅黑","Microsoft YaHei UI", "Noto Sans SC","PingFang SC",sans-serif !important; } |
|
|
.app-header { text-align: center; max-width: 900px; margin: 0 auto 4px !important; padding: 0 !important; } |
|
|
.gradio-container { padding-top: 2px !important; padding-bottom: 2px !important; } |
|
|
.gradio-container .tabs { margin-top: 0px !important; } |
|
|
.gradio-container .tabitem { padding-top: 4px !important; } |
|
|
.prompt-grid { gap: 8px !important; margin-top: 4px !important; } |
|
|
.prompt-grid button { height: 40px !important; min-height: 0 !important; padding: 0 12px !important; border-radius: 8px !important; font-weight: 600 !important; font-size: 13px !important; letter-spacing: .2px; } |
|
|
.quick-links { text-align: center; padding: 8px 0; border: 1px solid #e5e7eb; border-radius: 8px; margin: 8px auto !important; max-width: 900px; } |
|
|
.quick-links a { margin: 0 15px; font-size: 14px; font-weight: 600; text-decoration: none; color: #3b82f6; } |
|
|
.quick-links a:hover { text-decoration: underline; } |
|
|
#image_preview_vl, #image_preview_doc { height: 60vh !important; overflow: auto; } |
|
|
#vis_image_doc { height: 42vh !important; } |
|
|
#image_preview_vl .uploaded-image, #image_preview_doc .uploaded-image { height: 100%; } |
|
|
#image_preview_vl img, #image_preview_doc img, #vis_image_doc img { width: 100% !important; height: 100% !important; object-fit: contain !important; } |
|
|
#md_preview_vl, #md_preview_doc { max-height: 60vh; overflow: auto; scrollbar-gutter: stable both-edges; } |
|
|
#md_preview_doc .prose img, |
|
|
#md_preview_vl .prose img { |
|
|
display: block !important; |
|
|
margin-left: auto !important; |
|
|
margin-right: auto !important; /* 块级元素用 margin auto 居中 */ |
|
|
height: auto; /* 可选:保持比例 */ |
|
|
max-width: 100%; /* 可选:避免溢出 */ |
|
|
} |
|
|
#md_preview_vl .prose, #md_preview_doc .prose { line-height: 1.7 !important; font-family: 'Microsoft YaHei','Noto Sans SC','PingFang SC',sans-serif !important; } |
|
|
''' |
|
|
|
|
|
with gr.Blocks(head=GOOGLE_FONTS_URL, css=custom_css, theme=gr.themes.Soft()) as demo: |
|
|
logo_data_url = image_to_base64_data_url(LOGO_IMAGE_PATH) if os.path.exists(LOGO_IMAGE_PATH) else "" |
|
|
gr.HTML(f""" |
|
|
<div class="app-header"> |
|
|
<img src="{logo_data_url}" alt="App Logo" style="max-height:10%; width: auto; margin: 10px auto; display: block;"> |
|
|
</div> |
|
|
""") |
|
|
gr.HTML(""" |
|
|
<div class="quick-links"> |
|
|
<a href="https://github.com/PaddlePaddle/PaddleOCR" target="_blank">GitHub</a> | |
|
|
<a href="https://github.com/PaddlePaddle/PaddleOCR/blob/main/ppstructure/docs/vls.md" target="_blank">Technical Report</a> | |
|
|
<a href="https://xinghe.baidu.com/" target="_blank">Model</a> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Document Parsing"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=5): |
|
|
file_doc = gr.File(label="Upload Image", file_count="single", type="filepath", file_types=["image"]) |
|
|
preview_doc_html = gr.HTML(value="", elem_id="image_preview_doc", visible=False) |
|
|
|
|
|
gr.Markdown("_( Use this mode for recognizing full-page documents with structured layouts, such as reports, papers, or magazines.)_") |
|
|
gr.Markdown("💡 *To recognize a single, pre-cropped element (e.g., a table or formula), switch to the 'Content Recognition' tab for better results.*") |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
|
chart_parsing_switch = gr.Checkbox(label="Enable chart parsing", value=False, scale=1) |
|
|
btn_parse = gr.Button("Parse Document", variant="primary", scale=2) |
|
|
|
|
|
if complex_document_examples: |
|
|
complex_paths = [e[0] for e in complex_document_examples] |
|
|
complex_state = gr.State(complex_paths) |
|
|
gr.Markdown("**Document Examples (Click an image to load)**") |
|
|
gallery_complex = gr.Gallery(value=complex_paths, columns=4, height=400, preview=False, label=None, allow_preview=False) |
|
|
gallery_complex.select(fn=_on_gallery_select, inputs=[complex_state], outputs=[file_doc]) |
|
|
|
|
|
with gr.Column(scale=7): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Markdown Preview"): |
|
|
md_preview_doc = gr.Markdown("Please upload an image and click 'Parse Document'.", latex_delimiters=LATEX_DELIMS, elem_id="md_preview_doc") |
|
|
with gr.Tab("Visualization"): |
|
|
vis_image_doc = gr.Image(label="Detection Visualization", interactive=False, elem_id="vis_image_doc") |
|
|
with gr.Tab("Markdown Source"): |
|
|
md_raw_doc = gr.Code(label="Markdown Source Code", language="markdown") |
|
|
|
|
|
file_doc.change(fn=update_preview_visibility, inputs=[file_doc], outputs=[preview_doc_html]) |
|
|
btn_parse.click(fn=handle_complex_doc, inputs=[file_doc, chart_parsing_switch], outputs=[md_preview_doc, vis_image_doc, md_raw_doc]) |
|
|
|
|
|
with gr.Tab("Content Recognition"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=5): |
|
|
file_vl = gr.File(label="Upload Image", file_count="single", type="filepath", file_types=["image"]) |
|
|
preview_vl_html = gr.HTML(value="", elem_id="image_preview_vl", visible=False) |
|
|
|
|
|
gr.Markdown("_(Best for images with a **simple, single-column layout** (e.g., pure text), or for a **pre-cropped single element** like a table, formula, or chart.)_") |
|
|
gr.Markdown("Choose a recognition type:") |
|
|
with gr.Row(elem_classes=["prompt-grid"]): |
|
|
btn_ocr = gr.Button("Text Recognition", variant="secondary") |
|
|
btn_formula = gr.Button("Formula Recognition", "secondary") |
|
|
with gr.Row(elem_classes=["prompt-grid"]): |
|
|
btn_table = gr.Button("Table Recognition", variant="secondary") |
|
|
btn_chart = gr.Button("Chart Recognition", variant="secondary") |
|
|
|
|
|
if targeted_recognition_examples: |
|
|
targeted_paths = [e[0] for e in targeted_recognition_examples] |
|
|
targeted_state = gr.State(targeted_paths) |
|
|
gr.Markdown("**Content Recognition Examples (Click an image to load)**") |
|
|
gallery_targeted = gr.Gallery(value=targeted_paths, columns=4, height=400, preview=False, label=None, allow_preview=False) |
|
|
gallery_targeted.select(fn=_on_gallery_select, inputs=[targeted_state], outputs=[file_vl]) |
|
|
|
|
|
with gr.Column(scale=7): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Recognition Result"): |
|
|
md_preview_vl = gr.Markdown("Please upload an image and click a recognition type.", latex_delimiters=LATEX_DELIMS, elem_id="md_preview_vl") |
|
|
with gr.Tab("Raw Output"): |
|
|
md_raw_vl = gr.Code(label="Raw Output", language="markdown") |
|
|
|
|
|
file_vl.change(fn=update_preview_visibility, inputs=[file_vl], outputs=[preview_vl_html]) |
|
|
btn_ocr.click(fn=handle_targeted_recognition, inputs=[file_vl, gr.State("Text Recognition")], outputs=[md_preview_vl, md_raw_vl]) |
|
|
btn_formula.click(fn=handle_targeted_recognition, inputs=[file_vl, gr.State("Formula Recognition")], outputs=[md_preview_vl, md_raw_vl]) |
|
|
btn_table.click(fn=handle_targeted_recognition, inputs=[file_vl, gr.State("Table Recognition")], outputs=[md_preview_vl, md_raw_vl]) |
|
|
btn_chart.click(fn=handle_targeted_recognition, inputs=[file_vl, gr.State("Chart Recognition")], outputs=[md_preview_vl, md_raw_vl]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch(server_name="0.0.0.0", server_port=8812, share=False) |