File size: 5,968 Bytes
fcaa164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import logging
import re
from typing import Iterable, List
from pydantic import BaseModel
from docling.datamodel.base_models import (
AssembledUnit,
ContainerElement,
FigureElement,
Page,
PageElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult
from docling.models.base_model import BasePageModel
from docling.models.layout_model import LayoutModel
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class PageAssembleOptions(BaseModel):
pass
class PageAssembleModel(BasePageModel):
def __init__(self, options: PageAssembleOptions):
self.options = options
def sanitize_text(self, lines):
if len(lines) <= 1:
return " ".join(lines)
for ix, line in enumerate(lines[1:]):
prev_line = lines[ix]
if prev_line.endswith("-"):
prev_words = re.findall(r"\b[\w]+\b", prev_line)
line_words = re.findall(r"\b[\w]+\b", line)
if (
len(prev_words)
and len(line_words)
and prev_words[-1].isalnum()
and line_words[0].isalnum()
):
lines[ix] = prev_line[:-1]
else:
lines[ix] += " "
sanitized_text = "".join(lines)
return sanitized_text.strip() # Strip any leading or trailing whitespace
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "page_assemble"):
assert page.predictions.layout is not None
# assembles some JSON output page by page.
elements: List[PageElement] = []
headers: List[PageElement] = []
body: List[PageElement] = []
for cluster in page.predictions.layout.clusters:
# _log.info("Cluster label seen:", cluster.label)
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
textlines = [
cell.text.replace("\x02", "-").strip()
for cell in cluster.cells
if len(cell.text.strip()) > 0
]
text = self.sanitize_text(textlines)
text_el = TextElement(
label=cluster.label,
id=cluster.id,
text=text,
page_no=page.page_no,
cluster=cluster,
)
elements.append(text_el)
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
headers.append(text_el)
else:
body.append(text_el)
elif cluster.label in LayoutModel.TABLE_LABELS:
tbl = None
if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get(
cluster.id, None
)
if (
not tbl
): # fallback: add table without structure, if it isn't present
tbl = Table(
label=cluster.label,
id=cluster.id,
text="",
otsl_seq=[],
table_cells=[],
cluster=cluster,
page_no=page.page_no,
)
elements.append(tbl)
body.append(tbl)
elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None
if page.predictions.figures_classification:
fig = page.predictions.figures_classification.figure_map.get(
cluster.id, None
)
if (
not fig
): # fallback: add figure without classification, if it isn't present
fig = FigureElement(
label=cluster.label,
id=cluster.id,
text="",
data=None,
cluster=cluster,
page_no=page.page_no,
)
elements.append(fig)
body.append(fig)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
id=cluster.id,
page_no=page.page_no,
cluster=cluster,
)
elements.append(container_el)
body.append(container_el)
page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body
)
yield page
|