Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| This code is copied from https://github.com/allenai/olmocr | |
| Under the Apache 2.0 license. | |
| All credit goes to the original authors. | |
| """ | |
| from dataclasses import dataclass | |
| import re | |
| import tempfile | |
| from PIL import Image | |
| import subprocess | |
| import base64 | |
| from typing import List, Literal | |
| import random | |
| import ftfy | |
| from pypdf.generic import RectangleObject | |
| from pypdf import PdfReader | |
| class Element: | |
| pass | |
| class BoundingBox: | |
| x0: float | |
| y0: float | |
| x1: float | |
| y1: float | |
| def from_rectangle(rect: RectangleObject) -> "BoundingBox": | |
| return BoundingBox(rect[0], rect[1], rect[2], rect[3]) | |
| class TextElement(Element): | |
| text: str | |
| x: float | |
| y: float | |
| class ImageElement(Element): | |
| name: str | |
| bbox: BoundingBox | |
| class PageReport: | |
| mediabox: BoundingBox | |
| text_elements: List[TextElement] | |
| image_elements: List[ImageElement] | |
| def image_to_pdf(image_path): | |
| try: | |
| # Open the image file. | |
| img = Image.open(image_path) | |
| # Create a temporary file to store the PDF. | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| filename = tmp.name | |
| temp_pdf_created = True | |
| # Convert image to RGB if necessary and save as PDF. | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| img.save(filename, "PDF") | |
| return filename | |
| except Exception as conv_err: | |
| return None | |
| def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]: | |
| """ | |
| Get the MediaBox dimensions for a specific page in a PDF file using the pdfinfo command. | |
| :param pdf_file: Path to the PDF file | |
| :param page_num: The page number for which to extract MediaBox dimensions | |
| :return: A dictionary containing MediaBox dimensions or None if not found | |
| """ | |
| # Construct the pdfinfo command to extract info for the specific page | |
| command = ["pdfinfo", "-f", str(page_num), "-l", str(page_num), "-box", "-enc", "UTF-8", local_pdf_path] | |
| # Run the command using subprocess | |
| result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
| # Check if there is any error in executing the command | |
| if result.returncode != 0: | |
| raise ValueError(f"Error running pdfinfo: {result.stderr}") | |
| # Parse the output to find MediaBox | |
| output = result.stdout | |
| for line in output.splitlines(): | |
| if "MediaBox" in line: | |
| media_box_str: List[str] = line.split(":")[1].strip().split() | |
| media_box: List[float] = [float(x) for x in media_box_str] | |
| return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1]) | |
| raise ValueError("MediaBox not found in the PDF info.") | |
| def render_pdf_to_base64png(local_pdf_path: str, page_num: int, target_longest_image_dim: int = 2048) -> str: | |
| longest_dim = max(get_pdf_media_box_width_height(local_pdf_path, page_num)) | |
| # Convert PDF page to PNG using pdftoppm | |
| pdftoppm_result = subprocess.run( | |
| [ | |
| "pdftoppm", | |
| "-png", | |
| "-f", | |
| str(page_num), | |
| "-l", | |
| str(page_num), | |
| "-r", | |
| str(target_longest_image_dim * 72 / longest_dim), # 72 pixels per point is the conversion factor | |
| local_pdf_path, | |
| ], | |
| timeout=120, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr | |
| return base64.b64encode(pdftoppm_result.stdout).decode("utf-8") | |
| def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str: | |
| result = "" | |
| result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n" | |
| if max_length < 20: | |
| return result | |
| images = _merge_image_elements(report.image_elements) | |
| # Process image elements | |
| image_strings = [] | |
| for element in images: | |
| image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]\n" | |
| # Use element's unique identifier (e.g., id or position) for comparison | |
| image_strings.append((element, image_str)) | |
| # Process text elements | |
| text_strings = [] | |
| for element in report.text_elements: # type: ignore | |
| if len(element.text.strip()) == 0: # type: ignore | |
| continue | |
| element_text = _cleanup_element_text(element.text) # type: ignore | |
| text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore | |
| text_strings.append((element, text_str)) | |
| # Combine all elements with their positions for sorting | |
| all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = [] | |
| for elem, s in image_strings: | |
| position = (elem.bbox.x0, elem.bbox.y0) | |
| all_elements.append(("image", elem, s, position)) | |
| for elem, s in text_strings: | |
| position = (elem.x, elem.y) # type: ignore | |
| all_elements.append(("text", elem, s, position)) | |
| # Calculate total length | |
| total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements) | |
| if total_length <= max_length: | |
| # Include all elements | |
| for _, _, s, _ in all_elements: | |
| result += s | |
| return result | |
| # Identify elements with min/max coordinates | |
| edge_elements = set() | |
| if images: | |
| min_x0_image = min(images, key=lambda e: e.bbox.x0) | |
| max_x1_image = max(images, key=lambda e: e.bbox.x1) | |
| min_y0_image = min(images, key=lambda e: e.bbox.y0) | |
| max_y1_image = max(images, key=lambda e: e.bbox.y1) | |
| edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image]) | |
| if report.text_elements: | |
| text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0] | |
| if text_elements: | |
| min_x_text = min(text_elements, key=lambda e: e.x) | |
| max_x_text = max(text_elements, key=lambda e: e.x) | |
| min_y_text = min(text_elements, key=lambda e: e.y) | |
| max_y_text = max(text_elements, key=lambda e: e.y) | |
| edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore | |
| # Keep track of element IDs to prevent duplication | |
| selected_element_ids = set() | |
| selected_elements = [] | |
| # Include edge elements first | |
| for elem_type, elem, s, position in all_elements: | |
| if elem in edge_elements and id(elem) not in selected_element_ids: | |
| selected_elements.append((elem_type, elem, s, position)) | |
| selected_element_ids.add(id(elem)) | |
| # Calculate remaining length | |
| current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements) | |
| _remaining_length = max_length - current_length | |
| # Exclude edge elements from the pool | |
| remaining_elements = [(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements if id(elem) not in selected_element_ids] | |
| # Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate) | |
| # remaining_elements.sort(key=lambda x: (x[3][0], x[3][1])) | |
| # Shuffle remaining elements randomly | |
| random.shuffle(remaining_elements) | |
| # Add elements until reaching max_length | |
| for elem_type, elem, s, position in remaining_elements: | |
| if current_length + len(s) > max_length: | |
| break | |
| selected_elements.append((elem_type, elem, s, position)) | |
| selected_element_ids.add(id(elem)) | |
| current_length += len(s) | |
| # Sort selected elements by their positions to maintain logical order | |
| selected_elements.sort(key=lambda x: (x[3][0], x[3][1])) | |
| # Build the final result | |
| for _, _, s, _ in selected_elements: | |
| result += s | |
| return result | |
| def _cap_split_string(text: str, max_length: int) -> str: | |
| if len(text) <= max_length: | |
| return text | |
| head_length = max_length // 2 - 3 | |
| tail_length = head_length | |
| head = text[:head_length].rsplit(" ", 1)[0] or text[:head_length] | |
| tail = text[-tail_length:].split(" ", 1)[-1] or text[-tail_length:] | |
| return f"{head} ... {tail}" | |
| def _cleanup_element_text(element_text: str) -> str: | |
| MAX_TEXT_ELEMENT_LENGTH = 250 | |
| TEXT_REPLACEMENTS = {"[": "\\[", "]": "\\]", "\n": "\\n", "\r": "\\r", "\t": "\\t"} | |
| text_replacement_pattern = re.compile("|".join(re.escape(key) for key in TEXT_REPLACEMENTS.keys())) | |
| element_text = ftfy.fix_text(element_text).strip() | |
| # Replace square brackets with escaped brackets and other escaped chars | |
| element_text = text_replacement_pattern.sub(lambda match: TEXT_REPLACEMENTS[match.group(0)], element_text) | |
| return _cap_split_string(element_text, MAX_TEXT_ELEMENT_LENGTH) | |
| def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) -> List[ImageElement]: | |
| n = len(images) | |
| parent = list(range(n)) # Initialize Union-Find parent pointers | |
| def find(i): | |
| # Find with path compression | |
| root = i | |
| while parent[root] != root: | |
| root = parent[root] | |
| while parent[i] != i: | |
| parent_i = parent[i] | |
| parent[i] = root | |
| i = parent_i | |
| return root | |
| def union(i, j): | |
| # Union by attaching root of one tree to another | |
| root_i = find(i) | |
| root_j = find(j) | |
| if root_i != root_j: | |
| parent[root_i] = root_j | |
| def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool: | |
| # Compute horizontal and vertical distances between boxes | |
| h_dist = max(0, max(b1.x0, b2.x0) - min(b1.x1, b2.x1)) | |
| v_dist = max(0, max(b1.y0, b2.y0) - min(b1.y1, b2.y1)) | |
| # Check if distances are within tolerance | |
| return h_dist <= tolerance and v_dist <= tolerance | |
| # Union overlapping images | |
| for i in range(n): | |
| for j in range(i + 1, n): | |
| if bboxes_overlap(images[i].bbox, images[j].bbox, tolerance): | |
| union(i, j) | |
| # Group images by their root parent | |
| groups: dict[int, list[int]] = {} | |
| for i in range(n): | |
| root = find(i) | |
| groups.setdefault(root, []).append(i) | |
| # Merge images in the same group | |
| merged_images = [] | |
| for indices in groups.values(): | |
| # Initialize merged bounding box | |
| merged_bbox = images[indices[0]].bbox | |
| merged_name = images[indices[0]].name | |
| for idx in indices[1:]: | |
| bbox = images[idx].bbox | |
| # Expand merged_bbox to include the current bbox | |
| merged_bbox = BoundingBox( | |
| x0=min(merged_bbox.x0, bbox.x0), | |
| y0=min(merged_bbox.y0, bbox.y0), | |
| x1=max(merged_bbox.x1, bbox.x1), | |
| y1=max(merged_bbox.y1, bbox.y1), | |
| ) | |
| # Optionally, update the name | |
| merged_name += f"+{images[idx].name}" | |
| merged_images.append(ImageElement(name=merged_name, bbox=merged_bbox)) | |
| # Return the merged images along with other elements | |
| return merged_images | |
| def _transform_point(x, y, m): | |
| x_new = m[0] * x + m[2] * y + m[4] | |
| y_new = m[1] * x + m[3] * y + m[5] | |
| return x_new, y_new | |
| def _mult(m: List[float], n: List[float]) -> List[float]: | |
| return [ | |
| m[0] * n[0] + m[1] * n[2], | |
| m[0] * n[1] + m[1] * n[3], | |
| m[2] * n[0] + m[3] * n[2], | |
| m[2] * n[1] + m[3] * n[3], | |
| m[4] * n[0] + m[5] * n[2] + n[4], | |
| m[4] * n[1] + m[5] * n[3] + n[5], | |
| ] | |
| def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport: | |
| reader = PdfReader(local_pdf_path) | |
| page = reader.pages[page_num - 1] | |
| resources = page.get("/Resources", {}) | |
| xobjects = resources.get("/XObject", {}) | |
| text_elements, image_elements = [], [] | |
| def visitor_body(text, cm, tm, font_dict, font_size): | |
| txt2user = _mult(tm, cm) | |
| text_elements.append(TextElement(text, txt2user[4], txt2user[5])) | |
| def visitor_op(op, args, cm, tm): | |
| if op == b"Do": | |
| xobject_name = args[0] | |
| xobject = xobjects.get(xobject_name) | |
| if xobject and xobject["/Subtype"] == "/Image": | |
| # Compute image bbox | |
| # The image is placed according to the CTM | |
| _width = xobject.get("/Width") | |
| _height = xobject.get("/Height") | |
| x0, y0 = _transform_point(0, 0, cm) | |
| x1, y1 = _transform_point(1, 1, cm) | |
| image_elements.append(ImageElement(xobject_name, BoundingBox(min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1)))) | |
| page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op) | |
| return PageReport( | |
| mediabox=BoundingBox.from_rectangle(page.mediabox), | |
| text_elements=text_elements, | |
| image_elements=image_elements, | |
| ) | |
| def get_anchor_text( | |
| local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000 | |
| ) -> str: | |
| assert page > 0, "Pages are 1-indexed in pdf-land" | |
| if pdf_engine == "pdfreport": | |
| return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length) | |
| else: | |
| raise NotImplementedError("Unknown engine") |