diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..428c8cc3b119c8babda9e0e5c4f58f0f4ed21ee5
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,5 @@
+/venv/
+/.venv/
+.git
+/detectron2/
+/images/
\ No newline at end of file
diff --git a/.gitattributes b/.gitattributes
index c5fb4604cc75da6c8eecfb51311d21e8c0ed2f87..8e0dbcb6180d54d5b9f9aa174c926be8b23aaf0a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,8 +1,8 @@
-# Handle Python code and text files
-*.py text eol=lf
-*.md text eol=lf
-*.txt text eol=lf
-
-# Handle binary files
-*.pdf binary
-*.docx binary
\ No newline at end of file
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.pdf filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d544f31a8409819e5bdc6c09d0215167d37a84de
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1 @@
+custom: ["https://huridocs.org/donate/"]
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000000000000000000000000000000000000..38002d52daca4b35470c67a5c906689e4056f713
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,17 @@
+version: 2
+updates:
+ - package-ecosystem: "pip"
+ directory: "/"
+ schedule:
+ interval: "daily"
+ open-pull-requests-limit: 5
+ labels:
+ - "dependencies"
+ - package-ecosystem: "github-actions"
+ directory: "/"
+ schedule:
+ interval: "daily"
+ - package-ecosystem: "docker"
+ directory: "/"
+ schedule:
+ interval: "daily"
diff --git a/.github/workflows/push_docker_image.yml b/.github/workflows/push_docker_image.yml
new file mode 100644
index 0000000000000000000000000000000000000000..f23df8a46cb452e06ecf241083319f301c71c256
--- /dev/null
+++ b/.github/workflows/push_docker_image.yml
@@ -0,0 +1,53 @@
+name: Create and publish Docker image
+
+on:
+ push:
+ tags:
+ - 'v*'
+
+env:
+ REGISTRY: ghcr.io
+ IMAGE_NAME: huridocs/pdf-document-layout-analysis
+
+jobs:
+ build-and-push-image:
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Install dependencies
+ run: sudo apt-get install -y just
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v3
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata (tags, labels) for Docker
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+ tags: |
+ type=ref,event=branch
+ type=ref,event=pr
+ type=semver,pattern={{version}}
+ type=semver,pattern={{major}}.{{minor}}
+
+ - name: Create folder models
+ run: mkdir -p models
+
+ - name: Build and push
+ uses: docker/build-push-action@v6
+ with:
+ context: .
+ file: Dockerfile
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000000000000000000000000000000000000..916b7dce1cfc7d1bb148992846daf98a3dcb7e66
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,49 @@
+# This workflow will install Python dependencies, run tests and lint with a single version of Python
+# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
+
+name: Test
+
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ branches: [ main ]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+
+ - name: Install dependencies
+ run: sudo apt-get update; sudo apt-get install -y pdftohtml qpdf just
+
+ - name: Free up space
+ run: just free_up_space
+
+ - name: Install venv
+ run: just install_venv
+
+ - name: Lint with black
+ run: just check_format
+
+ - name: Start service
+ run: just start_detached
+
+ - name: Check API ready
+ uses: emilioschepis/wait-for-endpoint@v1.0.3
+ with:
+ url: http://localhost:5060
+ method: GET
+ expected-status: 200
+ timeout: 120000
+ interval: 500
+
+ - name: Test with unittest
+ run: just test
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2e3ced585586cdd610cba1211664b215fc18373d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,167 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+/models/
+/word_grids/
+/jsons/
+/model_output/
+/pdf_outputs/
+/detectron2/
+/ocr/
diff --git a/Dockerfile b/Dockerfile
new file mode 100755
index 0000000000000000000000000000000000000000..aaf501a5524f91d5fe126f395a6c75a8f4c0fed6
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,55 @@
+FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime
+COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
+
+RUN apt-get update
+RUN apt-get install --fix-missing -y -q --no-install-recommends libgomp1 ffmpeg libsm6 pdftohtml libxext6 git ninja-build g++ qpdf pandoc
+
+
+RUN apt-get install -y ocrmypdf
+RUN apt-get install -y tesseract-ocr-fra
+RUN apt-get install -y tesseract-ocr-spa
+RUN apt-get install -y tesseract-ocr-deu
+RUN apt-get install -y tesseract-ocr-ara
+RUN apt-get install -y tesseract-ocr-mya
+RUN apt-get install -y tesseract-ocr-hin
+RUN apt-get install -y tesseract-ocr-tam
+RUN apt-get install -y tesseract-ocr-tha
+RUN apt-get install -y tesseract-ocr-chi-sim
+RUN apt-get install -y tesseract-ocr-tur
+RUN apt-get install -y tesseract-ocr-ukr
+RUN apt-get install -y tesseract-ocr-ell
+RUN apt-get install -y tesseract-ocr-rus
+RUN apt-get install -y tesseract-ocr-kor
+RUN apt-get install -y tesseract-ocr-kor-vert
+
+
+RUN mkdir -p /app/src
+RUN mkdir -p /app/models
+
+RUN addgroup --system python && adduser --system --group python
+RUN chown -R python:python /app
+USER python
+
+ENV VIRTUAL_ENV=/app/.venv
+RUN python -m venv $VIRTUAL_ENV
+ENV PATH="$VIRTUAL_ENV/bin:$PATH"
+
+COPY requirements.txt requirements.txt
+RUN uv pip install --upgrade pip
+RUN uv pip install -r requirements.txt
+
+WORKDIR /app
+
+RUN cd src; git clone https://github.com/facebookresearch/detectron2;
+RUN cd src/detectron2; git checkout 70f454304e1a38378200459dd2dbca0f0f4a5ab4; python setup.py build develop
+RUN uv pip install pycocotools==2.0.8
+
+COPY ./start.sh ./start.sh
+COPY ./src/. ./src
+COPY ./models/. ./models/
+RUN python src/download_models.py
+
+ENV PYTHONPATH "${PYTHONPATH}:/app/src"
+ENV TRANSFORMERS_VERBOSITY=error
+ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..650934a0bb575687497f6d5bffd7d26c1787c878
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024-present HURIDOCS
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..10db0fd3e3f7ab706b4b7c7d67c79af5741fa4a6
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,78 @@
+HAS_GPU := $(shell command -v nvidia-smi > /dev/null && echo 1 || echo 0)
+
+install:
+ . .venv/bin/activate; pip install -Ur requirements.txt
+
+activate:
+ . .venv/bin/activate
+
+install_venv:
+ python3 -m venv .venv
+ . .venv/bin/activate; python -m pip install --upgrade pip
+ . .venv/bin/activate; python -m pip install -r dev-requirements.txt
+
+formatter:
+ . .venv/bin/activate; command black --line-length 125 .
+
+check_format:
+ . .venv/bin/activate; command black --line-length 125 . --check
+
+remove_docker_containers:
+ docker compose ps -q | xargs docker rm
+
+remove_docker_images:
+ docker compose config --images | xargs docker rmi
+
+start:
+ifeq ($(OS), Windows_NT)
+ if not exist models mkdir models
+else
+ mkdir -p ./models
+endif
+ifeq ($(HAS_GPU), 1)
+ @echo "NVIDIA GPU detected, using docker-compose-gpu.yml"
+ docker compose -f docker-compose-gpu.yml up --build
+else
+ @echo "No NVIDIA GPU detected, using docker-compose.yml"
+ docker compose -f docker-compose.yml up --build
+endif
+
+
+start_no_gpu:
+ mkdir -p ./models
+ docker compose up --build
+
+stop:
+ docker compose stop
+
+test:
+ . .venv/bin/activate; command cd src; command python -m pytest
+
+free_up_space:
+ df -h
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+ sudo apt-get remove -y '^llvm-.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y google-cloud-sdk hhvm google-chrome-stable firefox mono-devel || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
+ sudo docker image prune --all --force
+ df -h
+
+
+start_detached:
+ mkdir -p ./models
+ docker compose up --build -d
+
+start_detached_gpu:
+ mkdir -p ./models
+ RESTART_IF_NO_GPU=true docker compose -f docker-compose-gpu.yml up --build -d
+
+upgrade:
+ . .venv/bin/activate; pip-upgrade
\ No newline at end of file
diff --git a/README.md b/README.md
index bf73e2526156324698a63a952f6f3fc2b925c196..e6533cec175eb0197702cd8ab9cdc0af63e32b1c 100644
--- a/README.md
+++ b/README.md
@@ -1,40 +1,910 @@
+
PDF Document Layout Analysis
+A Docker-powered microservice for intelligent PDF document layout analysis, OCR, and content extraction
+
+
+
+
+
+
+
+
+
+
+
+
+
---
-title: Audit Report Generator
-emoji: 📝
-colorFrom: purple
-colorTo: indigo
-sdk: gradio
-sdk_version: 5.38.2
-app_file: app.py
-pinned: false
+
+## 🚀 Overview
+
+This project provides a powerful and flexible PDF analysis microservice built with **Clean Architecture** principles. The service enables OCR, segmentation, and classification of different parts of PDF pages, identifying elements such as texts, titles, pictures, tables, formulas, and more. Additionally, it determines the correct reading order of these identified elements and can convert PDFs to various formats including Markdown and HTML.
+
+### ✨ Key Features
+
+- 🔍 **Advanced PDF Layout Analysis** - Segment and classify PDF content with high accuracy
+- 🖼️ **Visual & Fast Models** - Choose between VGT (Vision Grid Transformer) for accuracy or LightGBM for speed
+- 📝 **Multi-format Output** - Export to JSON, Markdown, HTML, and visualize PDF segmentations
+- 🌐 **OCR Support** - 150+ language support with Tesseract OCR
+- 📊 **Table & Formula Extraction** - Extract tables as HTML and formulas as LaTeX
+- 🏗️ **Clean Architecture** - Modular, testable, and maintainable codebase
+- 🐳 **Docker-Ready** - Easy deployment with GPU support
+- ⚡ **RESTful API** - Comprehensive API with 10+ endpoints
+
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+### 🔗 Project Links
+
+- **GitHub**: [pdf-document-layout-analysis](https://github.com/huridocs/pdf-document-layout-analysis)
+- **HuggingFace**: [pdf-document-layout-analysis](https://huggingface.co/HURIDOCS/pdf-document-layout-analysis)
+- **DockerHub**: [pdf-document-layout-analysis](https://hub.docker.com/r/huridocs/pdf-document-layout-analysis/)
+
---
-# NHVAS Audit Report Generator
+## 🚀 Quick Start
+
+### 1. Start the Service
+
+**With GPU support (recommended for better performance):**
+```bash
+make start
+```
+
+**Without GPU support:**
+```bash
+make start_no_gpu
+```
+
+The service will be available at `http://localhost:5060`
+
+**Check service status:**
+
+```bash
+curl http://localhost:5060/info
+```
+
+### 2. Basic PDF Analysis
+
+**Analyze a PDF document (VGT model - high accuracy):**
+```bash
+curl -X POST -F 'file=@/path/to/your/document.pdf' http://localhost:5060
+```
+
+**Fast analysis (LightGBM models - faster processing):**
+```bash
+curl -X POST -F 'file=@/path/to/your/document.pdf' -F "fast=true" http://localhost:5060
+```
+
+### 3. Stop the Service
+
+```bash
+make stop
+```
+
+> 💡 **Tip**: Replace `/path/to/your/document.pdf` with the actual path to your PDF file. The service will return a JSON response with segmented content and metadata.
+
+
+## 📋 Table of Contents
+
+- [🚀 Quick Start](#-quick-start)
+- [⚙️ Dependencies](#-dependencies)
+- [📋 Requirements](#-requirements)
+- [📚 API Reference](#-api-reference)
+- [💡 Usage Examples](#-usage-examples)
+- [🏗️ Architecture](#-architecture)
+- [🤖 Models](#-models)
+- [📊 Data](#-data)
+- [🔧 Development](#-development)
+- [📈 Benchmarks](#-benchmarks)
+ - [Performance](#performance)
+ - [Speed](#speed)
+- [🌐 Installation of More Languages for OCR](#-installation-of-more-languages-for-ocr)
+- [🔗 Related Services](#-related-services)
+- [🤝 Contributing](#-contributing)
+
+
+
+## ⚙️ Dependencies
+
+### Required
+- **Docker Desktop 4.25.0+** - [Installation Guide](https://www.docker.com/products/docker-desktop/)
+- **Python 3.10+** (for local development)
+
+### Optional
+- **NVIDIA Container Toolkit** - [Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) (for GPU support)
+
+## 📋 Requirements
+
+### System Requirements
+- **RAM**: 2 GB minimum
+- **GPU Memory**: 5 GB (optional, will fallback to CPU if unavailable)
+- **Disk Space**: 10 GB for models and dependencies
+- **CPU**: Multi-core recommended for better performance
+
+### Docker Requirements
+- Docker Engine 20.10+
+- Docker Compose 2.0+
+
+## 📚 API Reference
+
+The service provides a comprehensive RESTful API with the following endpoints:
+
+### Core Analysis Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/` | POST | Analyze PDF layout and extract segments | `file`, `fast`, `parse_tables_and_math` |
+| `/save_xml/{filename}` | POST | Analyze PDF and save XML output | `file`, `xml_file_name`, `fast` |
+| `/get_xml/{filename}` | GET | Retrieve saved XML analysis | `xml_file_name` |
+
+### Content Extraction Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/text` | POST | Extract text by content types | `file`, `fast`, `types` |
+| `/toc` | POST | Extract table of contents | `file`, `fast` |
+| `/toc_legacy_uwazi_compatible` | POST | Extract TOC (Uwazi compatible) | `file` |
+
+### Format Conversion Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/markdown` | POST | Convert PDF to Markdown (includes segmentation data in zip) | `file`, `fast`, `extract_toc`, `dpi`, `output_file` |
+| `/html` | POST | Convert PDF to HTML (includes segmentation data in zip) | `file`, `fast`, `extract_toc`, `dpi`, `output_file` |
+| `/visualize` | POST | Visualize segmentation results on the PDF | `file`, `fast` |
+
+### OCR & Utility Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/ocr` | POST | Apply OCR to PDF | `file`, `language` |
+| `/info` | GET | Get service information | - |
+| `/` | GET | Health check and system info | - |
+| `/error` | GET | Test error handling | - |
+
+### Common Parameters
+
+- **`file`**: PDF file to process (multipart/form-data)
+- **`fast`**: Use LightGBM models instead of VGT (boolean, default: false)
+- **`parse_tables_and_math`**: Apply OCR to table regions (boolean, default: false) and convert formulas to LaTeX
+- **`language`**: OCR language code (string, default: "en")
+- **`types`**: Comma-separated content types to extract (string, default: "all")
+- **`extract_toc`**: Include table of contents at the beginning of the output (boolean, default: false)
+- **`dpi`**: Image resolution for conversion (integer, default: 120)
+
+## 💡 Usage Examples
+
+### Basic PDF Analysis
+
+**Standard analysis with VGT model:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060
+```
+
+**Fast analysis with LightGBM models:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'fast=true' \
+ http://localhost:5060
+```
+
+**Analysis with table and math parsing:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'parse_tables_and_math=true' \
+ http://localhost:5060
+```
+
+### Text Extraction
+
+**Extract all text:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'types=all' \
+ http://localhost:5060/text
+```
+
+**Extract specific content types:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'types=title,text,table' \
+ http://localhost:5060/text
+```
+
+### Format Conversion
+
+**Convert to Markdown:**
+```bash
+curl -X POST http://localhost:5060/markdown \
+ -F 'file=@document.pdf' \
+ -F 'extract_toc=true' \
+ -F 'output_file=document.md' \
+ --output 'document.zip'
+```
+
+**Convert to HTML:**
+```bash
+curl -X POST http://localhost:5060/html \
+ -F 'file=@document.pdf' \
+ -F 'extract_toc=true' \
+ -F 'output_file=document.html' \
+ --output 'document.zip'
+```
+
+> **📋 Segmentation Data**: Format conversion endpoints automatically include detailed segmentation data in the zip output. The resulting zip file contains a `{filename}_segmentation.json` file with information about each detected document segment including:
+> - **Coordinates**: `left`, `top`, `width`, `height`
+> - **Page information**: `page_number`, `page_width`, `page_height`
+> - **Content**: `text` content and segment `type` (e.g., "Title", "Text", "Table", "Picture")
+
+
+### OCR Processing
+
+**OCR in English:**
+```bash
+curl -X POST \
+ -F 'file=@scanned_document.pdf' \
+ -F 'language=en' \
+ http://localhost:5060/ocr \
+ --output ocr_processed.pdf
+```
+
+**OCR in other languages:**
+```bash
+# French
+curl -X POST \
+ -F 'file=@document_french.pdf' \
+ -F 'language=fr' \
+ http://localhost:5060/ocr \
+ --output ocr_french.pdf
+
+# Spanish
+curl -X POST \
+ -F 'file=@document_spanish.pdf' \
+ -F 'language=es' \
+ http://localhost:5060/ocr \
+ --output ocr_spanish.pdf
+```
+
+### Visualization
+
+**Generate visualization PDF:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/visualize \
+ --output visualization.pdf
+```
+
+### Table of Contents Extraction
+
+**Extract structured TOC:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/toc
+```
+
+### XML Storage and Retrieval
+
+**Analyze and save XML:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/save_xml/my_analysis
+```
+
+**Retrieve saved XML:**
+```bash
+curl http://localhost:5060/get_xml/my_analysis.xml
+```
+
+### Service Information
+
+**Get service info and supported languages:**
+```bash
+curl http://localhost:5060/info
+```
+
+**Health check:**
+```bash
+curl http://localhost:5060/
+```
+
+### Response Format
+
+Most endpoints return JSON with segment information:
+
+```json
+[
+ {
+ "left": 72.0,
+ "top": 84.0,
+ "width": 451.2,
+ "height": 23.04,
+ "page_number": 1,
+ "page_width": 595.32,
+ "page_height": 841.92,
+ "text": "Document Title",
+ "type": "Title"
+ },
+ {
+ "left": 72.0,
+ "top": 120.0,
+ "width": 451.2,
+ "height": 200.0,
+ "page_number": 1,
+ "page_width": 595.32,
+ "page_height": 841.92,
+ "text": "This is the main text content...",
+ "type": "Text"
+ }
+]
+```
+
+### Supported Content Types
+
+- `Caption` - Image and table captions
+- `Footnote` - Footnote text
+- `Formula` - Mathematical formulas
+- `List item` - List items and bullet points
+- `Page footer` - Footer content
+- `Page header` - Header content
+- `Picture` - Images and figures
+- `Section header` - Section headings
+- `Table` - Table content
+- `Text` - Regular text paragraphs
+- `Title` - Document and section titles
+
+
+## 🏗️ Architecture
+
+This project follows **Clean Architecture** principles, ensuring separation of concerns, testability, and maintainability. The codebase is organized into distinct layers:
+
+### Directory Structure
+
+```
+src/
+├── domain/ # Enterprise Business Rules
+│ ├── PdfImages.py # PDF image handling domain logic
+│ ├── PdfSegment.py # PDF segment entity
+│ ├── Prediction.py # ML prediction entity
+│ └── SegmentBox.py # Core segment box entity
+├── use_cases/ # Application Business Rules
+│ ├── pdf_analysis/ # PDF analysis use case
+│ ├── text_extraction/ # Text extraction use case
+│ ├── toc_extraction/ # Table of contents extraction
+│ ├── visualization/ # PDF visualization use case
+│ ├── ocr/ # OCR processing use case
+│ ├── markdown_conversion/ # Markdown conversion use case
+│ └── html_conversion/ # HTML conversion use case
+├── adapters/ # Interface Adapters
+│ ├── infrastructure/ # External service adapters
+│ ├── ml/ # Machine learning model adapters
+│ ├── storage/ # File storage adapters
+│ └── web/ # Web framework adapters
+├── ports/ # Interface definitions
+│ ├── services/ # Service interfaces
+│ └── repositories/ # Repository interfaces
+└── drivers/ # Frameworks & Drivers
+ └── web/ # FastAPI application setup
+```
+
+### Layer Responsibilities
+
+- **Domain Layer**: Contains core business entities and rules independent of external concerns
+- **Use Cases Layer**: Orchestrates domain entities to fulfill specific application requirements
+- **Adapters Layer**: Implements interfaces defined by inner layers and adapts external frameworks
+- **Drivers Layer**: Contains frameworks, databases, and external agency configurations
+
+### Key Benefits
+
+- 🔄 **Dependency Inversion**: High-level modules don't depend on low-level modules
+- 🧪 **Testability**: Easy to unit test business logic in isolation
+- 🔧 **Maintainability**: Changes to external frameworks don't affect business rules
+- 📈 **Scalability**: Easy to add new features without modifying existing code
+
+
+## 🤖 Models
+
+The service offers two complementary model approaches, each optimized for different use cases:
+
+### 1. Vision Grid Transformer (VGT) - High Accuracy Model
+
+**Overview**: A state-of-the-art visual model developed by Alibaba Research Group that "sees" the entire page layout.
+
+**Key Features**:
+- 🎯 **High Accuracy**: Best-in-class performance on document layout analysis
+- 👁️ **Visual Understanding**: Analyzes the entire page context including spatial relationships
+- 📊 **Trained on DocLayNet**: Uses the comprehensive [DocLayNet dataset](https://github.com/DS4SD/DocLayNet)
+- 🔬 **Research-Backed**: Based on [Advanced Literate Machinery](https://github.com/AlibabaResearch/AdvancedLiterateMachinery)
+
+**Resource Requirements**:
+- GPU: 5GB+ VRAM (recommended)
+- CPU: Falls back automatically if GPU unavailable
+- Processing Speed: ~1.75 seconds/page (GPU [GTX 1070]) or ~13.5 seconds/page (CPU [i7-8700])
+
+### 2. LightGBM Models - Fast & Efficient
+
+**Overview**: Lightweight ensemble of two specialized models using XML-based features from Poppler.
+
+**Key Features**:
+- ⚡ **High Speed**: ~0.42 seconds per page on CPU (i7-8700)
+- 💾 **Low Resource Usage**: CPU-only, minimal memory footprint
+- 🔄 **Dual Model Approach**:
+ - **Token Type Classifier**: Identifies content types (title, text, table, etc.)
+ - **Segmentation Model**: Determines proper content boundaries
+- 📄 **XML-Based**: Uses Poppler's PDF-to-XML conversion for feature extraction
+
+**Trade-offs**:
+- Slightly lower accuracy compared to VGT
+- No visual context understanding
+- Excellent for batch processing and resource-constrained environments
+
+### OCR Integration
+
+Both models integrate seamlessly with OCR capabilities:
+
+- **Engine**: [Tesseract OCR](https://github.com/tesseract-ocr/tesseract)
+- **Processing**: [ocrmypdf](https://ocrmypdf.readthedocs.io/en/latest/index.html)
+- **Languages**: 150+ supported languages
+- **Output**: Searchable PDFs with preserved layout
+
+### Model Selection Guide
+
+| Use Case | Recommended Model | Reason |
+|----------|------------------|---------|
+| High accuracy requirements | VGT | Superior visual understanding |
+| Batch processing | LightGBM | Faster processing, lower resources |
+| GPU available | VGT | Leverages GPU acceleration |
+| CPU-only environment | LightGBM | Optimized for CPU processing |
+| Real-time applications | LightGBM | Consistent fast response times |
+| Research/analysis | VGT | Best accuracy for detailed analysis |
+
+## 📊 Data
-This tool automatically extracts relevant fields from an NHVAS PDF audit summary and populates a Word report template with the extracted data.
+### Training Dataset
-## Features
+Both model types are trained on the comprehensive [DocLayNet dataset](https://github.com/DS4SD/DocLayNet), a large-scale document layout analysis dataset containing over 80,000 document pages.
-- Upload an NHVAS PDF report
-- Upload your Word `.docx` report template
-- Automatically fills red-text placeholders in the Word document
-- Supports 7 module combinations (Mass, Maintenance, Fatigue, and their permutations)
-- Download the completed report instantly
+### Document Categories
-## How to Use
+The models can identify and classify 11 distinct content types:
-1. Upload your **PDF audit report**.
-2. Upload your **Word template (.docx)** with red-text placeholders.
-3. Click **Generate Report**.
-4. Download the updated Word document.
+| ID | Category | Description |
+|----|----------|-------------|
+| 1 | **Caption** | Image and table captions |
+| 2 | **Footnote** | Footnote references and text |
+| 3 | **Formula** | Mathematical equations and formulas |
+| 4 | **List item** | Bulleted and numbered list items |
+| 5 | **Page footer** | Footer content and page numbers |
+| 6 | **Page header** | Header content and titles |
+| 7 | **Picture** | Images, figures, and graphics |
+| 8 | **Section header** | Section and subsection headings |
+| 9 | **Table** | Tabular data and structures |
+| 10 | **Text** | Regular paragraph text |
+| 11 | **Title** | Document and chapter titles |
-## Tech Stack
+### Dataset Characteristics
-- Python 🐍
-- Gradio UI (via Hugging Face Spaces)
-- PyMuPDF (for PDF parsing)
-- python-docx (for Word file editing)
+- **Domain Coverage**: Academic papers, technical documents, reports
+- **Language**: Primarily English with multilingual support
+- **Quality**: High-quality annotations with bounding boxes and labels
+- **Diversity**: Various document layouts, fonts, and formatting styles
+
+For detailed information about the dataset, visit the [DocLayNet repository](https://github.com/DS4SD/DocLayNet).
+
+## 🔧 Development
+
+### Local Development Setup
+
+1. **Clone the repository:**
+ ```bash
+ git clone https://github.com/huridocs/pdf-document-layout-analysis.git
+ cd pdf-document-layout-analysis
+ ```
+
+2. **Create virtual environment:**
+ ```bash
+ make install_venv
+ ```
+
+3. **Activate environment:**
+ ```bash
+ make activate
+ # or manually: source .venv/bin/activate
+ ```
+
+4. **Install dependencies:**
+ ```bash
+ make install
+ ```
+
+### Code Quality
+
+**Format code:**
+```bash
+make formatter
+```
+
+**Check formatting:**
+```bash
+make check_format
+```
+
+### Testing
+
+**Run tests:**
+```bash
+make test
+```
+
+**Integration tests:**
+```bash
+# Tests are located in src/tests/integration/
+python -m pytest src/tests/integration/test_end_to_end.py
+```
+
+### Docker Development
+
+**Build and start (detached mode):**
+```bash
+# With GPU
+make start_detached_gpu
+
+# Without GPU
+make start_detached
+```
+
+**Clean up Docker resources:**
+```bash
+# Remove containers
+make remove_docker_containers
+
+# Remove images
+make remove_docker_images
+```
+
+### Project Structure
+
+```
+pdf-document-layout-analysis/
+├── src/ # Source code
+│ ├── domain/ # Business entities
+│ ├── use_cases/ # Application logic
+│ ├── adapters/ # External integrations
+│ ├── ports/ # Interface definitions
+│ └── drivers/ # Framework configurations
+├── test_pdfs/ # Test PDF files
+├── models/ # ML model storage
+├── docker-compose.yml # Docker configuration
+├── Dockerfile # Container definition
+├── Makefile # Development commands
+├── pyproject.toml # Python project configuration
+└── requirements.txt # Python dependencies
+```
+
+### Environment Variables
+
+Key configuration options:
+
+```bash
+# OCR configuration
+OCR_SOURCE=/tmp/ocr_source
+
+# Model paths (auto-configured)
+MODELS_PATH=./models
+
+# Service configuration
+HOST=0.0.0.0
+PORT=5060
+```
+
+### Adding New Features
+
+1. **Domain Logic**: Add entities in `src/domain/`
+2. **Use Cases**: Implement business logic in `src/use_cases/`
+3. **Adapters**: Create integrations in `src/adapters/`
+4. **Ports**: Define interfaces in `src/ports/`
+5. **Controllers**: Add endpoints in `src/adapters/web/`
+
+### Debugging
+
+**View logs:**
+```bash
+docker compose logs -f
+```
+
+**Access container:**
+```bash
+docker exec -it pdf-document-layout-analysis /bin/bash
+```
+
+**Free up disk space:**
+```bash
+make free_up_space
+```
+
+### Order of Output Elements
+
+The service returns SegmentBox elements in a carefully determined reading order:
+
+#### Reading Order Algorithm
+
+1. **Poppler Integration**: Uses [Poppler](https://poppler.freedesktop.org) PDF-to-XML conversion to establish initial token reading order
+2. **Segment Averaging**: Calculates average reading order for multi-token segments
+3. **Type-Based Sorting**: Prioritizes content types:
+ - **Headers** placed first
+ - **Main content** in reading order
+ - **Footers and footnotes** placed last
+
+#### Non-Text Elements
+
+For segments without text (e.g., images):
+- Processed after text-based sorting
+- Positioned based on nearest text segment proximity
+- Uses spatial distance as the primary criterion
+
+### Advanced Table and Formula Extraction
+
+#### Default Behavior
+- **Formulas**: Automatically extracted as LaTeX format in the `text` property
+- **Tables**: Basic text extraction included by default
+
+#### Enhanced Table Extraction
+
+Parse tables and extract them in HTML format by setting `parse_tables_and_math=true`:
+
+```bash
+curl -X POST -F 'file=@document.pdf' -F 'parse_tables_and_math=true' http://localhost:5060
+```
+
+
+#### Extraction Engines
+- **Formulas**: [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
+- **Tables**: [RapidTable](https://github.com/RapidAI/RapidTable)
+
+
+## 📈 Benchmarks
+
+### Performance
+
+VGT model performance on PubLayNet dataset:
+
+| Metric | Overall | Text | Title | List | Table | Figure |
+|--------|---------|------|-------|------|-------|--------|
+| **F1 Score** | **0.962** | 0.950 | 0.939 | 0.968 | 0.981 | 0.971 |
+
+> 📊 **Comparison**: View comprehensive model comparisons at [Papers With Code](https://paperswithcode.com/sota/document-layout-analysis-on-publaynet-val)
+
+### Speed
+
+Performance benchmarks on 15-page academic documents:
+
+| Model | Hardware | Speed (sec/page) | Use Case |
+|-------|----------|------------------|----------|
+| **LightGBM** | CPU (i7-8700 3.2GHz) | **0.42** | Fast processing |
+| **VGT** | GPU (GTX 1070) | **1.75** | High accuracy |
+| **VGT** | CPU (i7-8700 3.2GHz) | 13.5 | CPU fallback |
+
+### Performance Recommendations
+
+- **GPU Available**: Use VGT for best accuracy-speed balance
+- **CPU Only**: Use LightGBM for optimal performance
+- **Batch Processing**: LightGBM for consistent throughput
+- **High Accuracy**: VGT with GPU for best results
+
+
+## 🌐 Installation of More Languages for OCR
+
+The service uses Tesseract OCR with support for 150+ languages. The Docker image includes only common languages to minimize image size.
+
+### Installing Additional Languages
+
+#### 1. Access the Container
+```bash
+docker exec -it --user root pdf-document-layout-analysis /bin/bash
+```
+
+#### 2. Install Language Packs
+```bash
+# Install specific language
+apt-get update
+apt-get install tesseract-ocr-[LANGCODE]
+```
+
+#### 3. Common Language Examples
+
+```bash
+# Korean
+apt-get install tesseract-ocr-kor
+
+# German
+apt-get install tesseract-ocr-deu
+
+# French
+apt-get install tesseract-ocr-fra
+
+# Spanish
+apt-get install tesseract-ocr-spa
+
+# Chinese Simplified
+apt-get install tesseract-ocr-chi-sim
+
+# Arabic
+apt-get install tesseract-ocr-ara
+
+# Japanese
+apt-get install tesseract-ocr-jpn
+```
+
+#### 4. Verify Installation
+
+```bash
+curl http://localhost:5060/info
+```
+
+### Language Code Reference
+
+Find Tesseract language codes in the [ISO to Tesseract mapping](https://github.com/huridocs/pdf-document-layout-analysis/blob/main/src/adapters/infrastructure/ocr/languages.py).
+
+### Supported Languages
+
+Common language codes:
+- `eng` - English
+- `fra` - French
+- `deu` - German
+- `spa` - Spanish
+- `ita` - Italian
+- `por` - Portuguese
+- `rus` - Russian
+- `chi-sim` - Chinese Simplified
+- `chi-tra` - Chinese Traditional
+- `jpn` - Japanese
+- `kor` - Korean
+- `ara` - Arabic
+- `hin` - Hindi
+
+### Usage with Multiple Languages
+
+```bash
+# OCR with specific language
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'language=fr' \
+ http://localhost:5060/ocr \
+ --output french_ocr.pdf
+```
+
+
+## 🔗 Related Services
+
+Explore our ecosystem of PDF processing services built on this foundation:
+
+### [PDF Table of Contents Extractor](https://github.com/huridocs/pdf-table-of-contents-extractor)
+🔍 **Purpose**: Intelligent extraction of structured table of contents from PDF documents
+
+**Key Features**:
+- Leverages layout analysis for accurate TOC identification
+- Hierarchical structure recognition
+- Multiple output formats supported
+- Integration-ready API
+
+### [PDF Text Extraction](https://github.com/huridocs/pdf-text-extraction)
+📝 **Purpose**: Advanced text extraction with layout awareness
+
+**Key Features**:
+- Content-type aware extraction
+- Preserves document structure
+- Reading order optimization
+- Clean text output with metadata
+
+### Integration Benefits
+
+These services work seamlessly together:
+- **Shared Analysis**: Reuse layout analysis results across services
+- **Consistent Output**: Standardized JSON format for easy integration
+- **Scalable Architecture**: Deploy services independently or together
+- **Docker Ready**: All services containerized for easy deployment
+
+## 🤝 Contributing
+
+We welcome contributions to improve the PDF Document Layout Analysis service!
+
+### How to Contribute
+
+1. **Fork the Repository**
+ ```bash
+ git clone https://github.com/your-username/pdf-document-layout-analysis.git
+ ```
+
+2. **Create a Feature Branch**
+ ```bash
+ git checkout -b feature/your-feature-name
+ ```
+
+3. **Set Up Development Environment**
+ ```bash
+ make install_venv
+ make install
+ ```
+
+4. **Make Your Changes**
+ - Follow the Clean Architecture principles
+ - Add tests for new features
+ - Update documentation as needed
+
+5. **Run Tests and Quality Checks**
+ ```bash
+ make test
+ make check_format
+ ```
+
+6. **Submit a Pull Request**
+ - Provide clear description of changes
+ - Include test results
+ - Reference any related issues
+
+### Contribution Guidelines
+
+#### Code Standards
+- **Python**: Follow PEP 8 with 125-character line length
+- **Architecture**: Maintain Clean Architecture boundaries
+- **Testing**: Include unit tests for new functionality
+- **Documentation**: Update README and docstrings
+
+#### Areas for Contribution
+
+- 🐛 **Bug Fixes**: Report and fix issues
+- ✨ **New Features**: Add new endpoints or functionality
+- 📚 **Documentation**: Improve guides and examples
+- 🧪 **Testing**: Expand test coverage
+- 🚀 **Performance**: Optimize processing speed
+- 🌐 **Internationalization**: Add language support
+
+#### Development Workflow
+
+1. **Issue First**: Create or comment on relevant issues
+2. **Small PRs**: Keep pull requests focused and manageable
+3. **Clean Commits**: Use descriptive commit messages
+4. **Documentation**: Update relevant documentation
+5. **Testing**: Ensure all tests pass
+
+### Getting Help
+
+- 📚 **Documentation**: Check this README and inline docs
+- 💬 **Issues**: Search existing issues or create new ones
+- 🔍 **Code**: Explore the codebase structure
+- 📧 **Contact**: Reach out to maintainers for guidance
+
+---
-## Author
+### License
-Built by Shami (Muhammad Ahtesham Ahmad)
\ No newline at end of file
+This project is licensed under the terms specified in the [LICENSE](LICENSE) file.
diff --git a/app.py b/app.py
index 2998ab4872caf3635584faebe9824282a0522e34..b14b6ca9e0cf38627e47707b268dcde01d41c52e 100644
--- a/app.py
+++ b/app.py
@@ -3,34 +3,108 @@ import tempfile
import os
import shutil
import subprocess
+from pathlib import Path
+
+SCRIPT_DIR = Path(__file__).resolve().parent
+
+def run_cmd(cmd, cwd=None, env=None):
+ """Run a command, print nice logs, and also save them to run.log in cwd."""
+ cwd = str(cwd or os.getcwd())
+ print(f"🟦 Running: {' '.join(cmd)} (cwd={cwd})")
+ proc = subprocess.run(
+ cmd,
+ cwd=cwd,
+ env=env,
+ capture_output=True,
+ text=True
+ )
+ if proc.stdout:
+ print("🟩 STDOUT:")
+ print(proc.stdout)
+ if proc.stderr:
+ print("🟥 STDERR:")
+ print(proc.stderr)
+ # Save to run.log for debugging
+ try:
+ runlog = Path(cwd) / "run.log"
+ with open(runlog, "a", encoding="utf-8") as f:
+ f.write(f"$ {' '.join(cmd)}\n")
+ if proc.stdout:
+ f.write(proc.stdout + "\n")
+ if proc.stderr:
+ f.write(proc.stderr + "\n")
+ print(f"🧾 Run log saved to: {runlog}")
+ except Exception as e:
+ print(f"⚠️ Could not write run.log: {e}")
+
+ if proc.returncode != 0:
+ # Let Gradio see the failure so it surfaces properly
+ raise subprocess.CalledProcessError(proc.returncode, cmd, proc.stdout, proc.stderr)
+ return proc
+
+def _locate_pdf_json(temp_dir: str) -> str:
+ """
+ Your extractor writes a JSON like _comprehensive_data.json.
+ Find it (and a few common fallbacks). Raise if not found.
+ """
+ td = Path(temp_dir)
+
+ # Prefer exactly-named file if present
+ candidates = [
+ td / "pdf_data.json", # legacy name (if ever created)
+ td / "input_comprehensive_data.json", # most common from your logs
+ td / "comprehensive_data.json", # another common alias
+ td / "output.json", # generic
+ ]
+ for p in candidates:
+ if p.exists():
+ print(f"✅ Using PDF JSON: {p}")
+ return str(p)
+
+ # Generic pattern: anything *_comprehensive_data.json
+ globs = list(td.glob("*_comprehensive_data.json"))
+ if globs:
+ print(f"✅ Using PDF JSON (glob): {globs[0]}")
+ return str(globs[0])
+
+ # If still not found, surface a helpful error
+ searched = ", ".join(str(p) for p in candidates) + ", " + str(td / "*_comprehensive_data.json")
+ raise FileNotFoundError(
+ f"PDF JSON not found. Looked for: {searched}\nTemp dir: {temp_dir}"
+ )
def process_files(pdf_file, word_file):
# Create a unique temporary directory for this run
temp_dir = tempfile.mkdtemp(prefix="hf_redtext_")
+ print(f"📂 Temp dir: {temp_dir}")
# Define standard filenames for use in the pipeline
pdf_path = os.path.join(temp_dir, "input.pdf")
word_path = os.path.join(temp_dir, "input.docx")
- pdf_txt_path = os.path.join(temp_dir, "pdf_data.txt")
word_json_path = os.path.join(temp_dir, "word_data.json")
updated_json_path = os.path.join(temp_dir, "updated_word_data.json")
final_docx_path = os.path.join(temp_dir, "updated.docx")
# Copy the uploaded files to the temp directory
shutil.copy(pdf_file, pdf_path)
+ print(f"📄 PDF copied to: {pdf_path}")
shutil.copy(word_file, word_path)
+ print(f"📝 DOCX copied to: {word_path}")
+
+ # 1) PDF → JSON (extractor writes _comprehensive_data.json into cwd)
+ run_cmd(["python", str(SCRIPT_DIR / "extract_pdf_data.py"), pdf_path], cwd=temp_dir)
- # Step 1: Extract text from the PDF
- subprocess.run(["python", "extract_pdf_data.py", pdf_path, pdf_txt_path], check=True)
+ # Find the JSON produced by the extractor
+ pdf_json_path = _locate_pdf_json(temp_dir)
- # Step 2: Extract red text from the Word document
- subprocess.run(["python", "extract_red_text.py", word_path, word_json_path], check=True)
+ # 2) DOCX red text → JSON
+ run_cmd(["python", str(SCRIPT_DIR / "extract_red_text.py"), word_path, word_json_path], cwd=temp_dir)
- # Step 3: Update the Word JSON using the PDF text (calls OpenAI)
- subprocess.run(["python", "update_docx_with_pdf.py", word_json_path, pdf_txt_path, updated_json_path], check=True)
+ # 3) Merge JSON (uses the resolved pdf_json_path)
+ run_cmd(["python", str(SCRIPT_DIR / "update_docx_with_pdf.py"), word_json_path, pdf_json_path, updated_json_path], cwd=temp_dir)
- # Step 4: Apply the updated JSON to the Word doc to create the final output
- subprocess.run(["python", "updated_word.py", word_path, updated_json_path, final_docx_path], check=True)
+ # 4) Apply updates to DOCX
+ run_cmd(["python", str(SCRIPT_DIR / "updated_word.py"), word_path, updated_json_path, final_docx_path], cwd=temp_dir)
# Return the final .docx file
return final_docx_path
diff --git a/dev-requirements.txt b/dev-requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e7aad2f31f2adf9db297965c20bc76370f4df939
--- /dev/null
+++ b/dev-requirements.txt
@@ -0,0 +1,4 @@
+-r requirements.txt
+pytest==8.2.2
+black==24.4.2
+pip-upgrader==1.4.15
\ No newline at end of file
diff --git a/docker-compose-gpu.yml b/docker-compose-gpu.yml
new file mode 100755
index 0000000000000000000000000000000000000000..d60fb944f3d089d92e0f732b3f334b2f8c91bca5
--- /dev/null
+++ b/docker-compose-gpu.yml
@@ -0,0 +1,14 @@
+services:
+ pdf-document-layout-analysis-gpu:
+ extends:
+ file: docker-compose.yml
+ service: pdf-document-layout-analysis
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [ gpu ]
+ environment:
+ - RESTART_IF_NO_GPU=$RESTART_IF_NO_GPU
\ No newline at end of file
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100755
index 0000000000000000000000000000000000000000..b69105f323843ca37f2309993da2d2d862130fa9
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,11 @@
+services:
+ pdf-document-layout-analysis:
+ container_name: pdf-document-layout-analysis
+ entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
+ init: true
+ restart: unless-stopped
+ build:
+ context: .
+ dockerfile: Dockerfile
+ ports:
+ - "5060:5060"
diff --git a/extract_pdf_data.py b/extract_pdf_data.py
index 4a60260c1e4584bfe43fd2524f75740ff84e383c..a0f7e2bb1206fc97da217db01de9bb8c2b34f526 100644
--- a/extract_pdf_data.py
+++ b/extract_pdf_data.py
@@ -1,39 +1,534 @@
-import pdfplumber
-from pdf2image import convert_from_path
-import pytesseract
+#!/usr/bin/env python3
+"""
+Fixed PDF Data Extractor - Addresses key issues in comprehensive_extract.py
-def extract_pdf_full_text(pdf_path, txt_path):
- raw_texts = []
- need_ocr = []
+Key fixes:
+1. Better table extraction and cleaning
+2. Improved key-value pair extraction
+3. More robust text processing
+4. Enhanced vehicle registration extraction
+5. Better date/number pattern recognition
+"""
+
+import json
+import re
+import pandas as pd
+from typing import Dict, List, Any, Optional
+import logging
+from pathlib import Path
+import sys
+from datetime import datetime
+
+try:
+ import pdfplumber
+ HAS_PDFPLUMBER = True
+except ImportError:
+ HAS_PDFPLUMBER = False
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger("fixed_pdf_extractor")
+
+class FixedPDFExtractor:
+ def __init__(self):
+ logger.info("🚀 Initializing Fixed PDF Extractor")
+
+ def extract_everything(self, pdf_path: str) -> Dict[str, Any]:
+ if not HAS_PDFPLUMBER:
+ raise RuntimeError("pdfplumber is required. Install with: pip install pdfplumber")
+
+ logger.info(f"📖 Processing PDF: {pdf_path}")
+ result = {
+ "document_info": {
+ "filename": Path(pdf_path).name,
+ "total_pages": 0,
+ "extraction_timestamp": datetime.now().isoformat()
+ },
+ "extracted_data": {
+ "all_text_content": [],
+ "all_tables": [],
+ "key_value_pairs": {},
+ "audit_information": {},
+ "operator_information": {},
+ "vehicle_registrations": [],
+ "driver_records": [],
+ "compliance_summary": {},
+ "dates_and_numbers": {}
+ }
+ }
+
+ all_text_blocks, all_tables = [], []
+
+ with pdfplumber.open(pdf_path) as pdf:
+ result["document_info"]["total_pages"] = len(pdf.pages)
+
+ for page_num, page in enumerate(pdf.pages, 1):
+ logger.info(f"📄 Processing page {page_num}")
+
+ # Extract text with better handling
+ page_text = self._extract_page_text(page)
+ if page_text:
+ all_text_blocks.append({
+ "page": page_num,
+ "text": page_text,
+ "word_count": len(page_text.split())
+ })
+
+ # Extract tables with improved cleaning
+ tables = self._extract_page_tables(page, page_num)
+ all_tables.extend(tables)
+
+ result["extracted_data"]["all_text_content"] = all_text_blocks
+ result["extracted_data"]["all_tables"] = all_tables
+
+ # Process extracted data with improved methods
+ combined_text = "\n\n".join(b["text"] for b in all_text_blocks)
+
+ result["extracted_data"]["key_value_pairs"] = self._extract_key_value_pairs_improved(combined_text)
+ result["extracted_data"]["audit_information"] = self._extract_audit_info(combined_text, all_tables)
+ result["extracted_data"]["operator_information"] = self._extract_operator_info(combined_text, all_tables)
+ result["extracted_data"]["vehicle_registrations"] = self._extract_vehicle_registrations(all_tables)
+ result["extracted_data"]["driver_records"] = self._extract_driver_records(all_tables)
+ result["extracted_data"]["compliance_summary"] = self._extract_compliance_summary(combined_text, all_tables)
+ result["extracted_data"]["dates_and_numbers"] = self._extract_dates_and_numbers_improved(combined_text)
+
+ # Generate summary
+ result["extraction_summary"] = {
+ "text_blocks_found": len(all_text_blocks),
+ "tables_found": len(all_tables),
+ "key_value_pairs_found": len(result["extracted_data"]["key_value_pairs"]),
+ "vehicle_registrations_found": len(result["extracted_data"]["vehicle_registrations"]),
+ "driver_records_found": len(result["extracted_data"]["driver_records"]),
+ "total_characters": len(combined_text),
+ "processing_timestamp": datetime.now().isoformat()
+ }
+
+ logger.info("✅ Extraction completed!")
+ return result
+
+ def _extract_page_text(self, page) -> Optional[str]:
+ """Extract text from page with better handling"""
+ try:
+ text = page.extract_text()
+ if text:
+ # Clean up text
+ text = re.sub(r'[ \t]+', ' ', text.strip())
+ text = re.sub(r'\n\s*\n', '\n', text)
+ return text
+ except Exception as e:
+ logger.warning(f"Failed to extract text from page: {e}")
+ return None
+
+ def _extract_page_tables(self, page, page_num: int) -> List[Dict]:
+ """Extract tables with improved processing"""
+ tables = []
+ try:
+ raw_tables = page.extract_tables()
+ if raw_tables:
+ for table_idx, table in enumerate(raw_tables):
+ cleaned_table = self._clean_table_improved(table)
+ if cleaned_table and len(cleaned_table) > 0:
+ tables.append({
+ "page": page_num,
+ "table_index": table_idx + 1,
+ "headers": cleaned_table[0] if cleaned_table else [],
+ "data": cleaned_table[1:] if len(cleaned_table) > 1 else [],
+ "raw_data": cleaned_table,
+ "row_count": len(cleaned_table) - 1 if len(cleaned_table) > 1 else 0,
+ "column_count": len(cleaned_table[0]) if cleaned_table else 0
+ })
+ except Exception as e:
+ logger.warning(f"Failed to extract tables from page {page_num}: {e}")
+
+ return tables
+
+ def _clean_table_improved(self, table: List[List]) -> List[List[str]]:
+ """Improved table cleaning with better cell processing"""
+ if not table:
+ return []
+
+ cleaned = []
+ for row in table:
+ cleaned_row = []
+ for cell in row:
+ if cell is None:
+ cleaned_cell = ""
+ else:
+ cleaned_cell = str(cell).strip()
+ cleaned_cell = re.sub(r'\s+', ' ', cleaned_cell)
+ cleaned_cell = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', cleaned_cell)
+ cleaned_row.append(cleaned_cell)
+ if any(cell.strip() for cell in cleaned_row):
+ cleaned.append(cleaned_row)
+
+ # Optional: collapse single-column tables of empty strings
+ if cleaned and all(len(r) == len(cleaned[0]) for r in cleaned):
+ return cleaned
+ return cleaned
+
+ def _extract_key_value_pairs_improved(self, text: str) -> Dict[str, str]:
+ """Improved key-value pair extraction with better cleaning"""
+ pairs: Dict[str, str] = {}
+
+ # Normalize text a bit for regex stability
+ t = text.replace('\r', '\n')
+
+ # Pattern 1: colon-separated pairs (key: value)
+ pattern1 = re.compile(
+ r'([A-Za-z][\w\s()/\-.]{2,80}?):\s*([^\n\r:][^\n\r]*)'
+ )
+ for key, val in pattern1.findall(t):
+ k = key.strip()
+ v = val.strip()
+ # Filter junk: very long values, pure separators, or obvious headers
+ if not v or len(v) > 200:
+ continue
+ if re.fullmatch(r'[-_/\.]+', v):
+ continue
+ # Avoid capturing the next key as value by trimming trailing key-like tokens
+ v = re.sub(r'\s+[A-Z][\w\s()/\-.]{2,40}:$', '', v).strip()
+ # Skip values that are just long digit runs (likely id lists without meaning)
+ if re.fullmatch(r'\d{6,}', v):
+ continue
+ pairs[k] = v
+
+ # Pattern 2: inline “Key – Value” or “Key — Value”
+ pattern2 = re.compile(r'([A-Za-z][\w\s()/\-.]{2,80}?)\s*[–—-]\s*([^\n\r]+)')
+ for key, val in pattern2.findall(t):
+ k = key.strip()
+ v = val.strip()
+ if v and len(v) <= 200 and not re.fullmatch(r'\d{6,}', v):
+ pairs.setdefault(k, v)
+
+ return pairs
+
+ def _extract_audit_info(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract audit-specific information with better filtering"""
+ audit_info: Dict[str, Any] = {}
+
+ # Prefer tables
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+ joined = ' '.join(headers)
+ if "audit information" in joined or "auditinformation" in joined:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2 and row[0] and row[1]:
+ key = str(row[0]).strip()
+ value = str(row[1]).strip()
+ # Skip numbered list rows (e.g., "1.", "2)")
+ if re.match(r'^\s*\d+\s*[.)]\s*$', key):
+ continue
+ if key and value:
+ audit_info[key] = value
+
+ # Backup from text
+ candidates = {
+ "Date of Audit": r'Date\s+of\s+Audit[:\s]*([^\n\r]+)',
+ "Location of audit": r'Location\s+of\s+audit[:\s]*([^\n\r]+)',
+ "Auditor name": r'Auditor\s+name[:\s]*([^\n\r]+)',
+ "Audit Matrix Identifier (Name or Number)": r'Audit\s+Matrix\s+Identifier.*?[:\s]*([^\n\r]+)',
+ }
+ for k, pat in candidates.items():
+ if k not in audit_info:
+ m = re.search(pat, text, re.IGNORECASE)
+ if m:
+ audit_info[k] = m.group(1).strip()
+
+ return audit_info
+
+ def _extract_operator_info(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract operator information with better table parsing"""
+ operator_info: Dict[str, Any] = {}
+
+ # Look for operator information in tables first
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+ if ("operatorinformation" in ' '.join(headers) or
+ "operator information" in ' '.join(headers) or
+ "operatorcontactdetails" in ' '.join(headers)):
+
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2 and row[0] and row[1]:
+ key = str(row[0]).strip()
+ value = str(row[1]).strip()
+ if key and value:
+ # Clean up key names
+ kl = key.lower()
+ if "operator name" in kl:
+ operator_info["operator_name"] = value
+ elif "trading name" in kl:
+ operator_info["trading_name"] = value
+ elif "company number" in kl:
+ if len(row) > 2:
+ company_parts = [str(r).strip() for r in row[1:] if str(r).strip()]
+ operator_info["company_number"] = "".join(company_parts)
+ else:
+ operator_info["company_number"] = value
+ elif "business address" in kl:
+ operator_info["business_address"] = value
+ elif "postal address" in kl:
+ operator_info["postal_address"] = value
+ elif "email" in kl:
+ operator_info["email"] = value
+ elif "telephone" in kl or "phone" in kl:
+ operator_info["phone"] = value
+ elif "nhvas accreditation" in kl:
+ operator_info["nhvas_accreditation"] = value
+ elif "nhvas manual" in kl:
+ operator_info["nhvas_manual"] = value
+
+ # Extract from text patterns as backup
+ patterns = {
+ 'operator_name': r'Operator\s*name[:\s\(]*([^\n\r\)]+?)(?=\s*NHVAS|\s*Registered|$)',
+ 'trading_name': r'Registered\s*trading\s*name[:\s\/]*([^\n\r]+?)(?=\s*Australian|$)',
+ 'company_number': r'Australian\s*Company\s*Number[:\s]*([0-9\s]+?)(?=\s*NHVAS|$)',
+ 'business_address': r'Operator\s*business\s*address[:\s]*([^\n\r]+?)(?=\s*Operator\s*Postal|$)',
+ 'postal_address': r'Operator\s*Postal\s*address[:\s]*([^\n\r]+?)(?=\s*Email|$)',
+ 'email': r'Email\s*address[:\s]*([^\s\n\r]+)',
+ 'phone': r'Operator\s*Telephone\s*Number[:\s]*([^\s\n\r]+)',
+ 'nhvas_accreditation': r'NHVAS\s*Accreditation\s*No\.[:\s\(]*([^\n\r\)]+)',
+ }
+
+ for key, pattern in patterns.items():
+ if key not in operator_info: # Only use text if not found in tables
+ match = re.search(pattern, text, re.IGNORECASE)
+ if match:
+ value = match.group(1).strip()
+ if value and len(value) < 200:
+ if key == 'company_number':
+ value = re.sub(r'\s+', '', value)
+ operator_info[key] = value
+
+ return operator_info
+
+ def _extract_vehicle_registrations(self, tables: List[Dict]) -> List[Dict]:
+ """Extract vehicle registration information from tables"""
+ vehicles: List[Dict[str, Any]] = []
+
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ # Look for vehicle registration tables
+ if any(keyword in ' '.join(headers) for keyword in ['registration', 'vehicle', 'number']):
+ reg_col = None
+ for i, header in enumerate(headers):
+ if 'registration' in header and 'number' in header:
+ reg_col = i
+ break
+
+ if reg_col is not None:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) > reg_col and row[reg_col]:
+ reg_num = str(row[reg_col]).strip()
+ # Validate registration format (letters/numbers)
+ if re.match(r'^[A-Z]{1,3}\s*\d{1,3}\s*[A-Z]{0,3}$', reg_num):
+ vehicle_info = {"registration_number": reg_num}
+
+ # Add other columns as additional info
+ for i, header in enumerate(table.get("headers", [])):
+ if i < len(row) and i != reg_col:
+ vehicle_info[str(header)] = str(row[i]).strip()
+
+ vehicles.append(vehicle_info)
+
+ return vehicles
+
+ def _extract_driver_records(self, tables: List[Dict]) -> List[Dict]:
+ """Extract driver records from tables"""
+ drivers: List[Dict[str, Any]] = []
+
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ # Look for driver/scheduler tables
+ if any(keyword in ' '.join(headers) for keyword in ['driver', 'scheduler', 'name']):
+ name_col = None
+ for i, header in enumerate(headers):
+ if 'name' in header:
+ name_col = i
+ break
+
+ if name_col is not None:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) > name_col and row[name_col]:
+ name = str(row[name_col]).strip()
+ # Basic name validation
+ if re.match(r'^[A-Za-z\s]{2,}$', name) and len(name.split()) >= 2:
+ driver_info = {"name": name}
+
+ # Add other columns
+ for i, header in enumerate(table.get("headers", [])):
+ if i < len(row) and i != name_col:
+ driver_info[str(header)] = str(row[i]).strip()
+
+ drivers.append(driver_info)
+
+ return drivers
+
+ def _extract_compliance_summary(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract compliance information"""
+ compliance = {
+ "standards_compliance": {},
+ "compliance_codes": {},
+ "audit_results": []
+ }
+
+ # Look for compliance tables
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ if any(keyword in ' '.join(headers) for keyword in ['compliance', 'standard', 'requirement']):
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2:
+ standard = str(row[0]).strip()
+ code = str(row[1]).strip()
+ if standard.startswith('Std') and code in ['V', 'NC', 'SFI', 'NAP', 'NA']:
+ compliance["standards_compliance"][standard] = code
+
+ # Extract compliance codes definitions
+ code_patterns = {
+ 'V': r'\bV\b\s+([^\n\r]+)',
+ 'NC': r'\bNC\b\s+([^\n\r]+)',
+ 'SFI': r'\bSFI\b\s+([^\n\r]+)',
+ 'NAP': r'\bNAP\b\s+([^\n\r]+)',
+ 'NA': r'\bNA\b\s+([^\n\r]+)',
+ }
+
+ for code, pattern in code_patterns.items():
+ match = re.search(pattern, text, re.IGNORECASE)
+ if match:
+ compliance["compliance_codes"][code] = match.group(1).strip()
+
+ return compliance
+
+ def _extract_dates_and_numbers_improved(self, text: str) -> Dict[str, Any]:
+ """Improved date and number extraction"""
+ result = {
+ "dates": [],
+ "registration_numbers": [],
+ "phone_numbers": [],
+ "email_addresses": [],
+ "reference_numbers": []
+ }
+
+ # Date patterns
+ date_patterns = [
+ r'\b(\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})\b',
+ r'\b(\d{1,2}/\d{1,2}/\d{4})\b',
+ r'\b(\d{1,2}-\d{1,2}-\d{4})\b',
+ r'\b(\d{1,2}\.\d{1,2}\.\d{4})\b',
+ ]
+ for pattern in date_patterns:
+ result["dates"].extend(re.findall(pattern, text))
+
+ # Registration numbers (Australian format-ish)
+ reg_pattern = r'\b([A-Z]{1,3}\s*\d{1,3}\s*[A-Z]{0,3})\b'
+ result["registration_numbers"] = list(set(re.findall(reg_pattern, text)))
+
+ # Phone numbers (AU)
+ phone_pattern = r'\b((?:\+61|0)[2-9]\s?\d{4}\s?\d{4})\b'
+ result["phone_numbers"] = list(set(re.findall(phone_pattern, text)))
+
+ # Email addresses
+ email_pattern = r'\b([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})\b'
+ result["email_addresses"] = list(set(re.findall(email_pattern, text)))
+
+ # Reference numbers
+ ref_patterns = [
+ (r'RF(?:S)?\s*#?\s*(\d+)', 'RFS_Certifications'),
+ (r'NHVAS\s+Accreditation\s+No\.?\s*(\d+)', 'NHVAS_Numbers'),
+ (r'Registration\s+Number\s*#?\s*(\d+)', 'Registration_Numbers'),
+ ]
+ for pattern, key in ref_patterns:
+ matches = re.findall(pattern, text, re.IGNORECASE)
+ if matches:
+ result["reference_numbers"].extend([f"{key}: {m}" for m in matches])
+
+ return result
+
+ @staticmethod
+ def save_results(results: Dict[str, Any], output_path: str):
+ """Save results to JSON file"""
+ try:
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+ logger.info(f"💾 Results saved to {output_path}")
+ except Exception as e:
+ logger.error(f"Failed to save results: {e}")
+
+ @staticmethod
+ def export_to_excel(results: Dict[str, Any], excel_path: str):
+ """Export results to Excel with improved formatting"""
+ try:
+ with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
+ # Summary sheet
+ summary_data = []
+ extraction_summary = results.get("extraction_summary", {})
+ for key, value in extraction_summary.items():
+ summary_data.append({"Metric": key.replace("_", " ").title(), "Value": value})
+ pd.DataFrame(summary_data).to_excel(writer, sheet_name='Summary', index=False)
+
+ # Key-value pairs
+ kv_pairs = results.get("extracted_data", {}).get("key_value_pairs", {})
+ if kv_pairs:
+ kv_df = pd.DataFrame(list(kv_pairs.items()), columns=['Key', 'Value'])
+ kv_df.to_excel(writer, sheet_name='Key_Value_Pairs', index=False)
+
+ # Vehicle registrations
+ vehicles = results.get("extracted_data", {}).get("vehicle_registrations", [])
+ if vehicles:
+ pd.DataFrame(vehicles).to_excel(writer, sheet_name='Vehicle_Registrations', index=False)
+
+ # Driver records
+ drivers = results.get("extracted_data", {}).get("driver_records", [])
+ if drivers:
+ pd.DataFrame(drivers).to_excel(writer, sheet_name='Driver_Records', index=False)
+
+ # Compliance summary
+ compliance = results.get("extracted_data", {}).get("compliance_summary", {})
+ if compliance.get("standards_compliance"):
+ comp_df = pd.DataFrame(list(compliance["standards_compliance"].items()),
+ columns=['Standard', 'Compliance_Code'])
+ comp_df.to_excel(writer, sheet_name='Compliance_Standards', index=False)
+
+ logger.info(f"📊 Results exported to Excel: {excel_path}")
+ except Exception as e:
+ logger.error(f"Failed to export to Excel: {e}")
+
+def main():
+ if len(sys.argv) < 2:
+ print("Usage: python fixed_pdf_extractor.py ")
+ sys.exit(1)
+
+ pdf_path = Path(sys.argv[1])
+ if not pdf_path.exists():
+ print(f"❌ PDF not found: {pdf_path}")
+ sys.exit(1)
+
+ print("🚀 Fixed PDF Data Extractor")
+ print("=" * 50)
+
+ extractor = FixedPDFExtractor()
+ results = extractor.extract_everything(str(pdf_path))
+
+ base = pdf_path.stem
+ output_dir = pdf_path.parent
- # Step 1: Try to extract RAW text, record which pages need OCR
- with pdfplumber.open(pdf_path) as pdf:
- for i, page in enumerate(pdf.pages):
- print(f"Extracting text from page {i+1}...")
- text = page.extract_text() or ""
- if text.strip():
- raw_texts.append(f"\n--- PAGE {i+1} RAW TEXT ---\n{text.strip()}")
- else:
- raw_texts.append(None)
- # Mark that we need OCR for this page
- need_ocr.append(i)
+ # Save outputs
+ json_path = output_dir / f"{base}_comprehensive_data.json"
+ excel_path = output_dir / f"{base}_fixed_extraction.xlsx"
- # Step 2: OCR only those pages with no RAW text
- print("Running OCR where RAW text is missing...")
- images = convert_from_path(pdf_path, dpi=300)
- for idx in need_ocr:
- ocr_text = pytesseract.image_to_string(images[idx])
- raw_texts[idx] = f"\n--- PAGE {idx+1} OCR TEXT ---\n{ocr_text.strip()}"
+ FixedPDFExtractor.save_results(results, str(json_path))
+ FixedPDFExtractor.export_to_excel(results, str(excel_path))
- # Step 3: Save to file (skip any leftover Nones, but there shouldn't be any)
- result = [txt for txt in raw_texts if txt]
- with open(txt_path, "w", encoding="utf-8") as f:
- f.write("\n".join(result))
- print(f"✅ Saved deduped full text to {txt_path}")
+ print("\n💾 OUTPUT FILES:")
+ print(f" 📄 JSON Data: {json_path}")
+ print(f" 📊 Excel Data: {excel_path}")
+ print(f"\n✨ FIXED EXTRACTION COMPLETE!")
if __name__ == "__main__":
- import sys
- # Usage: python extract_pdf_data.py input.pdf output.txt
- input_pdf = sys.argv[1]
- output_txt = sys.argv[2]
- extract_pdf_full_text(input_pdf, output_txt)
\ No newline at end of file
+ main()
diff --git a/extract_red_text.py b/extract_red_text.py
index f74738a0b5fa7f703281a246f240e69155c3b1ec..584b7876db68ef237dcc238cd8145d50b942e541 100644
--- a/extract_red_text.py
+++ b/extract_red_text.py
@@ -6,6 +6,139 @@ from docx import Document
from docx.oxml.ns import qn
from master_key import TABLE_SCHEMAS, HEADING_PATTERNS, PARAGRAPH_PATTERNS
+def normalize_header_label(s: str) -> str:
+ """Normalize a header/label by stripping parentheticals & punctuation."""
+ s = re.sub(r"\s+", " ", s.strip())
+ # remove content in parentheses/brackets
+ s = re.sub(r"\([^)]*\)", "", s)
+ s = re.sub(r"\[[^]]*\]", "", s)
+ # unify slashes and hyphens, collapse spaces
+ s = s.replace("–", "-").replace("—", "-").replace("/", " / ").replace(" ", " ")
+ return s.strip()
+
+# Canonical label aliases for Vehicle/Maintenance/General headers
+LABEL_ALIASES = {
+ # Vehicle Registration (Maintenance)
+ "roadworthiness certificates": "Roadworthiness Certificates",
+ "maintenance records": "Maintenance Records",
+ "daily checks": "Daily Checks",
+ "fault recording / reporting": "Fault Recording/ Reporting",
+ "fault repair": "Fault Repair",
+
+ # Vehicle Registration (Mass)
+ "sub contracted vehicles statement of compliance": "Sub-contracted Vehicles Statement of Compliance",
+ "weight verification records": "Weight Verification Records",
+ "rfs suspension certification #": "RFS Suspension Certification #",
+ "suspension system maintenance": "Suspension System Maintenance",
+ "trip records": "Trip Records",
+ "fault recording/ reporting on suspension system": "Fault Recording/ Reporting on Suspension System",
+
+ # Common
+ "registration number": "Registration Number",
+ "no.": "No.",
+ "sub contractor": "Sub contractor",
+ "sub-contractor": "Sub contractor",
+}
+
+def looks_like_operator_declaration(context):
+ """True iff heading says Operator Declaration and headers include Print Name + Position Title."""
+ heading = (context.get("heading") or "").strip().lower()
+ headers = " ".join(context.get("headers") or []).lower()
+ return (
+ "operator declaration" in heading
+ and "print name" in headers
+ and "position" in headers
+ and "title" in headers
+ )
+
+def looks_like_auditor_declaration(context):
+ heading = (context.get("heading") or "").strip().lower()
+ headers = " ".join(context.get("headers") or []).lower()
+ return (
+ "auditor declaration" in heading
+ and "print name" in headers
+ and ("nhvr" in headers or "auditor registration number" in headers)
+ )
+
+# --- NEW: header-only fallback that ignores headings and just keys on the two column names
+def extract_operator_declaration_by_headers_from_end(doc):
+ """
+ Scan tables from the end; if a table's first row contains both
+ 'Print Name' AND 'Position Title' (case-insensitive), extract red text
+ from the data rows into:
+ {"Print Name": [...], "Position Title": [...]}
+ """
+ for tbl in reversed(doc.tables):
+ if len(tbl.rows) < 2:
+ continue # need header + at least one data row
+
+ headers_norm = [normalize_header_label(c.text).lower() for c in tbl.rows[0].cells]
+ has_print = any("print name" in h for h in headers_norm)
+ has_pos_tit = any(("position title" in h) or ("position" in h and "title" in h) for h in headers_norm)
+ if not (has_print and has_pos_tit):
+ continue
+
+ idx_print = next((i for i, h in enumerate(headers_norm) if "print name" in h), None)
+ idx_pos = next((i for i, h in enumerate(headers_norm) if "position title" in h), None)
+ if idx_pos is None:
+ idx_pos = next((i for i, h in enumerate(headers_norm) if ("position" in h and "title" in h)), None)
+
+ result = {"Print Name": [], "Position Title": []}
+ for row in tbl.rows[1:]:
+ if idx_print is not None and idx_print < len(row.cells):
+ cell = row.cells[idx_print]
+ reds = [r.text for p in cell.paragraphs for r in p.runs if is_red_font(r) and r.text]
+ reds = coalesce_numeric_runs(reds)
+ txt = normalize_text(" ".join(reds))
+ if txt:
+ result["Print Name"].append(txt)
+
+ if idx_pos is not None and idx_pos < len(row.cells):
+ cell = row.cells[idx_pos]
+ reds = [r.text for p in cell.paragraphs for r in p.runs if is_red_font(r) and r.text]
+ reds = coalesce_numeric_runs(reds)
+ txt = normalize_text(" ".join(reds))
+ if txt:
+ result["Position Title"].append(txt)
+
+ if result["Print Name"] or result["Position Title"]:
+ return {k: v for k, v in result.items() if v}
+
+ return None
+# --- end NEW helper
+
+def canonicalize_label(s: str) -> str:
+ key = normalize_header_label(s).lower()
+ key = re.sub(r"\s+", " ", key)
+ return LABEL_ALIASES.get(key, s)
+
+def bag_similarity(a: str, b: str) -> float:
+ """Loose bag-of-words similarity for header↔label matching."""
+ aw = {w for w in re.split(r"[^A-Za-z0-9#]+", normalize_header_label(a).lower()) if len(w) > 2 or w in {"#","no"}}
+ bw = {w for w in re.split(r"[^A-Za-z0-9#]+", normalize_header_label(b).lower()) if len(w) > 2 or w in {"#","no"}}
+ if not aw or not bw:
+ return 0.0
+ inter = len(aw & bw)
+ return inter / max(len(aw), len(bw))
+
+def coalesce_numeric_runs(text_list):
+ """
+ If a cell yields ['4','5','6','9','8','7','1','2','3'] etc., join continuous single-char digit runs.
+ Returns ['456987123'] instead of many singles. Non-digit tokens are preserved.
+ """
+ out, buf = [], []
+ for t in text_list:
+ if len(t) == 1 and t.isdigit():
+ buf.append(t)
+ else:
+ if buf:
+ out.append("".join(buf))
+ buf = []
+ out.append(t)
+ if buf:
+ out.append("".join(buf))
+ return out
+
def is_red_font(run):
"""Enhanced red font detection with better color checking"""
col = run.font.color
@@ -76,7 +209,6 @@ def calculate_schema_match_score(schema_name, spec, context):
if "Vehicle Registration" in schema_name:
vehicle_keywords = ["registration", "vehicle", "sub-contractor", "weight verification", "rfs suspension"]
table_text = " ".join(context['headers']).lower() + " " + context['heading'].lower()
-
keyword_matches = sum(1 for keyword in vehicle_keywords if keyword in table_text)
if keyword_matches >= 2:
score += 150 # Very high boost for vehicle tables
@@ -157,15 +289,12 @@ def calculate_schema_match_score(schema_name, spec, context):
labels = [normalize_text(lbl) for lbl in spec["labels"]]
matches = 0
for lbl in labels:
- # More flexible matching for vehicle tables
if any(lbl.upper() in h.upper() or h.upper() in lbl.upper() for h in context['headers']):
matches += 1
- # Also check for partial keyword matches
elif any(word.upper() in " ".join(context['headers']).upper() for word in lbl.split() if len(word) > 3):
matches += 0.5 # Partial credit
-
if matches > 0:
- score += (matches / len(labels)) * 40 # Higher weight for row1 tables
+ score += (matches / len(labels)) * 40
reasons.append(f"Row1 orientation header matches: {matches}/{len(labels)}")
# Special handling for Declaration tables (existing logic)
@@ -187,6 +316,16 @@ def calculate_schema_match_score(schema_name, spec, context):
def match_table_schema(tbl):
"""Improved table schema matching with scoring system"""
context = get_table_context(tbl)
+ # Auditor Declaration first
+ if ("print name" in " ".join(context.get("headers", [])).lower() and
+ "auditor" in " ".join(context.get("headers", [])).lower()):
+ return "NHVAS Approved Auditor Declaration"
+ # NEW: prioritize Auditor Declaration to avoid misclassification
+ if looks_like_auditor_declaration(context):
+ return "NHVAS Approved Auditor Declaration"
+ # hard-match Operator Declaration first (high priority, avoids misclassification)
+ if looks_like_operator_declaration(context):
+ return "Operator Declaration"
best_match = None
best_score = 0
for name, spec in TABLE_SCHEMAS.items():
@@ -245,102 +384,256 @@ def extract_multi_schema_table(tbl, schemas):
return result
def extract_table_data(tbl, schema_name, spec):
- """Extract red text data from table based on schema - ENHANCED for Vehicle Registration"""
-
- # 🎯 SPECIAL HANDLING for Vehicle Registration tables
+ """Extract red text data from table based on schema – per-row repeats for specific tables."""
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # OPERATOR DECLARATION (row1 headers: Print Name | Position Title)
+ # ───────────────────────────────────────────────────────────────────────────
+ if schema_name == "Operator Declaration":
+ print(f" 🧾 EXTRACTION FIX: Processing Operator Declaration table")
+
+ labels = spec["labels"] # ["Print Name", "Position Title"]
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels}
+
+ if len(tbl.rows) < 2:
+ print(f" ❌ Operator Declaration table has less than 2 rows")
+ return {}
+
+ # map header cells → labels (row1 orientation)
+ header_row = tbl.rows[0]
+ column_mapping = {}
+ print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
+
+ for col_idx, cell in enumerate(header_row.cells):
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
+ if not header_text:
+ continue
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # alias/canonical first
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # else bag-of-words similarity
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
+ else:
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+
+ print(f" 📊 Total column mappings: {len(column_mapping)}")
+
+ # collect red text from the (usually single) data row
+ for row_idx in range(1, len(tbl.rows)):
+ row = tbl.rows[row_idx]
+ print(f" 📌 Processing data row {row_idx}")
+ for col_idx, cell in enumerate(row.cells):
+ if col_idx not in column_mapping:
+ continue
+ label = column_mapping[col_idx]
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt)
+
+ result = {k: v for k, v in collected.items() if v}
+ print(f" ✅ Operator Declaration extracted: {len(result)} columns with data")
+ return result
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # A) Vehicle Registration tables (per-row accumulation; NO dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
if "Vehicle Registration" in schema_name:
print(f" 🚗 EXTRACTION FIX: Processing Vehicle Registration table")
-
+
labels = spec["labels"]
- collected = {lbl: [] for lbl in labels}
- seen = {lbl: set() for lbl in labels}
-
- # For Vehicle Registration, orientation is "row1" - headers in first row
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels} # ← keep every row value
+ unmapped_bucket = {}
+
if len(tbl.rows) < 2:
print(f" ❌ Vehicle table has less than 2 rows")
return {}
-
- # Map header cells to labels
+
header_row = tbl.rows[0]
column_mapping = {}
-
print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
-
+
for col_idx, cell in enumerate(header_row.cells):
- header_text = normalize_text(cell.text).strip()
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
if not header_text:
continue
-
- print(f" Column {col_idx}: '{header_text}'")
-
- # Find best matching label
- best_match = None
- best_score = 0
-
- for label in labels:
- # Direct match
- if header_text.upper() == label.upper():
- best_match = label
- best_score = 1.0
- break
-
- # Partial keyword matching
- header_words = set(word.upper() for word in header_text.split() if len(word) > 2)
- label_words = set(word.upper() for word in label.split() if len(word) > 2)
-
- if header_words and label_words:
- common_words = header_words.intersection(label_words)
- if common_words:
- score = len(common_words) / max(len(header_words), len(label_words))
- if score > best_score and score >= 0.4: # Lower threshold for vehicle tables
- best_score = score
- best_match = label
-
- if best_match:
- column_mapping[col_idx] = best_match
- print(f" ✅ Mapped to: '{best_match}' (score: {best_score:.2f})")
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # Try alias/canonical first
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # Else bag-of-words similarity
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
else:
- print(f" ⚠️ No mapping found for '{header_text}'")
-
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+ unmapped_bucket[raw_h] = []
+
print(f" 📊 Total column mappings: {len(column_mapping)}")
-
- # Extract red text from data rows (skip header)
+
+ header_texts = [normalize_text(hc.text) for hc in header_row.cells]
for row_idx in range(1, len(tbl.rows)):
row = tbl.rows[row_idx]
print(f" 📌 Processing data row {row_idx}")
-
for col_idx, cell in enumerate(row.cells):
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+
if col_idx in column_mapping:
label = column_mapping[col_idx]
-
- # Extract red text
- red_txt = "".join(run.text for p in cell.paragraphs for run in p.runs if is_red_font(run)).strip()
-
- if red_txt:
- print(f" 🔴 Found red text in '{label}': '{red_txt}'")
-
- if red_txt not in seen[label]:
- seen[label].add(red_txt)
- collected[label].append(red_txt)
-
- # Return only non-empty collections
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt) # ← append every occurrence
+ else:
+ header_name = header_texts[col_idx] if col_idx < len(header_texts) else f"(unmapped col {col_idx})"
+ unmapped_bucket.setdefault(header_name, []).append(red_txt)
+
result = {k: v for k, v in collected.items() if v}
+ if unmapped_bucket:
+ result.update({f"UNMAPPED::{k}": v for k, v in unmapped_bucket.items() if v})
print(f" ✅ Vehicle Registration extracted: {len(result)} columns with data")
return result
-
- # 🎯 ORIGINAL CODE for all other tables (unchanged)
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # B) Driver / Scheduler Records Examined (per-row accumulation; NO dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
+ if "Driver / Scheduler" in schema_name:
+ print(f" 👤 EXTRACTION FIX: Processing Driver / Scheduler table")
+
+ labels = spec["labels"]
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels} # ← keep every row value
+ unmapped_bucket = {}
+
+ if len(tbl.rows) < 2:
+ print(f" ❌ Driver/Scheduler table has less than 2 rows")
+ return {}
+
+ header_row = tbl.rows[0]
+ column_mapping = {}
+ print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
+
+ for col_idx, cell in enumerate(header_row.cells):
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
+ if not header_text:
+ continue
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # Try alias/canonical first (rarely used here, but safe)
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # Else bag-of-words similarity (good for long headings)
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
+ else:
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+ unmapped_bucket[raw_h] = []
+
+ print(f" 📊 Total column mappings: {len(column_mapping)}")
+
+ header_texts = [normalize_text(hc.text) for hc in header_row.cells]
+ for row_idx in range(1, len(tbl.rows)):
+ row = tbl.rows[row_idx]
+ print(f" 📌 Processing data row {row_idx}")
+ for col_idx, cell in enumerate(row.cells):
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+
+ if col_idx in column_mapping:
+ label = column_mapping[col_idx]
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt) # ← append every occurrence
+ else:
+ header_name = header_texts[col_idx] if col_idx < len(header_texts) else f"(unmapped col {col_idx})"
+ unmapped_bucket.setdefault(header_name, []).append(red_txt)
+
+ result = {k: v for k, v in collected.items() if v}
+ if unmapped_bucket:
+ result.update({f"UNMAPPED::{k}": v for k, v in unmapped_bucket.items() if v})
+ print(f" ✅ Driver / Scheduler extracted: {len(result)} columns with data")
+ return result
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # C) Generic tables (unchanged: WITH dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
labels = spec["labels"] + [schema_name]
collected = {lbl: [] for lbl in labels}
seen = {lbl: set() for lbl in labels}
- by_col = (spec["orientation"] == "row1")
+ by_col = (spec.get("orientation") == "row1")
start_row = 1 if by_col else 0
rows = tbl.rows[start_row:]
-
+
for ri, row in enumerate(rows):
for ci, cell in enumerate(row.cells):
- red_txt = "".join(run.text for p in cell.paragraphs for run in p.runs if is_red_font(run)).strip()
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
if not red_txt:
continue
+
if by_col:
if ci < len(spec["labels"]):
lbl = spec["labels"][ci]
@@ -354,17 +647,19 @@ def extract_table_data(tbl, schema_name, spec):
lbl = spec_label
break
if not lbl:
+ a_raw = normalize_header_label(raw_label).upper()
for spec_label in spec["labels"]:
- spec_norm = normalize_text(spec_label).upper()
- raw_norm = raw_label.upper()
- if spec_norm in raw_norm or raw_norm in spec_norm:
+ a_spec = normalize_header_label(spec_label).upper()
+ if a_spec in a_raw or a_raw in a_spec:
lbl = spec_label
break
if not lbl:
lbl = schema_name
+
if red_txt not in seen[lbl]:
seen[lbl].add(red_txt)
collected[lbl].append(red_txt)
+
return {k: v for k, v in collected.items() if v}
def extract_red_text(input_doc):
@@ -405,6 +700,8 @@ def extract_red_text(input_doc):
out[schema][k] = v
else:
out[schema] = data
+
+ # paragraphs (FIX: do not return early; build full 'paras' then attach)
paras = {}
for idx, para in enumerate(doc.paragraphs):
red_txt = "".join(r.text for r in para.runs if is_red_font(r)).strip()
@@ -423,8 +720,16 @@ def extract_red_text(input_doc):
if not context:
context = "(para)"
paras.setdefault(context, []).append(red_txt)
+
if paras:
out["paragraphs"] = paras
+
+ # Fallback: ensure we capture the last-page Operator Declaration by headers
+ if "Operator Declaration" not in out:
+ op_dec = extract_operator_declaration_by_headers_from_end(doc)
+ if op_dec:
+ out["Operator Declaration"] = op_dec
+
return out
def extract_red_text_filelike(input_file, output_file):
diff --git a/fine_tuning_lightgbm_models.ipynb b/fine_tuning_lightgbm_models.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..69ce9081a83d62a9de4adfe2b769c9a90184f6fc
--- /dev/null
+++ b/fine_tuning_lightgbm_models.ipynb
@@ -0,0 +1,961 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d15dad13-9732-4e4c-bbd1-1a33545a4293",
+ "metadata": {},
+ "source": [
+ "## Overview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f857fc7-d7fb-4b05-a242-de31fb1f086d",
+ "metadata": {},
+ "source": [
+ "In this notebook, we'll go through the process of fine-tuning the LightGBM models in the `pdf-document-layout-analysis` service."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0c96b645-eef0-47a2-8c4f-284cdc05e76d",
+ "metadata": {},
+ "source": [
+ "But before doing that, let's start with some basic concepts and introduce modules and methods to make the process easier and cleaner."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f1e5c19b-1920-4f2c-9994-943626cd8a58",
+ "metadata": {},
+ "source": [
+ "To begin with, you should first ensure that `Poppler` is installed on your system. We will use it to process PDFs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "5f198930-caf1-4cb4-bb1e-8ca063ad8587",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pdftohtml is already installed.\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%bash\n",
+ "\n",
+ "if ! command -v pdftohtml &> /dev/null\n",
+ "then\n",
+ " echo \"pdftohtml is not installed. Installing now...\"\n",
+ " sudo apt install pdftohtml\n",
+ "else\n",
+ " echo \"pdftohtml is already installed.\"\n",
+ "fi"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5d971faa-e9a8-47d6-8c02-66be6f3a3c6c",
+ "metadata": {},
+ "source": [
+ "We use Poppler to convert PDFs to XMLs. To work with Poppler in Python, we have created `PdfFeatures` module, which can be found in `pdf_features/PdfFeatures.py`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "f7ac5d42-fb70-4476-8e05-b159f18ae3dd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pdf_features.PdfFeatures import PdfFeatures"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e45522eb-6879-472a-a822-64b38041ccc3",
+ "metadata": {},
+ "source": [
+ "To open a PDF file with PdfFeatures module, simply write:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e4ac53e5-b249-4dcd-beeb-e3009e17079b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Page-1\n",
+ "Page-2\n"
+ ]
+ }
+ ],
+ "source": [
+ "pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(\"test_pdfs/regular.pdf\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c7c6241-9016-4416-a53e-644145f9063a",
+ "metadata": {},
+ "source": [
+ "When you open `pdf_features` like this, the XML file is saved in a temporary path and handled on the fly.\n",
+ "\n",
+ "If you want to save the XML file, you should provide a path where it can be saved:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "eb1056ee-2e45-4b12-b2bc-8d23553c2143",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Page-1\n",
+ "Page-2\n"
+ ]
+ }
+ ],
+ "source": [
+ "pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(\"test_pdfs/regular.pdf\", \"test_pdfs/regular.xml\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "703ec555-c3a5-4e7e-a6dd-886be67cb6de",
+ "metadata": {},
+ "source": [
+ "Here is a part of the XML to illustrate what it looks like:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5b6fcebd-f91b-43fe-b2d6-b9956c3fd173",
+ "metadata": {},
+ "source": [
+ "```\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\t\n",
+ "\t\n",
+ "RESOLUCIÓN DE LA \n",
+ "CORTE INTERAMERICANA DE DERECHOS HUMANOS \n",
+ "DEL 29 DE JULIO DE 1991 \n",
+ " \n",
+ " \n",
+ "MEDIDAS PROVISIONALES SOLICITADAS POR LA COMISIÓN \n",
+ "INTERAMERICANA DE DERECHOS HUMANOS \n",
+ "RESPECTO DE GUATEMALA \n",
+ "\n",
+ "...\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4be01120-c4ce-4e09-bc10-64b1742c9b0b",
+ "metadata": {},
+ "source": [
+ "When we convert PDFs to XMLs with Poppler, it creates `tokens`. These tokens are generally lines of text, but they can vary according to Poppler's heuristics and what has been extracted. \n",
+ "A token can be a single character, empty string, or an entire line. Every `` item you see above is a `token`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "00517165-bc84-4a6f-9a8b-91084cc603ab",
+ "metadata": {},
+ "source": [
+ "The PdfFeatures module provides basic capabilities for working with PDF files. Here are some features of this module. \n",
+ "You don't have to memorize them, but they can be useful for future reference:\n",
+ "\n",
+ "- Every PdfFeatures instance has `pages` attribute. This attribute includes a list of `PdfPage` elements to work with each of the pages.\n",
+ "- Every PdfPage element has attributes like `page_number`, `page_width`, `page_height` and `tokens`.\n",
+ "- The `tokens` attribute includes a list of `PdfToken` elements to work with each of the tokens within that page.\n",
+ "- Every PdfToken element has attributes like `content`, `bounding_box`, `token_type`, `page_number`.\n",
+ "- The `content` attribute is, as the name implies, the string content of the given token.\n",
+ "- The`bounding_box` attribute stores the position of the given token on the page.\n",
+ "- `bounding_box` is a `Rectangle` element. For example, if you want to get the left coordinate of the token, you can do so by typing `token.bounding_box.left`. It will return an integer value.\n",
+ "- `token_type` attribute is for keeping the type of the token. It's a `TokenType` element and you'll see more details about this one in the next sections.\n",
+ "- Like PdfPage items, tokens also have a `page_number` attribute to indicate which page they are on. This is useful in some scenarios."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "63a71904-0ad3-4fca-830a-402d9334614a",
+ "metadata": {},
+ "source": [
+ "If you want to loop through the tokens of a file and check their contents you can use something like this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "444d3778-c3f5-48fd-aa20-cfe1bf851aad",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001B[96mRESOLUCIÓN DE LA\u001B[0m \u001B[93m[Page: 1 || Coordinates: [244, 106, 355, 118]]\u001B[0m\n",
+ "\u001B[96mCORTE INTERAMERICANA DE DERECHOS HUMANOS\u001B[0m \u001B[93m[Page: 1 || Coordinates: [157, 118, 441, 130]]\u001B[0m\n",
+ "\u001B[96mDEL 29 DE JULIO DE 1991\u001B[0m \u001B[93m[Page: 1 || Coordinates: [227, 129, 372, 141]]\u001B[0m\n",
+ "\u001B[96mMEDIDAS PROVISIONALES SOLICITADAS POR LA COMISIÓN\u001B[0m \u001B[93m[Page: 1 || Coordinates: [132, 165, 466, 177]]\u001B[0m\n",
+ "\u001B[96mINTERAMERICANA DE DERECHOS HUMANOS\u001B[0m \u001B[93m[Page: 1 || Coordinates: [177, 177, 422, 189]]\u001B[0m\n",
+ "\u001B[96mRESPECTO DE GUATEMALA\u001B[0m \u001B[93m[Page: 1 || Coordinates: [225, 188, 374, 200]]\u001B[0m\n",
+ "\u001B[96mCASO CHUNIMA\u001B[0m \u001B[93m[Page: 1 || Coordinates: [254, 224, 344, 236]]\u001B[0m\n",
+ "\u001B[96mLA CORTE INTERAMERICANA DE DERECHOS HUMANOS,\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 259, 393, 271]]\u001B[0m\n",
+ "\u001B[96mVISTOS:\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 295, 137, 307]]\u001B[0m\n",
+ "\u001B[96m1.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 318, 101, 330]]\u001B[0m\n",
+ "\u001B[96mLa resolución del Presidente de la Corte Interamericana de Derechos Humanos\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 318, 511, 330]]\u001B[0m\n",
+ "\u001B[96mde 15 de julio de 1991, sobre medidas provisionales solicitadas por la Comisión\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 330, 514, 342]]\u001B[0m\n",
+ "\u001B[96mInteramericana de Derechos Humanos respecto de Guatemala;\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 342, 401, 354]]\u001B[0m\n",
+ "\u001B[96m2.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 366, 102, 378]]\u001B[0m\n",
+ "\u001B[96mLa convocatoria a una audiencia pública para el día 29 de julio de 1991 a las\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 366, 512, 378]]\u001B[0m\n",
+ "\u001B[96m3:00 p.m., contenida en la resolución citada;\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 378, 312, 390]]\u001B[0m\n",
+ "\u001B[96m3.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 401, 104, 413]]\u001B[0m\n",
+ "\u001B[96mLos escritos de fechas 24 y 26 de este mes de julio presentados por el\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 401, 514, 413]]\u001B[0m\n",
+ "\u001B[96mGobierno de Guatemala en los cuales informa que, en atención a la resolución del\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 413, 513, 425]]\u001B[0m\n",
+ "\u001B[96mPresidente, ha tomado medidas dirigidas a la protección de las personas\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 425, 518, 437]]\u001B[0m\n",
+ "\u001B[96mmencionadas en esa resolución y solicita un aplazamiento de por lo menos 30 días de\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 437, 512, 449]]\u001B[0m\n",
+ "\u001B[96mla audiencia convocada por el Presidente para hoy, a fin de contar con un plazo que\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 448, 512, 460]]\u001B[0m\n",
+ "\u001B[96mle permita hacer una presentación adecuada ante la Corte.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 460, 380, 472]]\u001B[0m\n",
+ "\u001B[96mCONSIDERANDO:\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 484, 189, 496]]\u001B[0m\n",
+ "\u001B[96m1.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 508, 101, 520]]\u001B[0m\n",
+ "\u001B[96mQue, en virtud del artículo 23.4 de su Reglamento, la Corte Interamericana de\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 508, 511, 520]]\u001B[0m\n",
+ "\u001B[96mDerechos Humanos debe pronunciarse sobre la resolución del Presidente del 15 de\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 519, 513, 531]]\u001B[0m\n",
+ "\u001B[96mjulio de 1991;\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 531, 160, 543]]\u001B[0m\n",
+ "\u001B[96m2.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 555, 104, 567]]\u001B[0m\n",
+ "\u001B[96mQue, habida cuenta de que la Corte se encuentra reunida, debe también\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 555, 514, 567]]\u001B[0m\n",
+ "\u001B[96mdecidir sobre la petición de aplazamiento de la audiencia sobre medidas provisionales\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 567, 512, 579]]\u001B[0m\n",
+ "\u001B[96mformuladas por el Gobierno de Guatemala.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 578, 300, 590]]\u001B[0m\n",
+ "\u001B[96mPOR TANTO:\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 602, 159, 614]]\u001B[0m\n",
+ "\u001B[96mLA CORTE INTERAMERICANA DE DERECHOS HUMANOS,\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 626, 393, 638]]\u001B[0m\n",
+ "\u001B[96mRESUELVE:\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 649, 151, 661]]\u001B[0m\n",
+ "\u001B[96m1.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 673, 103, 685]]\u001B[0m\n",
+ "\u001B[96mConvocar a una audiencia pública para el 30 de julio de 1991 a las 15:00\u001B[0m \u001B[93m[Page: 1 || Coordinates: [122, 673, 513, 685]]\u001B[0m\n",
+ "\u001B[96mhoras con el objeto de conocer los puntos de vista del Gobierno de Guatemala y de la\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 685, 512, 697]]\u001B[0m\n",
+ "\u001B[96mComisión sobre la solicitud de prórroga formulada por el primero.\u001B[0m \u001B[93m[Page: 1 || Coordinates: [88, 697, 412, 709]]\u001B[0m\n",
+ "\u001B[96m2\u001B[0m \u001B[93m[Page: 2 || Coordinates: [294, 71, 300, 83]]\u001B[0m\n",
+ "\u001B[96m2.\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 106, 101, 118]]\u001B[0m\n",
+ "\u001B[96mConocer también, en dicha audiencia pública, de las medidas que, en atención\u001B[0m \u001B[93m[Page: 2 || Coordinates: [122, 106, 511, 118]]\u001B[0m\n",
+ "\u001B[96ma la resolución del Presidente del 15 de julio del presente año, ha tomado el\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 118, 515, 130]]\u001B[0m\n",
+ "\u001B[96mGobierno de Guatemala.\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 129, 211, 141]]\u001B[0m\n",
+ "\u001B[96m3.\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 153, 103, 165]]\u001B[0m\n",
+ "\u001B[96mReservarse el derecho de convocar a una audiencia pública para resolver la\u001B[0m \u001B[93m[Page: 2 || Coordinates: [122, 153, 513, 165]]\u001B[0m\n",
+ "\u001B[96mpetición de la Comisión sobre medidas provisionales respecto de Guatemala.\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 165, 467, 177]]\u001B[0m\n",
+ "\u001B[96mHéctor Fix-Zamudio\u001B[0m \u001B[93m[Page: 2 || Coordinates: [249, 200, 349, 212]]\u001B[0m\n",
+ "\u001B[96mPresidente\u001B[0m \u001B[93m[Page: 2 || Coordinates: [272, 212, 327, 224]]\u001B[0m\n",
+ "\u001B[96mOrlando\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 248, 161, 260]]\u001B[0m\n",
+ "\u001B[96mTovar\u001B[0m \u001B[93m[Page: 2 || Coordinates: [129, 248, 191, 260]]\u001B[0m\n",
+ "\u001B[96mTamayo\u001B[0m \u001B[93m[Page: 2 || Coordinates: [161, 248, 234, 260]]\u001B[0m\n",
+ "\u001B[96mThomas\u001B[0m \u001B[93m[Page: 2 || Coordinates: [225, 248, 436, 260]]\u001B[0m\n",
+ "\u001B[96mBuergenthal\u001B[0m \u001B[93m[Page: 2 || Coordinates: [405, 248, 499, 260]]\u001B[0m\n",
+ "\u001B[96mRafael Nieto Navia\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 283, 195, 295]]\u001B[0m\n",
+ "\u001B[96mPolicarpo Callejas Bonilla\u001B[0m \u001B[93m[Page: 2 || Coordinates: [329, 283, 481, 295]]\u001B[0m\n",
+ "\u001B[96mSonia\u001B[0m \u001B[93m[Page: 2 || Coordinates: [88, 318, 150, 330]]\u001B[0m\n",
+ "\u001B[96mPicado\u001B[0m \u001B[93m[Page: 2 || Coordinates: [118, 318, 184, 330]]\u001B[0m\n",
+ "\u001B[96mSotela\u001B[0m \u001B[93m[Page: 2 || Coordinates: [153, 318, 218, 330]]\u001B[0m\n",
+ "\u001B[96mJulio\u001B[0m \u001B[93m[Page: 2 || Coordinates: [191, 318, 419, 330]]\u001B[0m\n",
+ "\u001B[96mA.\u001B[0m \u001B[93m[Page: 2 || Coordinates: [388, 318, 433, 330]]\u001B[0m\n",
+ "\u001B[96mBarberis\u001B[0m \u001B[93m[Page: 2 || Coordinates: [402, 318, 477, 330]]\u001B[0m\n",
+ "\u001B[96mManuel E. Ventura Robles\u001B[0m \u001B[93m[Page: 2 || Coordinates: [235, 354, 364, 366]]\u001B[0m\n",
+ "\u001B[96mSecretario\u001B[0m \u001B[93m[Page: 2 || Coordinates: [273, 366, 326, 378]]\u001B[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "for page in pdf_features.pages:\n",
+ " for token in page.tokens:\n",
+ " coordinates = [token.bounding_box.left, token.bounding_box.top, token.bounding_box.right, token.bounding_box.bottom]\n",
+ " print(f\"\\033[96m{token.content}\\033[0m \\033[93m[Page: {page.page_number} || Coordinates: {coordinates}]\\033[0m\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4576ff4d-92fc-4e19-a947-ebfb3fd01060",
+ "metadata": {},
+ "source": [
+ "## Fine-Tuning Models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "01826a89-25c9-4385-a1e6-b65c0edbd0c6",
+ "metadata": {},
+ "source": [
+ "Now that we have some overview about the `PdfFeatures` module, we can now start fine-tuning process."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "586eba43-9138-4eff-a3fa-24553de04e82",
+ "metadata": {},
+ "source": [
+ "In the `pdf-document-layout-analysis` service, there are two LightGBM (i.e. fast) models.\n",
+ "\n",
+ "- The first model is used to determine the types of tokens. We call it `token_type_model`.\n",
+ "- The second model is used to identify the segments to which the tokens belong. We call this model `paragraph_extraction_model`.\n",
+ "\n",
+ "The second model uses the predictions from the first model's output (predicted token types) as part of its features. So, let's start by fine-tuning the token type model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c326ccb1-a36b-40f9-b7e2-e83ba3c0e12b",
+ "metadata": {},
+ "source": [
+ "### Fine-Tuning Token Type Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7b3638eb-c512-4bd5-97f4-4df3ae984978",
+ "metadata": {},
+ "source": [
+ "#### Loading Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ab35b27c-8464-470c-9ef1-a9aef8945f6a",
+ "metadata": {},
+ "source": [
+ "To properly train a token type model, you should have a list of PdfFeatures items where the `token_type` attribute of their tokens is set correctly, as this attribute will be used as the label.\n",
+ "\n",
+ "To see what labels are going to be used in the model, you can check `pdf_token_type_labels/TokenType.py`. As default, we are using the labels of [DocLayNet](https://github.com/DS4SD/DocLayNet) dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "2ab3093c-6e67-4505-bac3-b7db73ef5372",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_pdf_features_labels() -> PdfFeatures:\n",
+ " # Assuming that you are loading your own labels in this part.\n",
+ " # I'm just going to put a list with a single file for demonstration.\n",
+ " pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(\"test_pdfs/regular.pdf\")\n",
+ " labeled_pdf_features_list: list[PdfFeatures] = [pdf_features]\n",
+ " return labeled_pdf_features_list\n",
+ "\n",
+ "def train_token_type_model():\n",
+ " model_configuration = ModelConfiguration()\n",
+ " labeled_pdf_features_list: list[PdfFeatures] = get_pdf_features_labels()\n",
+ " trainer = TokenTypeTrainer(labeled_pdf_features_list, model_configuration)\n",
+ " train_labels = [token.token_type.get_index() for token in trainer.loop_tokens()]\n",
+ " trainer.train(\"models/token_type_example_model.model\", train_labels) \n",
+ "\n",
+ "train_token_type_model()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "32db8aee-9d2c-45bf-b7af-ac6249081f32",
+ "metadata": {},
+ "source": "Don't forget to check what's inside the `model_configuration`. You might need to tune the hyperparameters."
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fda0c166-ac25-4084-974a-c73f1cb06f18",
+ "metadata": {},
+ "source": "If you want to use our trained models as base and refit with your own data, you can use this function:"
+ },
+ {
+ "cell_type": "code",
+ "id": "5acf2beb-f7a2-4e12-8f11-4bffff7efa74",
+ "metadata": {},
+ "source": [
+ "def refit_token_type_model():\n",
+ " model_configuration = ModelConfiguration()\n",
+ " model_configuration.resume_training = True\n",
+ " labeled_pdf_features_list: list[PdfFeatures] = get_pdf_features_labels()\n",
+ " trainer = TokenTypeTrainer(labeled_pdf_features_list, model_configuration)\n",
+ " train_labels = [token.token_type.get_index() for token in trainer.loop_tokens()]\n",
+ " trainer.train(\"models/token_type_lightgbm.model\", train_labels)\n"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7c50cbae-9841-4289-9097-7357a0c724a7",
+ "metadata": {},
+ "source": "Running this function will refit the same model with your data. Depending on your situation, it may or may not help you."
+ },
+ {
+ "cell_type": "markdown",
+ "id": "19abde59-7ba5-4e65-8ce7-6bb7fb2202d5",
+ "metadata": {},
+ "source": [
+ "If it does not help, you can try to check other fine-tuning strategies in LightGBM. \n",
+ "\n",
+ "In that case, all you need to do is changing this part in `pdf_tokens_type_trainer/PdfTrainer.py` (lines 47-49):\n",
+ "\n",
+ "```\n",
+ " if self.model_configuration.resume_training and exists(model_path):\n",
+ " model = lgb.Booster(model_file=model_path)\n",
+ " gbm = model.refit(x_train, labels)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5379e82a-9fa7-4fea-9d6b-a11e672707bc",
+ "metadata": {},
+ "source": "To make predictions with the trained model, you can use this function:"
+ },
+ {
+ "cell_type": "code",
+ "id": "f5b7f4fb-7052-4e8c-856a-6b1d83e5ece4",
+ "metadata": {},
+ "source": [
+ "def get_predictions():\n",
+ " model_configuration = ModelConfiguration()\n",
+ " pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(\"test_pdfs/regular.pdf\")\n",
+ " trainer = TokenTypeTrainer([pdf_features], model_configuration)\n",
+ " trainer.set_token_types()\n",
+ " for token in pdf_features.pages[0].tokens[:20]:\n",
+ " print(f\"\\033[96m{token.content}\\033[0m \\033[93m[{token.token_type}]\\033[0m\")\n",
+ "\n",
+ "get_predictions() "
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e6808202-892d-43e0-9e7a-73ebc347901f",
+ "metadata": {},
+ "source": "### Fine-Tuning Paragraph Extraction Model"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0b31a859-7867-4bd0-be13-7ae4ff4c8a61",
+ "metadata": {},
+ "source": "#### Loading Data"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2778fd0b-5351-4c83-a15a-ecf8aac91397",
+ "metadata": {},
+ "source": "The second model in our pipeline is the paragraph extraction model. After finding the type of each token, now, we are going to \"group\" the tokens, which means, we are going to find each token's segment."
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8112645a-816d-4579-b6e9-14b505703fc9",
+ "metadata": {},
+ "source": "We are going to explain the process but for this part, we highly recommend you to place your labeled data as in this following file structure and use the already existing methods. Otherwise, it can be a bit more harder for you to use our modules:"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b96c7988-cd1e-492a-9990-84db9f7111d2",
+ "metadata": {},
+ "source": [
+ "```\n",
+ ".\n",
+ "└── pdf-labeled-data\n",
+ " ├── labeled_data\n",
+ " │ ├── token_type\n",
+ " │ │ ├── train_data\n",
+ " │ │ │ ├── example_document1\n",
+ " │ │ │ │ └── labels.json\n",
+ " │ │ │ ├── example_document2\n",
+ " │ │ │ │ └── labels.json\n",
+ " │ │ │ └── example_document3\n",
+ " │ │ │ └── labels.json\n",
+ " │ │ └── test_data\n",
+ " │ │ └── example_document4\n",
+ " │ │ └── labels.json\n",
+ " │ └── paragraph_extraction\n",
+ " │ ├── train_data\n",
+ " │ │ ├── example_document1\n",
+ " │ │ │ └── labels.json\n",
+ " │ │ ├── example_document2\n",
+ " │ │ │ └── labels.json\n",
+ " │ │ └── example_document3\n",
+ " │ │ └── labels.json\n",
+ " │ └── test_data\n",
+ " │ └── example_document4\n",
+ " │ └── labels.json\n",
+ " └── pdfs\n",
+ " ├── example_document1\n",
+ " │ ├── document.pdf\n",
+ " │ └── etree.xml\n",
+ " ├── example_document2\n",
+ " │ ├── document.pdf\n",
+ " │ └── etree.xml\n",
+ " ├── example_document3\n",
+ " │ ├── document.pdf\n",
+ " │ └── etree.xml\n",
+ " └── example_document4\n",
+ " ├── document.pdf\n",
+ " └── etree.xml\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c40e426-af77-47fc-a82c-77b5ca4fddeb",
+ "metadata": {},
+ "source": [
+ "Some details about this structure:\n",
+ "\n",
+ "- Every detail in the token type labels file structure applies for this structure too.\n",
+ "- `paragraph_extraction` directory is where your paragraph extraction datasets are located, its name should not be something else.\n",
+ "- `token_type` labels are also shown in the structure because token types are used as a feature in the paragraph extraction model. If you do not have it, it will not break the pipeline and still train the model but the token_type feature for every token will be `TokenType.TEXT` in paragraph extractor model's features.\n",
+ "- If you do not have `token_type` labels, another option is, after loading the data, you can predict the token types with the token type model (will be shown below)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1e234e69-7e50-4ffe-a31b-2dc8248a676f",
+ "metadata": {},
+ "source": "For labels.json files, they should have this structure:"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "472072a6-a02c-4b75-bbc0-f13bb7e357d2",
+ "metadata": {},
+ "source": [
+ "```\n",
+ "{\n",
+ " \"pages\": [\n",
+ " {\n",
+ " \"number\": 1,\n",
+ " \"labels\": [\n",
+ " {\n",
+ " \"top\": 86,\n",
+ " \"left\": 162,\n",
+ " \"width\": 292,\n",
+ " \"height\": 24,\n",
+ " \"label_type\": 0\n",
+ " },\n",
+ " {\n",
+ " \"top\": 122,\n",
+ " \"left\": 221,\n",
+ " \"width\": 174,\n",
+ " \"height\": 12,\n",
+ " \"label_type\": 0\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " {\n",
+ " \"number\": 2,\n",
+ " \"labels\": [\n",
+ " {\n",
+ " \"top\": 36,\n",
+ " \"left\": 296,\n",
+ " \"width\": 22,\n",
+ " \"height\": 13,\n",
+ " \"label_type\": 0\n",
+ " },\n",
+ " {\n",
+ " \"top\": 72,\n",
+ " \"left\": 71,\n",
+ " \"width\": 473,\n",
+ " \"height\": 49,\n",
+ " \"label_type\": 0\n",
+ " }\n",
+ " ]\n",
+ " }\n",
+ " ]\n",
+ "}\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bb6e716b-b742-4186-9e1a-ac5ecea708ac",
+ "metadata": {},
+ "source": [
+ "Here you see a list of labels for each page. Each label includes information about the coordinates `top`, `left`, `width`, `height` for each segment/paragraph. So, this time the coordinates belongs to the segments, not to the tokens.\n",
+ "\n",
+ "As \"label_type\", it should be always 0 since there is only one type \"paragraph\" (don't get confused with this part, it's not important, just put 0 and go on).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a2c8a9b3-6180-41f2-bb82-bea892a61f5e",
+ "metadata": {},
+ "source": "Using this information, you can load your data like this:"
+ },
+ {
+ "cell_type": "code",
+ "id": "cb6ae549-4f52-45b0-853a-6414ca8b4af3",
+ "metadata": {},
+ "source": [
+ "from os.path import join\n",
+ "from paragraph_extraction_trainer.PdfParagraphTokens import PdfParagraphTokens\n",
+ "\n",
+ "\n",
+ "def load_paragraph_extraction_labels():\n",
+ "\t\n",
+ "\tpdf_labeled_data_root_path = \"path/to/pdf/labeled/data\"\n",
+ "\tdatasets_path = join(pdf_labeled_data_root_path, \"paragraph_extraction\")\n",
+ "\tlabeled_data: list[PdfParagraphTokens] = []\n",
+ "\t\n",
+ "\tfor dataset in listdir(join(datasets_path)):\n",
+ "\t\tif \"train\" not in dataset:\n",
+ "\t\t\tcontinue\n",
+ "\t\tpdf_paragraph_tokens: PdfParagraphTokens = PdfParagraphTokens.from_labeled_data(pdf_labeled_data_root_path, dataset, pdf_name)\n",
+ "\t\tlabeled_data.append(pdf_paragraph_tokens)\n",
+ "\t\n",
+ "\treturn labeled_data\n",
+ "\n",
+ "\n",
+ "from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer\n",
+ "\n",
+ "def load_paragraph_extraction_labels():\n",
+ "\n",
+ " pdf_labeled_data_root_path = \"path/to/pdf/labeled/data\"\n",
+ " datasets_path = join(pdf_labeled_data_root_path, \"paragraph_extraction\")\n",
+ " labeled_pdf_paragraph_tokens_list: list[PdfParagraphTokens] = []\n",
+ " \n",
+ " for dataset in listdir(join(datasets_path)):\n",
+ " if \"train\" not in dataset:\n",
+ " continue\n",
+ " pdf_paragraph_tokens: PdfParagraphTokens = PdfParagraphTokens.from_labeled_data(pdf_labeled_data_root_path, dataset, pdf_name)\n",
+ " labeled_pdf_paragraph_tokens_list.append(pdf_paragraph_tokens)\n",
+ " \n",
+ " \n",
+ " token_type_model_configuration = ModelConfiguration()\n",
+ " labeled_pdf_features_list = [pdf_paragraph_tokens.pdf_features for pdf_paragraph_tokens in labeled_pdf_paragraph_tokens_list]\n",
+ " trainer = TokenTypeTrainer(labeled_pdf_features_list, model_configuration)\n",
+ " \n",
+ " \n",
+ " return labeled_pdf_paragraph_tokens_list"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cf3f6a6c-cba7-43c4-9f72-85cbe447cb6e",
+ "metadata": {},
+ "source": "#### Fine-Tuning the Model"
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29dbaba4-d3d6-4985-be44-df872fe9b5d4",
+ "metadata": {},
+ "source": "Again, to be able to use our trained paragraph extraction model, you should download it from our huggingface repo. You can just run `download_models.py` and both models are going to be downloaded."
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8a82f6f6-cec9-48bc-9c64-b09aa65d2754",
+ "metadata": {},
+ "source": [
+ "If you want to download it manually, you can use this link: https://huggingface.co/HURIDOCS/pdf-document-layout-analysis/tree/main\n",
+ "\n",
+ "After downloading it, place it into `models` directory. The path should be as follows: \n",
+ "`~/pdf-document-layout-analysis/models/paragraph_extraction_lightgbm.model`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b95cd2cd-0d41-4518-8576-b1a0d2adc21b",
+ "metadata": {},
+ "source": "To train the paragraph extraction model from scratch:"
+ },
+ {
+ "cell_type": "code",
+ "id": "67948603-80e6-4b42-9ba1-78868fd9f946",
+ "metadata": {},
+ "source": [
+ "from paragraph_extraction_trainer.model_configuration import MODEL_CONFIGURATION\n",
+ "\n",
+ "\n",
+ "def loop_pdf_paragraph_tokens(pdf_paragraph_tokens_list: list[PdfParagraphTokens]):\n",
+ " for pdf_paragraph_tokens in pdf_paragraph_tokens_list:\n",
+ " for page in pdf_paragraph_tokens.pdf_features.pages:\n",
+ " if not page.tokens:\n",
+ " continue\n",
+ " for token, next_token in zip(page.tokens, page.tokens[1:]):\n",
+ " yield pdf_paragraph_tokens, token, next_token\n",
+ " yield pdf_paragraph_tokens, page.tokens[-1], page.tokens[-1]\n",
+ "\n",
+ "\n",
+ "def train_paragraph_extraction_model():\n",
+ " labeled_pdf_paragraph_tokens_list: list[PdfParagraphTokens] = load_paragraph_extraction_labels()\n",
+ " labeled_pdf_features_list = [pdf_paragraph_tokens.pdf_features for pdf_paragraph_tokens in labeled_pdf_paragraph_tokens_list]\n",
+ " trainer = ParagraphExtractorTrainer(labeled_pdf_features_list, MODEL_CONFIGURATION)\n",
+ " \n",
+ " train_labels = []\n",
+ " for pdf_paragraph_tokens, token, next_token in loop_pdf_paragraph_tokens([pdf_paragraph_tokens]):\n",
+ " train_labels.append(pdf_paragraph_tokens.check_same_paragraph(token, next_token))\n",
+ "\n",
+ " trainer.train(\"models/paragraph_extraction_example_model.model\", train_labels) "
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2e7cd129-874e-415d-9855-401d8c5d0136",
+ "metadata": {},
+ "source": "And to refit the model with your own data, all you need to do is setting `resume_training` configuration to `True`:"
+ },
+ {
+ "cell_type": "code",
+ "id": "37b6b980-deaf-4ba4-baf0-7bf137af63a7",
+ "metadata": {},
+ "source": [
+ "def refit_paragraph_extraction_model():\n",
+ " labeled_pdf_paragraph_tokens_list: list[PdfParagraphTokens] = load_paragraph_extraction_labels()\n",
+ " labeled_pdf_features_list = [pdf_paragraph_tokens.pdf_features for pdf_paragraph_tokens in labeled_pdf_paragraph_tokens_list]\n",
+ " MODEL_CONFIGURATION.resume_training = True\n",
+ " trainer = ParagraphExtractorTrainer(labeled_pdf_features_list, MODEL_CONFIGURATION)\n",
+ " \n",
+ " train_labels = []\n",
+ " for pdf_paragraph_tokens, token, next_token in loop_pdf_paragraph_tokens([pdf_paragraph_tokens]):\n",
+ " train_labels.append(pdf_paragraph_tokens.check_same_paragraph(token, next_token))\n",
+ "\n",
+ " trainer.train(\"models/paragraph_extraction_example_model.model\", train_labels) "
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1389cf49-c163-4f90-ab0c-9606756b8ef9",
+ "metadata": {},
+ "source": "[IMPORTANT] If you want to use your own trained models in pdf-document-layout-analysis service, make sure their names are `token_type_lightgbm.model` and `paragraph_extraction_lightgbm.model` and are placed in `models` directory."
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1d4cf8c-65d2-4496-adcf-ab73acc5000f",
+ "metadata": {},
+ "source": "After finishing training, you can get the predictions of the model like shown in below:"
+ },
+ {
+ "cell_type": "code",
+ "id": "69e747aa-9b19-4e8d-acbb-f8d221dfe006",
+ "metadata": {},
+ "source": [
+ "from pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration\n",
+ "from fast_trainer.model_configuration import MODEL_CONFIGURATION as PARAGRAPH_EXTRACTION_CONFIGURATION\n",
+ "from domain.PdfSegment import PdfSegment\n",
+ "from adapters.ml.fast_trainer.ParagraphExtractorTrainer import ParagraphExtractorTrainer\n",
+ "\n",
+ "def get_predictions():\n",
+ " pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(\"test_pdfs/regular.pdf\")\n",
+ " # First, use token type model to find and set the types.\n",
+ " token_type_trainer = TokenTypeTrainer([pdf_features], ModelConfiguration())\n",
+ " token_type_trainer.set_token_types(\"models/token_type_lightgbm.model\")\n",
+ " trainer = ParagraphExtractorTrainer(pdfs_features=[pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION)\n",
+ " segments: list[PdfSegment] = trainer.get_pdf_segments(\"models/paragraph_extraction_lightgbm.model\")\n",
+ " model_configuration = ModelConfiguration()\n",
+ " for segment in segments[:20]:\n",
+ " print(f\"\\033[96m{segment.text_content}\\033[0m \\033[93m[{segment.segment_type}]\\033[0m \\033[91m{segment.bounding_box.to_dict()}\\033[0m\")"
+ ],
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e3af70a1-404e-4bac-a366-f7962636b1eb",
+ "metadata": {},
+ "source": "Output of the `paragraph_extraction_model` is a list of `PdfSegment` items. Every item includes the information like `page_number`, `text_content`, `segment_type`, `bounding_box`, `pdf_name` for each of the segments. "
+ },
+ {
+ "cell_type": "code",
+ "id": "4dc0c106-7b22-42e3-969f-d52ecddae3ae",
+ "metadata": {},
+ "source": "",
+ "outputs": [],
+ "execution_count": null
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3d5b2376-d983-4c49-8130-b94368782828",
+ "metadata": {},
+ "source": [
+ "```\n",
+ "{\n",
+ " \"pages\": [\n",
+ " {\n",
+ " \"number\": 1,\n",
+ " \"labels\": [\n",
+ " {\n",
+ " \"top\": 86,\n",
+ " \"left\": 162,\n",
+ " \"width\": 292,\n",
+ " \"height\": 24,\n",
+ " \"label_type\": 0\n",
+ " },\n",
+ " {\n",
+ " \"top\": 122,\n",
+ " \"left\": 221,\n",
+ " \"width\": 174,\n",
+ " \"height\": 12,\n",
+ " \"label_type\": 0\n",
+ " }\n",
+ " ]\n",
+ " },\n",
+ " {\n",
+ " \"number\": 2,\n",
+ " \"labels\": [\n",
+ " {\n",
+ " \"top\": 36,\n",
+ " \"left\": 296,\n",
+ " \"width\": 22,\n",
+ " \"height\": 13,\n",
+ " \"label_type\": 0\n",
+ " },\n",
+ " {\n",
+ " \"top\": 72,\n",
+ " \"left\": 71,\n",
+ " \"width\": 473,\n",
+ " \"height\": 49,\n",
+ " \"label_type\": 0\n",
+ " }\n",
+ " ]\n",
+ " }\n",
+ " ]\n",
+ "}\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1972189b-c70b-436d-9830-56adc354b777",
+ "metadata": {},
+ "source": [
+ "Using this information, you can load your data like this:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d6c07ba4-334e-4ff3-8e2f-b2f684f053c9",
+ "metadata": {},
+ "source": [
+ "In case you do not have `token_type` labels and want to find the types with the `token_type_model`, you can use this:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "41b6bb64-92a2-4b75-95f9-a934c104b7c0",
+ "metadata": {},
+ "source": [
+ "#### Fine-Tuning the Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bd38ced0-2925-4fe5-98ec-b633a19b5ce3",
+ "metadata": {},
+ "source": [
+ "If you want to download it manually, you can use this link: https://huggingface.co/HURIDOCS/pdf-document-layout-analysis/tree/main\n",
+ "\n",
+ "After downloading it, place it into `models` directory. The path should be as follows: \n",
+ "`~/pdf-document-layout-analysis/models/paragraph_extraction_lightgbm.model`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "60b22be7-35d0-4c34-891e-c67d25942c72",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from paragraph_extraction_trainer.model_configuration import MODEL_CONFIGURATION\n",
+ "\n",
+ "\n",
+ "def loop_pdf_paragraph_tokens(pdf_paragraph_tokens_list: list[PdfParagraphTokens]):\n",
+ " for pdf_paragraph_tokens in pdf_paragraph_tokens_list:\n",
+ " for page in pdf_paragraph_tokens.pdf_features.pages:\n",
+ " if not page.tokens:\n",
+ " continue\n",
+ " for token, next_token in zip(page.tokens, page.tokens[1:]):\n",
+ " yield pdf_paragraph_tokens, token, next_token\n",
+ " yield pdf_paragraph_tokens, page.tokens[-1], page.tokens[-1]\n",
+ "\n",
+ "\n",
+ "def train_paragraph_extraction_model():\n",
+ " labeled_pdf_paragraph_tokens_list: list[PdfParagraphTokens] = load_paragraph_extraction_labels()\n",
+ " labeled_pdf_features_list = [pdf_paragraph_tokens.pdf_features for pdf_paragraph_tokens in labeled_pdf_paragraph_tokens_list]\n",
+ " trainer = ParagraphExtractorTrainer(labeled_pdf_features_list, MODEL_CONFIGURATION)\n",
+ " \n",
+ " train_labels = []\n",
+ " for pdf_paragraph_tokens, token, next_token in loop_pdf_paragraph_tokens([pdf_paragraph_tokens]):\n",
+ " train_labels.append(pdf_paragraph_tokens.check_same_paragraph(token, next_token))\n",
+ "\n",
+ " trainer.train(\"models/paragraph_extraction_example_model.model\", train_labels) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "5a652ca1-b9c7-4731-ba8b-aa98cd0d11a7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def refit_paragraph_extraction_model():\n",
+ " labeled_pdf_paragraph_tokens_list: list[PdfParagraphTokens] = load_paragraph_extraction_labels()\n",
+ " labeled_pdf_features_list = [pdf_paragraph_tokens.pdf_features for pdf_paragraph_tokens in labeled_pdf_paragraph_tokens_list]\n",
+ " MODEL_CONFIGURATION.resume_training = True\n",
+ " trainer = ParagraphExtractorTrainer(labeled_pdf_features_list, MODEL_CONFIGURATION)\n",
+ " \n",
+ " train_labels = []\n",
+ " for pdf_paragraph_tokens, token, next_token in loop_pdf_paragraph_tokens([pdf_paragraph_tokens]):\n",
+ " train_labels.append(pdf_paragraph_tokens.check_same_paragraph(token, next_token))\n",
+ "\n",
+ " trainer.train(\"models/paragraph_extraction_example_model.model\", train_labels) "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0ca5d8ef-7455-4723-af4e-d8c49096251f",
+ "metadata": {},
+ "source": [
+ "After finishing training, you can get the predictions of the model like shown in below:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e5a5ab63-7931-40e0-8f51-43e3f3ef5b32",
+ "metadata": {},
+ "source": [
+ "Output of the `paragraph_extraction_model` is a list of `PdfSegment` items. Every item includes the information like `page_number`, `text_content`, `segment_type`, `bounding_box`, `pdf_name` for each of the segments. "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/images/vgtexample1.png b/images/vgtexample1.png
new file mode 100644
index 0000000000000000000000000000000000000000..588dc5ddea4feacd953b5e8d7626c8ace73539ff
--- /dev/null
+++ b/images/vgtexample1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b68017bb1ff60317bc2575db44db7117a245321e2baa34efd24b115748a38ca
+size 240120
diff --git a/images/vgtexample2.png b/images/vgtexample2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b351938d3b3b9b5ce6df43a45463f500855f3315
--- /dev/null
+++ b/images/vgtexample2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb2bbb4a4ae5351cf7829b0ba217b21248fd0b92e510e3578c0130952b7573a1
+size 255601
diff --git a/images/vgtexample3.png b/images/vgtexample3.png
new file mode 100644
index 0000000000000000000000000000000000000000..cf867225c6b7407dac16c59c0565658c24ca61fb
--- /dev/null
+++ b/images/vgtexample3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae87bba8266250d03815b183f4c5ef3e839998bb9dcd187b99ea87e99384ff1
+size 127280
diff --git a/images/vgtexample4.png b/images/vgtexample4.png
new file mode 100644
index 0000000000000000000000000000000000000000..63fe29ed5ad27ff761cd14a85eeb5ee5c753f0ea
--- /dev/null
+++ b/images/vgtexample4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a7c9a4fe0d53c57cca52b56a1de98988b9d2ec0a7be25109d120e20f87fa118
+size 213474
diff --git a/justfile b/justfile
new file mode 100644
index 0000000000000000000000000000000000000000..a2a13b33d390532bbce5218aa5fc5c5d6ab45cfa
--- /dev/null
+++ b/justfile
@@ -0,0 +1,95 @@
+HAS_GPU := `command -v nvidia-smi > /dev/null && echo 1 || echo 0`
+
+install:
+ . .venv/bin/activate; pip install -Ur requirements.txt
+
+activate:
+ . .venv/bin/activate
+
+install_venv:
+ python3 -m venv .venv
+ . .venv/bin/activate; python -m pip install --upgrade pip
+ . .venv/bin/activate; python -m pip install -r dev-requirements.txt
+
+formatter:
+ . .venv/bin/activate; command black --line-length 125 .
+
+check_format:
+ . .venv/bin/activate; command black --line-length 125 . --check
+
+remove_docker_containers:
+ docker compose ps -q | xargs docker rm
+
+remove_docker_images:
+ docker compose config --images | xargs docker rmi
+
+start:
+ mkdir -p ./models
+ if [ {{HAS_GPU}} -eq 1 ]; then \
+ echo "NVIDIA GPU detected, using docker-compose-gpu.yml"; \
+ docker compose -f docker-compose-gpu.yml up --build; \
+ else \
+ echo "No NVIDIA GPU detected, using docker-compose.yml"; \
+ docker compose -f docker-compose.yml up --build; \
+ fi
+
+start_no_gpu:
+ mkdir -p ./models
+ docker compose up --build
+
+stop:
+ docker compose stop
+
+test:
+ . .venv/bin/activate; command cd src; command python -m pytest
+
+free_up_space:
+ df -h
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+ sudo apt-get remove -y '^llvm-.*' || true
+ sudo apt-get remove -y 'php.*' || true
+ sudo apt-get remove -y google-cloud-sdk hhvm google-chrome-stable firefox mono-devel || true
+ sudo apt-get autoremove -y
+ sudo apt-get clean
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /usr/local/lib/android
+ sudo rm -rf /opt/hostedtoolcache/CodeQL
+ sudo docker image prune --all --force
+ df -h
+
+start_detached:
+ mkdir -p ./models
+ docker compose up --build -d
+
+start_detached_gpu:
+ mkdir -p ./models
+ RESTART_IF_NO_GPU=true docker compose -f docker-compose-gpu.yml up --build -d
+
+upgrade:
+ . .venv/bin/activate; pip-upgrade
+
+tag:
+ #!/bin/bash
+ # Get current date
+ CURRENT_DATE=$(date +%Y.%-m.%-d)
+ echo "Current date: $CURRENT_DATE"
+
+ # Get the latest tag that matches today's date pattern
+ LATEST_TAG=$(git tag --list "${CURRENT_DATE}.*" --sort=-version:refname | head -n1)
+
+ if [ -z "$LATEST_TAG" ]; then
+ # No tag for today, start with revision 1
+ REVISION=1
+ else
+ # Extract revision number and increment
+ REVISION=$(echo $LATEST_TAG | cut -d. -f4)
+ REVISION=$((REVISION + 1))
+ fi
+
+ NEW_TAG="${CURRENT_DATE}.${REVISION}"
+ echo "Creating new tag: $NEW_TAG"
+ git tag $NEW_TAG
+ git push --tag
\ No newline at end of file
diff --git a/master_key.py b/master_key.py
index 9787bac11fe707cc58921b612b4d85de74007e24..749032a9f1bdb5c3e279285b2886ab10b6b10153 100644
--- a/master_key.py
+++ b/master_key.py
@@ -305,6 +305,7 @@ TABLE_SCHEMAS = {
"orientation": "row1",
"labels": ["Print Name", "NHVR or Exemplar Global Auditor Registration Number"],
"priority": 90,
+ "context_keywords": ["auditor declaration", "NHVR"],
"context_exclusions": ["manager", "operator declaration"]
},
"Audit Declaration dates": {
@@ -368,4 +369,4 @@ PARAGRAPH_PATTERNS = {
"declaration_text": r"I hereby acknowledge and agree with the findings.*",
"introductory_note": r"This audit assesses the.*",
"date_line": r"^\s*\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4}\s*$|^Date$"
-}
\ No newline at end of file
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..de1af463497d21dd4f3fa3230cfa0a2d4ecda912
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,39 @@
+[project]
+name = "pdf-document-layout-analysis"
+version = "2025.03.18.03"
+description = "This tool is for PDF document layout analysis"
+license = { file = "LICENSE" }
+authors = [{ name = "HURIDOCS" }]
+requires-python = ">= 3.10"
+dependencies = [
+ "fastapi==0.111.1",
+ "python-multipart==0.0.9",
+ "uvicorn==0.30.3",
+ "gunicorn==22.0.0",
+ "requests==2.32.3",
+ "torch==2.4.0",
+ "torchvision==0.19.0",
+ "timm==1.0.8",
+ "Pillow==10.4.0",
+ "pdf-annotate==0.12.0",
+ "scipy==1.14.0",
+ "opencv-python==4.10.0.84",
+ "Shapely==2.0.5",
+ "transformers==4.40.2",
+ "huggingface_hub==0.23.5",
+ "pdf2image==1.17.0",
+ "lxml==5.2.2",
+ "lightgbm==4.5.0",
+ "setuptools==75.4.0",
+ "roman==4.2",
+ "hydra-core==1.3.2",
+ "pypandoc==1.13",
+ "rapid-latex-ocr==0.0.9",
+ "struct_eqtable @ git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git@fd06078bfa9364849eb39330c075dd63cbed73ff"
+]
+
+[project.urls]
+HURIDOCS = "https://huridocs.org"
+GitHub = "https://github.com/huridocs/pdf-document-layout-analysis"
+HuggingFace = "https://huggingface.co/HURIDOCS/pdf-document-layout-analysis"
+DockerHub = "https://hub.docker.com/r/huridocs/pdf-document-layout-analysis"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 420c14e6dd03bf7b8982ef2ce8ab867358b172a2..9343132ab8d20fd49ba8cbd5d72294e9b6ac6991 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,30 @@
+fastapi==0.111.1
+pydantic==2.11.0
+python-multipart==0.0.9
+uvicorn==0.30.3
+gunicorn==22.0.0
+requests==2.32.3
+torch==2.4.0
+torchvision==0.19.0
+Pillow==10.4.0
+pdf-annotate==0.12.0
+scipy==1.14.0
+opencv-python==4.10.0.84
+Shapely==2.0.5
+transformers==4.40.2
+huggingface_hub==0.23.5
+pdf2image==1.17.0
+lightgbm==4.5.0
+setuptools==75.4.0
+roman==4.2
+hydra-core==1.3.2
+pypandoc==1.13
+rapid-table==2.0.3
+rapidocr==3.2.0
+pix2tex==0.1.4
+latex2mathml==3.78.0
+PyMuPDF==1.25.5
+git+https://github.com/huridocs/pdf-features.git@2025.7.30.1
gradio==4.44.1
pytesseract
python-docx
diff --git a/space-pdf/README.md b/space-pdf/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e6533cec175eb0197702cd8ab9cdc0af63e32b1c
--- /dev/null
+++ b/space-pdf/README.md
@@ -0,0 +1,910 @@
+PDF Document Layout Analysis
+A Docker-powered microservice for intelligent PDF document layout analysis, OCR, and content extraction
+
+
+
+
+
+
+
+
+
+
+
+
+
+---
+
+## 🚀 Overview
+
+This project provides a powerful and flexible PDF analysis microservice built with **Clean Architecture** principles. The service enables OCR, segmentation, and classification of different parts of PDF pages, identifying elements such as texts, titles, pictures, tables, formulas, and more. Additionally, it determines the correct reading order of these identified elements and can convert PDFs to various formats including Markdown and HTML.
+
+### ✨ Key Features
+
+- 🔍 **Advanced PDF Layout Analysis** - Segment and classify PDF content with high accuracy
+- 🖼️ **Visual & Fast Models** - Choose between VGT (Vision Grid Transformer) for accuracy or LightGBM for speed
+- 📝 **Multi-format Output** - Export to JSON, Markdown, HTML, and visualize PDF segmentations
+- 🌐 **OCR Support** - 150+ language support with Tesseract OCR
+- 📊 **Table & Formula Extraction** - Extract tables as HTML and formulas as LaTeX
+- 🏗️ **Clean Architecture** - Modular, testable, and maintainable codebase
+- 🐳 **Docker-Ready** - Easy deployment with GPU support
+- ⚡ **RESTful API** - Comprehensive API with 10+ endpoints
+
+
+
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+ |
+
+
+
+### 🔗 Project Links
+
+- **GitHub**: [pdf-document-layout-analysis](https://github.com/huridocs/pdf-document-layout-analysis)
+- **HuggingFace**: [pdf-document-layout-analysis](https://huggingface.co/HURIDOCS/pdf-document-layout-analysis)
+- **DockerHub**: [pdf-document-layout-analysis](https://hub.docker.com/r/huridocs/pdf-document-layout-analysis/)
+
+---
+
+## 🚀 Quick Start
+
+### 1. Start the Service
+
+**With GPU support (recommended for better performance):**
+```bash
+make start
+```
+
+**Without GPU support:**
+```bash
+make start_no_gpu
+```
+
+The service will be available at `http://localhost:5060`
+
+**Check service status:**
+
+```bash
+curl http://localhost:5060/info
+```
+
+### 2. Basic PDF Analysis
+
+**Analyze a PDF document (VGT model - high accuracy):**
+```bash
+curl -X POST -F 'file=@/path/to/your/document.pdf' http://localhost:5060
+```
+
+**Fast analysis (LightGBM models - faster processing):**
+```bash
+curl -X POST -F 'file=@/path/to/your/document.pdf' -F "fast=true" http://localhost:5060
+```
+
+### 3. Stop the Service
+
+```bash
+make stop
+```
+
+> 💡 **Tip**: Replace `/path/to/your/document.pdf` with the actual path to your PDF file. The service will return a JSON response with segmented content and metadata.
+
+
+## 📋 Table of Contents
+
+- [🚀 Quick Start](#-quick-start)
+- [⚙️ Dependencies](#-dependencies)
+- [📋 Requirements](#-requirements)
+- [📚 API Reference](#-api-reference)
+- [💡 Usage Examples](#-usage-examples)
+- [🏗️ Architecture](#-architecture)
+- [🤖 Models](#-models)
+- [📊 Data](#-data)
+- [🔧 Development](#-development)
+- [📈 Benchmarks](#-benchmarks)
+ - [Performance](#performance)
+ - [Speed](#speed)
+- [🌐 Installation of More Languages for OCR](#-installation-of-more-languages-for-ocr)
+- [🔗 Related Services](#-related-services)
+- [🤝 Contributing](#-contributing)
+
+
+
+## ⚙️ Dependencies
+
+### Required
+- **Docker Desktop 4.25.0+** - [Installation Guide](https://www.docker.com/products/docker-desktop/)
+- **Python 3.10+** (for local development)
+
+### Optional
+- **NVIDIA Container Toolkit** - [Installation Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) (for GPU support)
+
+## 📋 Requirements
+
+### System Requirements
+- **RAM**: 2 GB minimum
+- **GPU Memory**: 5 GB (optional, will fallback to CPU if unavailable)
+- **Disk Space**: 10 GB for models and dependencies
+- **CPU**: Multi-core recommended for better performance
+
+### Docker Requirements
+- Docker Engine 20.10+
+- Docker Compose 2.0+
+
+## 📚 API Reference
+
+The service provides a comprehensive RESTful API with the following endpoints:
+
+### Core Analysis Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/` | POST | Analyze PDF layout and extract segments | `file`, `fast`, `parse_tables_and_math` |
+| `/save_xml/{filename}` | POST | Analyze PDF and save XML output | `file`, `xml_file_name`, `fast` |
+| `/get_xml/{filename}` | GET | Retrieve saved XML analysis | `xml_file_name` |
+
+### Content Extraction Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/text` | POST | Extract text by content types | `file`, `fast`, `types` |
+| `/toc` | POST | Extract table of contents | `file`, `fast` |
+| `/toc_legacy_uwazi_compatible` | POST | Extract TOC (Uwazi compatible) | `file` |
+
+### Format Conversion Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/markdown` | POST | Convert PDF to Markdown (includes segmentation data in zip) | `file`, `fast`, `extract_toc`, `dpi`, `output_file` |
+| `/html` | POST | Convert PDF to HTML (includes segmentation data in zip) | `file`, `fast`, `extract_toc`, `dpi`, `output_file` |
+| `/visualize` | POST | Visualize segmentation results on the PDF | `file`, `fast` |
+
+### OCR & Utility Endpoints
+
+| Endpoint | Method | Description | Parameters |
+|----------|--------|-------------|------------|
+| `/ocr` | POST | Apply OCR to PDF | `file`, `language` |
+| `/info` | GET | Get service information | - |
+| `/` | GET | Health check and system info | - |
+| `/error` | GET | Test error handling | - |
+
+### Common Parameters
+
+- **`file`**: PDF file to process (multipart/form-data)
+- **`fast`**: Use LightGBM models instead of VGT (boolean, default: false)
+- **`parse_tables_and_math`**: Apply OCR to table regions (boolean, default: false) and convert formulas to LaTeX
+- **`language`**: OCR language code (string, default: "en")
+- **`types`**: Comma-separated content types to extract (string, default: "all")
+- **`extract_toc`**: Include table of contents at the beginning of the output (boolean, default: false)
+- **`dpi`**: Image resolution for conversion (integer, default: 120)
+
+## 💡 Usage Examples
+
+### Basic PDF Analysis
+
+**Standard analysis with VGT model:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060
+```
+
+**Fast analysis with LightGBM models:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'fast=true' \
+ http://localhost:5060
+```
+
+**Analysis with table and math parsing:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'parse_tables_and_math=true' \
+ http://localhost:5060
+```
+
+### Text Extraction
+
+**Extract all text:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'types=all' \
+ http://localhost:5060/text
+```
+
+**Extract specific content types:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'types=title,text,table' \
+ http://localhost:5060/text
+```
+
+### Format Conversion
+
+**Convert to Markdown:**
+```bash
+curl -X POST http://localhost:5060/markdown \
+ -F 'file=@document.pdf' \
+ -F 'extract_toc=true' \
+ -F 'output_file=document.md' \
+ --output 'document.zip'
+```
+
+**Convert to HTML:**
+```bash
+curl -X POST http://localhost:5060/html \
+ -F 'file=@document.pdf' \
+ -F 'extract_toc=true' \
+ -F 'output_file=document.html' \
+ --output 'document.zip'
+```
+
+> **📋 Segmentation Data**: Format conversion endpoints automatically include detailed segmentation data in the zip output. The resulting zip file contains a `{filename}_segmentation.json` file with information about each detected document segment including:
+> - **Coordinates**: `left`, `top`, `width`, `height`
+> - **Page information**: `page_number`, `page_width`, `page_height`
+> - **Content**: `text` content and segment `type` (e.g., "Title", "Text", "Table", "Picture")
+
+
+### OCR Processing
+
+**OCR in English:**
+```bash
+curl -X POST \
+ -F 'file=@scanned_document.pdf' \
+ -F 'language=en' \
+ http://localhost:5060/ocr \
+ --output ocr_processed.pdf
+```
+
+**OCR in other languages:**
+```bash
+# French
+curl -X POST \
+ -F 'file=@document_french.pdf' \
+ -F 'language=fr' \
+ http://localhost:5060/ocr \
+ --output ocr_french.pdf
+
+# Spanish
+curl -X POST \
+ -F 'file=@document_spanish.pdf' \
+ -F 'language=es' \
+ http://localhost:5060/ocr \
+ --output ocr_spanish.pdf
+```
+
+### Visualization
+
+**Generate visualization PDF:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/visualize \
+ --output visualization.pdf
+```
+
+### Table of Contents Extraction
+
+**Extract structured TOC:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/toc
+```
+
+### XML Storage and Retrieval
+
+**Analyze and save XML:**
+```bash
+curl -X POST \
+ -F 'file=@document.pdf' \
+ http://localhost:5060/save_xml/my_analysis
+```
+
+**Retrieve saved XML:**
+```bash
+curl http://localhost:5060/get_xml/my_analysis.xml
+```
+
+### Service Information
+
+**Get service info and supported languages:**
+```bash
+curl http://localhost:5060/info
+```
+
+**Health check:**
+```bash
+curl http://localhost:5060/
+```
+
+### Response Format
+
+Most endpoints return JSON with segment information:
+
+```json
+[
+ {
+ "left": 72.0,
+ "top": 84.0,
+ "width": 451.2,
+ "height": 23.04,
+ "page_number": 1,
+ "page_width": 595.32,
+ "page_height": 841.92,
+ "text": "Document Title",
+ "type": "Title"
+ },
+ {
+ "left": 72.0,
+ "top": 120.0,
+ "width": 451.2,
+ "height": 200.0,
+ "page_number": 1,
+ "page_width": 595.32,
+ "page_height": 841.92,
+ "text": "This is the main text content...",
+ "type": "Text"
+ }
+]
+```
+
+### Supported Content Types
+
+- `Caption` - Image and table captions
+- `Footnote` - Footnote text
+- `Formula` - Mathematical formulas
+- `List item` - List items and bullet points
+- `Page footer` - Footer content
+- `Page header` - Header content
+- `Picture` - Images and figures
+- `Section header` - Section headings
+- `Table` - Table content
+- `Text` - Regular text paragraphs
+- `Title` - Document and section titles
+
+
+## 🏗️ Architecture
+
+This project follows **Clean Architecture** principles, ensuring separation of concerns, testability, and maintainability. The codebase is organized into distinct layers:
+
+### Directory Structure
+
+```
+src/
+├── domain/ # Enterprise Business Rules
+│ ├── PdfImages.py # PDF image handling domain logic
+│ ├── PdfSegment.py # PDF segment entity
+│ ├── Prediction.py # ML prediction entity
+│ └── SegmentBox.py # Core segment box entity
+├── use_cases/ # Application Business Rules
+│ ├── pdf_analysis/ # PDF analysis use case
+│ ├── text_extraction/ # Text extraction use case
+│ ├── toc_extraction/ # Table of contents extraction
+│ ├── visualization/ # PDF visualization use case
+│ ├── ocr/ # OCR processing use case
+│ ├── markdown_conversion/ # Markdown conversion use case
+│ └── html_conversion/ # HTML conversion use case
+├── adapters/ # Interface Adapters
+│ ├── infrastructure/ # External service adapters
+│ ├── ml/ # Machine learning model adapters
+│ ├── storage/ # File storage adapters
+│ └── web/ # Web framework adapters
+├── ports/ # Interface definitions
+│ ├── services/ # Service interfaces
+│ └── repositories/ # Repository interfaces
+└── drivers/ # Frameworks & Drivers
+ └── web/ # FastAPI application setup
+```
+
+### Layer Responsibilities
+
+- **Domain Layer**: Contains core business entities and rules independent of external concerns
+- **Use Cases Layer**: Orchestrates domain entities to fulfill specific application requirements
+- **Adapters Layer**: Implements interfaces defined by inner layers and adapts external frameworks
+- **Drivers Layer**: Contains frameworks, databases, and external agency configurations
+
+### Key Benefits
+
+- 🔄 **Dependency Inversion**: High-level modules don't depend on low-level modules
+- 🧪 **Testability**: Easy to unit test business logic in isolation
+- 🔧 **Maintainability**: Changes to external frameworks don't affect business rules
+- 📈 **Scalability**: Easy to add new features without modifying existing code
+
+
+## 🤖 Models
+
+The service offers two complementary model approaches, each optimized for different use cases:
+
+### 1. Vision Grid Transformer (VGT) - High Accuracy Model
+
+**Overview**: A state-of-the-art visual model developed by Alibaba Research Group that "sees" the entire page layout.
+
+**Key Features**:
+- 🎯 **High Accuracy**: Best-in-class performance on document layout analysis
+- 👁️ **Visual Understanding**: Analyzes the entire page context including spatial relationships
+- 📊 **Trained on DocLayNet**: Uses the comprehensive [DocLayNet dataset](https://github.com/DS4SD/DocLayNet)
+- 🔬 **Research-Backed**: Based on [Advanced Literate Machinery](https://github.com/AlibabaResearch/AdvancedLiterateMachinery)
+
+**Resource Requirements**:
+- GPU: 5GB+ VRAM (recommended)
+- CPU: Falls back automatically if GPU unavailable
+- Processing Speed: ~1.75 seconds/page (GPU [GTX 1070]) or ~13.5 seconds/page (CPU [i7-8700])
+
+### 2. LightGBM Models - Fast & Efficient
+
+**Overview**: Lightweight ensemble of two specialized models using XML-based features from Poppler.
+
+**Key Features**:
+- ⚡ **High Speed**: ~0.42 seconds per page on CPU (i7-8700)
+- 💾 **Low Resource Usage**: CPU-only, minimal memory footprint
+- 🔄 **Dual Model Approach**:
+ - **Token Type Classifier**: Identifies content types (title, text, table, etc.)
+ - **Segmentation Model**: Determines proper content boundaries
+- 📄 **XML-Based**: Uses Poppler's PDF-to-XML conversion for feature extraction
+
+**Trade-offs**:
+- Slightly lower accuracy compared to VGT
+- No visual context understanding
+- Excellent for batch processing and resource-constrained environments
+
+### OCR Integration
+
+Both models integrate seamlessly with OCR capabilities:
+
+- **Engine**: [Tesseract OCR](https://github.com/tesseract-ocr/tesseract)
+- **Processing**: [ocrmypdf](https://ocrmypdf.readthedocs.io/en/latest/index.html)
+- **Languages**: 150+ supported languages
+- **Output**: Searchable PDFs with preserved layout
+
+### Model Selection Guide
+
+| Use Case | Recommended Model | Reason |
+|----------|------------------|---------|
+| High accuracy requirements | VGT | Superior visual understanding |
+| Batch processing | LightGBM | Faster processing, lower resources |
+| GPU available | VGT | Leverages GPU acceleration |
+| CPU-only environment | LightGBM | Optimized for CPU processing |
+| Real-time applications | LightGBM | Consistent fast response times |
+| Research/analysis | VGT | Best accuracy for detailed analysis |
+
+## 📊 Data
+
+### Training Dataset
+
+Both model types are trained on the comprehensive [DocLayNet dataset](https://github.com/DS4SD/DocLayNet), a large-scale document layout analysis dataset containing over 80,000 document pages.
+
+### Document Categories
+
+The models can identify and classify 11 distinct content types:
+
+| ID | Category | Description |
+|----|----------|-------------|
+| 1 | **Caption** | Image and table captions |
+| 2 | **Footnote** | Footnote references and text |
+| 3 | **Formula** | Mathematical equations and formulas |
+| 4 | **List item** | Bulleted and numbered list items |
+| 5 | **Page footer** | Footer content and page numbers |
+| 6 | **Page header** | Header content and titles |
+| 7 | **Picture** | Images, figures, and graphics |
+| 8 | **Section header** | Section and subsection headings |
+| 9 | **Table** | Tabular data and structures |
+| 10 | **Text** | Regular paragraph text |
+| 11 | **Title** | Document and chapter titles |
+
+### Dataset Characteristics
+
+- **Domain Coverage**: Academic papers, technical documents, reports
+- **Language**: Primarily English with multilingual support
+- **Quality**: High-quality annotations with bounding boxes and labels
+- **Diversity**: Various document layouts, fonts, and formatting styles
+
+For detailed information about the dataset, visit the [DocLayNet repository](https://github.com/DS4SD/DocLayNet).
+
+## 🔧 Development
+
+### Local Development Setup
+
+1. **Clone the repository:**
+ ```bash
+ git clone https://github.com/huridocs/pdf-document-layout-analysis.git
+ cd pdf-document-layout-analysis
+ ```
+
+2. **Create virtual environment:**
+ ```bash
+ make install_venv
+ ```
+
+3. **Activate environment:**
+ ```bash
+ make activate
+ # or manually: source .venv/bin/activate
+ ```
+
+4. **Install dependencies:**
+ ```bash
+ make install
+ ```
+
+### Code Quality
+
+**Format code:**
+```bash
+make formatter
+```
+
+**Check formatting:**
+```bash
+make check_format
+```
+
+### Testing
+
+**Run tests:**
+```bash
+make test
+```
+
+**Integration tests:**
+```bash
+# Tests are located in src/tests/integration/
+python -m pytest src/tests/integration/test_end_to_end.py
+```
+
+### Docker Development
+
+**Build and start (detached mode):**
+```bash
+# With GPU
+make start_detached_gpu
+
+# Without GPU
+make start_detached
+```
+
+**Clean up Docker resources:**
+```bash
+# Remove containers
+make remove_docker_containers
+
+# Remove images
+make remove_docker_images
+```
+
+### Project Structure
+
+```
+pdf-document-layout-analysis/
+├── src/ # Source code
+│ ├── domain/ # Business entities
+│ ├── use_cases/ # Application logic
+│ ├── adapters/ # External integrations
+│ ├── ports/ # Interface definitions
+│ └── drivers/ # Framework configurations
+├── test_pdfs/ # Test PDF files
+├── models/ # ML model storage
+├── docker-compose.yml # Docker configuration
+├── Dockerfile # Container definition
+├── Makefile # Development commands
+├── pyproject.toml # Python project configuration
+└── requirements.txt # Python dependencies
+```
+
+### Environment Variables
+
+Key configuration options:
+
+```bash
+# OCR configuration
+OCR_SOURCE=/tmp/ocr_source
+
+# Model paths (auto-configured)
+MODELS_PATH=./models
+
+# Service configuration
+HOST=0.0.0.0
+PORT=5060
+```
+
+### Adding New Features
+
+1. **Domain Logic**: Add entities in `src/domain/`
+2. **Use Cases**: Implement business logic in `src/use_cases/`
+3. **Adapters**: Create integrations in `src/adapters/`
+4. **Ports**: Define interfaces in `src/ports/`
+5. **Controllers**: Add endpoints in `src/adapters/web/`
+
+### Debugging
+
+**View logs:**
+```bash
+docker compose logs -f
+```
+
+**Access container:**
+```bash
+docker exec -it pdf-document-layout-analysis /bin/bash
+```
+
+**Free up disk space:**
+```bash
+make free_up_space
+```
+
+### Order of Output Elements
+
+The service returns SegmentBox elements in a carefully determined reading order:
+
+#### Reading Order Algorithm
+
+1. **Poppler Integration**: Uses [Poppler](https://poppler.freedesktop.org) PDF-to-XML conversion to establish initial token reading order
+2. **Segment Averaging**: Calculates average reading order for multi-token segments
+3. **Type-Based Sorting**: Prioritizes content types:
+ - **Headers** placed first
+ - **Main content** in reading order
+ - **Footers and footnotes** placed last
+
+#### Non-Text Elements
+
+For segments without text (e.g., images):
+- Processed after text-based sorting
+- Positioned based on nearest text segment proximity
+- Uses spatial distance as the primary criterion
+
+### Advanced Table and Formula Extraction
+
+#### Default Behavior
+- **Formulas**: Automatically extracted as LaTeX format in the `text` property
+- **Tables**: Basic text extraction included by default
+
+#### Enhanced Table Extraction
+
+Parse tables and extract them in HTML format by setting `parse_tables_and_math=true`:
+
+```bash
+curl -X POST -F 'file=@document.pdf' -F 'parse_tables_and_math=true' http://localhost:5060
+```
+
+
+#### Extraction Engines
+- **Formulas**: [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
+- **Tables**: [RapidTable](https://github.com/RapidAI/RapidTable)
+
+
+## 📈 Benchmarks
+
+### Performance
+
+VGT model performance on PubLayNet dataset:
+
+| Metric | Overall | Text | Title | List | Table | Figure |
+|--------|---------|------|-------|------|-------|--------|
+| **F1 Score** | **0.962** | 0.950 | 0.939 | 0.968 | 0.981 | 0.971 |
+
+> 📊 **Comparison**: View comprehensive model comparisons at [Papers With Code](https://paperswithcode.com/sota/document-layout-analysis-on-publaynet-val)
+
+### Speed
+
+Performance benchmarks on 15-page academic documents:
+
+| Model | Hardware | Speed (sec/page) | Use Case |
+|-------|----------|------------------|----------|
+| **LightGBM** | CPU (i7-8700 3.2GHz) | **0.42** | Fast processing |
+| **VGT** | GPU (GTX 1070) | **1.75** | High accuracy |
+| **VGT** | CPU (i7-8700 3.2GHz) | 13.5 | CPU fallback |
+
+### Performance Recommendations
+
+- **GPU Available**: Use VGT for best accuracy-speed balance
+- **CPU Only**: Use LightGBM for optimal performance
+- **Batch Processing**: LightGBM for consistent throughput
+- **High Accuracy**: VGT with GPU for best results
+
+
+## 🌐 Installation of More Languages for OCR
+
+The service uses Tesseract OCR with support for 150+ languages. The Docker image includes only common languages to minimize image size.
+
+### Installing Additional Languages
+
+#### 1. Access the Container
+```bash
+docker exec -it --user root pdf-document-layout-analysis /bin/bash
+```
+
+#### 2. Install Language Packs
+```bash
+# Install specific language
+apt-get update
+apt-get install tesseract-ocr-[LANGCODE]
+```
+
+#### 3. Common Language Examples
+
+```bash
+# Korean
+apt-get install tesseract-ocr-kor
+
+# German
+apt-get install tesseract-ocr-deu
+
+# French
+apt-get install tesseract-ocr-fra
+
+# Spanish
+apt-get install tesseract-ocr-spa
+
+# Chinese Simplified
+apt-get install tesseract-ocr-chi-sim
+
+# Arabic
+apt-get install tesseract-ocr-ara
+
+# Japanese
+apt-get install tesseract-ocr-jpn
+```
+
+#### 4. Verify Installation
+
+```bash
+curl http://localhost:5060/info
+```
+
+### Language Code Reference
+
+Find Tesseract language codes in the [ISO to Tesseract mapping](https://github.com/huridocs/pdf-document-layout-analysis/blob/main/src/adapters/infrastructure/ocr/languages.py).
+
+### Supported Languages
+
+Common language codes:
+- `eng` - English
+- `fra` - French
+- `deu` - German
+- `spa` - Spanish
+- `ita` - Italian
+- `por` - Portuguese
+- `rus` - Russian
+- `chi-sim` - Chinese Simplified
+- `chi-tra` - Chinese Traditional
+- `jpn` - Japanese
+- `kor` - Korean
+- `ara` - Arabic
+- `hin` - Hindi
+
+### Usage with Multiple Languages
+
+```bash
+# OCR with specific language
+curl -X POST \
+ -F 'file=@document.pdf' \
+ -F 'language=fr' \
+ http://localhost:5060/ocr \
+ --output french_ocr.pdf
+```
+
+
+## 🔗 Related Services
+
+Explore our ecosystem of PDF processing services built on this foundation:
+
+### [PDF Table of Contents Extractor](https://github.com/huridocs/pdf-table-of-contents-extractor)
+🔍 **Purpose**: Intelligent extraction of structured table of contents from PDF documents
+
+**Key Features**:
+- Leverages layout analysis for accurate TOC identification
+- Hierarchical structure recognition
+- Multiple output formats supported
+- Integration-ready API
+
+### [PDF Text Extraction](https://github.com/huridocs/pdf-text-extraction)
+📝 **Purpose**: Advanced text extraction with layout awareness
+
+**Key Features**:
+- Content-type aware extraction
+- Preserves document structure
+- Reading order optimization
+- Clean text output with metadata
+
+### Integration Benefits
+
+These services work seamlessly together:
+- **Shared Analysis**: Reuse layout analysis results across services
+- **Consistent Output**: Standardized JSON format for easy integration
+- **Scalable Architecture**: Deploy services independently or together
+- **Docker Ready**: All services containerized for easy deployment
+
+## 🤝 Contributing
+
+We welcome contributions to improve the PDF Document Layout Analysis service!
+
+### How to Contribute
+
+1. **Fork the Repository**
+ ```bash
+ git clone https://github.com/your-username/pdf-document-layout-analysis.git
+ ```
+
+2. **Create a Feature Branch**
+ ```bash
+ git checkout -b feature/your-feature-name
+ ```
+
+3. **Set Up Development Environment**
+ ```bash
+ make install_venv
+ make install
+ ```
+
+4. **Make Your Changes**
+ - Follow the Clean Architecture principles
+ - Add tests for new features
+ - Update documentation as needed
+
+5. **Run Tests and Quality Checks**
+ ```bash
+ make test
+ make check_format
+ ```
+
+6. **Submit a Pull Request**
+ - Provide clear description of changes
+ - Include test results
+ - Reference any related issues
+
+### Contribution Guidelines
+
+#### Code Standards
+- **Python**: Follow PEP 8 with 125-character line length
+- **Architecture**: Maintain Clean Architecture boundaries
+- **Testing**: Include unit tests for new functionality
+- **Documentation**: Update README and docstrings
+
+#### Areas for Contribution
+
+- 🐛 **Bug Fixes**: Report and fix issues
+- ✨ **New Features**: Add new endpoints or functionality
+- 📚 **Documentation**: Improve guides and examples
+- 🧪 **Testing**: Expand test coverage
+- 🚀 **Performance**: Optimize processing speed
+- 🌐 **Internationalization**: Add language support
+
+#### Development Workflow
+
+1. **Issue First**: Create or comment on relevant issues
+2. **Small PRs**: Keep pull requests focused and manageable
+3. **Clean Commits**: Use descriptive commit messages
+4. **Documentation**: Update relevant documentation
+5. **Testing**: Ensure all tests pass
+
+### Getting Help
+
+- 📚 **Documentation**: Check this README and inline docs
+- 💬 **Issues**: Search existing issues or create new ones
+- 🔍 **Code**: Explore the codebase structure
+- 📧 **Contact**: Reach out to maintainers for guidance
+
+---
+
+### License
+
+This project is licensed under the terms specified in the [LICENSE](LICENSE) file.
diff --git a/space-pdf/app.py b/space-pdf/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14b6ca9e0cf38627e47707b268dcde01d41c52e
--- /dev/null
+++ b/space-pdf/app.py
@@ -0,0 +1,124 @@
+import gradio as gr
+import tempfile
+import os
+import shutil
+import subprocess
+from pathlib import Path
+
+SCRIPT_DIR = Path(__file__).resolve().parent
+
+def run_cmd(cmd, cwd=None, env=None):
+ """Run a command, print nice logs, and also save them to run.log in cwd."""
+ cwd = str(cwd or os.getcwd())
+ print(f"🟦 Running: {' '.join(cmd)} (cwd={cwd})")
+ proc = subprocess.run(
+ cmd,
+ cwd=cwd,
+ env=env,
+ capture_output=True,
+ text=True
+ )
+ if proc.stdout:
+ print("🟩 STDOUT:")
+ print(proc.stdout)
+ if proc.stderr:
+ print("🟥 STDERR:")
+ print(proc.stderr)
+ # Save to run.log for debugging
+ try:
+ runlog = Path(cwd) / "run.log"
+ with open(runlog, "a", encoding="utf-8") as f:
+ f.write(f"$ {' '.join(cmd)}\n")
+ if proc.stdout:
+ f.write(proc.stdout + "\n")
+ if proc.stderr:
+ f.write(proc.stderr + "\n")
+ print(f"🧾 Run log saved to: {runlog}")
+ except Exception as e:
+ print(f"⚠️ Could not write run.log: {e}")
+
+ if proc.returncode != 0:
+ # Let Gradio see the failure so it surfaces properly
+ raise subprocess.CalledProcessError(proc.returncode, cmd, proc.stdout, proc.stderr)
+ return proc
+
+def _locate_pdf_json(temp_dir: str) -> str:
+ """
+ Your extractor writes a JSON like _comprehensive_data.json.
+ Find it (and a few common fallbacks). Raise if not found.
+ """
+ td = Path(temp_dir)
+
+ # Prefer exactly-named file if present
+ candidates = [
+ td / "pdf_data.json", # legacy name (if ever created)
+ td / "input_comprehensive_data.json", # most common from your logs
+ td / "comprehensive_data.json", # another common alias
+ td / "output.json", # generic
+ ]
+ for p in candidates:
+ if p.exists():
+ print(f"✅ Using PDF JSON: {p}")
+ return str(p)
+
+ # Generic pattern: anything *_comprehensive_data.json
+ globs = list(td.glob("*_comprehensive_data.json"))
+ if globs:
+ print(f"✅ Using PDF JSON (glob): {globs[0]}")
+ return str(globs[0])
+
+ # If still not found, surface a helpful error
+ searched = ", ".join(str(p) for p in candidates) + ", " + str(td / "*_comprehensive_data.json")
+ raise FileNotFoundError(
+ f"PDF JSON not found. Looked for: {searched}\nTemp dir: {temp_dir}"
+ )
+
+def process_files(pdf_file, word_file):
+ # Create a unique temporary directory for this run
+ temp_dir = tempfile.mkdtemp(prefix="hf_redtext_")
+ print(f"📂 Temp dir: {temp_dir}")
+
+ # Define standard filenames for use in the pipeline
+ pdf_path = os.path.join(temp_dir, "input.pdf")
+ word_path = os.path.join(temp_dir, "input.docx")
+ word_json_path = os.path.join(temp_dir, "word_data.json")
+ updated_json_path = os.path.join(temp_dir, "updated_word_data.json")
+ final_docx_path = os.path.join(temp_dir, "updated.docx")
+
+ # Copy the uploaded files to the temp directory
+ shutil.copy(pdf_file, pdf_path)
+ print(f"📄 PDF copied to: {pdf_path}")
+ shutil.copy(word_file, word_path)
+ print(f"📝 DOCX copied to: {word_path}")
+
+ # 1) PDF → JSON (extractor writes _comprehensive_data.json into cwd)
+ run_cmd(["python", str(SCRIPT_DIR / "extract_pdf_data.py"), pdf_path], cwd=temp_dir)
+
+ # Find the JSON produced by the extractor
+ pdf_json_path = _locate_pdf_json(temp_dir)
+
+ # 2) DOCX red text → JSON
+ run_cmd(["python", str(SCRIPT_DIR / "extract_red_text.py"), word_path, word_json_path], cwd=temp_dir)
+
+ # 3) Merge JSON (uses the resolved pdf_json_path)
+ run_cmd(["python", str(SCRIPT_DIR / "update_docx_with_pdf.py"), word_json_path, pdf_json_path, updated_json_path], cwd=temp_dir)
+
+ # 4) Apply updates to DOCX
+ run_cmd(["python", str(SCRIPT_DIR / "updated_word.py"), word_path, updated_json_path, final_docx_path], cwd=temp_dir)
+
+ # Return the final .docx file
+ return final_docx_path
+
+iface = gr.Interface(
+ fn=process_files,
+ inputs=[
+ gr.File(label="Upload PDF File", type="filepath"),
+ gr.File(label="Upload Word File", type="filepath")
+ ],
+ outputs=gr.File(label="Download Updated Word File"),
+ title="Red Text Replacer",
+ description="Upload a PDF and Word document. Red-colored text in the Word doc will be replaced by matching content from the PDF."
+)
+
+if __name__ == "__main__":
+ iface.launch()
\ No newline at end of file
diff --git a/space-pdf/extract_pdf_data.py b/space-pdf/extract_pdf_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f7e2bb1206fc97da217db01de9bb8c2b34f526
--- /dev/null
+++ b/space-pdf/extract_pdf_data.py
@@ -0,0 +1,534 @@
+#!/usr/bin/env python3
+"""
+Fixed PDF Data Extractor - Addresses key issues in comprehensive_extract.py
+
+Key fixes:
+1. Better table extraction and cleaning
+2. Improved key-value pair extraction
+3. More robust text processing
+4. Enhanced vehicle registration extraction
+5. Better date/number pattern recognition
+"""
+
+import json
+import re
+import pandas as pd
+from typing import Dict, List, Any, Optional
+import logging
+from pathlib import Path
+import sys
+from datetime import datetime
+
+try:
+ import pdfplumber
+ HAS_PDFPLUMBER = True
+except ImportError:
+ HAS_PDFPLUMBER = False
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger("fixed_pdf_extractor")
+
+class FixedPDFExtractor:
+ def __init__(self):
+ logger.info("🚀 Initializing Fixed PDF Extractor")
+
+ def extract_everything(self, pdf_path: str) -> Dict[str, Any]:
+ if not HAS_PDFPLUMBER:
+ raise RuntimeError("pdfplumber is required. Install with: pip install pdfplumber")
+
+ logger.info(f"📖 Processing PDF: {pdf_path}")
+ result = {
+ "document_info": {
+ "filename": Path(pdf_path).name,
+ "total_pages": 0,
+ "extraction_timestamp": datetime.now().isoformat()
+ },
+ "extracted_data": {
+ "all_text_content": [],
+ "all_tables": [],
+ "key_value_pairs": {},
+ "audit_information": {},
+ "operator_information": {},
+ "vehicle_registrations": [],
+ "driver_records": [],
+ "compliance_summary": {},
+ "dates_and_numbers": {}
+ }
+ }
+
+ all_text_blocks, all_tables = [], []
+
+ with pdfplumber.open(pdf_path) as pdf:
+ result["document_info"]["total_pages"] = len(pdf.pages)
+
+ for page_num, page in enumerate(pdf.pages, 1):
+ logger.info(f"📄 Processing page {page_num}")
+
+ # Extract text with better handling
+ page_text = self._extract_page_text(page)
+ if page_text:
+ all_text_blocks.append({
+ "page": page_num,
+ "text": page_text,
+ "word_count": len(page_text.split())
+ })
+
+ # Extract tables with improved cleaning
+ tables = self._extract_page_tables(page, page_num)
+ all_tables.extend(tables)
+
+ result["extracted_data"]["all_text_content"] = all_text_blocks
+ result["extracted_data"]["all_tables"] = all_tables
+
+ # Process extracted data with improved methods
+ combined_text = "\n\n".join(b["text"] for b in all_text_blocks)
+
+ result["extracted_data"]["key_value_pairs"] = self._extract_key_value_pairs_improved(combined_text)
+ result["extracted_data"]["audit_information"] = self._extract_audit_info(combined_text, all_tables)
+ result["extracted_data"]["operator_information"] = self._extract_operator_info(combined_text, all_tables)
+ result["extracted_data"]["vehicle_registrations"] = self._extract_vehicle_registrations(all_tables)
+ result["extracted_data"]["driver_records"] = self._extract_driver_records(all_tables)
+ result["extracted_data"]["compliance_summary"] = self._extract_compliance_summary(combined_text, all_tables)
+ result["extracted_data"]["dates_and_numbers"] = self._extract_dates_and_numbers_improved(combined_text)
+
+ # Generate summary
+ result["extraction_summary"] = {
+ "text_blocks_found": len(all_text_blocks),
+ "tables_found": len(all_tables),
+ "key_value_pairs_found": len(result["extracted_data"]["key_value_pairs"]),
+ "vehicle_registrations_found": len(result["extracted_data"]["vehicle_registrations"]),
+ "driver_records_found": len(result["extracted_data"]["driver_records"]),
+ "total_characters": len(combined_text),
+ "processing_timestamp": datetime.now().isoformat()
+ }
+
+ logger.info("✅ Extraction completed!")
+ return result
+
+ def _extract_page_text(self, page) -> Optional[str]:
+ """Extract text from page with better handling"""
+ try:
+ text = page.extract_text()
+ if text:
+ # Clean up text
+ text = re.sub(r'[ \t]+', ' ', text.strip())
+ text = re.sub(r'\n\s*\n', '\n', text)
+ return text
+ except Exception as e:
+ logger.warning(f"Failed to extract text from page: {e}")
+ return None
+
+ def _extract_page_tables(self, page, page_num: int) -> List[Dict]:
+ """Extract tables with improved processing"""
+ tables = []
+ try:
+ raw_tables = page.extract_tables()
+ if raw_tables:
+ for table_idx, table in enumerate(raw_tables):
+ cleaned_table = self._clean_table_improved(table)
+ if cleaned_table and len(cleaned_table) > 0:
+ tables.append({
+ "page": page_num,
+ "table_index": table_idx + 1,
+ "headers": cleaned_table[0] if cleaned_table else [],
+ "data": cleaned_table[1:] if len(cleaned_table) > 1 else [],
+ "raw_data": cleaned_table,
+ "row_count": len(cleaned_table) - 1 if len(cleaned_table) > 1 else 0,
+ "column_count": len(cleaned_table[0]) if cleaned_table else 0
+ })
+ except Exception as e:
+ logger.warning(f"Failed to extract tables from page {page_num}: {e}")
+
+ return tables
+
+ def _clean_table_improved(self, table: List[List]) -> List[List[str]]:
+ """Improved table cleaning with better cell processing"""
+ if not table:
+ return []
+
+ cleaned = []
+ for row in table:
+ cleaned_row = []
+ for cell in row:
+ if cell is None:
+ cleaned_cell = ""
+ else:
+ cleaned_cell = str(cell).strip()
+ cleaned_cell = re.sub(r'\s+', ' ', cleaned_cell)
+ cleaned_cell = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', cleaned_cell)
+ cleaned_row.append(cleaned_cell)
+ if any(cell.strip() for cell in cleaned_row):
+ cleaned.append(cleaned_row)
+
+ # Optional: collapse single-column tables of empty strings
+ if cleaned and all(len(r) == len(cleaned[0]) for r in cleaned):
+ return cleaned
+ return cleaned
+
+ def _extract_key_value_pairs_improved(self, text: str) -> Dict[str, str]:
+ """Improved key-value pair extraction with better cleaning"""
+ pairs: Dict[str, str] = {}
+
+ # Normalize text a bit for regex stability
+ t = text.replace('\r', '\n')
+
+ # Pattern 1: colon-separated pairs (key: value)
+ pattern1 = re.compile(
+ r'([A-Za-z][\w\s()/\-.]{2,80}?):\s*([^\n\r:][^\n\r]*)'
+ )
+ for key, val in pattern1.findall(t):
+ k = key.strip()
+ v = val.strip()
+ # Filter junk: very long values, pure separators, or obvious headers
+ if not v or len(v) > 200:
+ continue
+ if re.fullmatch(r'[-_/\.]+', v):
+ continue
+ # Avoid capturing the next key as value by trimming trailing key-like tokens
+ v = re.sub(r'\s+[A-Z][\w\s()/\-.]{2,40}:$', '', v).strip()
+ # Skip values that are just long digit runs (likely id lists without meaning)
+ if re.fullmatch(r'\d{6,}', v):
+ continue
+ pairs[k] = v
+
+ # Pattern 2: inline “Key – Value” or “Key — Value”
+ pattern2 = re.compile(r'([A-Za-z][\w\s()/\-.]{2,80}?)\s*[–—-]\s*([^\n\r]+)')
+ for key, val in pattern2.findall(t):
+ k = key.strip()
+ v = val.strip()
+ if v and len(v) <= 200 and not re.fullmatch(r'\d{6,}', v):
+ pairs.setdefault(k, v)
+
+ return pairs
+
+ def _extract_audit_info(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract audit-specific information with better filtering"""
+ audit_info: Dict[str, Any] = {}
+
+ # Prefer tables
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+ joined = ' '.join(headers)
+ if "audit information" in joined or "auditinformation" in joined:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2 and row[0] and row[1]:
+ key = str(row[0]).strip()
+ value = str(row[1]).strip()
+ # Skip numbered list rows (e.g., "1.", "2)")
+ if re.match(r'^\s*\d+\s*[.)]\s*$', key):
+ continue
+ if key and value:
+ audit_info[key] = value
+
+ # Backup from text
+ candidates = {
+ "Date of Audit": r'Date\s+of\s+Audit[:\s]*([^\n\r]+)',
+ "Location of audit": r'Location\s+of\s+audit[:\s]*([^\n\r]+)',
+ "Auditor name": r'Auditor\s+name[:\s]*([^\n\r]+)',
+ "Audit Matrix Identifier (Name or Number)": r'Audit\s+Matrix\s+Identifier.*?[:\s]*([^\n\r]+)',
+ }
+ for k, pat in candidates.items():
+ if k not in audit_info:
+ m = re.search(pat, text, re.IGNORECASE)
+ if m:
+ audit_info[k] = m.group(1).strip()
+
+ return audit_info
+
+ def _extract_operator_info(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract operator information with better table parsing"""
+ operator_info: Dict[str, Any] = {}
+
+ # Look for operator information in tables first
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+ if ("operatorinformation" in ' '.join(headers) or
+ "operator information" in ' '.join(headers) or
+ "operatorcontactdetails" in ' '.join(headers)):
+
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2 and row[0] and row[1]:
+ key = str(row[0]).strip()
+ value = str(row[1]).strip()
+ if key and value:
+ # Clean up key names
+ kl = key.lower()
+ if "operator name" in kl:
+ operator_info["operator_name"] = value
+ elif "trading name" in kl:
+ operator_info["trading_name"] = value
+ elif "company number" in kl:
+ if len(row) > 2:
+ company_parts = [str(r).strip() for r in row[1:] if str(r).strip()]
+ operator_info["company_number"] = "".join(company_parts)
+ else:
+ operator_info["company_number"] = value
+ elif "business address" in kl:
+ operator_info["business_address"] = value
+ elif "postal address" in kl:
+ operator_info["postal_address"] = value
+ elif "email" in kl:
+ operator_info["email"] = value
+ elif "telephone" in kl or "phone" in kl:
+ operator_info["phone"] = value
+ elif "nhvas accreditation" in kl:
+ operator_info["nhvas_accreditation"] = value
+ elif "nhvas manual" in kl:
+ operator_info["nhvas_manual"] = value
+
+ # Extract from text patterns as backup
+ patterns = {
+ 'operator_name': r'Operator\s*name[:\s\(]*([^\n\r\)]+?)(?=\s*NHVAS|\s*Registered|$)',
+ 'trading_name': r'Registered\s*trading\s*name[:\s\/]*([^\n\r]+?)(?=\s*Australian|$)',
+ 'company_number': r'Australian\s*Company\s*Number[:\s]*([0-9\s]+?)(?=\s*NHVAS|$)',
+ 'business_address': r'Operator\s*business\s*address[:\s]*([^\n\r]+?)(?=\s*Operator\s*Postal|$)',
+ 'postal_address': r'Operator\s*Postal\s*address[:\s]*([^\n\r]+?)(?=\s*Email|$)',
+ 'email': r'Email\s*address[:\s]*([^\s\n\r]+)',
+ 'phone': r'Operator\s*Telephone\s*Number[:\s]*([^\s\n\r]+)',
+ 'nhvas_accreditation': r'NHVAS\s*Accreditation\s*No\.[:\s\(]*([^\n\r\)]+)',
+ }
+
+ for key, pattern in patterns.items():
+ if key not in operator_info: # Only use text if not found in tables
+ match = re.search(pattern, text, re.IGNORECASE)
+ if match:
+ value = match.group(1).strip()
+ if value and len(value) < 200:
+ if key == 'company_number':
+ value = re.sub(r'\s+', '', value)
+ operator_info[key] = value
+
+ return operator_info
+
+ def _extract_vehicle_registrations(self, tables: List[Dict]) -> List[Dict]:
+ """Extract vehicle registration information from tables"""
+ vehicles: List[Dict[str, Any]] = []
+
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ # Look for vehicle registration tables
+ if any(keyword in ' '.join(headers) for keyword in ['registration', 'vehicle', 'number']):
+ reg_col = None
+ for i, header in enumerate(headers):
+ if 'registration' in header and 'number' in header:
+ reg_col = i
+ break
+
+ if reg_col is not None:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) > reg_col and row[reg_col]:
+ reg_num = str(row[reg_col]).strip()
+ # Validate registration format (letters/numbers)
+ if re.match(r'^[A-Z]{1,3}\s*\d{1,3}\s*[A-Z]{0,3}$', reg_num):
+ vehicle_info = {"registration_number": reg_num}
+
+ # Add other columns as additional info
+ for i, header in enumerate(table.get("headers", [])):
+ if i < len(row) and i != reg_col:
+ vehicle_info[str(header)] = str(row[i]).strip()
+
+ vehicles.append(vehicle_info)
+
+ return vehicles
+
+ def _extract_driver_records(self, tables: List[Dict]) -> List[Dict]:
+ """Extract driver records from tables"""
+ drivers: List[Dict[str, Any]] = []
+
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ # Look for driver/scheduler tables
+ if any(keyword in ' '.join(headers) for keyword in ['driver', 'scheduler', 'name']):
+ name_col = None
+ for i, header in enumerate(headers):
+ if 'name' in header:
+ name_col = i
+ break
+
+ if name_col is not None:
+ data = table.get("data", [])
+ for row in data:
+ if len(row) > name_col and row[name_col]:
+ name = str(row[name_col]).strip()
+ # Basic name validation
+ if re.match(r'^[A-Za-z\s]{2,}$', name) and len(name.split()) >= 2:
+ driver_info = {"name": name}
+
+ # Add other columns
+ for i, header in enumerate(table.get("headers", [])):
+ if i < len(row) and i != name_col:
+ driver_info[str(header)] = str(row[i]).strip()
+
+ drivers.append(driver_info)
+
+ return drivers
+
+ def _extract_compliance_summary(self, text: str, tables: List[Dict]) -> Dict[str, Any]:
+ """Extract compliance information"""
+ compliance = {
+ "standards_compliance": {},
+ "compliance_codes": {},
+ "audit_results": []
+ }
+
+ # Look for compliance tables
+ for table in tables:
+ headers = [str(h).lower() for h in table.get("headers", [])]
+
+ if any(keyword in ' '.join(headers) for keyword in ['compliance', 'standard', 'requirement']):
+ data = table.get("data", [])
+ for row in data:
+ if len(row) >= 2:
+ standard = str(row[0]).strip()
+ code = str(row[1]).strip()
+ if standard.startswith('Std') and code in ['V', 'NC', 'SFI', 'NAP', 'NA']:
+ compliance["standards_compliance"][standard] = code
+
+ # Extract compliance codes definitions
+ code_patterns = {
+ 'V': r'\bV\b\s+([^\n\r]+)',
+ 'NC': r'\bNC\b\s+([^\n\r]+)',
+ 'SFI': r'\bSFI\b\s+([^\n\r]+)',
+ 'NAP': r'\bNAP\b\s+([^\n\r]+)',
+ 'NA': r'\bNA\b\s+([^\n\r]+)',
+ }
+
+ for code, pattern in code_patterns.items():
+ match = re.search(pattern, text, re.IGNORECASE)
+ if match:
+ compliance["compliance_codes"][code] = match.group(1).strip()
+
+ return compliance
+
+ def _extract_dates_and_numbers_improved(self, text: str) -> Dict[str, Any]:
+ """Improved date and number extraction"""
+ result = {
+ "dates": [],
+ "registration_numbers": [],
+ "phone_numbers": [],
+ "email_addresses": [],
+ "reference_numbers": []
+ }
+
+ # Date patterns
+ date_patterns = [
+ r'\b(\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})\b',
+ r'\b(\d{1,2}/\d{1,2}/\d{4})\b',
+ r'\b(\d{1,2}-\d{1,2}-\d{4})\b',
+ r'\b(\d{1,2}\.\d{1,2}\.\d{4})\b',
+ ]
+ for pattern in date_patterns:
+ result["dates"].extend(re.findall(pattern, text))
+
+ # Registration numbers (Australian format-ish)
+ reg_pattern = r'\b([A-Z]{1,3}\s*\d{1,3}\s*[A-Z]{0,3})\b'
+ result["registration_numbers"] = list(set(re.findall(reg_pattern, text)))
+
+ # Phone numbers (AU)
+ phone_pattern = r'\b((?:\+61|0)[2-9]\s?\d{4}\s?\d{4})\b'
+ result["phone_numbers"] = list(set(re.findall(phone_pattern, text)))
+
+ # Email addresses
+ email_pattern = r'\b([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})\b'
+ result["email_addresses"] = list(set(re.findall(email_pattern, text)))
+
+ # Reference numbers
+ ref_patterns = [
+ (r'RF(?:S)?\s*#?\s*(\d+)', 'RFS_Certifications'),
+ (r'NHVAS\s+Accreditation\s+No\.?\s*(\d+)', 'NHVAS_Numbers'),
+ (r'Registration\s+Number\s*#?\s*(\d+)', 'Registration_Numbers'),
+ ]
+ for pattern, key in ref_patterns:
+ matches = re.findall(pattern, text, re.IGNORECASE)
+ if matches:
+ result["reference_numbers"].extend([f"{key}: {m}" for m in matches])
+
+ return result
+
+ @staticmethod
+ def save_results(results: Dict[str, Any], output_path: str):
+ """Save results to JSON file"""
+ try:
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+ logger.info(f"💾 Results saved to {output_path}")
+ except Exception as e:
+ logger.error(f"Failed to save results: {e}")
+
+ @staticmethod
+ def export_to_excel(results: Dict[str, Any], excel_path: str):
+ """Export results to Excel with improved formatting"""
+ try:
+ with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
+ # Summary sheet
+ summary_data = []
+ extraction_summary = results.get("extraction_summary", {})
+ for key, value in extraction_summary.items():
+ summary_data.append({"Metric": key.replace("_", " ").title(), "Value": value})
+ pd.DataFrame(summary_data).to_excel(writer, sheet_name='Summary', index=False)
+
+ # Key-value pairs
+ kv_pairs = results.get("extracted_data", {}).get("key_value_pairs", {})
+ if kv_pairs:
+ kv_df = pd.DataFrame(list(kv_pairs.items()), columns=['Key', 'Value'])
+ kv_df.to_excel(writer, sheet_name='Key_Value_Pairs', index=False)
+
+ # Vehicle registrations
+ vehicles = results.get("extracted_data", {}).get("vehicle_registrations", [])
+ if vehicles:
+ pd.DataFrame(vehicles).to_excel(writer, sheet_name='Vehicle_Registrations', index=False)
+
+ # Driver records
+ drivers = results.get("extracted_data", {}).get("driver_records", [])
+ if drivers:
+ pd.DataFrame(drivers).to_excel(writer, sheet_name='Driver_Records', index=False)
+
+ # Compliance summary
+ compliance = results.get("extracted_data", {}).get("compliance_summary", {})
+ if compliance.get("standards_compliance"):
+ comp_df = pd.DataFrame(list(compliance["standards_compliance"].items()),
+ columns=['Standard', 'Compliance_Code'])
+ comp_df.to_excel(writer, sheet_name='Compliance_Standards', index=False)
+
+ logger.info(f"📊 Results exported to Excel: {excel_path}")
+ except Exception as e:
+ logger.error(f"Failed to export to Excel: {e}")
+
+def main():
+ if len(sys.argv) < 2:
+ print("Usage: python fixed_pdf_extractor.py ")
+ sys.exit(1)
+
+ pdf_path = Path(sys.argv[1])
+ if not pdf_path.exists():
+ print(f"❌ PDF not found: {pdf_path}")
+ sys.exit(1)
+
+ print("🚀 Fixed PDF Data Extractor")
+ print("=" * 50)
+
+ extractor = FixedPDFExtractor()
+ results = extractor.extract_everything(str(pdf_path))
+
+ base = pdf_path.stem
+ output_dir = pdf_path.parent
+
+ # Save outputs
+ json_path = output_dir / f"{base}_comprehensive_data.json"
+ excel_path = output_dir / f"{base}_fixed_extraction.xlsx"
+
+ FixedPDFExtractor.save_results(results, str(json_path))
+ FixedPDFExtractor.export_to_excel(results, str(excel_path))
+
+ print("\n💾 OUTPUT FILES:")
+ print(f" 📄 JSON Data: {json_path}")
+ print(f" 📊 Excel Data: {excel_path}")
+ print(f"\n✨ FIXED EXTRACTION COMPLETE!")
+
+if __name__ == "__main__":
+ main()
diff --git a/space-pdf/extract_red_text.py b/space-pdf/extract_red_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..584b7876db68ef237dcc238cd8145d50b942e541
--- /dev/null
+++ b/space-pdf/extract_red_text.py
@@ -0,0 +1,764 @@
+#!/usr/bin/env python3
+import re
+import json
+import sys
+from docx import Document
+from docx.oxml.ns import qn
+from master_key import TABLE_SCHEMAS, HEADING_PATTERNS, PARAGRAPH_PATTERNS
+
+def normalize_header_label(s: str) -> str:
+ """Normalize a header/label by stripping parentheticals & punctuation."""
+ s = re.sub(r"\s+", " ", s.strip())
+ # remove content in parentheses/brackets
+ s = re.sub(r"\([^)]*\)", "", s)
+ s = re.sub(r"\[[^]]*\]", "", s)
+ # unify slashes and hyphens, collapse spaces
+ s = s.replace("–", "-").replace("—", "-").replace("/", " / ").replace(" ", " ")
+ return s.strip()
+
+# Canonical label aliases for Vehicle/Maintenance/General headers
+LABEL_ALIASES = {
+ # Vehicle Registration (Maintenance)
+ "roadworthiness certificates": "Roadworthiness Certificates",
+ "maintenance records": "Maintenance Records",
+ "daily checks": "Daily Checks",
+ "fault recording / reporting": "Fault Recording/ Reporting",
+ "fault repair": "Fault Repair",
+
+ # Vehicle Registration (Mass)
+ "sub contracted vehicles statement of compliance": "Sub-contracted Vehicles Statement of Compliance",
+ "weight verification records": "Weight Verification Records",
+ "rfs suspension certification #": "RFS Suspension Certification #",
+ "suspension system maintenance": "Suspension System Maintenance",
+ "trip records": "Trip Records",
+ "fault recording/ reporting on suspension system": "Fault Recording/ Reporting on Suspension System",
+
+ # Common
+ "registration number": "Registration Number",
+ "no.": "No.",
+ "sub contractor": "Sub contractor",
+ "sub-contractor": "Sub contractor",
+}
+
+def looks_like_operator_declaration(context):
+ """True iff heading says Operator Declaration and headers include Print Name + Position Title."""
+ heading = (context.get("heading") or "").strip().lower()
+ headers = " ".join(context.get("headers") or []).lower()
+ return (
+ "operator declaration" in heading
+ and "print name" in headers
+ and "position" in headers
+ and "title" in headers
+ )
+
+def looks_like_auditor_declaration(context):
+ heading = (context.get("heading") or "").strip().lower()
+ headers = " ".join(context.get("headers") or []).lower()
+ return (
+ "auditor declaration" in heading
+ and "print name" in headers
+ and ("nhvr" in headers or "auditor registration number" in headers)
+ )
+
+# --- NEW: header-only fallback that ignores headings and just keys on the two column names
+def extract_operator_declaration_by_headers_from_end(doc):
+ """
+ Scan tables from the end; if a table's first row contains both
+ 'Print Name' AND 'Position Title' (case-insensitive), extract red text
+ from the data rows into:
+ {"Print Name": [...], "Position Title": [...]}
+ """
+ for tbl in reversed(doc.tables):
+ if len(tbl.rows) < 2:
+ continue # need header + at least one data row
+
+ headers_norm = [normalize_header_label(c.text).lower() for c in tbl.rows[0].cells]
+ has_print = any("print name" in h for h in headers_norm)
+ has_pos_tit = any(("position title" in h) or ("position" in h and "title" in h) for h in headers_norm)
+ if not (has_print and has_pos_tit):
+ continue
+
+ idx_print = next((i for i, h in enumerate(headers_norm) if "print name" in h), None)
+ idx_pos = next((i for i, h in enumerate(headers_norm) if "position title" in h), None)
+ if idx_pos is None:
+ idx_pos = next((i for i, h in enumerate(headers_norm) if ("position" in h and "title" in h)), None)
+
+ result = {"Print Name": [], "Position Title": []}
+ for row in tbl.rows[1:]:
+ if idx_print is not None and idx_print < len(row.cells):
+ cell = row.cells[idx_print]
+ reds = [r.text for p in cell.paragraphs for r in p.runs if is_red_font(r) and r.text]
+ reds = coalesce_numeric_runs(reds)
+ txt = normalize_text(" ".join(reds))
+ if txt:
+ result["Print Name"].append(txt)
+
+ if idx_pos is not None and idx_pos < len(row.cells):
+ cell = row.cells[idx_pos]
+ reds = [r.text for p in cell.paragraphs for r in p.runs if is_red_font(r) and r.text]
+ reds = coalesce_numeric_runs(reds)
+ txt = normalize_text(" ".join(reds))
+ if txt:
+ result["Position Title"].append(txt)
+
+ if result["Print Name"] or result["Position Title"]:
+ return {k: v for k, v in result.items() if v}
+
+ return None
+# --- end NEW helper
+
+def canonicalize_label(s: str) -> str:
+ key = normalize_header_label(s).lower()
+ key = re.sub(r"\s+", " ", key)
+ return LABEL_ALIASES.get(key, s)
+
+def bag_similarity(a: str, b: str) -> float:
+ """Loose bag-of-words similarity for header↔label matching."""
+ aw = {w for w in re.split(r"[^A-Za-z0-9#]+", normalize_header_label(a).lower()) if len(w) > 2 or w in {"#","no"}}
+ bw = {w for w in re.split(r"[^A-Za-z0-9#]+", normalize_header_label(b).lower()) if len(w) > 2 or w in {"#","no"}}
+ if not aw or not bw:
+ return 0.0
+ inter = len(aw & bw)
+ return inter / max(len(aw), len(bw))
+
+def coalesce_numeric_runs(text_list):
+ """
+ If a cell yields ['4','5','6','9','8','7','1','2','3'] etc., join continuous single-char digit runs.
+ Returns ['456987123'] instead of many singles. Non-digit tokens are preserved.
+ """
+ out, buf = [], []
+ for t in text_list:
+ if len(t) == 1 and t.isdigit():
+ buf.append(t)
+ else:
+ if buf:
+ out.append("".join(buf))
+ buf = []
+ out.append(t)
+ if buf:
+ out.append("".join(buf))
+ return out
+
+def is_red_font(run):
+ """Enhanced red font detection with better color checking"""
+ col = run.font.color
+ if col and col.rgb:
+ r, g, b = col.rgb
+ if r > 150 and g < 100 and b < 100 and (r-g) > 30 and (r-b) > 30:
+ return True
+ rPr = getattr(run._element, "rPr", None)
+ if rPr is not None:
+ clr = rPr.find(qn('w:color'))
+ if clr is not None:
+ val = clr.get(qn('w:val'))
+ if val and re.fullmatch(r"[0-9A-Fa-f]{6}", val):
+ rr, gg, bb = int(val[:2], 16), int(val[2:4], 16), int(val[4:], 16)
+ if rr > 150 and gg < 100 and bb < 100 and (rr-gg) > 30 and (rr-bb) > 30:
+ return True
+ return False
+
+def _prev_para_text(tbl):
+ """Get text from previous paragraph before table"""
+ prev = tbl._tbl.getprevious()
+ while prev is not None and not prev.tag.endswith("}p"):
+ prev = prev.getprevious()
+ if prev is None:
+ return ""
+ return "".join(node.text for node in prev.iter() if node.tag.endswith("}t") and node.text).strip()
+
+def normalize_text(text):
+ """Normalize text for better matching"""
+ return re.sub(r'\s+', ' ', text.strip())
+
+def fuzzy_match_heading(heading, patterns):
+ """Check if heading matches any pattern with fuzzy matching"""
+ heading_norm = normalize_text(heading.upper())
+ for pattern in patterns:
+ if re.search(pattern, heading_norm, re.IGNORECASE):
+ return True
+ return False
+
+def get_table_context(tbl):
+ """Get comprehensive context information for table"""
+ heading = normalize_text(_prev_para_text(tbl))
+ headers = [normalize_text(c.text) for c in tbl.rows[0].cells if c.text.strip()]
+ col0 = [normalize_text(r.cells[0].text) for r in tbl.rows if r.cells[0].text.strip()]
+ first_cell = normalize_text(tbl.rows[0].cells[0].text) if tbl.rows else ""
+ all_cells = []
+ for row in tbl.rows:
+ for cell in row.cells:
+ text = normalize_text(cell.text)
+ if text:
+ all_cells.append(text)
+ return {
+ 'heading': heading,
+ 'headers': headers,
+ 'col0': col0,
+ 'first_cell': first_cell,
+ 'all_cells': all_cells,
+ 'num_rows': len(tbl.rows),
+ 'num_cols': len(tbl.rows[0].cells) if tbl.rows else 0
+ }
+
+def calculate_schema_match_score(schema_name, spec, context):
+ """Enhanced calculate match score - IMPROVED for Vehicle Registration tables"""
+ score = 0
+ reasons = []
+
+ # 🎯 VEHICLE REGISTRATION BOOST
+ if "Vehicle Registration" in schema_name:
+ vehicle_keywords = ["registration", "vehicle", "sub-contractor", "weight verification", "rfs suspension"]
+ table_text = " ".join(context['headers']).lower() + " " + context['heading'].lower()
+ keyword_matches = sum(1 for keyword in vehicle_keywords if keyword in table_text)
+ if keyword_matches >= 2:
+ score += 150 # Very high boost for vehicle tables
+ reasons.append(f"Vehicle Registration keywords: {keyword_matches}/5")
+ elif keyword_matches >= 1:
+ score += 75 # Medium boost
+ reasons.append(f"Some Vehicle Registration keywords: {keyword_matches}/5")
+
+ # 🎯 SUMMARY TABLE BOOST (existing logic)
+ if "Summary" in schema_name and "details" in " ".join(context['headers']).lower():
+ score += 100
+ reasons.append(f"Summary schema with DETAILS column - perfect match")
+
+ if "Summary" not in schema_name and "details" in " ".join(context['headers']).lower():
+ score -= 75
+ reasons.append(f"Non-summary schema penalized for DETAILS column presence")
+
+ # Context exclusions
+ if spec.get("context_exclusions"):
+ table_text = " ".join(context['headers']).lower() + " " + context['heading'].lower()
+ for exclusion in spec["context_exclusions"]:
+ if exclusion.lower() in table_text:
+ score -= 50
+ reasons.append(f"Context exclusion penalty: '{exclusion}' found")
+
+ # Context keywords
+ if spec.get("context_keywords"):
+ table_text = " ".join(context['headers']).lower() + " " + context['heading'].lower()
+ keyword_matches = 0
+ for keyword in spec["context_keywords"]:
+ if keyword.lower() in table_text:
+ keyword_matches += 1
+
+ if keyword_matches > 0:
+ score += keyword_matches * 15
+ reasons.append(f"Context keyword matches: {keyword_matches}/{len(spec['context_keywords'])}")
+
+ # Direct first cell match
+ if context['first_cell'] and context['first_cell'].upper() == schema_name.upper():
+ score += 100
+ reasons.append(f"Direct first cell match: '{context['first_cell']}'")
+
+ # Heading pattern matching
+ if spec.get("headings"):
+ for h in spec["headings"]:
+ if fuzzy_match_heading(context['heading'], [h["text"]]):
+ score += 50
+ reasons.append(f"Heading match: '{context['heading']}'")
+ break
+
+ # Column header matching
+ if spec.get("columns"):
+ cols = [normalize_text(col) for col in spec["columns"]]
+ matches = 0
+ for col in cols:
+ if any(col.upper() in h.upper() for h in context['headers']):
+ matches += 1
+ if matches == len(cols):
+ score += 60
+ reasons.append(f"All column headers match: {cols}")
+ elif matches > 0:
+ score += matches * 20
+ reasons.append(f"Partial column matches: {matches}/{len(cols)}")
+
+ # Label matching for left-oriented tables
+ if spec.get("orientation") == "left":
+ labels = [normalize_text(lbl) for lbl in spec["labels"]]
+ matches = 0
+ for lbl in labels:
+ if any(lbl.upper() in c.upper() or c.upper() in lbl.upper() for c in context['col0']):
+ matches += 1
+ if matches > 0:
+ score += (matches / len(labels)) * 30
+ reasons.append(f"Left orientation label matches: {matches}/{len(labels)}")
+
+ # 🎯 ENHANCED Label matching for row1-oriented tables (Vehicle Registration)
+ elif spec.get("orientation") == "row1":
+ labels = [normalize_text(lbl) for lbl in spec["labels"]]
+ matches = 0
+ for lbl in labels:
+ if any(lbl.upper() in h.upper() or h.upper() in lbl.upper() for h in context['headers']):
+ matches += 1
+ elif any(word.upper() in " ".join(context['headers']).upper() for word in lbl.split() if len(word) > 3):
+ matches += 0.5 # Partial credit
+ if matches > 0:
+ score += (matches / len(labels)) * 40
+ reasons.append(f"Row1 orientation header matches: {matches}/{len(labels)}")
+
+ # Special handling for Declaration tables (existing logic)
+ if schema_name == "Operator Declaration" and context['first_cell'].upper() == "PRINT NAME":
+ if "OPERATOR DECLARATION" in context['heading'].upper():
+ score += 80
+ reasons.append("Operator Declaration context match")
+ elif any("MANAGER" in cell.upper() for cell in context['all_cells']):
+ score += 60
+ reasons.append("Manager found in cells (likely Operator Declaration)")
+
+ if schema_name == "NHVAS Approved Auditor Declaration" and context['first_cell'].upper() == "PRINT NAME":
+ if any("MANAGER" in cell.upper() for cell in context['all_cells']):
+ score -= 50
+ reasons.append("Penalty: Manager found (not auditor)")
+
+ return score, reasons
+
+def match_table_schema(tbl):
+ """Improved table schema matching with scoring system"""
+ context = get_table_context(tbl)
+ # Auditor Declaration first
+ if ("print name" in " ".join(context.get("headers", [])).lower() and
+ "auditor" in " ".join(context.get("headers", [])).lower()):
+ return "NHVAS Approved Auditor Declaration"
+ # NEW: prioritize Auditor Declaration to avoid misclassification
+ if looks_like_auditor_declaration(context):
+ return "NHVAS Approved Auditor Declaration"
+ # hard-match Operator Declaration first (high priority, avoids misclassification)
+ if looks_like_operator_declaration(context):
+ return "Operator Declaration"
+ best_match = None
+ best_score = 0
+ for name, spec in TABLE_SCHEMAS.items():
+ score, reasons = calculate_schema_match_score(name, spec, context)
+ if score > best_score:
+ best_score = score
+ best_match = name
+ if best_score >= 20:
+ return best_match
+ return None
+
+def check_multi_schema_table(tbl):
+ """Check if table contains multiple schemas and split appropriately"""
+ context = get_table_context(tbl)
+ operator_labels = ["Operator name (Legal entity)", "NHVAS Accreditation No.", "Registered trading name/s",
+ "Australian Company Number", "NHVAS Manual"]
+ contact_labels = ["Operator business address", "Operator Postal address", "Email address", "Operator Telephone Number"]
+ has_operator = any(any(op_lbl.upper() in cell.upper() for op_lbl in operator_labels) for cell in context['col0'])
+ has_contact = any(any(cont_lbl.upper() in cell.upper() for cont_lbl in contact_labels) for cell in context['col0'])
+ if has_operator and has_contact:
+ return ["Operator Information", "Operator contact details"]
+ return None
+
+def extract_multi_schema_table(tbl, schemas):
+ """Extract data from table with multiple schemas"""
+ result = {}
+ for schema_name in schemas:
+ if schema_name not in TABLE_SCHEMAS:
+ continue
+ spec = TABLE_SCHEMAS[schema_name]
+ schema_data = {}
+ for ri, row in enumerate(tbl.rows):
+ if ri == 0:
+ continue
+ row_label = normalize_text(row.cells[0].text)
+ belongs_to_schema = False
+ matched_label = None
+ for spec_label in spec["labels"]:
+ spec_norm = normalize_text(spec_label).upper()
+ row_norm = row_label.upper()
+ if spec_norm == row_norm or spec_norm in row_norm or row_norm in spec_norm:
+ belongs_to_schema = True
+ matched_label = spec_label
+ break
+ if not belongs_to_schema:
+ continue
+ for ci, cell in enumerate(row.cells):
+ red_txt = "".join(run.text for p in cell.paragraphs for run in p.runs if is_red_font(run)).strip()
+ if red_txt:
+ if matched_label not in schema_data:
+ schema_data[matched_label] = []
+ if red_txt not in schema_data[matched_label]:
+ schema_data[matched_label].append(red_txt)
+ if schema_data:
+ result[schema_name] = schema_data
+ return result
+
+def extract_table_data(tbl, schema_name, spec):
+ """Extract red text data from table based on schema – per-row repeats for specific tables."""
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # OPERATOR DECLARATION (row1 headers: Print Name | Position Title)
+ # ───────────────────────────────────────────────────────────────────────────
+ if schema_name == "Operator Declaration":
+ print(f" 🧾 EXTRACTION FIX: Processing Operator Declaration table")
+
+ labels = spec["labels"] # ["Print Name", "Position Title"]
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels}
+
+ if len(tbl.rows) < 2:
+ print(f" ❌ Operator Declaration table has less than 2 rows")
+ return {}
+
+ # map header cells → labels (row1 orientation)
+ header_row = tbl.rows[0]
+ column_mapping = {}
+ print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
+
+ for col_idx, cell in enumerate(header_row.cells):
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
+ if not header_text:
+ continue
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # alias/canonical first
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # else bag-of-words similarity
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
+ else:
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+
+ print(f" 📊 Total column mappings: {len(column_mapping)}")
+
+ # collect red text from the (usually single) data row
+ for row_idx in range(1, len(tbl.rows)):
+ row = tbl.rows[row_idx]
+ print(f" 📌 Processing data row {row_idx}")
+ for col_idx, cell in enumerate(row.cells):
+ if col_idx not in column_mapping:
+ continue
+ label = column_mapping[col_idx]
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt)
+
+ result = {k: v for k, v in collected.items() if v}
+ print(f" ✅ Operator Declaration extracted: {len(result)} columns with data")
+ return result
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # A) Vehicle Registration tables (per-row accumulation; NO dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
+ if "Vehicle Registration" in schema_name:
+ print(f" 🚗 EXTRACTION FIX: Processing Vehicle Registration table")
+
+ labels = spec["labels"]
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels} # ← keep every row value
+ unmapped_bucket = {}
+
+ if len(tbl.rows) < 2:
+ print(f" ❌ Vehicle table has less than 2 rows")
+ return {}
+
+ header_row = tbl.rows[0]
+ column_mapping = {}
+ print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
+
+ for col_idx, cell in enumerate(header_row.cells):
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
+ if not header_text:
+ continue
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # Try alias/canonical first
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # Else bag-of-words similarity
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
+ else:
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+ unmapped_bucket[raw_h] = []
+
+ print(f" 📊 Total column mappings: {len(column_mapping)}")
+
+ header_texts = [normalize_text(hc.text) for hc in header_row.cells]
+ for row_idx in range(1, len(tbl.rows)):
+ row = tbl.rows[row_idx]
+ print(f" 📌 Processing data row {row_idx}")
+ for col_idx, cell in enumerate(row.cells):
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+
+ if col_idx in column_mapping:
+ label = column_mapping[col_idx]
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt) # ← append every occurrence
+ else:
+ header_name = header_texts[col_idx] if col_idx < len(header_texts) else f"(unmapped col {col_idx})"
+ unmapped_bucket.setdefault(header_name, []).append(red_txt)
+
+ result = {k: v for k, v in collected.items() if v}
+ if unmapped_bucket:
+ result.update({f"UNMAPPED::{k}": v for k, v in unmapped_bucket.items() if v})
+ print(f" ✅ Vehicle Registration extracted: {len(result)} columns with data")
+ return result
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # B) Driver / Scheduler Records Examined (per-row accumulation; NO dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
+ if "Driver / Scheduler" in schema_name:
+ print(f" 👤 EXTRACTION FIX: Processing Driver / Scheduler table")
+
+ labels = spec["labels"]
+ canonical_labels = {canonicalize_label(lbl): lbl for lbl in labels}
+
+ collected = {lbl: [] for lbl in labels} # ← keep every row value
+ unmapped_bucket = {}
+
+ if len(tbl.rows) < 2:
+ print(f" ❌ Driver/Scheduler table has less than 2 rows")
+ return {}
+
+ header_row = tbl.rows[0]
+ column_mapping = {}
+ print(f" 📋 Mapping {len(header_row.cells)} header cells to labels")
+
+ for col_idx, cell in enumerate(header_row.cells):
+ raw_h = normalize_text(cell.text)
+ header_text = normalize_header_label(raw_h)
+ if not header_text:
+ continue
+ print(f" Column {col_idx}: '{raw_h}'")
+
+ # Try alias/canonical first (rarely used here, but safe)
+ canon = canonicalize_label(header_text)
+ if canon in canonical_labels:
+ best_label = canonical_labels[canon]
+ print(f" ✅ Mapped to: '{best_label}' (alias)")
+ column_mapping[col_idx] = best_label
+ continue
+
+ # Else bag-of-words similarity (good for long headings)
+ best_label, best_score = None, 0.0
+ for canon_lab, original_lab in canonical_labels.items():
+ s = bag_similarity(header_text, canon_lab)
+ if s > best_score:
+ best_score, best_label = s, original_lab
+
+ if best_label and best_score >= 0.40:
+ print(f" ✅ Mapped to: '{best_label}' (score: {best_score:.2f})")
+ column_mapping[col_idx] = best_label
+ else:
+ print(f" ⚠️ No mapping found for '{raw_h}'")
+ unmapped_bucket[raw_h] = []
+
+ print(f" 📊 Total column mappings: {len(column_mapping)}")
+
+ header_texts = [normalize_text(hc.text) for hc in header_row.cells]
+ for row_idx in range(1, len(tbl.rows)):
+ row = tbl.rows[row_idx]
+ print(f" 📌 Processing data row {row_idx}")
+ for col_idx, cell in enumerate(row.cells):
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+
+ if col_idx in column_mapping:
+ label = column_mapping[col_idx]
+ print(f" 🔴 Found red text in '{label}': '{red_txt}'")
+ collected[label].append(red_txt) # ← append every occurrence
+ else:
+ header_name = header_texts[col_idx] if col_idx < len(header_texts) else f"(unmapped col {col_idx})"
+ unmapped_bucket.setdefault(header_name, []).append(red_txt)
+
+ result = {k: v for k, v in collected.items() if v}
+ if unmapped_bucket:
+ result.update({f"UNMAPPED::{k}": v for k, v in unmapped_bucket.items() if v})
+ print(f" ✅ Driver / Scheduler extracted: {len(result)} columns with data")
+ return result
+
+ # ───────────────────────────────────────────────────────────────────────────
+ # C) Generic tables (unchanged: WITH dedupe)
+ # ───────────────────────────────────────────────────────────────────────────
+ labels = spec["labels"] + [schema_name]
+ collected = {lbl: [] for lbl in labels}
+ seen = {lbl: set() for lbl in labels}
+ by_col = (spec.get("orientation") == "row1")
+ start_row = 1 if by_col else 0
+ rows = tbl.rows[start_row:]
+
+ for ri, row in enumerate(rows):
+ for ci, cell in enumerate(row.cells):
+ reds = [run.text for p in cell.paragraphs for run in p.runs if is_red_font(run) and run.text]
+ if not reds:
+ continue
+ reds = coalesce_numeric_runs(reds)
+ red_txt = normalize_text(" ".join(reds))
+ if not red_txt:
+ continue
+
+ if by_col:
+ if ci < len(spec["labels"]):
+ lbl = spec["labels"][ci]
+ else:
+ lbl = schema_name
+ else:
+ raw_label = normalize_text(row.cells[0].text)
+ lbl = None
+ for spec_label in spec["labels"]:
+ if normalize_text(spec_label).upper() == raw_label.upper():
+ lbl = spec_label
+ break
+ if not lbl:
+ a_raw = normalize_header_label(raw_label).upper()
+ for spec_label in spec["labels"]:
+ a_spec = normalize_header_label(spec_label).upper()
+ if a_spec in a_raw or a_raw in a_spec:
+ lbl = spec_label
+ break
+ if not lbl:
+ lbl = schema_name
+
+ if red_txt not in seen[lbl]:
+ seen[lbl].add(red_txt)
+ collected[lbl].append(red_txt)
+
+ return {k: v for k, v in collected.items() if v}
+
+def extract_red_text(input_doc):
+ # input_doc: docx.Document object or file path
+ if isinstance(input_doc, str):
+ doc = Document(input_doc)
+ else:
+ doc = input_doc
+ out = {}
+ table_count = 0
+ for tbl in doc.tables:
+ table_count += 1
+ multi_schemas = check_multi_schema_table(tbl)
+ if multi_schemas:
+ multi_data = extract_multi_schema_table(tbl, multi_schemas)
+ for schema_name, schema_data in multi_data.items():
+ if schema_data:
+ if schema_name in out:
+ for k, v in schema_data.items():
+ if k in out[schema_name]:
+ out[schema_name][k].extend(v)
+ else:
+ out[schema_name][k] = v
+ else:
+ out[schema_name] = schema_data
+ continue
+ schema = match_table_schema(tbl)
+ if not schema:
+ continue
+ spec = TABLE_SCHEMAS[schema]
+ data = extract_table_data(tbl, schema, spec)
+ if data:
+ if schema in out:
+ for k, v in data.items():
+ if k in out[schema]:
+ out[schema][k].extend(v)
+ else:
+ out[schema][k] = v
+ else:
+ out[schema] = data
+
+ # paragraphs (FIX: do not return early; build full 'paras' then attach)
+ paras = {}
+ for idx, para in enumerate(doc.paragraphs):
+ red_txt = "".join(r.text for r in para.runs if is_red_font(r)).strip()
+ if not red_txt:
+ continue
+ context = None
+ for j in range(idx-1, -1, -1):
+ txt = normalize_text(doc.paragraphs[j].text)
+ if txt:
+ all_patterns = HEADING_PATTERNS["main"] + HEADING_PATTERNS["sub"]
+ if any(re.search(p, txt, re.IGNORECASE) for p in all_patterns):
+ context = txt
+ break
+ if not context and re.fullmatch(PARAGRAPH_PATTERNS["date_line"], red_txt):
+ context = "Date"
+ if not context:
+ context = "(para)"
+ paras.setdefault(context, []).append(red_txt)
+
+ if paras:
+ out["paragraphs"] = paras
+
+ # Fallback: ensure we capture the last-page Operator Declaration by headers
+ if "Operator Declaration" not in out:
+ op_dec = extract_operator_declaration_by_headers_from_end(doc)
+ if op_dec:
+ out["Operator Declaration"] = op_dec
+
+ return out
+
+def extract_red_text_filelike(input_file, output_file):
+ """
+ Accepts:
+ input_file: file-like object (BytesIO/File) or path
+ output_file: file-like object (opened for writing text) or path
+ """
+ if hasattr(input_file, "seek"):
+ input_file.seek(0)
+ doc = Document(input_file)
+ result = extract_red_text(doc)
+ if hasattr(output_file, "write"):
+ json.dump(result, output_file, indent=2, ensure_ascii=False)
+ output_file.flush()
+ else:
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(result, f, indent=2, ensure_ascii=False)
+ return result
+
+if __name__ == "__main__":
+ # Support both script and app/file-like usage
+ if len(sys.argv) == 3:
+ input_docx = sys.argv[1]
+ output_json = sys.argv[2]
+ doc = Document(input_docx)
+ word_data = extract_red_text(doc)
+ with open(output_json, 'w', encoding='utf-8') as f:
+ json.dump(word_data, f, indent=2, ensure_ascii=False)
+ print(json.dumps(word_data, indent=2, ensure_ascii=False))
+ else:
+ print("To use as a module: extract_red_text_filelike(input_file, output_file)")
\ No newline at end of file
diff --git a/space-pdf/master_key.py b/space-pdf/master_key.py
new file mode 100644
index 0000000000000000000000000000000000000000..749032a9f1bdb5c3e279285b2886ab10b6b10153
--- /dev/null
+++ b/space-pdf/master_key.py
@@ -0,0 +1,372 @@
+"""
+Improved Master Key for NHVAS Audit extraction:
+- TABLE_SCHEMAS: Enhanced definitions with better matching criteria for Summary vs Basic tables
+- HEADING_PATTERNS: Improved regex patterns for main/sub headings
+- PARAGRAPH_PATTERNS: Enhanced patterns for key narrative sections
+"""
+
+# 1. Enhanced table schemas with better matching logic
+TABLE_SCHEMAS = {
+ "Tick as appropriate": {
+ "headings": [
+ {"level": 1, "text": "NHVAS Audit Summary Report"},
+ ],
+ "orientation": "left",
+ "labels": [
+ "Mass",
+ "Entry Audit",
+ "Maintenance",
+ "Initial Compliance Audit",
+ "Basic Fatigue",
+ "Compliance Audit",
+ "Advanced Fatigue",
+ "Spot Check",
+ "Triggered Audit"
+ ],
+ "priority": 90 # High priority for direct match
+ },
+ "Audit Information": {
+ "orientation": "left",
+ "labels": [
+ "Date of Audit",
+ "Location of audit",
+ "Auditor name",
+ "Audit Matrix Identifier (Name or Number)",
+ "Auditor Exemplar Global Reg No.",
+ "expiry Date:",
+ "NHVR Auditor Registration Number",
+ "expiry Date:"
+ ],
+ "priority": 80
+ },
+ "Operator Information": {
+ "headings": [
+ {"level": 1, "text": "Operator Information"}
+ ],
+ "orientation": "left",
+ "labels": [
+ "Operator name (Legal entity)",
+ "NHVAS Accreditation No. (If applicable)",
+ "Registered trading name/s",
+ "Australian Company Number",
+ "NHVAS Manual (Policies and Procedures) developed by"
+ ],
+ "priority": 85
+ },
+ "Operator contact details": {
+ "orientation": "left",
+ "labels": [
+ "Operator business address",
+ "Operator Postal address",
+ "Email address",
+ "Operator Telephone Number"
+ ],
+ "priority": 75,
+ "context_keywords": ["contact", "address", "email", "telephone"]
+ },
+ "Attendance List (Names and Position Titles)": {
+ "headings": [
+ {"level": 1, "text": "NHVAS Audit Summary Report"}
+ ],
+ "orientation": "row1",
+ "labels": ["Attendance List (Names and Position Titles)"],
+ "priority": 90
+ },
+ "Nature of the Operators Business (Summary)": {
+ "orientation": "row1",
+ "labels": ["Nature of the Operators Business (Summary):"],
+ "split_labels": ["Accreditation Number:", "Expiry Date:"],
+ "priority": 85
+ },
+ "Accreditation Vehicle Summary": {
+ "orientation": "left",
+ "labels": ["Number of powered vehicles", "Number of trailing vehicles"],
+ "priority": 80
+ },
+ "Accreditation Driver Summary": {
+ "orientation": "left",
+ "labels": ["Number of drivers in BFM", "Number of drivers in AFM"],
+ "priority": 80
+ },
+ "Compliance Codes": {
+ "orientation": "left",
+ "labels": ["V", "NC", "TNC", "SFI", "NAP", "NA"],
+ "priority": 70,
+ "context_exclusions": ["MASS MANAGEMENT", "MAINTENANCE MANAGEMENT", "FATIGUE MANAGEMENT"]
+ },
+ "Corrective Action Request Identification": {
+ "orientation": "row1",
+ "labels": ["Title", "Abbreviation", "Description"],
+ "priority": 80
+ },
+
+ # 🎯 BASIC MANAGEMENT SCHEMAS (Compliance Tables - Lower Priority)
+ "Maintenance Management": {
+ "headings": [
+ {"level": 1, "text": "NHVAS AUDIT SUMMARY REPORT"}
+ ],
+ "orientation": "left",
+ "labels": [
+ "Std 1. Daily Check",
+ "Std 2. Fault Recording and Reporting",
+ "Std 3. Fault Repair",
+ "Std 4. Maintenance Schedules and Methods",
+ "Std 5. Records and Documentation",
+ "Std 6. Responsibilities",
+ "Std 7. Internal Review",
+ "Std 8. Training and Education"
+ ],
+ "priority": 60,
+ "context_keywords": ["maintenance"],
+ "context_exclusions": ["summary", "details", "audit findings"] # Exclude Summary tables
+ },
+ "Mass Management": {
+ "headings": [
+ {"level": 1, "text": "NHVAS AUDIT SUMMARY REPORT"}
+ ],
+ "orientation": "left",
+ "labels": [
+ "Std 1. Responsibilities",
+ "Std 2. Vehicle Control",
+ "Std 3. Vehicle Use",
+ "Std 4. Records and Documentation",
+ "Std 5. Verification",
+ "Std 6. Internal Review",
+ "Std 7. Training and Education",
+ "Std 8. Maintenance of Suspension"
+ ],
+ "priority": 60,
+ "context_keywords": ["mass"],
+ "context_exclusions": ["summary", "details", "audit findings"] # Exclude Summary tables
+ },
+ "Fatigue Management": {
+ "headings": [
+ {"level": 1, "text": "NHVAS AUDIT SUMMARY REPORT"}
+ ],
+ "orientation": "left",
+ "labels": [
+ "Std 1. Scheduling and Rostering",
+ "Std 2. Health and wellbeing for performed duty",
+ "Std 3. Training and Education",
+ "Std 4. Responsibilities and management practices",
+ "Std 5. Internal Review",
+ "Std 6. Records and Documentation",
+ "Std 7. Workplace conditions"
+ ],
+ "priority": 60,
+ "context_keywords": ["fatigue"],
+ "context_exclusions": ["summary", "details", "audit findings"] # Exclude Summary tables
+ },
+
+ # 🎯 SUMMARY MANAGEMENT SCHEMAS (Detailed Tables with DETAILS column - Higher Priority)
+ "Maintenance Management Summary": {
+ "headings": [
+ {"level": 1, "text": "Audit Observations and Comments"},
+ {"level": 2, "text": "Maintenance Management Summary of Audit findings"}
+ ],
+ "orientation": "left",
+ "columns": ["MAINTENANCE MANAGEMENT", "DETAILS"],
+ "labels": [
+ "Std 1. Daily Check",
+ "Std 2. Fault Recording and Reporting",
+ "Std 3. Fault Repair",
+ "Std 4. Maintenance Schedules and Methods",
+ "Std 5. Records and Documentation",
+ "Std 6. Responsibilities",
+ "Std 7. Internal Review",
+ "Std 8. Training and Education"
+ ],
+ "priority": 85, # Higher priority than basic Maintenance Management
+ "context_keywords": ["maintenance", "summary", "details", "audit findings"]
+ },
+ "Mass Management Summary": {
+ "headings": [
+ {"level": 1, "text": "Mass Management Summary of Audit findings"}
+ ],
+ "orientation": "left",
+ "columns": ["MASS MANAGEMENT", "DETAILS"],
+ "labels": [
+ "Std 1. Responsibilities",
+ "Std 2. Vehicle Control",
+ "Std 3. Vehicle Use",
+ "Std 4. Records and Documentation",
+ "Std 5. Verification",
+ "Std 6. Internal Review",
+ "Std 7. Training and Education",
+ "Std 8. Maintenance of Suspension"
+ ],
+ "priority": 85, # Higher priority than basic Mass Management
+ "context_keywords": ["mass", "summary", "details", "audit findings"]
+ },
+ "Fatigue Management Summary": {
+ "headings": [
+ {"level": 1, "text": "Fatigue Management Summary of Audit findings"}
+ ],
+ "orientation": "left",
+ "columns": ["FATIGUE MANAGEMENT", "DETAILS"],
+ "labels": [
+ "Std 1. Scheduling and Rostering",
+ "Std 2. Health and wellbeing for performed duty",
+ "Std 3. Training and Education",
+ "Std 4. Responsibilities and management practices",
+ "Std 5. Internal Review",
+ "Std 6. Records and Documentation",
+ "Std 7. Workplace conditions"
+ ],
+ "priority": 85, # Higher priority than basic Fatigue Management
+ "context_keywords": ["fatigue", "summary", "details", "audit findings"]
+ },
+
+ # Vehicle Registration Tables
+ "Vehicle Registration Numbers Mass": {
+ "headings": [
+ {"level": 1, "text": "Vehicle Registration Numbers of Records Examined"},
+ {"level": 2, "text": "MASS MANAGEMENT"}
+ ],
+ "orientation": "row1",
+ "labels": [
+ "No.", "Registration Number", "Sub contractor",
+ "Sub-contracted Vehicles Statement of Compliance",
+ "Weight Verification Records",
+ "RFS Suspension Certification #",
+ "Suspension System Maintenance", "Trip Records",
+ "Fault Recording/ Reporting on Suspension System"
+ ],
+ "priority": 90, # Higher priority
+ "context_keywords": ["mass", "vehicle registration", "rfs suspension", "weight verification"],
+ "context_exclusions": ["maintenance", "roadworthiness", "daily checks"] # Exclude maintenance-specific terms
+},
+"Vehicle Registration Numbers Maintenance": {
+ "headings": [
+ {"level": 1, "text": "Vehicle Registration Numbers of Records Examined"},
+ {"level": 2, "text": "Maintenance Management"}
+ ],
+ "orientation": "row1",
+ "labels": [
+ "No.", "Registration Number", "Roadworthiness Certificates",
+ "Maintenance Records", "Daily Checks",
+ "Fault Recording/ Reporting", "Fault Repair"
+ ],
+ "priority": 85, # Lower priority
+ "context_keywords": ["maintenance", "vehicle registration", "roadworthiness", "daily checks"],
+ "context_exclusions": ["mass", "rfs suspension", "weight verification"] # Exclude mass-specific terms
+},
+ "Driver / Scheduler Records Examined": {
+ "headings": [
+ {"level": 1, "text": "Driver / Scheduler Records Examined"},
+ {"level": 2, "text": "FATIGUE MANAGEMENT"},
+ ],
+ "orientation": "row1",
+ "labels": [
+ "No.",
+ "Driver / Scheduler Name",
+ "Driver TLIF Course # Completed",
+ "Scheduler TLIF Course # Completed",
+ "Medical Certificates (Current Yes/No) Date of expiry",
+ "Roster / Schedule / Safe Driving Plan (Date Range)",
+ "Fit for Duty Statement Completed (Yes/No)",
+ "Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)"
+ ],
+ "priority": 80,
+ "context_keywords": ["driver", "scheduler", "fatigue"]
+ },
+
+ # Other Tables
+ "Operator's Name (legal entity)": {
+ "headings": [
+ {"level": 1, "text": "CORRECTIVE ACTION REQUEST (CAR)"}
+ ],
+ "orientation": "left",
+ "labels": ["Operator's Name (legal entity)"],
+ "priority": 85
+ },
+ "Non-conformance and CAR details": {
+ "orientation": "left",
+ "labels": [
+ "Non-conformance agreed close out date",
+ "Module and Standard",
+ "Corrective Action Request (CAR) Number",
+ "Observed Non-conformance:",
+ "Corrective Action taken or to be taken by operator:",
+ "Operator or Representative Signature",
+ "Position",
+ "Date",
+ "Comments:",
+ "Auditor signature",
+ "Date"
+ ],
+ "priority": 75,
+ "context_keywords": ["non-conformance", "corrective action"]
+ },
+ "NHVAS Approved Auditor Declaration": {
+ "headings": [
+ {"level": 1, "text": "NHVAS APPROVED AUDITOR DECLARATION"}
+ ],
+ "orientation": "row1",
+ "labels": ["Print Name", "NHVR or Exemplar Global Auditor Registration Number"],
+ "priority": 90,
+ "context_keywords": ["auditor declaration", "NHVR"],
+ "context_exclusions": ["manager", "operator declaration"]
+ },
+ "Audit Declaration dates": {
+ "headings": [
+ {"level": 1, "text": "Audit Declaration dates"}
+ ],
+ "orientation": "left",
+ "labels": [
+ "Audit was conducted on",
+ "Unconditional CARs closed out on:",
+ "Conditional CARs to be closed out by:"
+ ],
+ "priority": 80
+ },
+ "Print accreditation name": {
+ "headings": [
+ {"level": 1, "text": "(print accreditation name)"}
+ ],
+ "orientation": "left",
+ "labels": ["(print accreditation name)"],
+ "priority": 85
+ },
+ "Operator Declaration": {
+ "headings": [
+ {"level": 1, "text": "Operator Declaration"}
+ ],
+ "orientation": "row1",
+ "labels": ["Print Name", "Position Title"],
+ "priority": 90,
+ "context_keywords": ["operator declaration", "manager"],
+ "context_exclusions": ["auditor", "nhvas approved"]
+ }
+}
+
+# 2. Enhanced heading detection patterns
+HEADING_PATTERNS = {
+ "main": [
+ r"NHVAS\s+Audit\s+Summary\s+Report",
+ r"NATIONAL\s+HEAVY\s+VEHICLE\s+ACCREDITATION\s+AUDIT\s+SUMMARY\s+REPORT",
+ r"NHVAS\s+AUDIT\s+SUMMARY\s+REPORT"
+ ],
+ "sub": [
+ r"AUDIT\s+OBSERVATIONS\s+AND\s+COMMENTS",
+ r"MAINTENANCE\s+MANAGEMENT",
+ r"MASS\s+MANAGEMENT",
+ r"FATIGUE\s+MANAGEMENT",
+ r"Fatigue\s+Management\s+Summary\s+of\s+Audit\s+findings",
+ r"MAINTENANCE\s+MANAGEMENT\s+SUMMARY\s+OF\s+AUDIT\s+FINDINGS",
+ r"MASS\s+MANAGEMENT\s+SUMMARY\s+OF\s+AUDIT\s+FINDINGS",
+ r"Vehicle\s+Registration\s+Numbers\s+of\s+Records\s+Examined",
+ r"CORRECTIVE\s+ACTION\s+REQUEST\s+\(CAR\)",
+ r"NHVAS\s+APPROVED\s+AUDITOR\s+DECLARATION",
+ r"Operator\s+Declaration",
+ r"Operator\s+Information"
+ ]
+}
+
+# 3. Enhanced paragraph patterns for key narrative sections
+PARAGRAPH_PATTERNS = {
+ "findings_summary": r"Provide a summary of findings based on the evidence gathered during the audit\.",
+ "declaration_text": r"I hereby acknowledge and agree with the findings.*",
+ "introductory_note": r"This audit assesses the.*",
+ "date_line": r"^\s*\d{1,2}(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4}\s*$|^Date$"
+}
diff --git a/space-pdf/packages.txt b/space-pdf/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..15f6a51cbff6603c1d8c3d80135eb9c9f469afdb
--- /dev/null
+++ b/space-pdf/packages.txt
@@ -0,0 +1,2 @@
+poppler-utils
+tesseract-ocr
\ No newline at end of file
diff --git a/space-pdf/requirements.txt b/space-pdf/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9343132ab8d20fd49ba8cbd5d72294e9b6ac6991
--- /dev/null
+++ b/space-pdf/requirements.txt
@@ -0,0 +1,37 @@
+fastapi==0.111.1
+pydantic==2.11.0
+python-multipart==0.0.9
+uvicorn==0.30.3
+gunicorn==22.0.0
+requests==2.32.3
+torch==2.4.0
+torchvision==0.19.0
+Pillow==10.4.0
+pdf-annotate==0.12.0
+scipy==1.14.0
+opencv-python==4.10.0.84
+Shapely==2.0.5
+transformers==4.40.2
+huggingface_hub==0.23.5
+pdf2image==1.17.0
+lightgbm==4.5.0
+setuptools==75.4.0
+roman==4.2
+hydra-core==1.3.2
+pypandoc==1.13
+rapid-table==2.0.3
+rapidocr==3.2.0
+pix2tex==0.1.4
+latex2mathml==3.78.0
+PyMuPDF==1.25.5
+git+https://github.com/huridocs/pdf-features.git@2025.7.30.1
+gradio==4.44.1
+pytesseract
+python-docx
+camelot-py[cv] # for digital-table parsing
+pdf2image # for fallback OCR on images
+pytesseract
+Pillow
+rapidfuzz
+pdfplumber
+openai
\ No newline at end of file
diff --git a/space-pdf/update_docx_with_pdf.py b/space-pdf/update_docx_with_pdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b534c2c2628dd417b9e750438c5fd8ef7691565
--- /dev/null
+++ b/space-pdf/update_docx_with_pdf.py
@@ -0,0 +1,1470 @@
+#!/usr/bin/env python3
+"""
+Enhanced NHVAS PDF to DOCX JSON Merger
+Comprehensive extraction and mapping from PDF to DOCX structure
+(keep pipeline intact; fix spacing, operator info mapping, vehicle-reg header mapping, date fallback)
+"""
+import json
+import re
+import sys
+from pathlib import Path
+from typing import Dict, List, Any, Optional
+from collections import OrderedDict # <-- add this
+
+
+def _nz(x):
+ return x if isinstance(x, str) and x.strip() else ""
+
+SUMMARY_SECTIONS = {
+ "MAINTENANCE MANAGEMENT": "Maintenance Management Summary",
+ "MASS MANAGEMENT": "Mass Management Summary",
+ "FATIGUE MANAGEMENT": "Fatigue Management Summary",
+}
+
+# ───────────────────────────── helpers: text cleanup & label matching ─────────────────────────────
+def _canon_header(s: str) -> str:
+ if not s: return ""
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ s = re.sub(r"[/]+", " / ", s)
+ s = re.sub(r"[^a-z0-9#/ ]+", " ", s)
+ s = re.sub(r"\s+", " ", s).strip()
+ return s
+
+
+# Header aliases -> internal keys we already use later during mapping
+_VEH_HEADER_ALIASES = {
+ # common
+ "registration number": "registration",
+ "reg no": "registration",
+ "reg.#": "registration",
+ "no.": "no",
+ "no": "no",
+
+ # maintenance table
+ "roadworthiness certificates": "roadworthiness",
+ "maintenance records": "maintenance_records",
+ "daily checks": "daily_checks",
+ "fault recording reporting": "fault_recording",
+ "fault recording / reporting": "fault_recording",
+ "fault repair": "fault_repair",
+
+ # mass table
+ "sub contractor": "sub_contractor",
+ "sub-contractor": "sub_contractor",
+ "sub contracted vehicles statement of compliance": "sub_comp",
+ "sub-contracted vehicles statement of compliance": "sub_comp",
+ "weight verification records": "weight_verification",
+ "rfs suspension certification #": "rfs_certification",
+ "rfs suspension certification number": "rfs_certification",
+ "suspension system maintenance": "suspension_maintenance",
+ "trip records": "trip_records",
+ "fault recording reporting on suspension system": "fault_reporting_suspension",
+ "fault recording / reporting on suspension system": "fault_reporting_suspension",
+}
+
+# --- helpers ---
+def build_vehicle_sections(extracted: dict) -> dict:
+ """Build arrays for Maintenance and Mass tables. Maintenance uses recorded rows to include ALL entries."""
+ maint = {
+ "Registration Number": [],
+ "Roadworthiness Certificates": [],
+ "Maintenance Records": [],
+ "Daily Checks": [],
+ "Fault Recording/ Reporting": [],
+ "Fault Repair": [],
+ }
+ mass = {
+ "Registration Number": [],
+ "Weight Verification Records": [],
+ "RFS Suspension Certification #": [],
+ "Suspension System Maintenance": [],
+ "Trip Records": [],
+ "Fault Recording/ Reporting on Suspension System": [],
+ }
+
+ # Prefer authoritative maintenance rows captured during parsing (spans all pages)
+ if extracted.get("_maint_rows"):
+ for row in extracted["_maint_rows"]:
+ maint["Registration Number"].append(_smart_space(row.get("registration", "")))
+ maint["Roadworthiness Certificates"].append(_nz(row.get("roadworthiness", "")))
+ maint["Maintenance Records"].append(_nz(row.get("maintenance_records", "")))
+ maint["Daily Checks"].append(_nz(row.get("daily_checks", "")))
+ maint["Fault Recording/ Reporting"].append(_nz(row.get("fault_recording", "")))
+ maint["Fault Repair"].append(_nz(row.get("fault_repair", "")))
+ else:
+ # Fallback to vehicles map (older behavior)
+ for v in extracted.get("vehicles", []) or []:
+ if not v.get("registration"): continue
+ if v.get("seen_in_maintenance") or any(v.get(k) for k in ["roadworthiness","maintenance_records","daily_checks","fault_recording","fault_repair"]):
+ rw = _nz(v.get("roadworthiness", "")); mr = _nz(v.get("maintenance_records", "")); dc = _nz(v.get("daily_checks", ""))
+ fr = _nz(v.get("fault_recording", "")); rp = _nz(v.get("fault_repair", ""))
+ if not mr and dc: mr = dc
+ if not rp and fr: rp = fr
+ if not fr and rp: fr = rp
+ maint["Registration Number"].append(_smart_space(v["registration"]))
+ maint["Roadworthiness Certificates"].append(rw)
+ maint["Maintenance Records"].append(mr)
+ maint["Daily Checks"].append(dc)
+ maint["Fault Recording/ Reporting"].append(fr)
+ maint["Fault Repair"].append(rp)
+
+ # Mass stays as-is (from vehicles)
+ for v in extracted.get("vehicles", []) or []:
+ if not v.get("registration"): continue
+ if v.get("seen_in_mass") or any(v.get(k) for k in ["weight_verification","rfs_certification","suspension_maintenance","trip_records","fault_reporting_suspension"]):
+ mass["Registration Number"].append(_smart_space(v["registration"]))
+ mass["Weight Verification Records"].append(_nz(v.get("weight_verification", "")))
+ mass["RFS Suspension Certification #"].append(_nz(v.get("rfs_certification", "")))
+ mass["Suspension System Maintenance"].append(_nz(v.get("suspension_maintenance", "")))
+ mass["Trip Records"].append(_nz(v.get("trip_records", "")))
+ mass["Fault Recording/ Reporting on Suspension System"].append(_nz(v.get("fault_reporting_suspension", "")))
+
+ return {
+ "Vehicle Registration Numbers Maintenance": maint,
+ "Vehicle Registration Numbers Mass": mass,
+ }
+
+
+def _map_header_indices(headers: list[str]) -> dict:
+ """Return {internal_key: column_index} by matching/aliasing header text."""
+ idx = {}
+ for i, h in enumerate(headers or []):
+ ch = _canon_header(h)
+ # try direct alias
+ if ch in _VEH_HEADER_ALIASES:
+ idx[_VEH_HEADER_ALIASES[ch]] = i
+ continue
+ # relax a little for 'registration number' variants
+ if "registration" in ch and "number" in ch:
+ idx["registration"] = i
+ continue
+ if "roadworthiness" in ch:
+ idx["roadworthiness"] = i
+ continue
+ if "maintenance" in ch and "records" in ch:
+ idx["maintenance_records"] = i
+ continue
+ if "daily" in ch and "check" in ch:
+ idx["daily_checks"] = i
+ continue
+ if "fault" in ch and "record" in ch and "suspension" not in ch:
+ # maintenance fault-recording column
+ if "repair" in ch:
+ idx["fault_repair"] = i
+ else:
+ idx["fault_recording"] = i
+ continue
+ if "weight" in ch and "verification" in ch:
+ idx["weight_verification"] = i
+ continue
+ if "rfs" in ch and "certification" in ch:
+ idx["rfs_certification"] = i
+ continue
+ if "suspension" in ch and "maintenance" in ch:
+ idx["suspension_maintenance"] = i
+ continue
+ if "trip" in ch and "record" in ch:
+ idx["trip_records"] = i
+ continue
+ if "fault" in ch and "report" in ch and "suspension" in ch:
+ idx["fault_reporting_suspension"] = i
+ continue
+ return idx
+
+def _canon(s: str) -> str:
+ if not s: return ""
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = re.sub(r"[^a-z0-9#]+", " ", s)
+ return re.sub(r"\s+", " ", s).strip()
+
+def _smart_space(s: str) -> str:
+ if not s: return s
+ s = str(s)
+
+ # Insert spaces at typical OCR glue points
+ s = re.sub(r'([a-z])([A-Z])', r'\1 \2', s)
+ s = re.sub(r'([A-Za-z])(\d)', r'\1 \2', s)
+ s = re.sub(r'(\d)([A-Za-z])', r'\1 \2', s)
+ s = re.sub(r'([A-Z]{2,})(\d)', r'\1 \2', s)
+
+ # Fix common glued tokens
+ s = s.replace("POBox", "PO Box")
+
+ # Compact ordinals back together: "9 th" -> "9th", but preserve a space after the ordinal if followed by a word
+ s = re.sub(r'\b(\d+)\s*(st|nd|rd|th)\b', r'\1\2', s)
+
+ s = re.sub(r"\s+", " ", s).strip()
+ return s
+
+def looks_like_plate(s: str) -> bool:
+ if not s: return False
+ t = re.sub(r"[\s-]", "", str(s).upper())
+ if not (5 <= len(t) <= 8): return False
+ if not re.fullmatch(r"[A-Z0-9]+", t): return False
+ if sum(c.isalpha() for c in t) < 2: return False
+ if sum(c.isdigit() for c in t) < 2: return False
+ if t in {"ENTRY","YES","NO","N/A","NA"}: return False
+ return True
+
+def is_dateish(s: str) -> bool:
+ if not s: return False
+ s = _smart_space(s)
+ # tokens like 03/22, 20/02/2023, 01.02.21, 2023-02-20
+ return bool(re.search(r"\b\d{1,4}(?:[./-]\d{1,2}){1,2}\b", s))
+
+def extract_date_tokens(s: str) -> list[str]:
+ if not s: return []
+ s = _smart_space(s)
+ return re.findall(r"\b\d{1,4}(?:[./-]\d{1,2}){1,2}\b", s)
+
+
+def _clean_list(vals: List[str]) -> List[str]:
+ out = []
+ for v in vals:
+ v = _smart_space(v)
+ if v:
+ out.append(v)
+ return out
+
+def _looks_like_manual_value(s: str) -> bool:
+ if not s: return False
+ s = s.strip()
+ # reject pure digits (e.g., "51902") and very short tokens
+ if re.fullmatch(r"\d{3,}", s):
+ return False
+ # accept if it has any letters or typical version hints
+ return bool(re.search(r"[A-Za-z]", s))
+
+def _looks_like_company(s: str) -> bool:
+ """Very light validation to avoid capturing labels as values."""
+ if not s: return False
+ s = _smart_space(s)
+ # at least two words containing letters (e.g., "Kangaroo Transport")
+ return bool(re.search(r"[A-Za-z]{2,}\s+[A-Za-z&]{2,}", s))
+
+# ───────────────────────────── label index (non-summary only; no values) ─────────────────────────────
+LABEL_INDEX: Dict[str, Dict[str, Dict[str, Any]]] = {
+ "Audit Information": {
+ "Date of Audit": {"alts": ["Date of Audit"]},
+ "Location of audit": {"alts": ["Location of audit", "Location"]},
+ "Auditor name": {"alts": ["Auditor name", "Auditor"]},
+ "Audit Matrix Identifier (Name or Number)": {"alts": ["Audit Matrix Identifier (Name or Number)", "Audit Matrix Identifier"]},
+ "Auditor Exemplar Global Reg No.": {"alts": ["Auditor Exemplar Global Reg No."]},
+ "NHVR Auditor Registration Number": {"alts": ["NHVR Auditor Registration Number"]},
+ "expiry Date:": {"alts": ["expiry Date:", "Expiry Date:"]},
+ },
+ "Operator Information": {
+ "Operator name (Legal entity)": {"alts": ["Operator name (Legal entity)", "Operator's Name (legal entity)"]},
+ "NHVAS Accreditation No. (If applicable)": {"alts": ["NHVAS Accreditation No. (If applicable)", "NHVAS Accreditation No."]},
+ "Registered trading name/s": {"alts": ["Registered trading name/s", "Trading name/s"]},
+ "Australian Company Number": {"alts": ["Australian Company Number", "ACN"]},
+ "NHVAS Manual (Policies and Procedures) developed by": {"alts": [
+ "NHVAS Manual (Policies and Procedures) developed by",
+ "NHVAS Manual developed by",
+ "Manual developed by"
+ ]},
+ },
+ "Operator contact details": {
+ "Operator business address": {"alts": ["Operator business address", "Business address"]},
+ "Operator Postal address": {"alts": ["Operator Postal address", "Postal address"]},
+ "Email address": {"alts": ["Email address", "Email"]},
+ "Operator Telephone Number": {"alts": ["Operator Telephone Number", "Telephone", "Phone"]},
+ },
+ "Attendance List (Names and Position Titles)": {
+ "Attendance List (Names and Position Titles)": {"alts": ["Attendance List (Names and Position Titles)", "Attendance List"]},
+ },
+ "Nature of the Operators Business (Summary)": {
+ "Nature of the Operators Business (Summary):": {"alts": ["Nature of the Operators Business (Summary):"]},
+ },
+ "Accreditation Vehicle Summary": {
+ "Number of powered vehicles": {"alts": ["Number of powered vehicles"]},
+ "Number of trailing vehicles": {"alts": ["Number of trailing vehicles"]},
+ },
+ "Accreditation Driver Summary": {
+ "Number of drivers in BFM": {"alts": ["Number of drivers in BFM"]},
+ "Number of drivers in AFM": {"alts": ["Number of drivers in AFM"]},
+ },
+ "Vehicle Registration Numbers Maintenance": {
+ "No.": {"alts": ["No.", "No"]},
+ "Registration Number": {"alts": ["Registration Number", "Registration"]},
+ "Roadworthiness Certificates": {"alts": ["Roadworthiness Certificates", "Roadworthiness"]},
+ "Maintenance Records": {"alts": ["Maintenance Records"]},
+ "Daily Checks": {"alts": ["Daily Checks", "Daily Check"]},
+ "Fault Recording/ Reporting": {"alts": ["Fault Recording/ Reporting", "Fault Recording / Reporting"]},
+ "Fault Repair": {"alts": ["Fault Repair"]},
+ },
+ "Vehicle Registration Numbers Mass": {
+ "No.": {"alts": ["No.", "No"]},
+ "Registration Number": {"alts": ["Registration Number", "Registration"]},
+ "Sub contractor": {"alts": ["Sub contractor", "Sub-contractor"]},
+ "Sub-contracted Vehicles Statement of Compliance": {"alts": ["Sub-contracted Vehicles Statement of Compliance"]},
+ "Weight Verification Records": {"alts": ["Weight Verification Records"]},
+ "RFS Suspension Certification #": {"alts": ["RFS Suspension Certification #", "RFS Suspension Certification Number"]},
+ "Suspension System Maintenance": {"alts": ["Suspension System Maintenance"]},
+ "Trip Records": {"alts": ["Trip Records"]},
+ "Fault Recording/ Reporting on Suspension System": {"alts": ["Fault Recording/ Reporting on Suspension System"]},
+ },
+ "Driver / Scheduler Records Examined": {
+ "No.": {"alts": ["No.", "No"]},
+ "Driver / Scheduler Name": {"alts": ["Driver / Scheduler Name"]},
+ "Driver TLIF Course # Completed": {"alts": ["Driver TLIF Course # Completed"]},
+ "Scheduler TLIF Course # Completed": {"alts": ["Scheduler TLIF Course # Completed"]},
+ "Medical Certificates (Current Yes/No) Date of expiry": {"alts": ["Medical Certificates (Current Yes/No) Date of expiry"]},
+ "Roster / Schedule / Safe Driving Plan (Date Range)": {"alts": ["Roster / Schedule / Safe Driving Plan (Date Range)"]},
+ "Fit for Duty Statement Completed (Yes/No)": {"alts": ["Fit for Duty Statement Completed (Yes/No)"]},
+ "Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)": {"alts": ["Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)"]},
+ },
+ "NHVAS Approved Auditor Declaration": {
+ "Print Name": {"alts": ["Print Name"]},
+ "NHVR or Exemplar Global Auditor Registration Number": {"alts": ["NHVR or Exemplar Global Auditor Registration Number"]},
+ },
+ "Audit Declaration dates": {
+ "Audit was conducted on": {"alts": ["Audit was conducted on"]},
+ "Unconditional CARs closed out on:": {"alts": ["Unconditional CARs closed out on:"]},
+ "Conditional CARs to be closed out by:": {"alts": ["Conditional CARs to be closed out by:"]},
+ },
+ "Print accreditation name": {
+ "(print accreditation name)": {"alts": ["(print accreditation name)"]},
+ },
+ "Operator Declaration": {
+ "Print Name": {"alts": ["Print Name"]},
+ "Position Title": {"alts": ["Position Title"]},
+ },
+}
+
+class NHVASMerger:
+ def __init__(self):
+ self.debug_mode = True
+ self._vehicle_by_reg = OrderedDict()
+
+ def log_debug(self, msg: str):
+ if self.debug_mode:
+ print(f"🔍 {msg}")
+
+ def normalize_std_label(self, label: str) -> str:
+ if not label: return ""
+ base = re.sub(r"\([^)]*\)", "", label)
+ base = re.sub(r"\s+", " ", base).strip()
+ m = re.match(r"^(Std\s*\d+\.\s*[^:]+?)\s*$", base, flags=re.IGNORECASE)
+ return m.group(1).strip() if m else base
+
+ def _pick_nearby(self, row, anchor_idx: int | None, want: str = "plate", window: int = 3) -> str:
+ """Return the best cell for a field by looking at the anchor index and nearby columns.
+ want ∈ {"plate","date","rf","yn"}"""
+ def cell(i):
+ if i is None or i < 0 or i >= len(row): return ""
+ v = row[i]
+ return v.strip() if isinstance(v, str) else str(v).strip()
+
+ # 1) try the anchor cell
+ cand = cell(anchor_idx)
+ if want == "plate" and looks_like_plate(cand): return _smart_space(cand)
+ if want == "date" and is_dateish(cand): return _smart_space(cand)
+ if want == "rf" and re.search(r"\bRF\s*\d+\b", cand, re.I): return _smart_space(re.search(r"\bRF\s*\d+\b", cand, re.I).group(0))
+ if want == "yn" and cand.strip().lower() in {"yes","no"}: return cand.strip().title()
+
+ # 2) scan a window around the anchor
+ if anchor_idx is not None:
+ for offset in range(1, window+1):
+ for i in (anchor_idx - offset, anchor_idx + offset):
+ c = cell(i)
+ if not c: continue
+ if want == "plate" and looks_like_plate(c): return _smart_space(c)
+ if want == "date" and is_dateish(c): return _smart_space(c)
+ if want == "rf":
+ m = re.search(r"\bRF\s*\d+\b", c, re.I)
+ if m: return _smart_space(m.group(0))
+ if want == "yn" and c.strip().lower() in {"yes","no"}: return c.strip().title()
+
+ # 3) last resort: scan whole row
+ joined = " ".join(str(c or "") for c in row)
+ if want == "plate":
+ for tok in joined.split():
+ if looks_like_plate(tok): return _smart_space(tok)
+ if want == "date":
+ tok = extract_date_tokens(joined)
+ return tok[0] if tok else ""
+ if want == "rf":
+ m = re.search(r"\bRF\s*\d+\b", joined, re.I)
+ return _smart_space(m.group(0)) if m else ""
+ if want == "yn":
+ j = f" {joined.lower()} "
+ if " yes " in j: return "Yes"
+ if " no " in j: return "No"
+ return ""
+
+
+ def _force_fill_maintenance_from_tables(self, pdf_data: Dict, merged: Dict) -> None:
+ """Overwrite Maintenance arrays by scanning ALL maintenance tables across pages."""
+ maint = merged.get("Vehicle Registration Numbers Maintenance")
+ if not isinstance(maint, dict):
+ return
+
+ tables = (pdf_data.get("extracted_data") or {}).get("all_tables") or []
+ regs, rw, mr, dc, fr, rp = [], [], [], [], [], []
+
+ for t in tables:
+ hdrs = [_canon_header(h or "") for h in t.get("headers") or []]
+ if not hdrs:
+ continue
+ # detect a maintenance table
+ txt = " ".join(hdrs)
+ if ("registration" not in txt) or not any(
+ k in txt for k in ["maintenance records", "daily", "fault recording", "fault repair", "roadworthiness"]
+ ):
+ continue
+
+ def fidx(pred):
+ for i, h in enumerate(hdrs):
+ if pred(h):
+ return i
+ return None
+
+ reg_i = fidx(lambda h: "registration" in h)
+ rw_i = fidx(lambda h: "roadworthiness" in h)
+ mr_i = fidx(lambda h: "maintenance" in h and "record" in h)
+ dc_i = fidx(lambda h: "daily" in h and "check" in h)
+ fr_i = fidx(lambda h: "fault" in h and "record" in h and "suspension" not in h)
+ rp_i = fidx(lambda h: "fault" in h and "repair" in h)
+
+ for r in t.get("data") or []:
+ def cell(i):
+ if i is None or i >= len(r): return ""
+ v = r[i]
+ return v.strip() if isinstance(v, str) else str(v).strip()
+
+ plate = _smart_space(cell(reg_i))
+ if not plate or not looks_like_plate(plate):
+ continue
+
+ v_rw = _nz(cell(rw_i))
+ v_mr = _nz(cell(mr_i))
+ v_dc = _nz(cell(dc_i))
+ v_fr = _nz(cell(fr_i))
+ v_rp = _nz(cell(rp_i))
+
+ # sensible fallbacks
+ if not v_mr and v_dc: v_mr = v_dc
+ if not v_rp and v_fr: v_rp = v_fr
+ if not v_fr and v_rp: v_fr = v_rp
+
+ regs.append(plate); rw.append(v_rw); mr.append(v_mr)
+ dc.append(v_dc); fr.append(v_fr); rp.append(v_rp)
+
+ if regs: # overwrite arrays only if we found rows
+ maint["Registration Number"] = regs
+ maint["Roadworthiness Certificates"] = rw
+ maint["Maintenance Records"] = mr
+ maint["Daily Checks"] = dc
+ maint["Fault Recording/ Reporting"] = fr
+ maint["Fault Repair"] = rp
+
+ def _collapse_multiline_headers(self, headers: List[str], data_rows: List[List[str]]):
+ """
+ Merge header continuation rows (when first data rows are not numeric '1.', '2.', …)
+ into the main headers, then return (merged_headers, remaining_data_rows).
+ """
+ merged = [_smart_space(h or "") for h in (headers or [])]
+ consumed = 0
+ header_frags: List[List[str]] = []
+
+ # Collect up to 5 leading rows that look like header fragments
+ for r in data_rows[:5]:
+ first = (str(r[0]).strip() if r else "")
+ if re.match(r"^\d+\.?$", first):
+ break # real data starts
+ consumed += 1
+ header_frags.append(r)
+
+ # Merge every collected fragment row into merged
+ for frag in header_frags:
+ for i, cell in enumerate(frag):
+ cell_txt = _smart_space(str(cell or "").strip())
+ if not cell_txt:
+ continue
+ if i >= len(merged):
+ merged.append(cell_txt)
+ else:
+ merged[i] = (merged[i] + " " + cell_txt).strip()
+
+ return merged, data_rows[consumed:]
+
+ def _first_attendance_name_title(self, att_list: List[str]) -> Optional[tuple[str, str]]:
+ """Return (print_name, position_title) from the first 'Name - Title' in attendance."""
+ if not att_list:
+ return None
+ # First "Name - Title", stop before next "Name -"
+ pat = re.compile(
+ r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3})\s*-\s*(.*?)(?=(?:\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3}\s*-\s*)|$)'
+ )
+ for item in att_list:
+ s = _smart_space(str(item))
+ m = pat.search(s)
+ if m:
+ name = _smart_space(m.group(1))
+ title = _smart_space(m.group(2))
+ return name, title
+ return None
+
+
+ # ───────────────────────────── summary tables (unchanged logic) ─────────────────────────────
+ def build_summary_maps(self, pdf_json: dict) -> dict:
+ out = {v: {} for v in SUMMARY_SECTIONS.values()}
+ try:
+ tables = pdf_json["extracted_data"]["all_tables"]
+ except Exception:
+ return out
+
+ for t in tables:
+ headers = [re.sub(r"\s+", " ", (h or "")).strip().upper() for h in t.get("headers", [])]
+ if "DETAILS" not in headers:
+ continue
+ section_key_raw = next((h for h in headers if h in SUMMARY_SECTIONS), None)
+ if not section_key_raw:
+ continue
+ section_name = SUMMARY_SECTIONS[section_key_raw]
+ for row in t.get("data", []):
+ if not row: continue
+ left = str(row[0]) if len(row) >= 1 else ""
+ right = str(row[1]) if len(row) >= 2 else ""
+ left_norm = self.normalize_std_label(left)
+ if left_norm and right:
+ prev = out[section_name].get(left_norm, "")
+ merged_text = (prev + " " + right).strip() if prev else right.strip()
+ out[section_name][left_norm] = merged_text
+
+ for sec in out:
+ out[sec] = {k: [_smart_space(v)] for k, v in out[sec].items() if v}
+ return out
+
+ # ───────────────────────────── NEW: find cell by label in tables ─────────────────────────────
+ def _find_table_value(self, tables: List[Dict], label_variants: List[str]) -> Optional[str]:
+ targets = {_canon(v) for v in label_variants}
+ for t in tables:
+ data = t.get("data", [])
+ if not data: continue
+ for row in data:
+ if not row: continue
+ key = _canon(str(row[0]))
+ if key in targets:
+ vals = [str(c).strip() for c in row[1:] if str(c).strip()]
+ if vals:
+ return _smart_space(" ".join(vals))
+ return None
+
+ # ───────────────────────────── comprehensive extraction (minimal changes) ─────────────────────────────
+ def extract_from_pdf_comprehensive(self, pdf_data: Dict) -> Dict[str, Any]:
+ self._vehicle_by_reg.clear()
+ extracted = {}
+ extracted_data = pdf_data.get("extracted_data", {})
+ tables = extracted_data.get("all_tables", [])
+
+ # Capture "Audit was conducted on" from tables; ignore placeholder "Date"
+ awd = self._find_table_value(
+ tables,
+ LABEL_INDEX["Audit Declaration dates"]["Audit was conducted on"]["alts"]
+ )
+ if awd:
+ awd = _smart_space(awd)
+ if re.search(r"\d", awd) and not re.fullmatch(r"date", awd, re.I):
+ extracted["audit_conducted_date"] = awd
+
+
+
+ # 1) Audit Information (table first)
+ audit_info = extracted_data.get("audit_information", {})
+ if audit_info:
+ extracted["audit_info"] = {
+ "date_of_audit": _smart_space(audit_info.get("DateofAudit", "")),
+ "location": _smart_space(audit_info.get("Locationofaudit", "")),
+ "auditor_name": _smart_space(audit_info.get("Auditorname", "")),
+ "matrix_id": _smart_space(audit_info.get("AuditMatrixIdentifier (Name or Number)", "")),
+ }
+ # If missing, try generic table lookup
+ for label, meta in LABEL_INDEX.get("Audit Information", {}).items():
+ if label == "expiry Date:": # not used in your DOCX example
+ continue
+ val = self._find_table_value(tables, meta.get("alts", [label]))
+ if val:
+ extracted.setdefault("audit_info", {})
+ if _canon(label) == _canon("Date of Audit"): extracted["audit_info"]["date_of_audit"] = val
+ elif _canon(label) == _canon("Location of audit"): extracted["audit_info"]["location"] = val
+ elif _canon(label) == _canon("Auditor name"): extracted["audit_info"]["auditor_name"] = val
+ elif _canon(label) == _canon("Audit Matrix Identifier (Name or Number)"): extracted["audit_info"]["matrix_id"] = val
+
+ # 2) Operator Information (prefer table rows)
+ operator_info = extracted_data.get("operator_information", {})
+ if operator_info:
+ extracted["operator_info"] = {
+ "name": "",
+ "trading_name": _smart_space(operator_info.get("trading_name", "")),
+ "acn": _smart_space(operator_info.get("company_number", "")),
+ "manual": _smart_space(operator_info.get("nhvas_accreditation", "")),
+ "business_address": _smart_space(operator_info.get("business_address", "")),
+ "postal_address": _smart_space(operator_info.get("postal_address", "")),
+ "email": operator_info.get("email", ""),
+ "phone": _smart_space(operator_info.get("phone", "")),
+ }
+
+ # Fill operator info via table lookup
+ for label, meta in LABEL_INDEX.get("Operator Information", {}).items():
+ val = self._find_table_value(tables, meta.get("alts", [label]))
+ if not val: continue
+ if _canon(label) == _canon("Operator name (Legal entity)") and _looks_like_company(val):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["name"] = val
+ elif _canon(label) == _canon("Registered trading name/s"):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["trading_name"] = val
+ elif _canon(label) == _canon("Australian Company Number"):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["acn"] = val
+ elif _canon(label) == _canon("NHVAS Manual (Policies and Procedures) developed by"):
+ extracted.setdefault("operator_info", {})
+ if _looks_like_manual_value(val):
+ extracted["operator_info"]["manual"] = val
+
+ # 3) Generic table parsing (unchanged logic for other sections)
+ self._extract_table_data(tables, extracted)
+
+ # 4) Text parsing (kept, but spacing applied)
+ self._extract_text_content(extracted_data.get("all_text_content", []), extracted)
+ # Vehicle tables sometimes fail to land in all_tables; parse from text as a fallback
+ self._extract_vehicle_tables_from_text(extracted_data.get("all_text_content", []), extracted)
+
+ # 5) Vehicle/Driver data (kept)
+ self._extract_vehicle_driver_data(extracted_data, extracted)
+
+ # 6) Detailed mgmt (kept)
+ self._extract_detailed_management_data(extracted_data, extracted)
+
+ return extracted
+
+ # ───────────────────────────── table classifiers ─────────────────────────────
+ # replace your _extract_table_data with this version
+ def _extract_table_data(self, tables: List[Dict], extracted: Dict):
+ for table in tables:
+ headers = table.get("headers", []) or []
+ data_rows = table.get("data", []) or []
+ if not data_rows:
+ continue
+
+ page_num = table.get("page", 0)
+ self.log_debug(f"Processing table on page {page_num} with headers: {headers[:3]}...")
+
+ # 🔧 NEW: collapse possible multi-line headers once up front
+ collapsed_headers, collapsed_rows = self._collapse_multiline_headers(headers, data_rows)
+
+ # 🔧 Try vehicle tables FIRST using either raw or collapsed headers
+ if self._is_vehicle_registration_table(headers) or self._is_vehicle_registration_table(collapsed_headers):
+ # always extract with the collapsed header/rows so we see "Registration Number", etc.
+ self._extract_vehicle_registration_table(collapsed_headers, collapsed_rows, extracted, page_num)
+ continue
+
+ # the rest keep their existing order/logic (use the original headers/rows)
+ if self._is_audit_info_table(headers):
+ self._extract_audit_info_table(data_rows, extracted)
+ elif self._is_operator_info_table(headers):
+ self._extract_operator_info_table(data_rows, extracted)
+ elif self._is_attendance_table(headers):
+ self._extract_attendance_table(data_rows, extracted)
+ elif self._is_vehicle_summary_table(headers):
+ self._extract_vehicle_summary_table(data_rows, extracted)
+ elif self._is_driver_table(headers):
+ self._extract_driver_table(headers, data_rows, extracted)
+ elif self._is_management_compliance_table(headers):
+ self._extract_management_table(data_rows, extracted, headers)
+
+
+ def _is_audit_info_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["audit", "date", "location", "auditor"])
+
+ def _is_operator_info_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["operator", "company", "trading", "address"])
+
+ def _is_attendance_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return "attendance" in txt
+
+ def _is_vehicle_summary_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["powered vehicles", "trailing vehicles", "drivers in bfm"])
+
+ def _is_vehicle_registration_table(self, headers: List[str]) -> bool:
+ if not headers: return False
+ ch = [_canon_header(h) for h in headers]
+ has_reg = any(
+ ("registration" in h) or re.search(r"\breg(?:istration)?\b", h) or ("reg" in h and "no" in h)
+ for h in ch
+ )
+ others = ["roadworthiness","maintenance records","daily checks","fault recording","fault repair",
+ "sub contractor","sub-contractor","weight verification","rfs suspension","suspension system maintenance",
+ "trip records","fault recording reporting on suspension system","fault reporting suspension"]
+ has_signal = any(any(tok in h for tok in others) for h in ch)
+ return has_reg and has_signal
+
+ def _is_driver_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["driver", "scheduler", "tlif", "medical"])
+
+ def _is_management_compliance_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["maintenance management", "mass management", "fatigue management"])
+
+ def _extract_vehicle_tables_from_text(self, text_pages: List[Dict], extracted: Dict):
+ # flatten text
+ lines = []
+ for p in text_pages or []:
+ for ln in re.split(r"\s*\n\s*", p.get("text", "")):
+ ln = _smart_space(ln)
+ if ln: lines.append(ln)
+
+ maint_rows, mass_rows = [], []
+ rf_pat = re.compile(r"\bRF\s*\d+\b", re.IGNORECASE)
+
+ for ln in lines:
+ # find first token that looks like a rego
+ tokens = ln.split()
+ reg = next((t for t in tokens if looks_like_plate(t)), None)
+ if not reg:
+ continue
+
+ # everything after the reg on that line
+ tail = _smart_space(ln.split(reg, 1)[1]) if reg in ln else ""
+ dates = extract_date_tokens(tail)
+ has_rf = bool(rf_pat.search(ln)) or "suspension" in ln.lower()
+
+ if has_rf:
+ rfs = (rf_pat.search(ln).group(0).upper().replace(" ", "") if rf_pat.search(ln) else "")
+ wv = dates[0] if len(dates) > 0 else ""
+ rest = dates[1:]
+ mass_rows.append({
+ "registration": reg,
+ "sub_contractor": "Yes" if " yes " in f" {ln.lower()} " else ("No" if " no " in f" {ln.lower()} " else ""),
+ "sub_comp": "Yes" if " yes " in f" {ln.lower()} " else ("No" if " no " in f" {ln.lower()} " else ""),
+ "weight_verification": wv,
+ "rfs_certification": rfs or ("N/A" if "n/a" in ln.lower() else ""),
+ "suspension_maintenance": rest[0] if len(rest) > 0 else "",
+ "trip_records": rest[1] if len(rest) > 1 else "",
+ "fault_reporting_suspension": rest[2] if len(rest) > 2 else "",
+ })
+ else:
+ # map first 5 date-like tokens in sensible order; fallbacks keep table consistent
+ rw = dates[0] if len(dates) > 0 else ""
+ mr = dates[1] if len(dates) > 1 else ""
+ dc = dates[2] if len(dates) > 2 else ""
+ fr = dates[3] if len(dates) > 3 else ""
+ rp = dates[4] if len(dates) > 4 else ""
+ maint_rows.append({
+ "registration": reg,
+ "roadworthiness": rw,
+ "maintenance_records": mr or dc,
+ "daily_checks": dc,
+ "fault_recording": fr or rp,
+ "fault_repair": rp or fr,
+ })
+
+ # ... after building maint_rows and mass_rows ...
+ vlist = extracted.setdefault("vehicles", []) # ensure it always exists
+
+ if maint_rows or mass_rows:
+ for r in maint_rows:
+ r["section"] = "maintenance"
+ vlist.append(r)
+ for r in mass_rows:
+ r["section"] = "mass"
+ vlist.append(r)
+ self.log_debug(f"Vehicle rows (text fallback): maint={len(maint_rows)} mass={len(mass_rows)} total={len(vlist)}")
+ else:
+ self.log_debug("Vehicle rows (text fallback): none detected.")
+
+
+ # ───────────────────────────── simple extractors (spacing applied) ─────────────────────────────
+ def _extract_audit_info_table(self, data_rows: List[List], extracted: Dict):
+ ai = extracted.setdefault("audit_info", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ val = _smart_space(" ".join(str(c).strip() for c in row[1:] if str(c).strip()))
+ if not val: continue
+ if "date" in key and "audit" in key: ai["date_of_audit"] = val
+ elif "location" in key: ai["location"] = val
+ elif "auditor" in key and "name" in key: ai["auditor_name"] = val
+ elif "matrix" in key: ai["matrix_id"] = val
+
+ def _extract_operator_info_table(self, data_rows: List[List], extracted: Dict):
+ oi = extracted.setdefault("operator_info", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ val = _smart_space(" ".join(str(c).strip() for c in row[1:] if str(c).strip()))
+ if not val: continue
+ if "operator" in key and "name" in key and _looks_like_company(val): oi["name"] = val
+ elif "trading" in key: oi["trading_name"] = val
+ elif "australian" in key and "company" in key: oi["acn"] = val
+ elif "business" in key and "address" in key: oi["business_address"] = val
+ elif "postal" in key and "address" in key: oi["postal_address"] = val
+ elif "email" in key: oi["email"] = val
+ elif "telephone" in key or "phone" in key: oi["phone"] = val
+ elif "manual" in key or ("nhvas" in key and "manual" in key) or "developed" in key:
+ if _looks_like_manual_value(val):
+ oi["manual"] = val
+
+ def _extract_attendance_table(self, data_rows: List[List], extracted: Dict):
+ lst = []
+ for row in data_rows:
+ if not row: continue
+ cells = [str(c).strip() for c in row if str(c).strip()]
+ if not cells: continue
+ lst.append(_smart_space(" ".join(cells)))
+ if lst:
+ extracted["attendance"] = lst
+
+ def _extract_vehicle_summary_table(self, data_rows: List[List], extracted: Dict):
+ vs = extracted.setdefault("vehicle_summary", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ value = ""
+ for c in row[1:]:
+ if str(c).strip():
+ value = _smart_space(str(c).strip()); break
+ if not value: continue
+ if "powered" in key and "vehicle" in key: vs["powered_vehicles"] = value
+ elif "trailing" in key and "vehicle" in key: vs["trailing_vehicles"] = value
+ elif "drivers" in key and "bfm" in key: vs["drivers_bfm"] = value
+ elif "drivers" in key and "afm" in key: vs["drivers_afm"] = value
+
+ # ▶▶ REPLACED: column mapping by headers
+ def _extract_vehicle_registration_table(self, headers, rows, extracted, page_num):
+ ch = [_canon_header(h) for h in (headers or [])]
+ alias = _map_header_indices(headers or [])
+
+ # header indices (may be misaligned vs data; that's OK, we’ll search near them)
+ def idx_of(*needles):
+ for i, h in enumerate(ch):
+ if all(n in h for n in needles): return i
+ return None
+
+ reg_i = alias.get("registration") or idx_of("registration number") or idx_of("registration") or idx_of("reg","no")
+ rw_i = alias.get("roadworthiness") or idx_of("roadworthiness")
+ maint_i = alias.get("maintenance_records") or idx_of("maintenance","records")
+ daily_i = alias.get("daily_checks") or idx_of("daily","check")
+ fr_i = alias.get("fault_recording") or idx_of("fault","recording")
+ rep_i = alias.get("fault_repair") or idx_of("fault","repair")
+
+ weight_i = alias.get("weight_verification") or idx_of("weight","verification")
+ rfs_i = alias.get("rfs_certification") or idx_of("rfs","certification")
+ susp_i = alias.get("suspension_maintenance") or idx_of("suspension","maintenance")
+ trip_i = alias.get("trip_records") or idx_of("trip","records")
+ frs_i = alias.get("fault_reporting_suspension") or idx_of("fault","reporting","suspension")
+
+ # classify table type by header signals
+ is_maint = any("roadworthiness" in h or "maintenance records" in h or ("daily" in h and "check" in h) or "fault repair" in h for h in ch)
+ is_mass = any("weight verification" in h or "rfs" in h or "suspension system" in h or "trip records" in h or "reporting on suspension" in h for h in ch)
+
+ maint_rows = extracted.setdefault("_maint_rows", []) if is_maint else None
+ added = 0
+
+ for r in rows or []:
+ # tolerant plate pick (handles misaligned columns)
+ reg = self._pick_nearby(r, reg_i, "plate", window=4)
+ if not reg or not looks_like_plate(reg):
+ continue
+
+ # collect values using tolerant picks
+ if is_maint:
+ rw = self._pick_nearby(r, rw_i, "date", window=4)
+ mr = self._pick_nearby(r, maint_i, "date", window=4)
+ dc = self._pick_nearby(r, daily_i, "date", window=4)
+ fr = self._pick_nearby(r, fr_i, "date", window=4)
+ rep = self._pick_nearby(r, rep_i, "date", window=4)
+
+ # sensible fallbacks
+ if not mr and dc: mr = dc
+ if not rep and fr: rep = fr
+ if not fr and rep: fr = rep
+
+ else: # mass or mixed
+ wv = self._pick_nearby(r, weight_i, "date", window=4)
+ rfs = self._pick_nearby(r, rfs_i, "rf", window=5)
+ sm = self._pick_nearby(r, susp_i, "date", window=4)
+ tr = self._pick_nearby(r, trip_i, "date", window=4)
+ frs = self._pick_nearby(r, frs_i, "date", window=4)
+ yn1 = self._pick_nearby(r, idx_of("sub","contractor"), "yn", window=3) or ""
+ yn2 = self._pick_nearby(r, idx_of("sub contracted vehicles statement of compliance"), "yn", window=3) or yn1
+
+ # merge into vehicle map
+ v = self._vehicle_by_reg.get(reg)
+ if v is None:
+ v = {"registration": reg}
+ self._vehicle_by_reg[reg] = v
+ added += 1
+
+ if is_maint:
+ v["seen_in_maintenance"] = True
+ if rw: v.setdefault("roadworthiness", rw)
+ if mr: v.setdefault("maintenance_records", mr)
+ if dc: v.setdefault("daily_checks", dc)
+ if fr: v.setdefault("fault_recording", fr)
+ if rep: v.setdefault("fault_repair", rep)
+
+ if maint_rows is not None:
+ maint_rows.append({
+ "registration": reg,
+ "roadworthiness": rw,
+ "maintenance_records": mr or dc,
+ "daily_checks": dc,
+ "fault_recording": fr or rep,
+ "fault_repair": rep or fr,
+ })
+ else:
+ v["seen_in_mass"] = True
+ if yn1: v.setdefault("sub_contractor", yn1)
+ if yn2: v.setdefault("sub_comp", yn2)
+ if wv: v.setdefault("weight_verification", wv)
+ if rfs: v.setdefault("rfs_certification", _smart_space(rfs).upper().replace(" ", ""))
+ if sm: v.setdefault("suspension_maintenance", sm)
+ if tr: v.setdefault("trip_records", tr)
+ if frs: v.setdefault("fault_reporting_suspension", frs)
+
+ extracted["vehicles"] = list(self._vehicle_by_reg.values())
+ return added
+
+ def _extract_driver_table(self, headers: List[str], data_rows: List[List], extracted: Dict):
+ """Header-driven extraction for Driver / Scheduler Records."""
+ drivers = []
+ ch = [_canon_header(h) for h in headers or []]
+
+ # helpers
+ def find_col(needles: list[str]) -> Optional[int]:
+ for i, h in enumerate(ch):
+ if any(n in h for n in needles):
+ return i
+ return None
+
+ def find_col_rx(patterns: list[str]) -> Optional[int]:
+ for i, h in enumerate(ch):
+ if any(re.search(p, h) for p in patterns):
+ return i
+ return None
+
+ name_idx = find_col_rx([r"\bdriver\s*/\s*scheduler\s*name\b",
+ r"\bdriver\s+name\b", r"\bscheduler\s+name\b", r"\bname\b"])
+ tlif_d_idx = find_col(["driver tlif"])
+ tlif_s_idx = find_col(["scheduler tlif"])
+ medical_idx= find_col(["medical", "expiry"])
+ roster_idx = find_col_rx([r"\broster\b", r"\bsafe\s+driving\s+plan\b", r"\bschedule\b(?!r\b)"])
+ fit_idx = find_col(["fit for duty"])
+ diary_idx = find_col(["work diary", "electronic work diary", "page numbers"])
+
+ for row in data_rows:
+ if not row:
+ continue
+
+ name = None
+ if name_idx is not None and name_idx < len(row):
+ name = _smart_space(str(row[name_idx]).strip())
+ if not name:
+ continue
+
+ d = {"name": name}
+
+ if tlif_d_idx is not None and tlif_d_idx < len(row):
+ d["driver_tlif"] = _smart_space(str(row[tlif_d_idx]).strip())
+ if tlif_s_idx is not None and tlif_s_idx < len(row):
+ d["scheduler_tlif"] = _smart_space(str(row[tlif_s_idx]).strip())
+ if medical_idx is not None and medical_idx < len(row):
+ d["medical_expiry"] = _smart_space(str(row[medical_idx]).strip())
+
+ # Roster/Schedule/SDP: prefer the detected column; accept only date/range-like, not the name
+ if roster_idx is not None and roster_idx < len(row):
+ raw_roster = _smart_space(str(row[roster_idx]).strip())
+ if raw_roster and re.search(r"[0-9/–-]", raw_roster) and raw_roster.lower() != name.lower():
+ d["roster_schedule"] = raw_roster
+
+ # Fallback: scan the row for the first date/range-like cell that's not the name cell
+ if "roster_schedule" not in d:
+ for j, cell in enumerate(row):
+ if j == name_idx:
+ continue
+ s = _smart_space(str(cell).strip())
+ if s and re.search(r"[0-9/–-]", s) and s.lower() != name.lower():
+ d["roster_schedule"] = s
+ break
+
+ if fit_idx is not None and fit_idx < len(row):
+ d["fit_for_duty"] = _smart_space(str(row[fit_idx]).strip())
+ if diary_idx is not None and diary_idx < len(row):
+ d["work_diary"] = _smart_space(str(row[diary_idx]).strip())
+
+ drivers.append(d)
+
+ if drivers:
+ extracted["drivers_detailed"] = drivers
+ self.log_debug(f"Driver rows extracted (header-based): {len(drivers)}")
+
+
+ def _extract_management_table(self, data_rows: List[List], extracted: Dict, headers: List[str]):
+ txt = " ".join(str(h) for h in headers).lower()
+ comp = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ std = str(row[0]).strip()
+ val = _smart_space(str(row[1]).strip())
+ if std.startswith("Std") and val:
+ comp[std] = val
+ if comp:
+ if "maintenance" in txt: extracted["maintenance_compliance"] = comp
+ elif "mass" in txt: extracted["mass_compliance"] = comp
+ elif "fatigue" in txt: extracted["fatigue_compliance"] = comp
+
+ def _extract_text_content(self, text_pages: List[Dict], extracted: Dict):
+ all_text = " ".join(page.get("text", "") for page in text_pages)
+ all_text = _smart_space(all_text)
+
+ # business summary
+ patt = [
+ r"Nature of the Operators? Business.*?:\s*(.*?)(?:Accreditation Number|Expiry Date|$)",
+ r"Nature of.*?Business.*?Summary.*?:\s*(.*?)(?:Accreditation|$)"
+ ]
+ for p in patt:
+ m = re.search(p, all_text, re.IGNORECASE | re.DOTALL)
+ if m:
+ txt = re.sub(r'\s+', ' ', m.group(1).strip())
+ txt = re.sub(r'\s*(Accreditation Number.*|Expiry Date.*)', '', txt, flags=re.IGNORECASE)
+ if len(txt) > 50:
+ extracted["business_summary"] = txt
+ break
+
+ # audit conducted date
+ for p in [
+ r"Audit was conducted on\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})",
+ r"DATE\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})",
+ r"AUDITOR SIGNATURE\s+DATE\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ extracted["audit_conducted_date"] = _smart_space(m.group(1).strip())
+ break
+
+ # print accreditation name
+ for p in [
+ r"\(print accreditation name\)\s*([A-Za-z0-9\s&().,'/\-]+?)(?:\s+DOES|\s+does|\n|$)",
+ r"print accreditation name.*?\n\s*([A-Za-z0-9\s&().,'/\-]+?)(?:\s+DOES|\s+does|\n|$)"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ extracted["print_accreditation_name"] = _smart_space(m.group(1).strip())
+ break
+
+ # numbers in text (optional)
+ for p in [
+ r"Number of powered vehicles\s+(\d+)",
+ r"powered vehicles\s+(\d+)",
+ r"Number of trailing vehicles\s+(\d+)",
+ r"trailing vehicles\s+(\d+)",
+ r"Number of drivers in BFM\s+(\d+)",
+ r"drivers in BFM\s+(\d+)"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ val = m.group(1)
+ if "powered" in p: extracted.setdefault("vehicle_summary", {})["powered_vehicles"] = val
+ elif "trailing" in p: extracted.setdefault("vehicle_summary", {})["trailing_vehicles"] = val
+ elif "bfm" in p.lower(): extracted.setdefault("vehicle_summary", {})["drivers_bfm"] = val
+
+ def _extract_detailed_management_data(self, extracted_data: Dict, extracted: Dict):
+ all_tables = extracted_data.get("all_tables", [])
+ for table in all_tables:
+ headers = table.get("headers", [])
+ data_rows = table.get("data", [])
+ page_num = table.get("page", 0)
+ if self._has_details_column(headers):
+ section = self._identify_management_section(headers)
+ if section:
+ self._extract_management_details(data_rows, extracted, section)
+ elif 6 <= page_num <= 15:
+ self._extract_summary_by_content(data_rows, headers, extracted, page_num)
+
+ def _extract_summary_by_content(self, data_rows: List[List], headers: List[str], extracted: Dict, page_num: int):
+ section_type = "maintenance" if 6 <= page_num <= 9 else "mass" if 10 <= page_num <= 12 else "fatigue" if 13 <= page_num <= 15 else None
+ if not section_type: return
+ details_key = f"{section_type}_summary_details"
+ extracted[details_key] = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ standard = str(row[0]).strip()
+ details = _smart_space(str(row[1]).strip())
+ if standard.startswith("Std") and details and len(details) > 10:
+ m = re.search(r"Std\s+(\d+)\.\s*([^(]+)", standard)
+ if m:
+ key = f"Std {m.group(1)}. {m.group(2).strip()}"
+ extracted[details_key][key] = details
+
+ def _has_details_column(self, headers: List[str]) -> bool:
+ return "details" in " ".join(str(h) for h in headers).lower()
+
+ def _identify_management_section(self, headers: List[str]) -> Optional[str]:
+ txt = " ".join(str(h) for h in headers).lower()
+ if "maintenance" in txt: return "maintenance"
+ if "mass" in txt: return "mass"
+ if "fatigue" in txt: return "fatigue"
+ return None
+
+ def _extract_management_details(self, data_rows: List[List], extracted: Dict, section: str):
+ details_key = f"{section}_details"
+ extracted[details_key] = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ standard = str(row[0]).strip()
+ details = _smart_space(str(row[1]).strip())
+ if standard.startswith("Std") and details and details != "V" and len(details) > 10:
+ m = re.search(r"Std\s+\d+\.\s*([^(]+)", standard)
+ if m:
+ extracted[details_key][m.group(1).strip()] = details
+
+ def _extract_vehicle_driver_data(self, extracted_data: Dict, extracted: Dict):
+ vehicle_regs = extracted_data.get("vehicle_registrations", [])
+ if vehicle_regs:
+ extracted["vehicle_registrations"] = vehicle_regs
+ driver_records = extracted_data.get("driver_records", [])
+ if driver_records:
+ extracted["driver_records"] = driver_records
+
+ # Add this method inside your NHVASMerger class, with proper indentation
+ # Place it after the _extract_vehicle_driver_data method
+
+ def map_vehicle_registration_arrays(self, pdf_extracted: Dict, merged: Dict):
+ """Extract and map vehicle registration data (Maintenance + Mass) to DOCX arrays."""
+ vehicles_src = []
+
+ # Prefer rows we parsed ourselves (header-based). Fall back to curated list if present.
+ if "vehicles" in pdf_extracted and isinstance(pdf_extracted["vehicles"], list):
+ vehicles_src = pdf_extracted["vehicles"]
+ elif "vehicle_registrations" in pdf_extracted and isinstance(pdf_extracted["vehicle_registrations"], list):
+ # Normalize curated structure (list of dicts with keys like 'registration_number', etc.)
+ for row in pdf_extracted["vehicle_registrations"]:
+ if not isinstance(row, dict):
+ continue
+ v = {
+ "registration": _smart_space(row.get("registration_number") or row.get("registration") or ""),
+ # Maintenance table columns (names as seen in curated JSON)
+ "roadworthiness": _smart_space(row.get("roadworthiness_certificates", "")),
+ "maintenance_records": _smart_space(row.get("maintenance_records", "")),
+ "daily_checks": _smart_space(row.get("daily_checks", "")),
+ "fault_recording": _smart_space(row.get("fault_recording_reporting", "")),
+ "fault_repair": _smart_space(row.get("fault_repair", "")),
+ # Mass table columns (in case the curated list ever includes them)
+ "sub_contractor": _smart_space(row.get("sub_contractor", "")),
+ "sub_comp": _smart_space(row.get("sub_contracted_vehicles_statement_of_compliance", "")),
+ "weight_verification": _smart_space(row.get("weight_verification_records", "")),
+ "rfs_certification": _smart_space(row.get("rfs_suspension_certification", row.get("rfs_suspension_certification_#", ""))),
+ "suspension_maintenance": _smart_space(row.get("suspension_system_maintenance", "")),
+ "trip_records": _smart_space(row.get("trip_records", "")),
+ "fault_reporting_suspension": _smart_space(row.get("fault_recording_reporting_on_suspension_system", "")),
+ }
+ if v["registration"]:
+ vehicles_src.append(v)
+
+ if not vehicles_src:
+ return # nothing to map
+
+ # Build column arrays
+ regs = []
+ roadworthiness = []
+ maint_records = []
+ daily_checks = []
+ fault_recording = []
+ fault_repair = []
+
+ sub_contractors = []
+ weight_verification = []
+ rfs_certification = []
+ suspension_maintenance = []
+ trip_records = []
+ fault_reporting_suspension = []
+
+ for v in vehicles_src:
+ reg = _smart_space(v.get("registration", "")).strip()
+ if not reg:
+ continue
+ regs.append(reg)
+
+ roadworthiness.append(_smart_space(v.get("roadworthiness", "")).strip())
+ maint_records.append(_smart_space(v.get("maintenance_records", "")).strip())
+ daily_checks.append(_smart_space(v.get("daily_checks", "")).strip())
+ fault_recording.append(_smart_space(v.get("fault_recording", "")).strip())
+ fault_repair.append(_smart_space(v.get("fault_repair", "")).strip())
+
+ sub_contractors.append(_smart_space(v.get("sub_contractor", "")).strip())
+ weight_verification.append(_smart_space(v.get("weight_verification", "")).strip())
+ rfs_certification.append(_smart_space(v.get("rfs_certification", "")).strip())
+ suspension_maintenance.append(_smart_space(v.get("suspension_maintenance", "")).strip())
+ trip_records.append(_smart_space(v.get("trip_records", "")).strip())
+ fault_reporting_suspension.append(_smart_space(v.get("fault_reporting_suspension", "")).strip())
+
+ # Update Maintenance table arrays (if present in template)
+ if "Vehicle Registration Numbers Maintenance" in merged and regs:
+ m = merged["Vehicle Registration Numbers Maintenance"]
+ m["Registration Number"] = regs
+ m["Roadworthiness Certificates"] = roadworthiness
+ m["Maintenance Records"] = maint_records
+ m["Daily Checks"] = daily_checks
+ m["Fault Recording/ Reporting"] = fault_recording
+ m["Fault Repair"] = fault_repair
+
+ # Update Mass table arrays (if present in template)
+ if "Vehicle Registration Numbers Mass" in merged and regs:
+ ms = merged["Vehicle Registration Numbers Mass"]
+ ms["Registration Number"] = regs
+ ms["Sub contractor"] = sub_contractors
+ ms["Weight Verification Records"] = weight_verification
+ ms["RFS Suspension Certification #"] = rfs_certification
+ ms["Suspension System Maintenance"] = suspension_maintenance
+ ms["Trip Records"] = trip_records
+ ms["Fault Recording/ Reporting on Suspension System"] = fault_reporting_suspension
+
+ self.log_debug(f"Updated vehicle registration arrays for {len(regs)} vehicles")
+ # ───────────────────────────── map to DOCX (apply spacing + safe fallbacks) ─────────────────────────────
+ def map_to_docx_structure(self, pdf_extracted: Dict, docx_data: Dict, pdf_data: Dict) -> Dict:
+ merged = json.loads(json.dumps(docx_data))
+
+ # Audit Information
+ if "audit_info" in pdf_extracted and "Audit Information" in merged:
+ ai = pdf_extracted["audit_info"]
+ if ai.get("date_of_audit"):
+ merged["Audit Information"]["Date of Audit"] = [_smart_space(ai["date_of_audit"])]
+ if ai.get("location"):
+ merged["Audit Information"]["Location of audit"] = [_smart_space(ai["location"])]
+ if ai.get("auditor_name"):
+ merged["Audit Information"]["Auditor name"] = [_smart_space(ai["auditor_name"])]
+ if ai.get("matrix_id"):
+ merged["Audit Information"]["Audit Matrix Identifier (Name or Number)"] = [_smart_space(ai["matrix_id"])]
+
+ # Operator Information
+ if "operator_info" in pdf_extracted and "Operator Information" in merged:
+ op = pdf_extracted["operator_info"]
+ if op.get("name") and _looks_like_company(op["name"]):
+ merged["Operator Information"]["Operator name (Legal entity)"] = [_smart_space(op["name"])]
+ if op.get("trading_name"):
+ merged["Operator Information"]["Registered trading name/s"] = [_smart_space(op["trading_name"])]
+ if op.get("acn"):
+ merged["Operator Information"]["Australian Company Number"] = [_smart_space(op["acn"])]
+ if op.get("manual"):
+ merged["Operator Information"]["NHVAS Manual (Policies and Procedures) developed by"] = [_smart_space(op["manual"])]
+
+ # Contact details
+ if "operator_info" in pdf_extracted and "Operator contact details" in merged:
+ op = pdf_extracted["operator_info"]
+ if op.get("business_address"):
+ merged["Operator contact details"]["Operator business address"] = [_smart_space(op["business_address"])]
+ if op.get("postal_address"):
+ merged["Operator contact details"]["Operator Postal address"] = [_smart_space(op["postal_address"])]
+ if op.get("email"):
+ merged["Operator contact details"]["Email address"] = [op["email"]]
+ if op.get("phone"):
+ merged["Operator contact details"]["Operator Telephone Number"] = [_smart_space(op["phone"])]
+
+ # Attendance
+ if "attendance" in pdf_extracted and "Attendance List (Names and Position Titles)" in merged:
+ merged["Attendance List (Names and Position Titles)"]["Attendance List (Names and Position Titles)"] = _clean_list(pdf_extracted["attendance"])
+
+ # Business summary
+ if "business_summary" in pdf_extracted and "Nature of the Operators Business (Summary)" in merged:
+ merged["Nature of the Operators Business (Summary)"]["Nature of the Operators Business (Summary):"] = [_smart_space(pdf_extracted["business_summary"])]
+
+ # Vehicle summary
+ if "vehicle_summary" in pdf_extracted:
+ vs = pdf_extracted["vehicle_summary"]
+ if "Accreditation Vehicle Summary" in merged:
+ if vs.get("powered_vehicles"):
+ merged["Accreditation Vehicle Summary"]["Number of powered vehicles"] = [vs["powered_vehicles"]]
+ if vs.get("trailing_vehicles"):
+ merged["Accreditation Vehicle Summary"]["Number of trailing vehicles"] = [vs["trailing_vehicles"]]
+ if "Accreditation Driver Summary" in merged:
+ if vs.get("drivers_bfm"):
+ merged["Accreditation Driver Summary"]["Number of drivers in BFM"] = [vs["drivers_bfm"]]
+ if vs.get("drivers_afm"):
+ merged["Accreditation Driver Summary"]["Number of drivers in AFM"] = [vs["drivers_afm"]]
+
+ # Summary sections (unchanged behavior)
+ summary_maps = self.build_summary_maps(pdf_data)
+ for section_name, std_map in summary_maps.items():
+ if section_name in merged and std_map:
+ for detail_key, details_list in std_map.items():
+ if detail_key in merged[section_name]:
+ merged[section_name][detail_key] = details_list
+ continue
+ for docx_key in list(merged[section_name].keys()):
+ m1 = re.search(r"Std\s+(\d+)", detail_key)
+ m2 = re.search(r"Std\s+(\d+)", docx_key)
+ if m1 and m2 and m1.group(1) == m2.group(1):
+ merged[section_name][docx_key] = details_list
+ break
+
+ # Vehicle registration arrays via consolidated builder
+ sections = build_vehicle_sections(pdf_extracted)
+ if "Vehicle Registration Numbers Maintenance" in merged:
+ merged["Vehicle Registration Numbers Maintenance"].update(
+ sections["Vehicle Registration Numbers Maintenance"]
+ )
+ if "Vehicle Registration Numbers Mass" in merged:
+ merged["Vehicle Registration Numbers Mass"].update(
+ sections["Vehicle Registration Numbers Mass"]
+ )
+
+
+ # replace the whole Drivers/Scheduler block with:
+ if "drivers_detailed" in pdf_extracted and "Driver / Scheduler Records Examined" in merged:
+ drivers = pdf_extracted["drivers_detailed"]
+
+ def _looks_like_range(s):
+ return bool(re.search(r"[0-9]{1,2}[/-]", s or ""))
+
+ merged["Driver / Scheduler Records Examined"]["Roster / Schedule / Safe Driving Plan (Date Range)"] = [d.get("roster_schedule","") for d in drivers]
+ merged["Driver / Scheduler Records Examined"]["Fit for Duty Statement Completed (Yes/No)"] = [d.get("fit_for_duty","") for d in drivers]
+ merged["Driver / Scheduler Records Examined"]["Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)"] = [d.get("work_diary","") for d in drivers]
+
+
+ # --- Print accreditation name (robust, no UnboundLocalError) ---
+ if "Print accreditation name" in merged:
+ acc_name = "" # init
+ acc_name = _smart_space(pdf_extracted.get("print_accreditation_name") or "")
+ if not acc_name:
+ oi = pdf_extracted.get("operator_info") or {}
+ acc_name = _smart_space(oi.get("name") or "") or _smart_space(oi.get("trading_name") or "")
+ if acc_name:
+ merged["Print accreditation name"]["(print accreditation name)"] = [acc_name]
+
+ # Audit Declaration dates: prefer explicit extracted date; fallback to audit_info; ignore literal "Date"
+ if "Audit Declaration dates" in merged:
+ def _real_date(s: Optional[str]) -> bool:
+ return bool(s and re.search(r"\d", s) and not re.fullmatch(r"date", s.strip(), re.I))
+
+ val = pdf_extracted.get("audit_conducted_date")
+ if not _real_date(val):
+ val = (pdf_extracted.get("audit_info", {}) or {}).get("date_of_audit")
+
+ if _real_date(val):
+ merged["Audit Declaration dates"]["Audit was conducted on"] = [_smart_space(val)]
+
+
+ # Operator Declaration: page 22 image missing → derive from first Attendance "Name - Title"
+ if "Operator Declaration" in merged:
+ # If an explicit operator declaration exists, use it
+ if "operator_declaration" in pdf_extracted:
+ od = pdf_extracted["operator_declaration"]
+ pn = _smart_space(od.get("print_name", ""))
+ pt = _smart_space(od.get("position_title", ""))
+ if pn:
+ merged["Operator Declaration"]["Print Name"] = [pn]
+ if pt:
+ merged["Operator Declaration"]["Position Title"] = [pt]
+ else:
+ # Fallback: first "Name - Title" from Attendance
+ nt = self._first_attendance_name_title(pdf_extracted.get("attendance", []))
+ if nt:
+ merged["Operator Declaration"]["Print Name"] = [nt[0]]
+ merged["Operator Declaration"]["Position Title"] = [nt[1]]
+
+
+ # Paragraphs: fill company name for the 3 management headings; set the 2 dates
+ if "paragraphs" in merged:
+ paras = merged["paragraphs"]
+
+ audit_date = (
+ pdf_extracted.get("audit_conducted_date")
+ or pdf_extracted.get("audit_info", {}).get("date_of_audit")
+ )
+
+ # Prefer accreditation name, else operator legal name, else trading name
+ company_name = (
+ _smart_space(pdf_extracted.get("print_accreditation_name") or "")
+ or _smart_space(pdf_extracted.get("operator_info", {}).get("name") or "")
+ or _smart_space(pdf_extracted.get("operator_info", {}).get("trading_name") or "")
+ )
+
+ # Update the three layered headings
+ for key in ("MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"):
+ if key in paras and company_name:
+ paras[key] = [company_name]
+
+ # Second-last page: date under page heading
+ if "NHVAS APPROVED AUDITOR DECLARATION" in paras and audit_date:
+ paras["NHVAS APPROVED AUDITOR DECLARATION"] = [_smart_space(audit_date)]
+
+ # Last page: date under long acknowledgement paragraph
+ ack_key = ("I hereby acknowledge and agree with the findings detailed in this NHVAS Audit Summary Report. "
+ "I have read and understand the conditions applicable to the Scheme, including the NHVAS Business Rules and Standards.")
+ if ack_key in paras and audit_date:
+ paras[ack_key] = [_smart_space(audit_date)]
+
+ self._force_fill_maintenance_from_tables(pdf_data, merged)
+ return merged
+
+ # ───────────────────────────── merge & CLI (unchanged) ─────────────────────────────
+ def merge_pdf_to_docx(self, docx_data: Dict, pdf_data: Dict) -> Dict:
+ self.log_debug("Starting comprehensive PDF extraction...")
+ pdf_extracted = self.extract_from_pdf_comprehensive(pdf_data)
+ self.log_debug(f"Extracted PDF data keys: {list(pdf_extracted.keys())}")
+
+ self.log_debug("Mapping to DOCX structure...")
+ merged_data = self.map_to_docx_structure(pdf_extracted, docx_data, pdf_data)
+
+ for section_name, section_data in docx_data.items():
+ if isinstance(section_data, dict):
+ for label in section_data:
+ if (section_name in merged_data and
+ label in merged_data[section_name] and
+ merged_data[section_name][label] != docx_data[section_name][label]):
+ print(f"✓ Updated {section_name}.{label}: {merged_data[section_name][label]}")
+ return merged_data
+
+ def process_files(self, docx_file: str, pdf_file: str, output_file: str):
+ try:
+ print(f"Loading DOCX JSON from: {docx_file}")
+ with open(docx_file, 'r', encoding='utf-8') as f:
+ docx_data = json.load(f)
+ print(f"Loading PDF JSON from: {pdf_file}")
+ with open(pdf_file, 'r', encoding='utf-8') as f:
+ pdf_data = json.load(f)
+
+ print("Merging PDF data into DOCX structure...")
+ merged_data = self.merge_pdf_to_docx(docx_data, pdf_data)
+
+ print(f"Saving merged data to: {output_file}")
+ with open(output_file, 'w', encoding='utf-8') as f:
+ json.dump(merged_data, f, indent=2, ensure_ascii=False)
+
+ print("✅ Merge completed successfully!")
+ return merged_data
+ except Exception as e:
+ print(f"❌ Error processing files: {str(e)}")
+ import traceback
+ traceback.print_exc()
+ raise
+
+def main():
+ if len(sys.argv) != 4:
+ print("Usage: python nhvas_merger.py ")
+ print("Example: python nhvas_merger.py docx_template.json pdf_extracted.json merged_output.json")
+ sys.exit(1)
+
+ docx_file = sys.argv[1]
+ pdf_file = sys.argv[2]
+ output_file = sys.argv[3]
+
+ for file_path in [docx_file, pdf_file]:
+ if not Path(file_path).exists():
+ print(f"❌ File not found: {file_path}")
+ sys.exit(1)
+
+ merger = NHVASMerger()
+ merger.process_files(docx_file, pdf_file, output_file)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/space-pdf/updated_word.py b/space-pdf/updated_word.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc13cf6770ac2717695632cfd12464a8a2206df3
--- /dev/null
+++ b/space-pdf/updated_word.py
@@ -0,0 +1,1189 @@
+#!/usr/bin/env python3
+# update_docx_from_json.py
+import sys, json, re
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional
+from docx import Document
+from docx.shared import RGBColor, Pt # add Pt
+from docx.table import _Cell, Table
+from docx.text.paragraph import Paragraph
+from copy import deepcopy
+from docx.oxml.ns import qn
+from docx.oxml.table import CT_Tbl
+from docx.oxml.text.paragraph import CT_P
+
+BLACK = RGBColor(0, 0, 0)
+RED = RGBColor(0xFF, 0x00, 0x00)
+
+# ----------------------------- text helpers -----------------------------
+def _find_table_with_headers(doc: Document, must_have: list[str]) -> Optional[Table]:
+ for t in doc.tables:
+ if not t.rows:
+ continue
+ head = canon(" ".join(cell_text(c) for c in t.rows[0].cells))
+ if all(canon_label(x) in head for x in must_have):
+ return t
+ return None
+
+def ensure_auditor_decl_headers(doc: Document) -> bool:
+ """
+ Second-last page table under 'NHVAS APPROVED AUDITOR DECLARATION'.
+ Force the HEADER row to read exactly:
+ [ Print Name | NHVR or Exemplar Global Auditor Registration Number ]
+ Never touch the bottom (values) row.
+ """
+ changed = False
+ expected_left = "Print Name"
+ expected_right = "NHVR or Exemplar Global Auditor Registration Number"
+
+ for t in doc.tables:
+ if not t.rows or not t.rows[0].cells:
+ continue
+ # must look like the auditor table: header left says "Print Name", 2+ cols, 2+ rows
+ head_left = canon_label(cell_text(t.rows[0].cells[0]))
+ if head_left == "print name" and len(t.rows[0].cells) >= 2 and len(t.rows) >= 2:
+ # fix left header if needed
+ if canon_label(cell_text(t.rows[0].cells[0])) != canon_label(expected_left) or \
+ any(is_red_run(r) for p in t.rows[0].cells[0].paragraphs for r in p.runs):
+ _set_cell_text_black(t.rows[0].cells[0], expected_left)
+ changed = True
+ # unconditionally set the RIGHT header text (this is where "Peter Sheppard" was sitting)
+ if canon_label(cell_text(t.rows[0].cells[1])) != canon_label(expected_right) or \
+ any(is_red_run(r) for p in t.rows[0].cells[1].paragraphs for r in p.runs):
+ _set_cell_text_black(t.rows[0].cells[1], expected_right)
+ changed = True
+ # found and fixed the table; no need to continue
+ break
+
+ return changed
+
+
+def fill_operator_declaration(doc: Document, print_name: str, position_title: str) -> bool:
+ """Last page table: write values ONLY into the bottom row (red placeholders)."""
+ t = _find_table_with_headers(doc, ["Print Name", "Position Title"])
+ if not t or len(t.rows) < 2 or len(t.rows[0].cells) < 2:
+ return False
+ bot_left = t.rows[1].cells[0]
+ bot_right = t.rows[1].cells[1]
+
+ # only replace if that cell has a red placeholder
+ if any(is_red_run(r) for p in bot_left.paragraphs for r in p.runs):
+ _set_cell_text_black(bot_left, print_name)
+ if any(is_red_run(r) for p in bot_right.paragraphs for r in p.runs):
+ _set_cell_text_black(bot_right, position_title)
+ return True
+
+def find_heading_index_from_end(doc: Document, heading: str) -> Optional[int]:
+ key = canon(heading)
+ allp = iter_paragraphs(doc)
+ for i in range(len(allp) - 1, -1, -1):
+ if key in canon(para_text(allp[i])):
+ return i
+ return None
+
+def set_date_by_heading_from_end(doc: Document, heading: str, date_text: str, max_scan: int = 60) -> bool:
+ """Find the LAST occurrence of `heading`, then replace the FIRST red run in the next paragraphs."""
+ if not date_text:
+ return False
+ allp = iter_paragraphs(doc)
+ idx = find_heading_index_from_end(doc, heading)
+ if idx is None:
+ return False
+ for p in allp[idx + 1 : min(idx + 1 + max_scan, len(allp))]:
+ if replace_red_in_paragraph(p, date_text): # writes in black
+ return True
+ return False
+
+def set_date_by_paragraph_from_end(doc: Document, paragraph_text: str, date_text: str, max_scan: int = 60) -> bool:
+ """Find the LAST paragraph matching `paragraph_text`, then set the FIRST red run after it."""
+ if not date_text:
+ return False
+ key = canon(paragraph_text)
+ allp = iter_paragraphs(doc)
+ hit = None
+ for i in range(len(allp) - 1, -1, -1):
+ if key in canon(para_text(allp[i])):
+ hit = i
+ break
+ if hit is None:
+ return False
+ # date placeholder is on the LAST page, right after this long paragraph
+ for p in allp[hit + 1 : min(hit + 1 + max_scan, len(allp))]:
+ if replace_red_in_paragraph(p, date_text): # writes in black
+ return True
+ return False
+
+def set_layer3_name_after_management_heading(doc: Document, mid_heading: str, allowed_prev_titles: List[str], name: str) -> bool:
+ if not name:
+ return False
+
+ allp = iter_paragraphs(doc)
+ wrote = False
+ mid = canon(mid_heading)
+ allowed_prev = {canon(t) for t in allowed_prev_titles}
+
+ for i, p in enumerate(allp):
+ if canon(para_text(p)) != mid:
+ continue
+
+ # previous non-empty must be one of the allowed titles
+ j = i - 1
+ while j >= 0 and not nz(para_text(allp[j])):
+ j -= 1
+ if j < 0 or canon(para_text(allp[j])) not in allowed_prev:
+ continue
+
+ # next non-empty is the 3rd line we overwrite
+ k = i + 1
+ while k < len(allp) and not nz(para_text(allp[k])):
+ k += 1
+ if k >= len(allp):
+ continue
+
+ # compute target size from the middle heading; fall back to a sensible bump
+ target_size = _para_effective_font_size(allp[i]) or Pt(16)
+
+ _clear_para_and_write_black(allp[k], name)
+
+ # apply size to all runs explicitly (overrides style)
+ for r in allp[k].runs:
+ r.font.size = target_size
+
+ wrote = True
+
+ return wrote
+
+def _para_effective_font_size(p: Paragraph):
+ # try explicit run sizes first
+ for r in p.runs:
+ if r.font.size:
+ return r.font.size
+ # then the paragraph style
+ if p.style and p.style.font and p.style.font.size:
+ return p.style.font.size
+ return None
+
+# --- helpers for summary tables ---
+# --- helpers for summary overwrite ---
+def _std_key(s: str) -> str:
+ """
+ Normalize a label to match a 'Std N' key.
+ e.g. 'Std 7. Internal Review' -> 'std 7'
+ """
+ t = canon_label(s)
+ m = re.match(r"(std\s+\d+)", t)
+ return m.group(1) if m else t
+
+def _looks_like_summary_table(table: Table) -> Optional[Tuple[int, int]]:
+ """
+ Return (label_col_idx, details_col_idx) if this is a Summary table
+ with a DETAILS column; otherwise None.
+ """
+ if not table.rows:
+ return None
+ first = table.rows[0]
+ cols = len(first.cells)
+ if cols < 2:
+ return None
+
+ # header texts for first row
+ head = [canon(cell_text(c)) for c in first.cells]
+
+ # find DETAILS column
+ details_col = None
+ for j, t in enumerate(head):
+ if "detail" in t:
+ details_col = j
+ break
+ if details_col is None:
+ return None
+
+ # find the label column (left-hand standards column)
+ label_col = None
+ for j, t in enumerate(head):
+ if any(k in t for k in ["maintenance management", "mass management", "fatigue management"]):
+ label_col = j
+ break
+ if label_col is None:
+ # fallback: assume the first non-DETAILS column is the label column
+ label_col = 0 if details_col != 0 else 1
+
+ return (label_col, details_col)
+def count_header_rows(table: Table, scan_up_to: int = 6) -> int:
+ """Heuristically count header rows (stop when first data row like '1.' appears)."""
+ for i, row in enumerate(table.rows[:scan_up_to]):
+ first = cell_text(row.cells[0]).strip()
+ if re.match(r"^\d+\.?$", first):
+ return i
+ return 1
+def _header_col_texts(table: Table, scan_rows: int = 5) -> List[str]:
+ scan_rows = min(scan_rows, len(table.rows))
+ if scan_rows == 0:
+ return []
+ # pick the row with the most cells as base
+ base_row = max(range(scan_rows), key=lambda i: len(table.rows[i].cells))
+ base_cols = len(table.rows[base_row].cells)
+ cols = []
+ for j in range(base_cols):
+ parts = []
+ for i in range(scan_rows):
+ row = table.rows[i]
+ if j < len(row.cells):
+ parts.append(cell_text(row.cells[j]))
+ cols.append(canon(" ".join(parts)))
+ return cols
+
+def count_header_rows(table: Table, scan_up_to: int = 6) -> int:
+ """Header ends right before the first row whose 1st cell looks like '1.'"""
+ limit = min(scan_up_to, len(table.rows))
+ for i in range(limit):
+ first = cell_text(table.rows[i].cells[0]).strip()
+ if re.match(r"^\d+\.?$", first):
+ return i
+ # fallback to 1 header row
+ return 1
+
+def map_cols_mass_strict(table: Table) -> Dict[str, int]:
+ cols = _header_col_texts(table, 5)
+ def first_col(*needles):
+ for j, t in enumerate(cols):
+ if all(n in t for n in needles):
+ return j
+ return None
+ idx = {
+ "no": first_col("no"),
+ "reg": first_col("registration", "number") or first_col("registration"),
+ "wv": first_col("weight", "verification"),
+ "rfs": first_col("rfs", "cert") or first_col("rfs", "certification"),
+ "susp": first_col("suspension", "maintenance"),
+ "trip": first_col("trip", "record"),
+ "frs": first_col("fault", "suspension") or first_col("fault", "reporting", "suspension"),
+ }
+ return {k: v for k, v in idx.items() if v is not None}
+
+def find_mass_vehicle_numbers_table(doc: Document) -> Optional[Table]:
+ """Pick the Mass vehicle-number table by matching its column set (not the Summary table)."""
+ best = None
+ best_score = -1
+ for t in iter_tables(doc):
+ cols = _header_col_texts(t, 5)
+ allhdr = " ".join(cols)
+ # must look like the vehicle numbers table
+ hits = 0
+ hits += int(any("registration" in c and "number" in c for c in cols))
+ hits += int(any("weight" in c and "verification" in c for c in cols))
+ hits += int(any("rfs" in c and ("cert" in c or "certification" in c) for c in cols))
+ hits += int(any("suspension" in c and "maintenance" in c for c in cols))
+ hits += int(any("trip" in c and "record" in c for c in cols))
+ hits += int(any("fault" in c and "suspension" in c for c in cols))
+ # reject obvious Summary tables
+ if "details" in allhdr:
+ continue
+ # prefer tables with numbering column and many rows
+ score = hits + (0.5 if any("no" == c or c.startswith("no ") for c in cols) else 0) + (len(t.rows) / 100.0)
+ if hits >= 4 and score > best_score:
+ best, best_score = t, score
+ return best
+
+def update_operator_declaration(doc: Document, print_name: str, position_title: str) -> bool:
+ """
+ First try strict table label mapping for 'Print Name' and 'Position Title'.
+ If not found, fallback to the first two red placeholders under the 'Operator Declaration' heading.
+ """
+ changed = False
+ # 1) Table label approach
+ for lbl, val in (("Print Name", print_name), ("Position Title", position_title)):
+ if not val:
+ continue
+ loc = find_label_cell(doc, lbl)
+ if not loc:
+ # tolerate odd spacing/colon/camelcase
+ for alt in ("PrintName", "Print Name", "Print Name:", "PositionTitle", "Position Title", "Position Title:"):
+ loc = find_label_cell(doc, alt)
+ if loc:
+ break
+ if loc:
+ t, r, c = loc
+ cell = get_adjacent_value_cell(t, r, c)
+ if not replace_red_in_cell(cell, val):
+ _set_cell_text_black(cell, val)
+ changed = True
+
+ if changed:
+ return True
+
+ # 2) Fallback: heading-scoped red placeholders
+ head = "OPERATOR DECLARATION"
+ p = find_heading_paragraph(doc, head) or find_heading_paragraph(doc, head.title())
+ if not p:
+ return False
+ allp = iter_paragraphs(doc)
+ try:
+ i = allp.index(p)
+ except ValueError:
+ i = 0
+ red_targets = []
+ for q in allp[i+1:i+1+20]:
+ reds = [r for r in q.runs if is_red_run(r)]
+ if reds:
+ red_targets.extend(reds)
+ if len(red_targets) >= 2:
+ break
+ wrote = False
+ if print_name and red_targets:
+ _set_text_and_black(red_targets[0], print_name); wrote = True
+ if position_title and len(red_targets) >= 2:
+ _set_text_and_black(red_targets[1], position_title); wrote = True
+ return wrote
+
+
+def fill_mass_vehicle_table_preserve_headers(table: Table, arrays: Dict[str, List[str]]):
+ colmap = map_cols_mass_strict(table)
+ if "reg" not in colmap:
+ return
+ hdr_rows = count_header_rows(table, 6)
+ regs = arrays.get("Registration Number", [])
+ n = len(regs)
+
+ # clear data rows only
+ while len(table.rows) > hdr_rows:
+ table._tbl.remove(table.rows[-1]._tr)
+ # ensure enough rows
+ while len(table.rows) < hdr_rows + n:
+ table.add_row()
+
+ def put(row, key, arr_key, i):
+ if key in colmap:
+ vals = arrays.get(arr_key, [])
+ val = nz(vals[i]) if i < len(vals) else ""
+ replace_red_in_cell(row.cells[colmap[key]], val)
+
+ for i in range(n):
+ row = table.rows[hdr_rows + i]
+ replace_red_in_cell(row.cells[colmap["reg"]], nz(regs[i]))
+ put(row, "wv", "Weight Verification Records", i)
+ put(row, "rfs", "RFS Suspension Certification #", i)
+ put(row, "susp", "Suspension System Maintenance", i)
+ put(row, "trip", "Trip Records", i)
+ put(row, "frs", "Fault Recording/ Reporting on Suspension System", i)
+
+def overwrite_summary_details_cells(doc: Document, section_name: str, section_dict: Dict[str, List[str]]) -> int:
+ """For a Summary table (Maintenance/Mass/Fatigue), replace the entire DETAILS cell
+ for each Std N row with the JSON text (written in black)."""
+ # build desired texts
+ desired: Dict[str, str] = { _std_key(k): join_value(v) for k, v in section_dict.items() }
+
+ # pick which tables belong to this section by header sniff
+ wanted_prefix = canon_label(section_name.split()[0]) # "maintenance" | "mass" | "fatigue"
+
+ updated = 0
+ for t in doc.tables:
+ cols = _looks_like_summary_table(t)
+ if not cols:
+ continue
+ label_col, details_col = cols
+
+ head_txt = table_header_text(t, up_to_rows=2)
+ if wanted_prefix not in head_txt: # keep to the correct section
+ continue
+
+ # walk body rows
+ for i in range(1, len(t.rows)):
+ row = t.rows[i]
+ key = _std_key(cell_text(row.cells[label_col]))
+
+ # exact match or "std N" prefix match
+ cand = desired.get(key)
+ if not cand:
+ m = re.match(r"(std\s+\d+)", key)
+ if m:
+ for k2, v2 in desired.items():
+ if k2.startswith(m.group(1)):
+ cand = v2
+ break
+ if not cand:
+ continue
+
+ _set_cell_text_black(row.cells[details_col], cand) # full overwrite, black
+ updated += 1
+ return updated
+
+SPLIT_SENT_PAT = re.compile(r"(?<=\.|\?|!)\s+")
+ORDINAL_DATE_PAT = re.compile(r"\b(\d{1,2}(?:st|nd|rd|th)\s+[A-Za-z]+\s+\d{4})\b", re.I)
+
+def split_sentences_keep(text: str) -> List[str]:
+ s = " ".join(str(text or "").split())
+ if not s:
+ return []
+ out = []
+ start = 0
+ for m in SPLIT_SENT_PAT.finditer(s):
+ out.append(s[start:m.start()].strip())
+ start = m.end()
+ last = s[start:].strip()
+ if last:
+ out.append(last)
+ return out
+
+_sent_split = re.compile(r'(?<=[.!?])\s+|\n+')
+_date_pat = re.compile(r'\b(?:\d{1,2}(?:st|nd|rd|th)\s+[A-Za-z]+\s+\d{4}|\d{1,2}/\d{1,2}/\d{2,4}|[A-Za-z]+\s+\d{1,2},\s*\d{4})\b')
+
+def extract_summary_snippets(desired_text: str):
+ sents = _sentences(desired_text)
+ dates = [m.group(0) for m in _date_pat.finditer(desired_text)]
+ pick = lambda rx: next((s for s in sents if re.search(rx, s, re.I)), None)
+ return {
+ "sheet_sent": pick(r'\b(daily\s+check|sheet)\b'),
+ "sheet_phrase": _extract_sheet_phrase_from_desired(desired_text),
+ "review": pick(r'\binternal\s+review\b'),
+ "qcs": pick(r'\bquarterly\b.*\bcompliance\b') or pick(r'\bquarterly\b'),
+ "dates": dates,
+ "sents": sents,
+ }
+
+def fill_management_summary_tables(doc: Document, section_key: str, section_data: Dict[str, List[str]]):
+ """
+ Fill ALL summary tables for the given section_key ('maintenance'|'mass'|'fatigue')
+ by matching each row label (left column) against keys in section_data and
+ patching only the red text inside the DETAILS cell.
+ """
+ targets = [x for x in find_all_summary_tables(doc) if x[0] == section_key]
+ if not targets:
+ return
+
+ # build list of (normalized label, original label, desired_text)
+ desired = []
+ for label, vals in section_data.items():
+ want = canon_label(label)
+ if not want:
+ continue
+ desired.append((want, label, join_value(vals)))
+
+ for _, table, lcol, dcol in targets:
+ # iterate data rows (skip header)
+ for i in range(1, len(table.rows)):
+ left_txt_norm = canon_label(cell_text(table.rows[i].cells[lcol]))
+ if not left_txt_norm:
+ continue
+ for want_norm, _orig_lbl, value in desired:
+ # loose contains match handles minor punctuation differences
+ if want_norm and want_norm in left_txt_norm:
+ patch_details_cell_from_json(table.rows[i].cells[dcol], value)
+
+def _set_text_and_black(run, new_text: str):
+ """Replace a run's text and force color to black (clears theme color too)."""
+ if new_text is None:
+ new_text = ""
+ run.text = str(new_text)
+ run.font.color.rgb = BLACK
+ try:
+ # clear any theme color so rgb sticks
+ run.font.color.theme_color = None
+ except Exception:
+ pass
+
+def update_business_summary_once(doc: Document, value) -> bool:
+ """Replace only the red summary paragraph; keep 'Accreditation Number' and 'Expiry Date' lines."""
+ loc = (find_label_cell(doc, "Nature of the Operators Business (Summary)")
+ or find_label_cell(doc, "Nature of the Operators Business (Summary):"))
+ if not loc:
+ return False
+
+ t, r, c = loc
+ cell = get_adjacent_value_cell(t, r, c)
+ if not cell.paragraphs:
+ cell.add_paragraph("")
+
+ txt = join_value(value)
+
+ # find paragraphs with any red runs (the placeholders for the summary)
+ red_paras = [p for p in cell.paragraphs if any(is_red_run(run) for run in p.runs)]
+
+ if red_paras:
+ # write the summary into the first red paragraph (in black)
+ _clear_para_and_write_black(red_paras[0], txt)
+ # clear any extra red placeholders
+ for p in red_paras[1:]:
+ _clear_para_and_write_black(p, "")
+ else:
+ # no red placeholder found: just put the summary into the first paragraph, leave others
+ _clear_para_and_write_black(cell.paragraphs[0], txt)
+
+ return True
+
+
+def _nuke_cell_paragraphs(cell: _Cell):
+ """Remove ALL paragraphs from a cell (true delete, not just emptying runs)."""
+ for p in list(cell.paragraphs):
+ p._element.getparent().remove(p._element)
+
+def _clear_para_and_write_black(paragraph, text: str):
+ """Clear a whole paragraph and write fresh black text."""
+ # wipe existing runs
+ for r in list(paragraph.runs):
+ r.text = ""
+ r = paragraph.add_run(str(text or ""))
+ r.font.color.rgb = BLACK
+ try:
+ r.font.color.theme_color = None
+ except Exception:
+ pass
+
+def _set_cell_text_black(cell, text: str):
+ """Clear a table cell and insert black text."""
+ # remove text from all runs in all paragraphs
+ for p in cell.paragraphs:
+ for r in p.runs:
+ r.text = ""
+ p = cell.paragraphs[0] if cell.paragraphs else cell.add_paragraph()
+ r = p.add_run(str(text or ""))
+ r.font.color.rgb = BLACK
+ try:
+ r.font.color.theme_color = None
+ except Exception:
+ pass
+
+def nz(x: Optional[str]) -> str:
+ return (x or "").strip()
+
+def canon(s: str) -> str:
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ return re.sub(r"[^a-z0-9/#()+,.\- ]+", "", s)
+
+def canon_label(s: str) -> str:
+ # labels often vary by punctuation/casing; keep digits/letters
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ s = re.sub(r"[^a-z0-9 ]+", " ", s)
+ return re.sub(r"\s+", " ", s).strip()
+
+def join_value(value) -> str:
+ if isinstance(value, list):
+ # Keep multi-line when list provided
+ return "\n".join([str(v) for v in value if nz(v)])
+ return str(value)
+
+def split_digits(s: str) -> List[str]:
+ return re.findall(r"\d", s)
+
+def para_text(p: Paragraph) -> str:
+ return "".join(run.text for run in p.runs)
+
+def cell_text(c: _Cell) -> str:
+ return "\n".join(para_text(p) for p in c.paragraphs)
+
+def is_red_run(run) -> bool:
+ col = run.font.color
+ if not col:
+ return False
+ if col.rgb is not None:
+ return col.rgb == RED
+ # Some templates use theme colors; treat explicit red text snippets only
+ return False
+
+def replace_red_in_paragraph(p: Paragraph, new_text: str) -> bool:
+ replaced = False
+ red_runs = [r for r in p.runs if is_red_run(r)]
+ if not red_runs:
+ return False
+ # collapse all red runs into one and write value (in black)
+ first = red_runs[0]
+ _set_text_and_black(first, new_text)
+ for r in red_runs[1:]:
+ r.text = ""
+ replaced = True
+ return replaced
+
+def replace_red_in_cell(cell: _Cell, new_text: str) -> bool:
+ # replace only red runs; if none, replace whole cell with a single run (fallback)
+ any_red = False
+ for p in cell.paragraphs:
+ if replace_red_in_paragraph(p, new_text):
+ any_red = True
+ if any_red:
+ return True
+ # fallback: clear cell, set single paragraph text in black
+ _set_cell_text_black(cell, new_text)
+ return True
+
+def parse_attendance_lines(value) -> List[str]:
+ """
+ Parse strings like:
+ "Peter Sheppard - Compliance Greg Dyer - Auditor"
+ into:
+ ["Peter Sheppard - Compliance", "Greg Dyer - Auditor"]
+ Handles lists, newlines, semicolons, and pipes too.
+ """
+ if isinstance(value, list):
+ s = " ".join(str(v) for v in value if v)
+ else:
+ s = str(value or "")
+ s = re.sub(r"\s+", " ", s).strip()
+ if not s:
+ return []
+
+ # First split on explicit separators; then within each chunk, extract Name - Title pairs.
+ chunks = re.split(r"\s*[\n;|]\s*", s)
+ items: List[str] = []
+
+ pair_pat = re.compile(
+ r"([A-Z][A-Za-z.'-]+(?:\s+[A-Z][A-Za-z.'-]+){0,3})\s*-\s*"
+ r"([^-\n]+?)(?=\s+[A-Z][A-Za-z.'-]+(?:\s+[A-Z][A-Za-z.'-]+){0,3}\s*-\s*|$)"
+ )
+
+ for chunk in chunks:
+ chunk = chunk.strip()
+ if not chunk:
+ continue
+ found = False
+ for m in pair_pat.finditer(chunk):
+ name = m.group(1).strip()
+ title = m.group(2).strip()
+ items.append(f"{name} - {title}")
+ found = True
+ if not found:
+ # Fallback: single "Name - Title"
+ if " - " in chunk:
+ a, b = chunk.split(" - ", 1)
+ items.append(f"{a.strip()} - {b.strip()}")
+ elif chunk:
+ items.append(chunk)
+
+ return items
+
+def fill_attendance_block(doc: Document, value) -> bool:
+ items = parse_attendance_lines(value)
+ if not items:
+ return False
+
+ loc = find_label_cell(doc, "Attendance List (Names and Position Titles)")
+ if not loc:
+ return False
+
+ t, r, c = loc
+ # value cell: usually directly under the heading cell
+ target = (
+ t.rows[r + 1].cells[c]
+ if r + 1 < len(t.rows) and c < len(t.rows[r + 1].cells)
+ else get_adjacent_value_cell(t, r, c)
+ )
+
+ # ---- read ONLY the target cell (don’t touch the row)
+ def is_red_para(p): return any(is_red_run(run) for run in p.runs)
+ def looks_like_pair(s: str) -> bool:
+ if " - " not in s: return False
+ a, b = s.split(" - ", 1)
+ return bool(a.strip()) and bool(b.strip())
+
+ paras = list(target.paragraphs)
+ red_count = sum(1 for p in paras if is_red_para(p))
+ existing_black = [para_text(p).strip() for p in paras
+ if (not is_red_para(p)) and looks_like_pair(para_text(p))]
+
+ # compose final lines
+ out_lines: List[str] = []
+ out_lines.extend(items[:red_count]) # replace red placeholders
+ out_lines.extend(existing_black) # keep black lines
+ norm = lambda s: re.sub(r"\s+", " ", s.strip().lower())
+ seen = {norm(x) for x in out_lines}
+ for extra in items[red_count:]:
+ k = norm(extra)
+ if k not in seen:
+ out_lines.append(extra); seen.add(k)
+
+ # ---- hard clear target cell and write fresh (all black)
+ _nuke_cell_paragraphs(target)
+ # first line
+ p = target.add_paragraph()
+ _clear_para_and_write_black(p, out_lines[0] if out_lines else "")
+ # remaining lines
+ for line in out_lines[1:]:
+ p = target.add_paragraph()
+ _clear_para_and_write_black(p, line)
+
+ return True
+
+# ----------------------------- document search -----------------------------
+def iter_tables(doc: Document) -> List[Table]:
+ return list(doc.tables)
+
+def iter_paragraphs(doc: Document) -> List[Paragraph]:
+ # paragraphs at doc level + inside tables
+ out = list(doc.paragraphs)
+ for t in doc.tables:
+ for row in t.rows:
+ for cell in row.cells:
+ out.extend(cell.paragraphs)
+ return out
+
+def find_heading_paragraph(doc: Document, heading_text: str, window: int = 60) -> Optional[Paragraph]:
+ key = canon(heading_text)
+ for p in iter_paragraphs(doc):
+ if canon(para_text(p)).startswith(key):
+ return p
+ # fuzzy contains
+ for p in iter_paragraphs(doc):
+ if key in canon(para_text(p)):
+ return p
+ return None
+
+def find_label_cell_in_table(table: Table, label: str) -> Optional[Tuple[int, int]]:
+ target = canon_label(label)
+ for r_i, row in enumerate(table.rows):
+ for c_i, cell in enumerate(row.cells):
+ if canon_label(cell_text(cell)) == target:
+ return (r_i, c_i)
+ # allow contains (safe-ish)
+ for r_i, row in enumerate(table.rows):
+ for c_i, cell in enumerate(row.cells):
+ if target and target in canon_label(cell_text(cell)):
+ return (r_i, c_i)
+ return None
+
+def find_label_cell(doc: Document, label: str) -> Optional[Tuple[Table, int, int]]:
+ for t in iter_tables(doc):
+ pos = find_label_cell_in_table(t, label)
+ if pos:
+ return (t, pos[0], pos[1])
+ return None
+
+def get_adjacent_value_cell(table: Table, r: int, c: int) -> _Cell:
+ # Prefer right cell, otherwise next row same col, otherwise this cell
+ cols = len(table.rows[0].cells)
+ if c + 1 < cols:
+ return table.rows[r].cells[c+1]
+ if r + 1 < len(table.rows):
+ return table.rows[r+1].cells[c]
+ return table.rows[r].cells[c]
+
+# ----------------------------- label/value updates -----------------------------
+def update_label_value_in_tables(doc: Document, label: str, value) -> bool:
+ tup = find_label_cell(doc, label)
+ val = join_value(value)
+ if not tup:
+ return False
+ t, r, c = tup
+ target_cell = get_adjacent_value_cell(t, r, c)
+ return replace_red_in_cell(target_cell, val)
+
+def update_heading_followed_red(doc: Document, heading: str, value, max_scan: int = 12) -> bool:
+ """Find heading paragraph, then replace the first red run found within next N paragraphs (including inside tables)"""
+ start = find_heading_paragraph(doc, heading)
+ if not start:
+ return False
+ # Build a linear list of paragraphs across whole doc to get an index
+ allp = iter_paragraphs(doc)
+ try:
+ idx = allp.index(start)
+ except ValueError:
+ idx = 0
+ new_text = join_value(value)
+ # Scan forward
+ for p in allp[idx+1: idx+1+max_scan]:
+ if replace_red_in_paragraph(p, new_text):
+ return True
+ # Also check any red in table cells inside this paragraph's parent (already covered via iter_paragraphs)
+ return False
+
+# ----------------------------- ACN per-digit fill -----------------------------
+def fill_acn_digits(doc: Document, acn_value: str) -> bool:
+ digits = split_digits(acn_value)
+ if not digits:
+ return False
+ loc = find_label_cell(doc, "Australian Company Number")
+ if not loc:
+ return False
+
+ t, r, c = loc
+
+ # Collect cells to the RIGHT in the same row first
+ targets: List[_Cell] = [t.rows[r].cells[j] for j in range(c + 1, len(t.rows[r].cells))]
+
+ # If not enough, continue row-by-row below (left→right)
+ rr = r + 1
+ while len(targets) < len(digits) and rr < len(t.rows):
+ targets.extend(list(t.rows[rr].cells))
+ rr += 1
+
+ targets = targets[:len(digits)]
+ if not targets:
+ return False
+
+ # Clear each target cell and write ONE digit in black
+ for d, cell in zip(digits, targets):
+ _set_cell_text_black(cell, d)
+
+ return True
+
+
+# ----------------------------- vehicle tables -----------------------------
+def table_header_text(table: Table, up_to_rows: int = 3) -> str:
+ heads = []
+ for i, row in enumerate(table.rows[:up_to_rows]):
+ for cell in row.cells:
+ heads.append(cell_text(cell))
+ return canon(" ".join(heads))
+
+def find_vehicle_table(doc: Document, want: str) -> Optional[Table]:
+ """
+ want = "maintenance" or "mass"
+ """
+ MAINT_KEYS = ["registration number", "maintenance records", "daily checks", "fault recording", "fault repair"]
+ MASS_KEYS = ["registration number", "weight verification", "rfs suspension", "suspension system maintenance", "trip records", "reporting on suspension"]
+ candidates = []
+ for t in iter_tables(doc):
+ htxt = table_header_text(t)
+ if want == "maintenance":
+ if all(k in htxt for k in ["registration", "maintenance", "fault"]) and "suspension" not in htxt:
+ candidates.append(t)
+ elif want == "mass":
+ if "suspension" in htxt and "weight" in htxt:
+ candidates.append(t)
+ # Prefer the one with most rows
+ if not candidates:
+ return None
+ return max(candidates, key=lambda tb: len(tb.rows))
+
+def map_cols(table: Table, want: str) -> Dict[str, int]:
+ # map header columns by keywords from the first 2 rows that contain headers
+ header_rows = table.rows[:2]
+ col_texts = []
+ cols = len(table.rows[0].cells)
+ for j in range(cols):
+ txt = " ".join(cell_text(r.cells[j]) for r in header_rows if j < len(r.cells))
+ col_texts.append(canon(txt))
+ idx = {}
+ def first_col(*needles) -> Optional[int]:
+ for j, t in enumerate(col_texts):
+ if all(n in t for n in needles):
+ return j
+ return None
+ if want == "maintenance":
+ idx["reg"] = first_col("registration")
+ idx["rw"] = first_col("roadworthiness")
+ idx["mr"] = first_col("maintenance", "records")
+ idx["daily"] = first_col("daily", "check")
+ idx["fr"] = first_col("fault", "recording")
+ idx["rep"] = first_col("fault", "repair")
+ else:
+ idx["reg"] = first_col("registration")
+ idx["wv"] = first_col("weight", "verification")
+ idx["rfs"] = first_col("rfs", "cert")
+ idx["susp"] = first_col("suspension", "maintenance")
+ idx["trip"] = first_col("trip", "record")
+ idx["frs"] = first_col("fault", "suspension")
+ return {k:v for k,v in idx.items() if v is not None}
+
+def clear_data_rows_keep_headers(table: Table, header_rows: int = 1):
+ # Keep first header_rows, drop everything else
+ while len(table.rows) > header_rows:
+ table._tbl.remove(table.rows[-1]._tr)
+
+def ensure_rows(table: Table, need_rows: int):
+ # assumes 1 header row; add rows to reach need_rows + 1 total
+ while len(table.rows) < need_rows + 1:
+ table.add_row()
+
+def fill_vehicle_table(table: Table, want: str, arrays: Dict[str, List[str]]):
+ colmap = map_cols(table, want)
+ if "reg" not in colmap:
+ return
+ if want == "maintenance":
+ regs = arrays.get("Registration Number", [])
+ rw = arrays.get("Roadworthiness Certificates", [])
+ mr = arrays.get("Maintenance Records", [])
+ daily= arrays.get("Daily Checks", [])
+ fr = arrays.get("Fault Recording/ Reporting", [])
+ rep = arrays.get("Fault Repair", [])
+ n = len(regs)
+ # keep header row(s), then fill N rows
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+ for i in range(n):
+ row = table.rows[i+1]
+ def put(col_key, vals):
+ if col_key not in colmap or i >= len(vals): return
+ c = row.cells[colmap[col_key]]
+ replace_red_in_cell(c, nz(vals[i]))
+ # write each col
+ c_reg = row.cells[colmap["reg"]]; replace_red_in_cell(c_reg, nz(regs[i]))
+ put("rw", rw)
+ put("mr", mr)
+ put("daily",daily)
+ put("fr", fr)
+ put("rep", rep)
+ else:
+ regs = arrays.get("Registration Number", [])
+ wv = arrays.get("Weight Verification Records", [])
+ rfs = arrays.get("RFS Suspension Certification #", [])
+ susp = arrays.get("Suspension System Maintenance", [])
+ trip = arrays.get("Trip Records", [])
+ frs = arrays.get("Fault Recording/ Reporting on Suspension System", [])
+ n = len(regs)
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+ for i in range(n):
+ row = table.rows[i+1]
+ def put(col_key, vals):
+ if col_key not in colmap or i >= len(vals): return
+ c = row.cells[colmap[col_key]]
+ replace_red_in_cell(c, nz(vals[i]))
+ c_reg = row.cells[colmap["reg"]]; replace_red_in_cell(c_reg, nz(regs[i]))
+ put("wv", wv)
+ put("rfs", rfs)
+ put("susp", susp)
+ put("trip", trip)
+ put("frs", frs)
+
+# ----------------------------- driver table -----------------------------
+def find_driver_table(doc: Document) -> Optional[Table]:
+ for t in iter_tables(doc):
+ h = table_header_text(t)
+ if "driver / scheduler" in h and ("fit for duty" in h or "work diary" in h):
+ return t
+ return None
+
+def map_driver_cols(table: Table) -> Dict[str,int]:
+ header_rows = table.rows[:2]
+ cols = len(table.rows[0].cells)
+ col_texts = []
+ for j in range(cols):
+ txt = " ".join(cell_text(r.cells[j]) for r in header_rows if j < len(r.cells))
+ col_texts.append(canon(txt))
+ idx = {}
+ def first_col(*needles):
+ for j, t in enumerate(col_texts):
+ if all(n in t for n in needles):
+ return j
+ return None
+ idx["name"] = first_col("driver", "name")
+ idx["roster"]= first_col("roster", "safe")
+ idx["fit"] = first_col("fit for duty")
+ # Work diary might be split across two headers; match "work diary" OR "electronic work diary"
+ wd = first_col("work diary") or first_col("electronic work diary")
+ if wd is not None: idx["wd"] = wd
+ return {k:v for k,v in idx.items() if v is not None}
+
+def fill_driver_table(table: Table, arrays: Dict[str, List[str]]):
+ colmap = map_driver_cols(table)
+ if not colmap:
+ return
+
+ names = arrays.get("Driver / Scheduler Name", [])
+ rosters = arrays.get("Roster / Schedule / Safe Driving Plan (Date Range)", [])
+ fit = arrays.get("Fit for Duty Statement Completed (Yes/No)", [])
+ wd = arrays.get("Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)", [])
+
+ n = max(len(rosters), len(fit), len(wd), len(names))
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+
+ has_any_name = any(str(x).strip() for x in names)
+
+ for i in range(n):
+ row = table.rows[i+1]
+ if "name" in colmap and has_any_name:
+ replace_red_in_cell(row.cells[colmap["name"]], names[i] if i < len(names) else "")
+ if "roster" in colmap:
+ replace_red_in_cell(row.cells[colmap["roster"]], rosters[i] if i < len(rosters) else "")
+ if "fit" in colmap:
+ replace_red_in_cell(row.cells[colmap["fit"]], fit[i] if i < len(fit) else "")
+ if "wd" in colmap:
+ replace_red_in_cell(row.cells[colmap["wd"]], wd[i] if i < len(wd) else "")
+
+
+
+# ----------------------------- main mapping -----------------------------
+def flatten_simple_sections(data: Dict) -> Dict[str, str]:
+ """Collect simple label->single value mappings from top-level sections other than tables."""
+ out = {}
+ skip_sections = {
+ "Vehicle Registration Numbers Maintenance",
+ "Vehicle Registration Numbers Mass",
+ "Driver / Scheduler Records Examined",
+ "paragraphs",
+ "Attendance List (Names and Position Titles)",
+ "Nature of the Operators Business (Summary)",
+ "Maintenance Management Summary",
+ "Mass Management Summary",
+ "Fatigue Management Summary",
+ }
+ for sec, kv in data.items():
+ if sec in skip_sections: continue
+ if not isinstance(kv, dict): continue
+ for label, val in kv.items():
+ out[f"{sec}::{label}"] = join_value(val)
+ return out
+
+def run(input_json: Path, template_docx: Path, output_docx: Path):
+ with open(input_json, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ doc = Document(str(template_docx))
+
+ # 1) simple label/value tables
+ simple = flatten_simple_sections(data)
+
+ # Map by (section::label). We try: (a) find exact label cell somewhere and write in the adjacent cell;
+ # (b) if not found, search by heading then the next red run below the heading.
+ for k, v in simple.items():
+ # use the part after '::' as the label
+ label = k.split("::", 1)[1] if "::" in k else k
+
+ # SPECIAL: skip ACN here; we'll fill per-digit later
+ if canon_label(label) == "australian company number":
+ continue
+
+ ok = update_label_value_in_tables(doc, label, v)
+ if not ok:
+ sec = k.split("::", 1)[0] if "::" in k else k
+ update_heading_followed_red(doc, sec, v)
+
+
+ # 2) paragraphs block
+ paras = data.get("paragraphs", {})
+
+ # 2a) generic headings → replace next red (skip the 3 management headings here)
+ # third-line headings above the three tables
+ for head in ("MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"):
+ name_val = join_value(paras.get(head, ""))
+ if name_val:
+ update_heading_followed_red(doc, head, name_val, max_scan=6)
+
+ # 2b) the 3-layer headings → overwrite the 3rd line only
+ # second-last page: date under page heading
+ aud_head = "NHVAS APPROVED AUDITOR DECLARATION"
+ aud_date = join_value(paras.get(aud_head, ""))
+ if aud_date:
+ set_date_by_heading_from_end(doc, aud_head, aud_date, max_scan=40)
+
+ # last page: date under the long acknowledgement paragraph
+ ack_head = ("I hereby acknowledge and agree with the findings detailed in this NHVAS Audit Summary Report. "
+ "I have read and understand the conditions applicable to the Scheme, including the NHVAS Business Rules and Standards.")
+ ack_date = join_value(paras.get(ack_head, ""))
+ if ack_date:
+ set_date_by_paragraph_from_end(doc, ack_head, ack_date, max_scan=40)
+
+ maint_name = join_value(paras.get("MAINTENANCE MANAGEMENT", ""))
+ if maint_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "MAINTENANCE MANAGEMENT",
+ ["Vehicle Registration Numbers of Records Examined"],
+ maint_name,
+ )
+
+ mass_name = join_value(paras.get("MASS MANAGEMENT", ""))
+ if mass_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "MASS MANAGEMENT",
+ ["Vehicle Registration Numbers of Records Examined"],
+ mass_name,
+ )
+
+ fat_name = join_value(paras.get("FATIGUE MANAGEMENT", ""))
+ if fat_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "FATIGUE MANAGEMENT",
+ ["Driver / Scheduler Records Examined"],
+ fat_name,
+ )
+
+
+ # 3) ACN digits
+ op_info = data.get("Operator Information", {})
+ acn_val = join_value(op_info.get("Australian Company Number", ""))
+ if acn_val:
+ fill_acn_digits(doc, acn_val)
+
+ # 4) Vehicle tables
+ maint = data.get("Vehicle Registration Numbers Maintenance", {})
+ mass = data.get("Vehicle Registration Numbers Mass", {})
+ t_m = find_vehicle_table(doc, "maintenance")
+ if t_m and maint:
+ fill_vehicle_table(t_m, "maintenance", maint)
+ t_ms = find_mass_vehicle_numbers_table(doc)
+ if t_ms and mass:
+ fill_mass_vehicle_table_preserve_headers(t_ms, mass)
+
+ # 5) Driver table
+ drivers = data.get("Driver / Scheduler Records Examined", {})
+ t_d = find_driver_table(doc)
+ if t_d and drivers:
+ fill_driver_table(t_d, drivers)
+
+ # 6) Special: Audit Declaration dates via heading
+ decl = data.get("Audit Declaration dates", {})
+ if decl.get("Audit was conducted on"):
+ update_heading_followed_red(doc, "Audit was conducted on", decl["Audit was conducted on"])
+
+ # 7) Operator Declaration (last page, bottom row only), and fix Auditor table header
+ op_decl = data.get("Operator Declaration", {})
+ if op_decl:
+ fill_operator_declaration(
+ doc,
+ join_value(op_decl.get("Print Name", "")),
+ join_value(op_decl.get("Position Title", "")),
+ )
+
+ # make sure the second-last page “NHVAS APPROVED AUDITOR DECLARATION” header row is labels
+ ensure_auditor_decl_headers(doc)
+
+
+ # 8) Attendance List
+ # Attendance: replace red lines only
+ atts = data.get("Attendance List (Names and Position Titles)", {})
+ att_val = atts.get("Attendance List (Names and Position Titles)")
+ if att_val:
+ fill_attendance_block(doc, att_val)
+
+ # 9) Nature of the Operators Business (Summary): write once (no duplicates)
+ biz = data.get("Nature of the Operators Business (Summary)", {})
+ if biz:
+ val = biz.get("Nature of the Operators Business (Summary):") or next(iter(biz.values()), "")
+ if val:
+ update_business_summary_once(doc, val)
+
+ # 10) Summary tables: FULL OVERWRITE of DETAILS from JSON
+ mm_sum = data.get("Maintenance Management Summary", {})
+ if mm_sum:
+ overwrite_summary_details_cells(doc, "Maintenance Management Summary", mm_sum)
+
+ mass_sum = data.get("Mass Management Summary", {})
+ if mass_sum:
+ overwrite_summary_details_cells(doc, "Mass Management Summary", mass_sum)
+
+ fat_sum = data.get("Fatigue Management Summary", {})
+ if fat_sum:
+ overwrite_summary_details_cells(doc, "Fatigue Management Summary", fat_sum)
+
+
+ doc.save(str(output_docx))
+
+# ----------------------------- CLI -----------------------------
+if __name__ == "__main__":
+ import sys
+ from pathlib import Path
+
+ if len(sys.argv) != 4:
+ print("Usage: python updated_word.py ")
+ sys.exit(1)
+
+ a, b, c = map(Path, sys.argv[1:4])
+ files = [a, b, c]
+
+ json_path = next((p for p in files if p.suffix.lower() == ".json"), None)
+ docx_paths = [p for p in files if p.suffix.lower() == ".docx"]
+
+ if not json_path or len(docx_paths) < 2:
+ print("Error: provide one .json and two .docx (template + output).")
+ sys.exit(1)
+
+ # Template = the .docx that already exists; Output = the other .docx
+ template_docx = next((p for p in docx_paths if p.exists()), docx_paths[0])
+ output_docx = docx_paths[1] if docx_paths[0] == template_docx else docx_paths[0]
+
+ run(json_path, template_docx, output_docx)
\ No newline at end of file
diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/__init__.py b/src/adapters/infrastructure/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/format_conversion_service_adapter.py b/src/adapters/infrastructure/format_conversion_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d2390ccb5db536a4d27dd7d098121fffbc6a3d
--- /dev/null
+++ b/src/adapters/infrastructure/format_conversion_service_adapter.py
@@ -0,0 +1,13 @@
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+from ports.services.format_conversion_service import FormatConversionService
+from adapters.infrastructure.format_converters.convert_table_to_html import extract_table_format
+from adapters.infrastructure.format_converters.convert_formula_to_latex import extract_formula_format
+
+
+class FormatConversionServiceAdapter(FormatConversionService):
+ def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
+ extract_table_format(pdf_images, segments)
+
+ def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
+ extract_formula_format(pdf_images, segments)
diff --git a/src/adapters/infrastructure/format_converters/__init__.py b/src/adapters/infrastructure/format_converters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/format_converters/convert_formula_to_latex.py b/src/adapters/infrastructure/format_converters/convert_formula_to_latex.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4d071796d2e3bc40f45d42d62fe2a489967d2a
--- /dev/null
+++ b/src/adapters/infrastructure/format_converters/convert_formula_to_latex.py
@@ -0,0 +1,43 @@
+from PIL.Image import Image
+from pix2tex.cli import LatexOCR
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+from pdf_token_type_labels import TokenType
+import latex2mathml.converter
+
+
+def has_arabic(text: str) -> bool:
+ return any("\u0600" <= char <= "\u06FF" or "\u0750" <= char <= "\u077F" for char in text)
+
+
+def is_valid_latex(formula: str) -> bool:
+ try:
+ latex2mathml.converter.convert(formula)
+ return True
+ except Exception:
+ return False
+
+
+def extract_formula_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]):
+ formula_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.FORMULA]
+ if not formula_segments:
+ return
+
+ model = LatexOCR()
+ model.args.temperature = 1e-8
+
+ for formula_segment in formula_segments:
+ if has_arabic(formula_segment.text_content):
+ continue
+ page_image: Image = pdf_images.pdf_images[formula_segment.page_number - 1]
+ left, top = formula_segment.bounding_box.left, formula_segment.bounding_box.top
+ right, bottom = formula_segment.bounding_box.right, formula_segment.bounding_box.bottom
+ left = int(left * pdf_images.dpi / 72)
+ top = int(top * pdf_images.dpi / 72)
+ right = int(right * pdf_images.dpi / 72)
+ bottom = int(bottom * pdf_images.dpi / 72)
+ formula_image = page_image.crop((left, top, right, bottom))
+ formula_result = model(formula_image)
+ if not is_valid_latex(formula_result):
+ continue
+ formula_segment.text_content = f"$${formula_result}$$"
diff --git a/src/adapters/infrastructure/format_converters/convert_table_to_html.py b/src/adapters/infrastructure/format_converters/convert_table_to_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aa13f5b4c560103a87b15daa621885a5b1410d3
--- /dev/null
+++ b/src/adapters/infrastructure/format_converters/convert_table_to_html.py
@@ -0,0 +1,33 @@
+from PIL import Image
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+from pdf_token_type_labels import TokenType
+from rapidocr import RapidOCR
+from rapid_table import ModelType, RapidTable, RapidTableInput
+
+
+def extract_table_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]):
+ table_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.TABLE]
+ if not table_segments:
+ return
+
+ input_args = RapidTableInput(model_type=ModelType["SLANETPLUS"])
+
+ ocr_engine = RapidOCR()
+ table_engine = RapidTable(input_args)
+
+ for table_segment in table_segments:
+ page_image: Image = pdf_images.pdf_images[table_segment.page_number - 1]
+ left, top = table_segment.bounding_box.left, table_segment.bounding_box.top
+ right, bottom = table_segment.bounding_box.right, table_segment.bounding_box.bottom
+ left = int(left * pdf_images.dpi / 72)
+ top = int(top * pdf_images.dpi / 72)
+ right = int(right * pdf_images.dpi / 72)
+ bottom = int(bottom * pdf_images.dpi / 72)
+ table_image = page_image.crop((left, top, right, bottom))
+ ori_ocr_res = ocr_engine(table_image)
+ if not ori_ocr_res.txts:
+ continue
+ ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
+ table_result = table_engine(table_image, ocr_results=ocr_results)
+ table_segment.text_content = table_result.pred_html
diff --git a/src/adapters/infrastructure/html_conversion_service_adapter.py b/src/adapters/infrastructure/html_conversion_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b41e2022835b3e2a84f710d46c53221d358c499c
--- /dev/null
+++ b/src/adapters/infrastructure/html_conversion_service_adapter.py
@@ -0,0 +1,23 @@
+from typing import Optional, Union
+from starlette.responses import Response
+
+from domain.SegmentBox import SegmentBox
+from ports.services.html_conversion_service import HtmlConversionService
+from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter
+from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat
+
+
+class HtmlConversionServiceAdapter(HtmlConversionService, PdfToMarkupServiceAdapter):
+
+ def __init__(self):
+ PdfToMarkupServiceAdapter.__init__(self, OutputFormat.HTML)
+
+ def convert_to_html(
+ self,
+ pdf_content: bytes,
+ segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ return self.convert_to_format(pdf_content, segments, extract_toc, dpi, output_file)
diff --git a/src/adapters/infrastructure/markdown_conversion_service_adapter.py b/src/adapters/infrastructure/markdown_conversion_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b371150f4b0e60eca7e645297bbc5ea3ced7777
--- /dev/null
+++ b/src/adapters/infrastructure/markdown_conversion_service_adapter.py
@@ -0,0 +1,23 @@
+from typing import Optional, Union
+from starlette.responses import Response
+
+from domain.SegmentBox import SegmentBox
+from ports.services.markdown_conversion_service import MarkdownConversionService
+from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter
+from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat
+
+
+class MarkdownConversionServiceAdapter(MarkdownConversionService, PdfToMarkupServiceAdapter):
+
+ def __init__(self):
+ PdfToMarkupServiceAdapter.__init__(self, OutputFormat.MARKDOWN)
+
+ def convert_to_markdown(
+ self,
+ pdf_content: bytes,
+ segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ return self.convert_to_format(pdf_content, segments, extract_toc, dpi, output_file)
diff --git a/src/adapters/infrastructure/markup_conversion/ExtractedImage.py b/src/adapters/infrastructure/markup_conversion/ExtractedImage.py
new file mode 100644
index 0000000000000000000000000000000000000000..9395fca648212b9b6e0a875efea9272c798744cc
--- /dev/null
+++ b/src/adapters/infrastructure/markup_conversion/ExtractedImage.py
@@ -0,0 +1,6 @@
+from pydantic import BaseModel
+
+
+class ExtractedImage(BaseModel):
+ image_data: bytes
+ filename: str
diff --git a/src/adapters/infrastructure/markup_conversion/Link.py b/src/adapters/infrastructure/markup_conversion/Link.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e37dcb3498f7c307d94bdb1f797dfef569f49d2
--- /dev/null
+++ b/src/adapters/infrastructure/markup_conversion/Link.py
@@ -0,0 +1,8 @@
+from pydantic import BaseModel
+from domain.SegmentBox import SegmentBox
+
+
+class Link(BaseModel):
+ source_segment: SegmentBox
+ destination_segment: SegmentBox
+ text: str
diff --git a/src/adapters/infrastructure/markup_conversion/OutputFormat.py b/src/adapters/infrastructure/markup_conversion/OutputFormat.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e89a585f71b8c352f84f60ca4602e733129a0e8
--- /dev/null
+++ b/src/adapters/infrastructure/markup_conversion/OutputFormat.py
@@ -0,0 +1,6 @@
+from enum import StrEnum
+
+
+class OutputFormat(StrEnum):
+ HTML = "html"
+ MARKDOWN = "markdown"
diff --git a/src/adapters/infrastructure/markup_conversion/__init__.py b/src/adapters/infrastructure/markup_conversion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/markup_conversion/pdf_to_markup_service_adapter.py b/src/adapters/infrastructure/markup_conversion/pdf_to_markup_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4da8695cd5f8d1899182f933c0d40cc95b36654
--- /dev/null
+++ b/src/adapters/infrastructure/markup_conversion/pdf_to_markup_service_adapter.py
@@ -0,0 +1,361 @@
+import fitz
+import tempfile
+import zipfile
+import io
+import json
+from fitz import Page
+from pathlib import Path
+from typing import Optional, Union
+from PIL.Image import Image
+from pdf2image import convert_from_path
+from starlette.responses import Response
+
+from domain.SegmentBox import SegmentBox
+from pdf_features.PdfFeatures import PdfFeatures
+from pdf_features.PdfToken import PdfToken
+from pdf_features.Rectangle import Rectangle
+from pdf_token_type_labels.Label import Label
+from pdf_token_type_labels.PageLabels import PageLabels
+from pdf_token_type_labels.PdfLabels import PdfLabels
+from pdf_token_type_labels.TokenType import TokenType
+
+from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat
+from adapters.infrastructure.markup_conversion.Link import Link
+from adapters.infrastructure.markup_conversion.ExtractedImage import ExtractedImage
+
+
+class PdfToMarkupServiceAdapter:
+ def __init__(self, output_format: OutputFormat):
+ self.output_format = output_format
+
+ def convert_to_format(
+ self,
+ pdf_content: bytes,
+ segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:
+ temp_file.write(pdf_content)
+ temp_pdf_path = Path(temp_file.name)
+
+ try:
+ extracted_images: list[ExtractedImage] = [] if output_file else None
+ user_base_name = Path(output_file).stem if output_file else None
+
+ content = self._generate_content(temp_pdf_path, segments, extract_toc, dpi, extracted_images, user_base_name)
+
+ if output_file:
+ return self._create_zip_response(content, extracted_images, output_file, segments)
+
+ return content
+ finally:
+ if temp_pdf_path.exists():
+ temp_pdf_path.unlink()
+
+ def _create_zip_response(
+ self,
+ content: str,
+ extracted_images: list[ExtractedImage],
+ output_filename: str,
+ segments: list[SegmentBox],
+ ) -> Response:
+ zip_buffer = io.BytesIO()
+
+ with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
+ zip_file.writestr(output_filename, content.encode("utf-8"))
+
+ if extracted_images:
+ base_name = Path(output_filename).stem
+ pictures_dir = f"{base_name}_pictures/"
+
+ for image in extracted_images:
+ zip_file.writestr(f"{pictures_dir}{image.filename}", image.image_data)
+
+ base_name = Path(output_filename).stem
+ segmentation_filename = f"{base_name}_segmentation.json"
+ segmentation_data = self._create_segmentation_json(segments)
+ zip_file.writestr(segmentation_filename, segmentation_data)
+
+ zip_buffer.seek(0)
+
+ zip_filename = f"{Path(output_filename).stem}.zip"
+ return Response(
+ content=zip_buffer.getvalue(),
+ media_type="application/zip",
+ headers={"Content-Disposition": f"attachment; filename={zip_filename}"},
+ )
+
+ def _create_segmentation_json(self, segments: list[SegmentBox]) -> str:
+ segmentation_data = []
+ for segment in segments:
+ segmentation_data.append(segment.to_dict())
+ return json.dumps(segmentation_data, indent=4, ensure_ascii=False)
+
+ def _create_pdf_labels_from_segments(self, vgt_segments: list[SegmentBox]) -> PdfLabels:
+ page_numbers = sorted(set(segment.page_number for segment in vgt_segments))
+ page_labels: list[PageLabels] = []
+ for page_number in page_numbers:
+ segments_in_page = [s for s in vgt_segments if s.page_number == page_number]
+ labels: list[Label] = []
+ for segment in segments_in_page:
+ rect = Rectangle.from_width_height(segment.left, segment.top, segment.width, segment.height)
+ label = Label.from_rectangle(rect, TokenType.from_text(segment.type).get_index())
+ labels.append(label)
+ page_labels.append(PageLabels(number=page_number, labels=labels))
+ return PdfLabels(pages=page_labels)
+
+ def _find_closest_segment(self, bounding_box: Rectangle, segments: list[SegmentBox]) -> Optional[SegmentBox]:
+ if not segments:
+ return None
+
+ def intersection_key(segment: SegmentBox) -> float:
+ segment_rect = Rectangle.from_width_height(segment.left, segment.top, segment.width, segment.height)
+ return bounding_box.get_intersection_percentage(segment_rect)
+
+ closest = max(segments, key=intersection_key)
+ max_intersection = intersection_key(closest)
+ if max_intersection > 0:
+ return closest
+
+ candidates = [s for s in segments if s.top > bounding_box.top]
+ if not candidates:
+ return None
+
+ def distance_key(segment: SegmentBox) -> tuple[float, float]:
+ vertical_dist = segment.top - bounding_box.top
+ segment_center_x = segment.left + segment.width / 2
+ box_center_x = bounding_box.left + bounding_box.width / 2
+ horizontal_dist = abs(segment_center_x - box_center_x)
+ return (vertical_dist, horizontal_dist)
+
+ return min(candidates, key=distance_key)
+
+ def _get_link_segments(
+ self, link: dict, page: Page, segments_by_page: dict[int, list[SegmentBox]]
+ ) -> Optional[tuple[SegmentBox, SegmentBox]]:
+ rect = link["from"]
+ source_box = Rectangle.from_coordinates(rect[0], rect[1], rect[2], rect[3])
+ source_page_num = page.number + 1
+ source_segments = segments_by_page.get(source_page_num, [])
+ source_segment = self._find_closest_segment(source_box, source_segments)
+ if not source_segment:
+ return None
+
+ dest_page_num = link.get("page", -1) + 1
+ dest_segments = segments_by_page.get(dest_page_num, [])
+ if not dest_segments:
+ return None
+
+ if "to" not in link:
+ dest_box = Rectangle.from_coordinates(0, 0, 20, 20)
+ else:
+ dest = link["to"] * page.transformation_matrix
+ dest_box = Rectangle.from_coordinates(dest[0], dest[1], dest[0] + 20, dest[1] + 20)
+
+ dest_segment = self._find_closest_segment(dest_box, dest_segments)
+ if not dest_segment:
+ return None
+
+ return source_segment, dest_segment
+
+ def _extract_links_by_segments(
+ self, pdf_path: Path, vgt_segments: list[SegmentBox]
+ ) -> tuple[dict[SegmentBox, list[Link]], dict[SegmentBox, list[Link]]]:
+ links_by_source: dict[SegmentBox, list[Link]] = {}
+ links_by_dest: dict[SegmentBox, list[Link]] = {}
+
+ segments_by_page: dict[int, list[SegmentBox]] = {}
+ for segment in vgt_segments:
+ segments_by_page.setdefault(segment.page_number, []).append(segment)
+
+ doc = fitz.open(pdf_path)
+ try:
+ for page_num in range(len(doc)):
+ page: Page = doc[page_num]
+ links = page.get_links()
+ for link in links:
+ if "page" not in link:
+ continue
+ rect = link["from"]
+ text = page.get_text("text", clip=rect).strip()
+ if not text:
+ continue
+ segments_pair = self._get_link_segments(link, page, segments_by_page)
+ if not segments_pair:
+ continue
+ source, dest = segments_pair
+ new_link = Link(source_segment=source, destination_segment=dest, text=text)
+ links_by_source.setdefault(source, []).append(new_link)
+ links_by_dest.setdefault(dest, []).append(new_link)
+ finally:
+ doc.close()
+
+ return links_by_source, links_by_dest
+
+ def _insert_reference_links(self, segment_text: str, links: list[Link]) -> str:
+ offset = 0
+ for link in links:
+ start_idx = segment_text.find(link.text, offset)
+ if start_idx == -1:
+ continue
+ escaped_text = link.text.replace("[", "\\[").replace("]", "\\]")
+ md_link = f"[{escaped_text}](#{link.destination_segment.id})"
+ segment_text = segment_text[:start_idx] + md_link + segment_text[start_idx + len(link.text) :]
+ offset = start_idx + len(md_link)
+ return segment_text
+
+ def _process_picture_segment(
+ self,
+ segment: SegmentBox,
+ pdf_images: list[Image],
+ pdf_path: Path,
+ picture_id: int,
+ dpi: int = 72,
+ extracted_images: Optional[list[ExtractedImage]] = None,
+ user_base_name: Optional[str] = None,
+ ) -> str:
+
+ if extracted_images is None:
+ return ""
+
+ segment_box = Rectangle.from_width_height(segment.left, segment.top, segment.width, segment.height)
+ image = pdf_images[segment.page_number - 1]
+ left, top, right, bottom = segment_box.left, segment_box.top, segment_box.right, segment_box.bottom
+ if dpi != 72:
+ left = left * dpi / 72
+ top = top * dpi / 72
+ right = right * dpi / 72
+ bottom = bottom * dpi / 72
+ cropped = image.crop((left, top, right, bottom))
+
+ base_name = user_base_name if user_base_name else pdf_path.stem
+ image_name = f"{base_name}_{segment.page_number}_{picture_id}.png"
+
+ img_buffer = io.BytesIO()
+ cropped.save(img_buffer, format="PNG")
+ extracted_images.append(ExtractedImage(image_data=img_buffer.getvalue(), filename=image_name))
+ return f"\n" + f"
\n\n"
+
+ def _process_table_segment(self, segment: SegmentBox) -> str:
+ return f"\n" + segment.text + "\n\n"
+
+ def _get_token_content(self, token: PdfToken) -> str:
+ if self.output_format == OutputFormat.HTML:
+ return token.content_html
+ else:
+ return token.content_markdown
+
+ def _get_styled_content(self, token: PdfToken, content: str) -> str:
+ if self.output_format == OutputFormat.HTML:
+ styled = token.token_style.get_styled_content_html(content)
+ styled = token.token_style.script_type.get_styled_content(styled)
+ styled = token.token_style.list_level.get_styled_content_html(styled)
+ return token.token_style.hyperlink_style.get_styled_content_html(styled)
+ else:
+ styled = token.token_style.get_styled_content_markdown(content)
+ styled = token.token_style.script_type.get_styled_content(styled)
+ styled = token.token_style.list_level.get_styled_content_markdown(styled)
+ return token.token_style.hyperlink_style.get_styled_content_markdown(styled)
+
+ def _process_title_segment(self, tokens: list[PdfToken], segment: SegmentBox) -> str:
+ if not tokens:
+ return ""
+
+ title_type = tokens[0].token_style.title_type
+ content = " ".join([self._get_styled_content(token, token.content) for token in tokens])
+ if self.output_format == OutputFormat.HTML:
+ content = title_type.get_styled_content_html(content)
+ else:
+ content = title_type.get_styled_content_markdown(content)
+ anchor = f"\n"
+ return anchor + content + "\n\n"
+
+ def _process_regular_segment(
+ self,
+ tokens: list[PdfToken],
+ segment: SegmentBox,
+ links_by_source: dict[SegmentBox, list[Link]],
+ links_by_dest: dict[SegmentBox, list[Link]],
+ ) -> str:
+ if not tokens:
+ return ""
+ content = " ".join(self._get_token_content(t) for t in tokens)
+ if segment in links_by_source:
+ content = self._insert_reference_links(content, links_by_source[segment])
+ if segment in links_by_dest:
+ content = f"\n" + content
+ return content + "\n\n"
+
+ def _get_table_of_contents(self, vgt_segments: list[SegmentBox]) -> str:
+ title_segments = [s for s in vgt_segments if s.type in {TokenType.TITLE, TokenType.SECTION_HEADER}]
+ table_of_contents = "# Table of Contents\n\n"
+ for segment in title_segments:
+ if not segment.text.strip():
+ continue
+ first_word = segment.text.split()[0]
+ indentation = max(0, first_word.count(".") - 1)
+ content = " " * indentation + "- [" + segment.text + "](#" + segment.id + ")\n"
+ table_of_contents += content
+ table_of_contents += "\n"
+ return table_of_contents + "\n\n"
+
+ def _set_segment_ids(self, vgt_segments: list[SegmentBox]) -> None:
+ segments_by_page: dict[int, list[SegmentBox]] = {}
+ for segment in vgt_segments:
+ segments_by_page.setdefault(segment.page_number, []).append(segment)
+ for page_number, segments in segments_by_page.items():
+ for segment_index, segment in enumerate(segments):
+ segment.id = f"page-{page_number}-{segment_index}"
+
+ def _generate_content(
+ self,
+ pdf_path: Path,
+ vgt_segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ extracted_images: Optional[list[ExtractedImage]] = None,
+ user_base_name: Optional[str] = None,
+ ) -> str:
+ pdf_labels: PdfLabels = self._create_pdf_labels_from_segments(vgt_segments)
+ pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
+ pdf_features.set_token_types(pdf_labels)
+ pdf_features.set_token_styles()
+
+ self._set_segment_ids(vgt_segments)
+ content_parts: list[str] = []
+ if extract_toc:
+ content_parts.append(self._get_table_of_contents(vgt_segments))
+
+ links_by_source, links_by_dest = self._extract_links_by_segments(pdf_path, vgt_segments)
+
+ picture_segments = [s for s in vgt_segments if s.type == TokenType.PICTURE]
+ pdf_images: list[Image] = convert_from_path(pdf_path, dpi=dpi) if picture_segments else []
+
+ for page in pdf_features.pages:
+ segments_in_page = [s for s in vgt_segments if s.page_number == page.page_number]
+ picture_id = 0
+ for segment in segments_in_page:
+ seg_box = Rectangle.from_width_height(segment.left, segment.top, segment.width, segment.height)
+ tokens_in_seg = [t for t in page.tokens if t.bounding_box.get_intersection_percentage(seg_box) > 50]
+
+ if segment.type == TokenType.PICTURE:
+ content_parts.append(
+ self._process_picture_segment(
+ segment, pdf_images, pdf_path, picture_id, dpi, extracted_images, user_base_name
+ )
+ )
+ picture_id += 1
+ elif segment.type == TokenType.TABLE:
+ content_parts.append(self._process_table_segment(segment))
+ elif segment.type in {TokenType.TITLE, TokenType.SECTION_HEADER}:
+ content_parts.append(self._process_title_segment(tokens_in_seg, segment))
+ elif segment.type == TokenType.FORMULA:
+ content_parts.append(segment.text + "\n\n")
+ else:
+ content_parts.append(
+ self._process_regular_segment(tokens_in_seg, segment, links_by_source, links_by_dest)
+ )
+
+ return "".join(content_parts)
diff --git a/src/adapters/infrastructure/ocr/__init__.py b/src/adapters/infrastructure/ocr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/ocr/languages.py b/src/adapters/infrastructure/ocr/languages.py
new file mode 100644
index 0000000000000000000000000000000000000000..6499a060b60c8e36555c741852775db479e30300
--- /dev/null
+++ b/src/adapters/infrastructure/ocr/languages.py
@@ -0,0 +1,174 @@
+import subprocess
+
+iso_to_tesseract = {
+ "af": "afr", # Afrikaans
+ "all": "all", # Allar
+ "am": "amh", # Amharic
+ "ar": "ara", # Arabic
+ "as": "asm", # Assamese
+ "az": "aze", # Azerbaijani
+ "aze-cyrl": "aze-cyrl", # Azerbaijani (Cyrillic)
+ "be": "bel", # Belarusian
+ "bn": "ben", # Bangla
+ "bo": "bod", # Tibetan
+ "bs": "bos", # Bosnian
+ "br": "bre", # Breton
+ "bg": "bul", # Bulgarian
+ "ca": "cat", # Catalan
+ "ceb": "ceb", # Cebuano
+ "cs": "ces", # Czech
+ "zh-Hans": "chi_sim", # Chinese (Simplified)
+ "chi-sim-vert": "chi-sim-vert", # Chinese (Simplified) vertical
+ "zh-Hant": "chi_tra", # Chinese (Traditional)
+ "chi-tra-vert": "chi-tra-vert", # Chinese (Traditional) vertical
+ "chr": "chr", # Cherokee
+ "co": "cos", # Corsican
+ "cy": "cym", # Welsh
+ "da": "dan", # Danish
+ "de": "deu", # German
+ "dv": "div", # Divehi
+ "dz": "dzo", # Dzongkha
+ "el": "ell", # Greek
+ "en": "eng", # English
+ "enm": "enm", # Middle English
+ "eo": "epo", # Esperanto
+ "et": "est", # Estonian
+ "eu": "eus", # Basque
+ "fo": "fao", # Faroese
+ "fa": "fas", # Persian
+ "fil": "fil", # Filipino
+ "fi": "fin", # Finnish
+ "fr": "fra", # French
+ "frk": "frk", # Frankish
+ "frm": "frm", # Middle French
+ "fy": "fry", # Western Frisian
+ "gd": "gla", # Scottish Gaelic
+ "ga": "gle", # Irish
+ "gl": "glg", # Galician
+ "grc": "grc", # Ancient Greek
+ "gu": "guj", # Gujarati
+ "ht": "hat", # Haitian Creole
+ "he": "heb", # Hebrew
+ "hi": "hin", # Hindi
+ "hr": "hrv", # Croatian
+ "hu": "hun", # Hungarian
+ "hy": "hye", # Armenian
+ "iu": "iku", # Inuktitut
+ "id": "ind", # Indonesian
+ "is": "isl", # Icelandic
+ "it": "ita", # Italian
+ "ita-old": "ita-old", # Old Italian
+ "jv": "jav", # Javanese
+ "ja": "jpn", # Japanese
+ "jpn-vert": "jpn-vert", # Japanese vertical
+ "kn": "kan", # Kannada
+ "ka": "kat", # Georgian
+ "kat-old": "kat-old", # Old Georgian
+ "kk": "kaz", # Kazakh
+ "km": "khm", # Khmer
+ "ky": "kir", # Kyrgyz
+ "kmr": "kmr", # Northern Kurdish
+ "ko": "kor", # Korean
+ "kor-vert": "kor_vert", # Korean vertical
+ "lo": "lao", # Lao
+ "la": "lat", # Latin
+ "lv": "lav", # Latvian
+ "lt": "lit", # Lithuanian
+ "lb": "ltz", # Luxembourgish
+ "ml": "mal", # Malayalam
+ "mr": "mar", # Marathi
+ "mk": "mkd", # Macedonian
+ "mt": "mlt", # Maltese
+ "mn": "mon", # Mongolian
+ "mi": "mri", # Māori
+ "ms": "msa", # Malay
+ "my": "mya", # Burmese
+ "ne": "nep", # Nepali
+ "nl": "nld", # Dutch
+ "no": "nor", # Norwegian
+ "oc": "oci", # Occitan
+ "or": "ori", # Odia
+ "osd": "osd", # Unknown language [osd]
+ "pa": "pan", # Punjabi
+ "pl": "pol", # Polish
+ "pt": "por", # Portuguese
+ "ps": "pus", # Pashto
+ "qu": "que", # Quechua
+ "ro": "ron", # Romanian
+ "ru": "rus", # Russian
+ "sa": "san", # Sanskrit
+ "script-arab": "script-arab", # Arabic script
+ "script-armn": "script-armn", # Armenian script
+ "script-beng": "script-beng", # Bengali script
+ "script-cans": "script-cans", # Canadian Aboriginal script
+ "script-cher": "script-cher", # Cherokee script
+ "script-cyrl": "script-cyrl", # Cyrillic script
+ "script-deva": "script-deva", # Devanagari script
+ "script-ethi": "script-ethi", # Ethiopic script
+ "script-frak": "script-frak", # Frankish script
+ "script-geor": "script-geor", # Georgian script
+ "script-grek": "script-grek", # Greek script
+ "script-gujr": "script-gujr", # Gujarati script
+ "script-guru": "script-guru", # Gurmukhi script
+ "script-hang": "script-hang", # Hangul script
+ "script-hang-vert": "script-hang-vert", # Hangul script vertical
+ "script-hans": "script-hans",
+ "script-hans-vert": "script-hans-vert",
+ "script-hant": "script-hant",
+ "script-hant-vert": "script-hant-vert",
+ "script-hebr": "script-hebr", # Hebrew script
+ "script-jpan": "script-jpan", # Japanese script
+ "script-jpan-vert": "script-jpan-vert", # Japanese script vertical
+ "script-khmr": "script-khmr", # Khmer script
+ "script-knda": "script-knda", # Kannada script
+ "script-laoo": "script-laoo", # Lao script
+ "script-latn": "script-latn",
+ "script-mlym": "script-mlym", # Malayalam script
+ "script-mymr": "script-mymr", # Myanmar script
+ "script-orya": "script-orya", # Odia script
+ "script-sinh": "script-sinh", # Sinhala script
+ "script-syrc": "script-syrc", # Syriac script
+ "script-taml": "script-taml", # Tamil script
+ "script-telu": "script-telu", # Telugu script
+ "script-thaa": "script-thaa", # Thaana script
+ "script-thai": "script-thai", # Thai script
+ "script-tibt": "script-tibt", # Tibetan script
+ "script-viet": "script-viet", # Vietnamese script
+ "si": "sin", # Sinhala
+ "sk": "slk", # Slovak
+ "sl": "slv", # Slovenian
+ "sd": "snd", # Sindhi
+ "es": "spa", # Spanish
+ "spa-old": "spa-old", # Old Spanish
+ "sq": "sqi", # Albanian
+ "sr": "srp", # Serbian
+ "srp-latn": "srp-latn", # Serbian (Latin)
+ "su": "sun", # Sundanese
+ "sw": "swa", # Swahili
+ "sv": "swe", # Swedish
+ "syr": "syr", # Syriac
+ "ta": "tam", # Tamil
+ "tt": "tat", # Tatar
+ "te": "tel", # Telugu
+ "tg": "tgk", # Tajik
+ "th": "tha", # Thai
+ "ti": "tir", # Tigrinya
+ "to": "ton", # Tongan
+ "tr": "tur", # Turkish
+ "ug": "uig", # Uyghur
+ "uk": "ukr", # Ukrainian
+ "ur": "urd", # Urdu
+ "uz": "uzb", # Uzbek
+ "uzb-cyrl": "uzb-cyrl", # Uzbek (Cyrillic)
+ "vi": "vie", # Vietnamese
+ "yi": "yid", # Yiddish
+ "yo": "yor", # Yoruba
+}
+
+
+def supported_languages():
+ cmd = "tesseract --list-langs | grep -v osd | awk '{if(NR>1)print}'"
+ sp = subprocess.Popen(["/bin/bash", "-c", cmd], stdout=subprocess.PIPE)
+ tesseract_langs = [line.strip().decode("utf-8") for line in sp.stdout.readlines()]
+ inverted_iso_dict = {v: k for k, v in iso_to_tesseract.items()}
+ return list({tesseract_key: inverted_iso_dict[tesseract_key] for tesseract_key in tesseract_langs}.values())
diff --git a/src/adapters/infrastructure/ocr_service_adapter.py b/src/adapters/infrastructure/ocr_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..c279df9972051a116a08df4b6be035ae6ab38b01
--- /dev/null
+++ b/src/adapters/infrastructure/ocr_service_adapter.py
@@ -0,0 +1,41 @@
+import os
+import shutil
+import subprocess
+from pathlib import Path
+from ports.services.ocr_service import OCRService
+from configuration import OCR_SOURCE, OCR_OUTPUT, OCR_FAILED
+from adapters.infrastructure.ocr.languages import iso_to_tesseract, supported_languages
+
+
+class OCRServiceAdapter(OCRService):
+ def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path:
+ source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath = self._get_paths(namespace, filename)
+ os.makedirs(processed_pdf_filepath.parent, exist_ok=True)
+
+ result = subprocess.run(
+ [
+ "ocrmypdf",
+ "-l",
+ iso_to_tesseract[language],
+ source_pdf_filepath,
+ processed_pdf_filepath,
+ "--force-ocr",
+ ]
+ )
+
+ if result.returncode == 0:
+ return processed_pdf_filepath
+
+ os.makedirs(failed_pdf_filepath.parent, exist_ok=True)
+ shutil.move(source_pdf_filepath, failed_pdf_filepath)
+ return False
+
+ def get_supported_languages(self) -> list[str]:
+ return supported_languages()
+
+ def _get_paths(self, namespace: str, pdf_file_name: str) -> tuple[Path, Path, Path]:
+ file_name = "".join(pdf_file_name.split(".")[:-1]) if "." in pdf_file_name else pdf_file_name
+ source_pdf_filepath = Path(OCR_SOURCE, namespace, pdf_file_name)
+ processed_pdf_filepath = Path(OCR_OUTPUT, namespace, f"{file_name}.pdf")
+ failed_pdf_filepath = Path(OCR_FAILED, namespace, pdf_file_name)
+ return source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath
diff --git a/src/adapters/infrastructure/pdf_analysis_service_adapter.py b/src/adapters/infrastructure/pdf_analysis_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9ec52cde499fa06ad38dac30d0236fa0e3ce6e
--- /dev/null
+++ b/src/adapters/infrastructure/pdf_analysis_service_adapter.py
@@ -0,0 +1,68 @@
+from typing import AnyStr
+from domain.PdfImages import PdfImages
+from domain.SegmentBox import SegmentBox
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from ports.services.ml_model_service import MLModelService
+from ports.services.format_conversion_service import FormatConversionService
+from ports.repositories.file_repository import FileRepository
+from configuration import service_logger
+
+
+class PDFAnalysisServiceAdapter(PDFAnalysisService):
+ def __init__(
+ self,
+ vgt_model_service: MLModelService,
+ fast_model_service: MLModelService,
+ format_conversion_service: FormatConversionService,
+ file_repository: FileRepository,
+ ):
+ self.vgt_model_service = vgt_model_service
+ self.fast_model_service = fast_model_service
+ self.format_conversion_service = format_conversion_service
+ self.file_repository = file_repository
+
+ def analyze_pdf_layout(
+ self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
+ ) -> list[dict]:
+ pdf_path = self.file_repository.save_pdf(pdf_content)
+ service_logger.info("Creating PDF images")
+
+ pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)]
+
+ predicted_segments = self.vgt_model_service.predict_document_layout(pdf_images_list)
+
+ if parse_tables_and_math:
+ pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200)
+ self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments)
+ self.format_conversion_service.convert_table_to_html(pdf_images_200_dpi, predicted_segments)
+
+ if not keep_pdf:
+ self.file_repository.delete_file(pdf_path)
+
+ return [
+ SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict()
+ for pdf_segment in predicted_segments
+ ]
+
+ def analyze_pdf_layout_fast(
+ self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
+ ) -> list[dict]:
+ pdf_path = self.file_repository.save_pdf(pdf_content)
+ service_logger.info("Creating PDF images for fast analysis")
+
+ pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)]
+
+ predicted_segments = self.fast_model_service.predict_layout_fast(pdf_images_list)
+
+ if parse_tables_and_math:
+ pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200)
+ self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments)
+ self.format_conversion_service.convert_table_to_html(pdf_images_list[0], predicted_segments)
+
+ if not keep_pdf:
+ self.file_repository.delete_file(pdf_path)
+
+ return [
+ SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict()
+ for pdf_segment in predicted_segments
+ ]
diff --git a/src/adapters/infrastructure/text_extraction_adapter.py b/src/adapters/infrastructure/text_extraction_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc07ec5dee645cb67fbd9b14b884b0b29854878
--- /dev/null
+++ b/src/adapters/infrastructure/text_extraction_adapter.py
@@ -0,0 +1,20 @@
+from pdf_token_type_labels import TokenType
+from ports.services.text_extraction_service import TextExtractionService
+from configuration import service_logger
+
+
+class TextExtractionAdapter(TextExtractionService):
+ def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict:
+ service_logger.info(f"Extracted types: {[t.name for t in token_types]}")
+ text = "\n".join(
+ [
+ segment_box["text"]
+ for segment_box in segment_boxes
+ if TokenType.from_text(segment_box["type"].replace(" ", "_")) in token_types
+ ]
+ )
+ return text
+
+ def extract_all_text(self, segment_boxes: list[dict]) -> dict:
+ all_types = [t for t in TokenType]
+ return self.extract_text_by_types(segment_boxes, all_types)
diff --git a/src/adapters/infrastructure/toc/MergeTwoSegmentsTitles.py b/src/adapters/infrastructure/toc/MergeTwoSegmentsTitles.py
new file mode 100644
index 0000000000000000000000000000000000000000..18ffcd6fa2f20d177adb4c65058f45144fc50057
--- /dev/null
+++ b/src/adapters/infrastructure/toc/MergeTwoSegmentsTitles.py
@@ -0,0 +1,48 @@
+from adapters.infrastructure.toc.TitleFeatures import TitleFeatures
+from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
+
+
+class MergeTwoSegmentsTitles:
+ def __init__(self, pdf_segmentation: PdfSegmentation):
+ self.title_features_list: list[TitleFeatures] = TitleFeatures.from_pdf_segmentation(pdf_segmentation)
+ self.titles_merged: list[TitleFeatures] = list()
+ self.merge()
+
+ def merge(self):
+ index = 0
+ while index < len(self.title_features_list):
+ if index == len(self.title_features_list) - 1:
+ self.titles_merged.append(self.title_features_list[index])
+ break
+
+ if not self.should_merge(self.title_features_list[index], self.title_features_list[index + 1]):
+ self.titles_merged.append(self.title_features_list[index])
+ index += 1
+ continue
+
+ self.title_features_list[index + 1] = self.title_features_list[index + 1].append(self.title_features_list[index])
+ index += 1
+
+ @staticmethod
+ def should_merge(title: TitleFeatures, other_title: TitleFeatures):
+ same_page = other_title.pdf_segment.page_number == title.pdf_segment.page_number
+
+ if not same_page:
+ return False
+
+ if abs(other_title.top - title.bottom) > 15:
+ return False
+
+ if abs(other_title.left - title.right) > 15 or abs(other_title.right - title.left) > 15:
+ return False
+
+ if title.first_characters_type in [1, 2, 3] and other_title.first_characters_type in [1, 2, 3]:
+ return False
+
+ if title.bullet_points_type and other_title.bullet_points_type:
+ return False
+
+ if title.get_features_to_merge() != other_title.get_features_to_merge():
+ return False
+
+ return True
diff --git a/src/adapters/infrastructure/toc/PdfSegmentation.py b/src/adapters/infrastructure/toc/PdfSegmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa111972faadf1583b1dab51d4f150f2a7a929b0
--- /dev/null
+++ b/src/adapters/infrastructure/toc/PdfSegmentation.py
@@ -0,0 +1,32 @@
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfFeatures
+from pdf_features import PdfToken
+
+
+class PdfSegmentation:
+ def __init__(self, pdf_features: PdfFeatures, pdf_segments: list[PdfSegment]):
+ self.pdf_features: PdfFeatures = pdf_features
+ self.pdf_segments: list[PdfSegment] = pdf_segments
+ self.tokens_by_segments: dict[PdfSegment, list[PdfToken]] = self.find_tokens_by_segments()
+
+ @staticmethod
+ def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments):
+ best_score: float = 0
+ most_probable_segment: PdfSegment | None = None
+ for segment in segments:
+ intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box)
+ if intersection_percentage > best_score:
+ best_score = intersection_percentage
+ most_probable_segment = segment
+ if best_score >= 99:
+ break
+ if most_probable_segment:
+ tokens_by_segments.setdefault(most_probable_segment, list()).append(token)
+
+ def find_tokens_by_segments(self):
+ tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {}
+ for page in self.pdf_features.pages:
+ page_segments = [segment for segment in self.pdf_segments if segment.page_number == page.page_number]
+ for token in page.tokens:
+ self.find_segment_for_token(token, page_segments, tokens_by_segments)
+ return tokens_by_segments
diff --git a/src/adapters/infrastructure/toc/TOCExtractor.py b/src/adapters/infrastructure/toc/TOCExtractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2c212e69a4c66d916babcd735a0aa6dbfacafb4
--- /dev/null
+++ b/src/adapters/infrastructure/toc/TOCExtractor.py
@@ -0,0 +1,67 @@
+from adapters.infrastructure.toc.MergeTwoSegmentsTitles import MergeTwoSegmentsTitles
+from adapters.infrastructure.toc.TitleFeatures import TitleFeatures
+from adapters.infrastructure.toc.data.TOCItem import TOCItem
+from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
+
+
+class TOCExtractor:
+ def __init__(self, pdf_segmentation: PdfSegmentation):
+ self.pdf_segmentation = pdf_segmentation
+ self.titles_features_sorted = MergeTwoSegmentsTitles(self.pdf_segmentation).titles_merged
+ self.toc: list[TOCItem] = list()
+ self.set_toc()
+
+ def set_toc(self):
+ for index, title_features in enumerate(self.titles_features_sorted):
+ indentation = self.get_indentation(index, title_features)
+ self.toc.append(title_features.to_toc_item(indentation))
+
+ def __str__(self):
+ return "\n".join([f'{" " * x.indentation} * {x.label}' for x in self.toc])
+
+ def get_indentation(self, title_index: int, title_features: TitleFeatures):
+ if title_index == 0:
+ return 0
+
+ for index in reversed(range(title_index)):
+ if self.toc[index].point_closed:
+ continue
+
+ if self.same_indentation(self.titles_features_sorted[index], title_features):
+ self.close_toc_items(self.toc[index].indentation)
+ return self.toc[index].indentation
+
+ return self.toc[title_index - 1].indentation + 1
+
+ def close_toc_items(self, indentation):
+ for toc in self.toc:
+ if toc.indentation > indentation:
+ toc.point_closed = True
+
+ @staticmethod
+ def same_indentation(previous_title_features: TitleFeatures, title_features: TitleFeatures):
+ if previous_title_features.first_characters in title_features.get_possible_previous_point():
+ return True
+
+ if previous_title_features.get_features_toc() == title_features.get_features_toc():
+ return True
+
+ return False
+
+ def to_dict(self):
+ toc: list[dict[str, any]] = list()
+
+ for toc_item in self.toc:
+ toc_element_dict = dict()
+ toc_element_dict["indentation"] = toc_item.indentation
+ toc_element_dict["label"] = toc_item.label
+ rectangle = dict()
+ rectangle["left"] = int(toc_item.selection_rectangle.left)
+ rectangle["top"] = int(toc_item.selection_rectangle.top)
+ rectangle["width"] = int(toc_item.selection_rectangle.width)
+ rectangle["height"] = int(toc_item.selection_rectangle.height)
+ rectangle["page"] = str(toc_item.selection_rectangle.page_number)
+ toc_element_dict["bounding_box"] = rectangle
+ toc.append(toc_element_dict)
+
+ return toc
diff --git a/src/adapters/infrastructure/toc/TitleFeatures.py b/src/adapters/infrastructure/toc/TitleFeatures.py
new file mode 100755
index 0000000000000000000000000000000000000000..eccb5353ff90c6bf3fb640584ebb787be41306dd
--- /dev/null
+++ b/src/adapters/infrastructure/toc/TitleFeatures.py
@@ -0,0 +1,171 @@
+import string
+import roman
+import numpy as np
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfToken
+from pdf_features import Rectangle
+from domain.SegmentBox import SegmentBox
+from adapters.infrastructure.toc.data.TOCItem import TOCItem
+from adapters.infrastructure.toc.methods.two_models_v3_segments_context_2.Modes import Modes
+from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
+
+
+class TitleFeatures:
+ SPECIAL_MARKERS = [".", "(", ")", "\\", "/", ":", ";", "-", "_", "[", "]", "•", "◦", "*", ","]
+ ALPHABET = list(string.ascii_lowercase)
+ ALPHABET_UPPERCASE = list(string.ascii_uppercase)
+ ROMAN_NUMBERS = [roman.toRoman(i) for i in range(1, 151)]
+ ROMAN_NUMBERS_LOWERCASE = [x.lower() for x in ROMAN_NUMBERS]
+ BULLET_POINTS = [ALPHABET, ALPHABET_UPPERCASE, ROMAN_NUMBERS, ROMAN_NUMBERS_LOWERCASE]
+
+ def __init__(self, pdf_segment: PdfSegment, segment_tokens: list[PdfToken], pdf_features, modes: Modes):
+ self.modes = modes
+ self.pdf_segment = pdf_segment
+ self.pdf_features = pdf_features
+
+ self.segment_tokens: list[PdfToken] = segment_tokens
+ self.first_characters: str = ""
+ self.first_characters_special_markers_count: int = 0
+ self.font_size: float = 0.0
+ self.text_content: str = ""
+ self.width: float = 0
+ self.font_family: str = ""
+ self.font_color: str = ""
+ self.line_height: float = 0.0
+ self.uppercase: bool = False
+ self.bold: float = False
+ self.italics: float = False
+ self.first_characters_type = 0
+ self.bullet_points_type = 0
+ self.text_centered: int = 0
+ self.is_left: bool = False
+ self.indentation: int = -1
+ self.left: int = self.pdf_segment.bounding_box.left
+ self.top: int = self.pdf_segment.bounding_box.top
+ self.right: int = self.pdf_segment.bounding_box.right
+ self.bottom: int = self.pdf_segment.bounding_box.bottom
+
+ self.initialize_text_properties()
+ self.process_first_characters()
+ self.process_font_properties()
+ self.process_positional_properties()
+
+ def initialize_text_properties(self):
+ words = [token.content for token in self.segment_tokens]
+ self.text_content = " ".join(words)
+
+ def process_first_characters(self):
+ self.first_characters = self.text_content.split(" ")[0].split("\n")[0].split("\t")[0]
+ clean_first_characters = [x for x in self.first_characters if x not in self.SPECIAL_MARKERS]
+ characters_checker = {
+ 1: lambda x_list: len(x_list) == len([letter for letter in x_list if letter in "IVXL"]),
+ 2: lambda x_list: len(x_list) == len([letter for letter in x_list if letter in "IVXL".lower()]),
+ 3: lambda x_list: len(x_list) == len([letter for letter in x_list if letter in "1234567890"]),
+ 4: lambda x_list: len(x_list) == len([letter for letter in x_list if letter == letter.upper()]),
+ }
+
+ self.first_characters_type = next(
+ (index for index, type_checker in characters_checker.items() if type_checker(clean_first_characters)), 0
+ )
+
+ self.bullet_points_type = (
+ self.SPECIAL_MARKERS.index(self.first_characters[-1]) + 1
+ if self.first_characters[-1] in self.SPECIAL_MARKERS
+ else 0
+ )
+ self.first_characters_special_markers_count = len(
+ [x for x in self.first_characters[:-1] if x in self.SPECIAL_MARKERS]
+ )
+
+ def process_font_properties(self):
+ self.font_family = self.segment_tokens[0].font.font_id
+ self.font_color = self.segment_tokens[0].font.color
+ self.bold = sum(token.font.bold for token in self.segment_tokens) / len(self.segment_tokens)
+ self.italics = sum(token.font.italics for token in self.segment_tokens) / len(self.segment_tokens)
+ self.uppercase = self.text_content.upper() == self.text_content
+ font_sizes = [token.font.font_size for token in self.segment_tokens]
+ self.font_size = np.mean(font_sizes)
+
+ def process_positional_properties(self):
+ self.line_height = self.segment_tokens[0].font.font_size
+ page_width = self.pdf_features.pages[self.pdf_segment.page_number - 1].page_width
+ self.text_centered = 1 if abs(self.left - (page_width - self.right)) < 10 else 0
+ self.is_left = self.left < page_width - self.right if not self.text_centered else False
+ self.indentation = int((self.left - self.modes.left_space_mode) / 15) if self.is_left else -1
+
+ def get_features_to_merge(self) -> np.array:
+ return (
+ 1 if self.bold else 0,
+ 1 if self.italics else 0,
+ )
+
+ def get_features_toc(self) -> np.array:
+ return (
+ 1 if self.bold else 0,
+ 1 if self.italics else 0,
+ self.first_characters_type,
+ self.first_characters_special_markers_count,
+ self.bullet_points_type,
+ )
+
+ def get_possible_previous_point(self) -> list[str]:
+ previous_characters = self.first_characters
+ final_special_markers = ""
+ last_part = ""
+ for letter in list(reversed(previous_characters)):
+ if not last_part and letter in self.SPECIAL_MARKERS:
+ final_special_markers = previous_characters[-1] + final_special_markers
+ previous_characters = previous_characters[:-1]
+ continue
+
+ if last_part and letter in self.SPECIAL_MARKERS:
+ break
+
+ last_part = letter + last_part
+ previous_characters = previous_characters[:-1]
+
+ previous_items = self.get_previous_items(last_part)
+
+ if not previous_items and len(self.first_characters) >= 4:
+ return [self.first_characters]
+
+ return [previous_characters + x + final_special_markers for x in previous_items]
+
+ def get_previous_items(self, item: str):
+ previous_items = []
+
+ for bullet_points in self.BULLET_POINTS:
+ if item in bullet_points and bullet_points.index(item):
+ previous_items.append(bullet_points[bullet_points.index(item) - 1])
+
+ if item.isnumeric():
+ previous_items.append(str(int(item) - 1))
+
+ return previous_items
+
+ @staticmethod
+ def from_pdf_segmentation(pdf_segmentation: PdfSegmentation) -> list["TitleFeatures"]:
+ titles_features = list()
+ modes = Modes(pdf_features=pdf_segmentation.pdf_features)
+ for pdf_segment in pdf_segmentation.pdf_segments:
+ segment_tokens = pdf_segmentation.tokens_by_segments[pdf_segment]
+ titles_features.append(TitleFeatures(pdf_segment, segment_tokens, pdf_segmentation.pdf_features, modes))
+
+ return titles_features
+
+ def to_toc_item(self, indentation):
+ return TOCItem(
+ indentation=indentation,
+ label=self.text_content,
+ selection_rectangle=SegmentBox.from_pdf_segment(self.pdf_segment, self.pdf_features.pages),
+ )
+
+ def append(self, other_title_features: "TitleFeatures"):
+ other_segment = other_title_features.pdf_segment
+ merged_bounding_box = Rectangle.merge_rectangles([self.pdf_segment.bounding_box, other_segment.bounding_box])
+ merged_content = self.pdf_segment.text_content + other_segment.text_content
+ merged_segment = PdfSegment(
+ self.pdf_segment.page_number, merged_bounding_box, merged_content, self.pdf_segment.segment_type
+ )
+ segment_tokens = self.segment_tokens + other_title_features.segment_tokens
+ return TitleFeatures(merged_segment, segment_tokens, pdf_features=self.pdf_features, modes=self.modes)
diff --git a/src/adapters/infrastructure/toc/__init__.py b/src/adapters/infrastructure/toc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/toc/data/TOCItem.py b/src/adapters/infrastructure/toc/data/TOCItem.py
new file mode 100644
index 0000000000000000000000000000000000000000..faff3c552e55ce2c09e3ed417b168a3c74f4b290
--- /dev/null
+++ b/src/adapters/infrastructure/toc/data/TOCItem.py
@@ -0,0 +1,10 @@
+from pydantic import BaseModel
+
+from domain.SegmentBox import SegmentBox
+
+
+class TOCItem(BaseModel):
+ indentation: int
+ label: str = ""
+ selection_rectangle: SegmentBox
+ point_closed: bool = False
diff --git a/src/adapters/infrastructure/toc/data/__init__.py b/src/adapters/infrastructure/toc/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/toc/extract_table_of_contents.py b/src/adapters/infrastructure/toc/extract_table_of_contents.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5bd6e0caad2df065ccfe28ffb067a235c6b65e
--- /dev/null
+++ b/src/adapters/infrastructure/toc/extract_table_of_contents.py
@@ -0,0 +1,73 @@
+import tempfile
+import uuid
+from os.path import join
+from pathlib import Path
+from typing import AnyStr
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfFeatures
+from pdf_features import Rectangle
+from pdf_token_type_labels import TokenType
+from adapters.infrastructure.toc.TOCExtractor import TOCExtractor
+from configuration import service_logger
+from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
+
+TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER}
+SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE}
+
+
+def get_file_path(file_name, extension):
+ return join(tempfile.gettempdir(), file_name + "." + extension)
+
+
+def pdf_content_to_pdf_path(file_content):
+ file_id = str(uuid.uuid1())
+
+ pdf_path = Path(get_file_path(file_id, "pdf"))
+ pdf_path.write_bytes(file_content)
+
+ return pdf_path
+
+
+def skip_name_of_the_document(pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]):
+ segments_to_remove = []
+ last_segment = None
+ for segment in pdf_segments:
+ if segment.segment_type not in SKIP_TYPES:
+ break
+ if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE:
+ continue
+ if not last_segment:
+ last_segment = segment
+ else:
+ if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66:
+ break
+ last_segment = segment
+ if segment.segment_type in TITLE_TYPES:
+ segments_to_remove.append(segment)
+ for segment in segments_to_remove:
+ title_segments.remove(segment)
+
+
+def get_pdf_segments_from_segment_boxes(pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]:
+ pdf_segments: list[PdfSegment] = []
+ for segment_box in segment_boxes:
+ left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"]
+ bounding_box = Rectangle.from_width_height(left, top, width, height)
+ segment_type = TokenType.from_value(segment_box["type"])
+ pdf_name = pdf_features.file_name
+ segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name)
+ pdf_segments.append(segment)
+ return pdf_segments
+
+
+def extract_table_of_contents(file: AnyStr, segment_boxes: list[dict], skip_document_name=False):
+ service_logger.info("Getting TOC")
+ pdf_path = pdf_content_to_pdf_path(file)
+ pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
+ pdf_segments: list[PdfSegment] = get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes)
+ title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES]
+ if skip_document_name:
+ skip_name_of_the_document(pdf_segments, title_segments)
+ pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments)
+ toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation)
+ return toc_instance.to_dict()
diff --git a/src/adapters/infrastructure/toc/methods/__init__.py b/src/adapters/infrastructure/toc/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/Modes.py b/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/Modes.py
new file mode 100644
index 0000000000000000000000000000000000000000..c536b1f3f805db2b028b28336da937e5dad5994b
--- /dev/null
+++ b/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/Modes.py
@@ -0,0 +1,44 @@
+import dataclasses
+import hashlib
+from statistics import mode
+
+from pdf_features import PdfFeatures
+
+
+@dataclasses.dataclass
+class Modes:
+ lines_space_mode: float
+ left_space_mode: float
+ right_space_mode: float
+ font_size_mode: float
+ font_family_name_mode: str
+ font_family_mode: int
+ font_family_mode_normalized: float
+ pdf_features: PdfFeatures
+
+ def __init__(self, pdf_features: PdfFeatures):
+ self.pdf_features = pdf_features
+ self.set_modes()
+
+ def set_modes(self):
+ line_spaces, right_spaces, left_spaces = [0], [0], [0]
+ for page, token in self.pdf_features.loop_tokens():
+ right_spaces.append(self.pdf_features.pages[0].page_width - token.bounding_box.right)
+ left_spaces.append(token.bounding_box.left)
+ line_spaces.append(token.bounding_box.bottom)
+
+ self.lines_space_mode = mode(line_spaces)
+ self.left_space_mode = mode(left_spaces)
+ self.right_space_mode = mode(right_spaces)
+
+ font_sizes = [token.font.font_size for page, token in self.pdf_features.loop_tokens() if token.font]
+ self.font_size_mode = mode(font_sizes) if font_sizes else 0
+ font_ids = [token.font.font_id for page, token in self.pdf_features.loop_tokens() if token.font]
+ self.font_family_name_mode = mode(font_ids) if font_ids else ""
+ self.font_family_mode = abs(
+ int(
+ str(hashlib.sha256(self.font_family_name_mode.encode("utf-8")).hexdigest())[:8],
+ 16,
+ )
+ )
+ self.font_family_mode_normalized = float(f"{str(self.font_family_mode)[0]}.{str(self.font_family_mode)[1:]}")
diff --git a/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py b/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/infrastructure/toc_service_adapter.py b/src/adapters/infrastructure/toc_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..708b44a66f9d26f7d09f2c7b7769ac9587fb7f97
--- /dev/null
+++ b/src/adapters/infrastructure/toc_service_adapter.py
@@ -0,0 +1,83 @@
+import tempfile
+import uuid
+from os.path import join
+from pathlib import Path
+from typing import AnyStr
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfFeatures, Rectangle
+from pdf_token_type_labels import TokenType
+from ports.services.toc_service import TOCService
+from configuration import service_logger
+from adapters.infrastructure.toc.TOCExtractor import TOCExtractor
+from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
+
+TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER}
+SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE}
+
+
+class TOCServiceAdapter(TOCService):
+
+ def extract_table_of_contents(
+ self, pdf_content: AnyStr, segment_boxes: list[dict], skip_document_name=False
+ ) -> list[dict]:
+ service_logger.info("Getting TOC")
+ pdf_path = self._pdf_content_to_pdf_path(pdf_content)
+ pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
+ pdf_segments: list[PdfSegment] = self._get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes)
+ title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES]
+ if skip_document_name:
+ self._skip_name_of_the_document(pdf_segments, title_segments)
+ pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments)
+ toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation)
+ return toc_instance.to_dict()
+
+ def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]:
+ toc_compatible = []
+ for toc_item in toc_items:
+ toc_compatible.append(toc_item.copy())
+ toc_compatible[-1]["bounding_box"]["left"] = int(toc_item["bounding_box"]["left"] / 0.75)
+ toc_compatible[-1]["bounding_box"]["top"] = int(toc_item["bounding_box"]["top"] / 0.75)
+ toc_compatible[-1]["bounding_box"]["width"] = int(toc_item["bounding_box"]["width"] / 0.75)
+ toc_compatible[-1]["bounding_box"]["height"] = int(toc_item["bounding_box"]["height"] / 0.75)
+ toc_compatible[-1]["selectionRectangles"] = [toc_compatible[-1]["bounding_box"]]
+ del toc_compatible[-1]["bounding_box"]
+ return toc_compatible
+
+ def _get_file_path(self, file_name: str, extension: str) -> str:
+ return join(tempfile.gettempdir(), file_name + "." + extension)
+
+ def _pdf_content_to_pdf_path(self, file_content: AnyStr) -> Path:
+ file_id = str(uuid.uuid1())
+ pdf_path = Path(self._get_file_path(file_id, "pdf"))
+ pdf_path.write_bytes(file_content)
+ return pdf_path
+
+ def _skip_name_of_the_document(self, pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]) -> None:
+ segments_to_remove = []
+ last_segment = None
+ for segment in pdf_segments:
+ if segment.segment_type not in SKIP_TYPES:
+ break
+ if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE:
+ continue
+ if not last_segment:
+ last_segment = segment
+ else:
+ if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66:
+ break
+ last_segment = segment
+ if segment.segment_type in TITLE_TYPES:
+ segments_to_remove.append(segment)
+ for segment in segments_to_remove:
+ title_segments.remove(segment)
+
+ def _get_pdf_segments_from_segment_boxes(self, pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]:
+ pdf_segments: list[PdfSegment] = []
+ for segment_box in segment_boxes:
+ left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"]
+ bounding_box = Rectangle.from_width_height(left, top, width, height)
+ segment_type = TokenType.from_value(segment_box["type"])
+ pdf_name = pdf_features.file_name
+ segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name)
+ pdf_segments.append(segment)
+ return pdf_segments
diff --git a/src/adapters/infrastructure/visualization_service_adapter.py b/src/adapters/infrastructure/visualization_service_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a79d5a111fec333c384dca06ac2a56d2b5aa8f7
--- /dev/null
+++ b/src/adapters/infrastructure/visualization_service_adapter.py
@@ -0,0 +1,81 @@
+from pathlib import Path
+from os import makedirs
+from os.path import join
+from pdf_annotate import PdfAnnotator, Location, Appearance
+from starlette.responses import FileResponse
+from ports.services.visualization_service import VisualizationService
+from configuration import ROOT_PATH
+
+DOCLAYNET_COLOR_BY_TYPE = {
+ "Caption": "#FFC300",
+ "Footnote": "#581845",
+ "Formula": "#FF5733",
+ "List item": "#008B8B",
+ "Page footer": "#FF5733",
+ "Page header": "#581845",
+ "Picture": "#C70039",
+ "Section header": "#C70039",
+ "Table": "#FF8C00",
+ "Text": "#606060",
+ "Title": "#EED400",
+}
+
+
+class VisualizationServiceAdapter(VisualizationService):
+ def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path:
+ pdf_outputs_path = join(ROOT_PATH, "pdf_outputs")
+ makedirs(pdf_outputs_path, exist_ok=True)
+ annotator = PdfAnnotator(str(pdf_path))
+ segment_index = 0
+ current_page = 1
+
+ for segment_box in segment_boxes:
+ if int(segment_box["page_number"]) != current_page:
+ segment_index = 0
+ current_page += 1
+ page_height = int(segment_box["page_height"])
+ self._add_prediction_annotation(annotator, segment_box, segment_index, page_height)
+ segment_index += 1
+
+ annotator.write(str(pdf_path))
+ return pdf_path
+
+ def get_visualization_response(self, pdf_path: Path) -> FileResponse:
+ return FileResponse(path=pdf_path, media_type="application/pdf", filename=pdf_path.name)
+
+ def _hex_color_to_rgb(self, color: str) -> tuple:
+ r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
+ alpha = 1
+ return r / 255, g / 255, b / 255, alpha
+
+ def _add_prediction_annotation(
+ self, annotator: PdfAnnotator, segment_box: dict, segment_index: int, page_height: int
+ ) -> None:
+ predicted_type = segment_box["type"]
+ color = DOCLAYNET_COLOR_BY_TYPE[predicted_type]
+ left, top, right, bottom = (
+ segment_box["left"],
+ page_height - segment_box["top"],
+ segment_box["left"] + segment_box["width"],
+ page_height - (segment_box["top"] + segment_box["height"]),
+ )
+ text_box_size = len(predicted_type) * 8 + 8
+
+ annotator.add_annotation(
+ "square",
+ Location(x1=left, y1=bottom, x2=right, y2=top, page=int(segment_box["page_number"]) - 1),
+ Appearance(stroke_color=self._hex_color_to_rgb(color)),
+ )
+
+ annotator.add_annotation(
+ "square",
+ Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1),
+ Appearance(fill=self._hex_color_to_rgb(color)),
+ )
+
+ content = predicted_type.capitalize() + f" [{str(segment_index+1)}]"
+ annotator.add_annotation(
+ "text",
+ Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1),
+ Appearance(content=content, font_size=8, fill=(1, 1, 1), stroke_width=3),
+ )
diff --git a/src/adapters/ml/__init__.py b/src/adapters/ml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/ml/fast_trainer/Paragraph.py b/src/adapters/ml/fast_trainer/Paragraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..29fc4e132c52bdd1ce08182c05d5c81db1c3d0b3
--- /dev/null
+++ b/src/adapters/ml/fast_trainer/Paragraph.py
@@ -0,0 +1,10 @@
+from pdf_features import PdfToken
+
+
+class Paragraph:
+ def __init__(self, tokens: list[PdfToken], pdf_name: str = ""):
+ self.tokens = tokens
+ self.pdf_name = pdf_name
+
+ def add_token(self, token: PdfToken):
+ self.tokens.append(token)
diff --git a/src/adapters/ml/fast_trainer/ParagraphExtractorTrainer.py b/src/adapters/ml/fast_trainer/ParagraphExtractorTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e22f6c860665b31ee507c2a414910786549d62de
--- /dev/null
+++ b/src/adapters/ml/fast_trainer/ParagraphExtractorTrainer.py
@@ -0,0 +1,61 @@
+from pathlib import Path
+
+from adapters.ml.fast_trainer.Paragraph import Paragraph
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfToken
+from pdf_token_type_labels import TokenType
+from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures
+from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
+
+
+class ParagraphExtractorTrainer(TokenTypeTrainer):
+ def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int):
+ token_row_features = list()
+ first_token_from_context = token_index - self.model_configuration.context_size
+ for i in range(self.model_configuration.context_size * 2):
+ first_token = page_tokens[first_token_from_context + i]
+ second_token = page_tokens[first_token_from_context + i + 1]
+ features = token_features.get_features(first_token, second_token, page_tokens)
+ features += self.get_paragraph_extraction_features(first_token, second_token)
+ token_row_features.extend(features)
+
+ return token_row_features
+
+ @staticmethod
+ def get_paragraph_extraction_features(first_token: PdfToken, second_token: PdfToken) -> list[int]:
+ one_hot_token_type_1 = [1 if token_type == first_token.token_type else 0 for token_type in TokenType]
+ one_hot_token_type_2 = [1 if token_type == second_token.token_type else 0 for token_type in TokenType]
+ return one_hot_token_type_1 + one_hot_token_type_2
+
+ def loop_token_next_token(self):
+ for pdf_features in self.pdfs_features:
+ for page in pdf_features.pages:
+ if not page.tokens:
+ continue
+ if len(page.tokens) == 1:
+ yield page, page.tokens[0], page.tokens[0]
+ for token, next_token in zip(page.tokens, page.tokens[1:]):
+ yield page, token, next_token
+
+ def get_pdf_segments(self, paragraph_extractor_model_path: str | Path) -> list[PdfSegment]:
+ paragraphs = self.get_paragraphs(paragraph_extractor_model_path)
+ pdf_segments = [PdfSegment.from_pdf_tokens(paragraph.tokens, paragraph.pdf_name) for paragraph in paragraphs]
+
+ return pdf_segments
+
+ def get_paragraphs(self, paragraph_extractor_model_path) -> list[Paragraph]:
+ self.predict(paragraph_extractor_model_path)
+ paragraphs: list[Paragraph] = []
+ last_page = None
+ for page, token, next_token in self.loop_token_next_token():
+ if last_page != page:
+ last_page = page
+ paragraphs.append(Paragraph([token], page.pdf_name))
+ if token == next_token:
+ continue
+ if token.prediction:
+ paragraphs[-1].add_token(next_token)
+ continue
+ paragraphs.append(Paragraph([next_token], page.pdf_name))
+
+ return paragraphs
diff --git a/src/adapters/ml/fast_trainer/__init__.py b/src/adapters/ml/fast_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/ml/fast_trainer/model_configuration.py b/src/adapters/ml/fast_trainer/model_configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c78a676355bdbe8d403888d5f5ce7daf7324c82
--- /dev/null
+++ b/src/adapters/ml/fast_trainer/model_configuration.py
@@ -0,0 +1,25 @@
+from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
+
+config_json = {
+ "boosting_type": "gbdt",
+ "verbose": -1,
+ "learning_rate": 0.1,
+ "num_class": 2,
+ "context_size": 1,
+ "num_boost_round": 400,
+ "num_leaves": 191,
+ "bagging_fraction": 0.9166599392739231,
+ "bagging_freq": 7,
+ "feature_fraction": 0.3116707710163228,
+ "lambda_l1": 0.0006901861637621734,
+ "lambda_l2": 1.1886914989632197e-05,
+ "min_data_in_leaf": 50,
+ "feature_pre_filter": True,
+ "seed": 22,
+ "deterministic": True,
+}
+
+MODEL_CONFIGURATION = ModelConfiguration(**config_json)
+
+if __name__ == "__main__":
+ print(MODEL_CONFIGURATION)
diff --git a/src/adapters/ml/fast_trainer_adapter.py b/src/adapters/ml/fast_trainer_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..617d31c987ed66e2a5a58ff1a54b94abea07fe12
--- /dev/null
+++ b/src/adapters/ml/fast_trainer_adapter.py
@@ -0,0 +1,31 @@
+from os.path import join
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+from ports.services.ml_model_service import MLModelService
+from adapters.ml.fast_trainer.ParagraphExtractorTrainer import ParagraphExtractorTrainer
+from adapters.ml.fast_trainer.model_configuration import MODEL_CONFIGURATION as PARAGRAPH_EXTRACTION_CONFIGURATION
+from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
+from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
+from configuration import ROOT_PATH, service_logger
+
+
+class FastTrainerAdapter(MLModelService):
+ def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ return self.predict_layout_fast(pdf_images)
+
+ def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ service_logger.info("Creating Paragraph Tokens [fast]")
+
+ pdf_images_obj = pdf_images[0]
+
+ token_type_trainer = TokenTypeTrainer([pdf_images_obj.pdf_features], ModelConfiguration())
+ token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model"))
+
+ trainer = ParagraphExtractorTrainer(
+ pdfs_features=[pdf_images_obj.pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION
+ )
+ segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model"))
+
+ pdf_images_obj.remove_images()
+
+ return segments
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/ModelConfiguration.py b/src/adapters/ml/pdf_tokens_type_trainer/ModelConfiguration.py
new file mode 100644
index 0000000000000000000000000000000000000000..b905cbb567433c01a1b384dabbf275e2b35c29dd
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/ModelConfiguration.py
@@ -0,0 +1,30 @@
+from dataclasses import dataclass, asdict
+
+from pdf_token_type_labels import TokenType
+
+
+@dataclass
+class ModelConfiguration:
+ context_size: int = 4
+ num_boost_round: int = 700
+ num_leaves: int = 127
+ bagging_fraction: float = 0.6810645192499981
+ lambda_l1: float = 1.1533558410486358e-08
+ lambda_l2: float = 4.91211684620458
+ feature_fraction: float = 0.7087268965467017
+ bagging_freq: int = 10
+ min_data_in_leaf: int = 47
+ feature_pre_filter: bool = False
+ boosting_type: str = "gbdt"
+ objective: str = "multiclass"
+ metric: str = "multi_logloss"
+ learning_rate: float = 0.1
+ seed: int = 22
+ num_class: int = len(TokenType)
+ verbose: int = -1
+ deterministic: bool = True
+ resume_training: bool = False
+ early_stopping_rounds: int = None
+
+ def dict(self):
+ return asdict(self)
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/PdfTrainer.py b/src/adapters/ml/pdf_tokens_type_trainer/PdfTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc14e980c94112c9fd5b2a6fd06941f0cce955f
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/PdfTrainer.py
@@ -0,0 +1,92 @@
+import os
+from os.path import exists, join
+from pathlib import Path
+
+import lightgbm as lgb
+import numpy as np
+
+from pdf_features import PdfFeatures, PdfTokenStyle
+from pdf_features import PdfFont
+from pdf_features import PdfToken
+from pdf_features import Rectangle
+from pdf_token_type_labels import TokenType
+from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
+from adapters.ml.pdf_tokens_type_trainer.download_models import pdf_tokens_type_model
+
+
+class PdfTrainer:
+ def __init__(self, pdfs_features: list[PdfFeatures], model_configuration: ModelConfiguration = None):
+ self.pdfs_features = pdfs_features
+ self.model_configuration = model_configuration if model_configuration else ModelConfiguration()
+
+ def get_model_input(self) -> np.ndarray:
+ pass
+
+ @staticmethod
+ def features_rows_to_x(features_rows):
+ if not features_rows:
+ return np.zeros((0, 0))
+
+ x = np.zeros(((len(features_rows)), len(features_rows[0])))
+ for i, v in enumerate(features_rows):
+ x[i] = v
+ return x
+
+ def train(self, model_path: str | Path, labels: list[int]):
+ print("Getting model input")
+ x_train = self.get_model_input()
+
+ if not x_train.any():
+ print("No data for training")
+ return
+
+ lgb_train = lgb.Dataset(x_train, labels)
+ lgb_eval = lgb.Dataset(x_train, labels, reference=lgb_train)
+ print("Training")
+
+ if self.model_configuration.resume_training and exists(model_path):
+ model = lgb.Booster(model_file=model_path)
+ gbm = model.refit(x_train, labels)
+ else:
+ gbm = lgb.train(params=self.model_configuration.dict(), train_set=lgb_train, valid_sets=[lgb_eval])
+
+ print("Saving")
+ gbm.save_model(model_path, num_iteration=gbm.best_iteration)
+
+ def loop_tokens(self):
+ for pdf_features in self.pdfs_features:
+ for page, token in pdf_features.loop_tokens():
+ yield token
+
+ @staticmethod
+ def get_padding_token(segment_number: int, page_number: int):
+ return PdfToken(
+ page_number=page_number,
+ id="pad_token",
+ content="",
+ font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black"),
+ reading_order_no=segment_number,
+ bounding_box=Rectangle.from_coordinates(0, 0, 0, 0),
+ token_type=TokenType.TEXT,
+ token_style=PdfTokenStyle(
+ font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black")
+ ),
+ )
+
+ def predict(self, model_path: str | Path = None):
+ model_path = model_path if model_path else pdf_tokens_type_model
+ x = self.get_model_input()
+
+ if not x.any():
+ return self.pdfs_features
+
+ lightgbm_model = lgb.Booster(model_file=model_path)
+ return lightgbm_model.predict(x)
+
+ def save_training_data(self, save_folder_path: str | Path, labels: list[int]):
+ os.makedirs(save_folder_path, exist_ok=True)
+
+ x = self.get_model_input()
+
+ np.save(join(str(save_folder_path), "x.npy"), x)
+ np.save(join(str(save_folder_path), "y.npy"), labels)
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/TokenFeatures.py b/src/adapters/ml/pdf_tokens_type_trainer/TokenFeatures.py
new file mode 100644
index 0000000000000000000000000000000000000000..1042e34ac7c4d443c232b53254bf7c49b423f1d0
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/TokenFeatures.py
@@ -0,0 +1,126 @@
+import string
+import unicodedata
+
+from pdf_features import PdfFeatures
+from pdf_features import PdfToken
+from adapters.ml.pdf_tokens_type_trainer.config import CHARACTER_TYPE
+
+
+class TokenFeatures:
+ def __init__(self, pdfs_features: PdfFeatures):
+ self.pdfs_features = pdfs_features
+
+ def get_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens: list[PdfToken]):
+ same_font = True if token_1.font.font_id == token_2.font.font_id else False
+
+ return (
+ [
+ same_font,
+ self.pdfs_features.pdf_modes.font_size_mode / 100,
+ len(token_1.content),
+ len(token_2.content),
+ token_1.content.count(" "),
+ token_2.content.count(" "),
+ sum(character in string.punctuation for character in token_1.content),
+ sum(character in string.punctuation for character in token_2.content),
+ ]
+ + self.get_position_features(token_1, token_2, page_tokens)
+ + self.get_unicode_categories(token_1)
+ + self.get_unicode_categories(token_2)
+ )
+
+ def get_position_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens):
+ left_1 = token_1.bounding_box.left
+ right_1 = token_1.bounding_box.right
+ height_1 = token_1.bounding_box.height
+ width_1 = token_1.bounding_box.width
+
+ left_2 = token_2.bounding_box.left
+ right_2 = token_2.bounding_box.right
+ height_2 = token_2.bounding_box.height
+ width_2 = token_2.bounding_box.width
+
+ right_gap_1, left_gap_2 = (
+ token_1.pdf_token_context.left_of_token_on_the_right - right_1,
+ left_2 - token_2.pdf_token_context.right_of_token_on_the_left,
+ )
+
+ absolute_right_1 = max(right_1, token_1.pdf_token_context.right_of_token_on_the_right)
+ absolute_right_2 = max(right_2, token_2.pdf_token_context.right_of_token_on_the_right)
+
+ absolute_left_1 = min(left_1, token_1.pdf_token_context.left_of_token_on_the_left)
+ absolute_left_2 = min(left_2, token_2.pdf_token_context.left_of_token_on_the_left)
+
+ right_distance, left_distance, height_difference = left_2 - left_1 - width_1, left_1 - left_2, height_1 - height_2
+
+ top_distance = token_2.bounding_box.top - token_1.bounding_box.top - height_1
+ top_distance_gaps = self.get_top_distance_gap(token_1, token_2, page_tokens)
+
+ start_lines_differences = absolute_left_1 - absolute_left_2
+ end_lines_difference = abs(absolute_right_1 - absolute_right_2)
+
+ return [
+ absolute_right_1,
+ token_1.bounding_box.top,
+ right_1,
+ width_1,
+ height_1,
+ token_2.bounding_box.top,
+ right_2,
+ width_2,
+ height_2,
+ right_distance,
+ left_distance,
+ right_gap_1,
+ left_gap_2,
+ height_difference,
+ top_distance,
+ top_distance - self.pdfs_features.pdf_modes.lines_space_mode,
+ top_distance_gaps,
+ top_distance - height_1,
+ end_lines_difference,
+ start_lines_differences,
+ self.pdfs_features.pdf_modes.lines_space_mode - top_distance_gaps,
+ self.pdfs_features.pdf_modes.right_space_mode - absolute_right_1,
+ ]
+
+ @staticmethod
+ def get_top_distance_gap(token_1: PdfToken, token_2: PdfToken, page_tokens):
+ top_distance = token_2.bounding_box.top - token_1.bounding_box.top - token_1.bounding_box.height
+ tokens_in_the_middle = [
+ token
+ for token in page_tokens
+ if token_1.bounding_box.bottom <= token.bounding_box.top < token_2.bounding_box.top
+ ]
+
+ gap_middle_bottom = 0
+ gap_middle_top = 0
+
+ if tokens_in_the_middle:
+ tokens_in_the_middle_top = min([token.bounding_box.top for token in tokens_in_the_middle])
+ tokens_in_the_middle_bottom = max([token.bounding_box.bottom for token in tokens_in_the_middle])
+ gap_middle_top = tokens_in_the_middle_top - token_1.bounding_box.top - token_1.bounding_box.height
+ gap_middle_bottom = token_2.bounding_box.top - tokens_in_the_middle_bottom
+
+ top_distance_gaps = top_distance - (gap_middle_bottom - gap_middle_top)
+ return top_distance_gaps
+
+ @staticmethod
+ def get_unicode_categories(token: PdfToken):
+ if token.id == "pad_token":
+ return [-1] * len(CHARACTER_TYPE) * 4
+
+ categories = [unicodedata.category(letter) for letter in token.content[:2] + token.content[-2:]]
+ categories += ["no_category"] * (4 - len(categories))
+
+ categories_one_hot_encoding = list()
+
+ for category in categories:
+ categories_one_hot_encoding.extend([0] * len(CHARACTER_TYPE))
+ if category not in CHARACTER_TYPE:
+ continue
+
+ category_index = len(categories_one_hot_encoding) - len(CHARACTER_TYPE) + CHARACTER_TYPE.index(category)
+ categories_one_hot_encoding[category_index] = 1
+
+ return categories_one_hot_encoding
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/TokenTypeTrainer.py b/src/adapters/ml/pdf_tokens_type_trainer/TokenTypeTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd2bb0ce73c5d83d1de2637c772ba6a241a57ef7
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/TokenTypeTrainer.py
@@ -0,0 +1,66 @@
+from pathlib import Path
+
+import numpy as np
+from tqdm import tqdm
+
+from pdf_features import PdfToken
+from pdf_token_type_labels import TokenType
+from adapters.ml.pdf_tokens_type_trainer.PdfTrainer import PdfTrainer
+from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures
+
+
+class TokenTypeTrainer(PdfTrainer):
+ def get_model_input(self) -> np.ndarray:
+ features_rows = []
+
+ contex_size = self.model_configuration.context_size
+ for token_features, page in self.loop_token_features():
+ page_tokens = [
+ self.get_padding_token(segment_number=i - 999999, page_number=page.page_number) for i in range(contex_size)
+ ]
+ page_tokens += page.tokens
+ page_tokens += [
+ self.get_padding_token(segment_number=999999 + i, page_number=page.page_number) for i in range(contex_size)
+ ]
+
+ tokens_indexes = range(contex_size, len(page_tokens) - contex_size)
+ page_features = [self.get_context_features(token_features, page_tokens, i) for i in tokens_indexes]
+ features_rows.extend(page_features)
+
+ return self.features_rows_to_x(features_rows)
+
+ def loop_token_features(self):
+ for pdf_features in tqdm(self.pdfs_features):
+ token_features = TokenFeatures(pdf_features)
+
+ for page in pdf_features.pages:
+ if not page.tokens:
+ continue
+
+ yield token_features, page
+
+ def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int):
+ token_row_features = []
+ first_token_from_context = token_index - self.model_configuration.context_size
+ for i in range(self.model_configuration.context_size * 2):
+ first_token = page_tokens[first_token_from_context + i]
+ second_token = page_tokens[first_token_from_context + i + 1]
+ token_row_features.extend(token_features.get_features(first_token, second_token, page_tokens))
+
+ return token_row_features
+
+ def predict(self, model_path: str | Path = None):
+ predictions = super().predict(model_path)
+ predictions_assigned = 0
+ for token_features, page in self.loop_token_features():
+ for token, prediction in zip(
+ page.tokens, predictions[predictions_assigned : predictions_assigned + len(page.tokens)]
+ ):
+ token.prediction = int(np.argmax(prediction))
+
+ predictions_assigned += len(page.tokens)
+
+ def set_token_types(self, model_path: str | Path = None):
+ self.predict(model_path)
+ for token in self.loop_tokens():
+ token.token_type = TokenType.from_index(token.prediction)
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/__init__.py b/src/adapters/ml/pdf_tokens_type_trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/config.py b/src/adapters/ml/pdf_tokens_type_trainer/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc2c303bf25e81783f96835f03d2d3a70de1d6d1
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/config.py
@@ -0,0 +1,48 @@
+from os.path import join
+from pathlib import Path
+
+ROOT_PATH = Path(__file__).parent.parent.parent.parent.parent.absolute()
+PDF_LABELED_DATA_ROOT_PATH = Path(join(ROOT_PATH.parent.absolute(), "pdf-labeled-data"))
+TOKEN_TYPE_LABEL_PATH = Path(join(PDF_LABELED_DATA_ROOT_PATH, "labeled_data", "token_type"))
+
+TRAINED_MODEL_PATH = join(ROOT_PATH, "model", "pdf_tokens_type.model")
+TOKEN_TYPE_RELATIVE_PATH = join("labeled_data", "token_type")
+MISTAKES_RELATIVE_PATH = join("labeled_data", "task_mistakes")
+
+XML_NAME = "etree.xml"
+LABELS_FILE_NAME = "labels.json"
+STATUS_FILE_NAME = "status.txt"
+
+CHARACTER_TYPE = [
+ "Lt",
+ "Lo",
+ "Sk",
+ "Lm",
+ "Sm",
+ "Cf",
+ "Nl",
+ "Pe",
+ "Po",
+ "Pd",
+ "Me",
+ "Sc",
+ "Ll",
+ "Pf",
+ "Mc",
+ "Lu",
+ "Zs",
+ "Cn",
+ "Cc",
+ "No",
+ "Co",
+ "Ps",
+ "Nd",
+ "Mn",
+ "Pi",
+ "So",
+ "Pc",
+]
+
+if __name__ == "__main__":
+ print(ROOT_PATH)
+ print(PDF_LABELED_DATA_ROOT_PATH)
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/download_models.py b/src/adapters/ml/pdf_tokens_type_trainer/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53d2069122e1294a5c4affa1722e8938256718d
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/download_models.py
@@ -0,0 +1,13 @@
+from huggingface_hub import hf_hub_download
+
+pdf_tokens_type_model = hf_hub_download(
+ repo_id="HURIDOCS/pdf-segmentation",
+ filename="pdf_tokens_type.model",
+ revision="c71f833500707201db9f3649a6d2010d3ce9d4c9",
+)
+
+token_type_finding_config_path = hf_hub_download(
+ repo_id="HURIDOCS/pdf-segmentation",
+ filename="tag_type_finding_model_config.txt",
+ revision="7d98776dd34acb2fe3a06495c82e64b9c84bdc16",
+)
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/get_paths.py b/src/adapters/ml/pdf_tokens_type_trainer/get_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae277b7f822bdf2b19bc0c007c3a4ffedef4a0e
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/get_paths.py
@@ -0,0 +1,6 @@
+from os.path import join
+from pathlib import Path
+
+
+def get_xml_path(pdf_labeled_data_project_path: str):
+ return Path(join(pdf_labeled_data_project_path, "pdfs"))
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py b/src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/ml/pdf_tokens_type_trainer/tests/test_trainer.py b/src/adapters/ml/pdf_tokens_type_trainer/tests/test_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..beae1e6adeccdf9eecc33d7bfde0c90e941a9d1a
--- /dev/null
+++ b/src/adapters/ml/pdf_tokens_type_trainer/tests/test_trainer.py
@@ -0,0 +1,34 @@
+from os.path import join, exists
+from unittest import TestCase
+
+from pdf_token_type_labels import TokenType
+from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
+
+from pdf_features import PdfFeatures
+
+from configuration import ROOT_PATH
+
+
+class TestTrainer(TestCase):
+ def test_train_blank_pdf(self):
+ pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf"))
+ model_path = join(ROOT_PATH, "model", "blank.model")
+ trainer = TokenTypeTrainer([pdf_features])
+ trainer.train(model_path, [])
+ self.assertFalse(exists(model_path))
+
+ def test_predict_blank_pdf(self):
+ pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf"))
+ trainer = TokenTypeTrainer([pdf_features])
+ trainer.set_token_types()
+ self.assertEqual([], pdf_features.pages[0].tokens)
+
+ def test_predict(self):
+ pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "test.pdf"))
+ trainer = TokenTypeTrainer([pdf_features])
+ trainer.set_token_types()
+ tokens = pdf_features.pages[0].tokens
+ self.assertEqual(TokenType.TITLE, tokens[0].token_type)
+ self.assertEqual("Document Big Centered Title", tokens[0].content)
+ self.assertEqual(TokenType.TEXT, tokens[1].token_type)
+ self.assertEqual("List Title", tokens[10].content)
diff --git a/src/adapters/ml/vgt/__init__.py b/src/adapters/ml/vgt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/ml/vgt/bros/__init__.py b/src/adapters/ml/vgt/bros/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1e1c74269e7e17e684c21611cc4cacb71a53ac0
--- /dev/null
+++ b/src/adapters/ml/vgt/bros/__init__.py
@@ -0,0 +1,70 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2021 NAVER CLOVA Team. All rights reserved.
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from transformers.file_utils import (
+ _LazyModule,
+ is_tokenizers_available,
+ is_torch_available,
+)
+
+_import_structure = {
+ "configuration_bros": ["BROS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BrosConfig"],
+ "tokenization_bros": ["BrosTokenizer"],
+}
+
+if is_tokenizers_available():
+ _import_structure["tokenization_bros_fast"] = ["BrosTokenizerFast"]
+
+if is_torch_available():
+ _import_structure["modeling_bros"] = [
+ "BROS_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "BrosForMaskedLM",
+ "BrosForPreTraining",
+ "BrosForSequenceClassification",
+ "BrosForTokenClassification",
+ "BrosModel",
+ "BrosLMHeadModel",
+ "BrosPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_bros import BROS_PRETRAINED_CONFIG_ARCHIVE_MAP, BrosConfig
+ from .tokenization_bros import BrosTokenizer
+
+ if is_tokenizers_available():
+ from .tokenization_bros_fast import BrosTokenizerFast
+
+ if is_torch_available():
+ from .modeling_bros import (
+ BROS_PRETRAINED_MODEL_ARCHIVE_LIST,
+ BrosForMaskedLM,
+ BrosForPreTraining,
+ BrosForSequenceClassification,
+ BrosForTokenClassification,
+ BrosLMHeadModel,
+ BrosModel,
+ BrosPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/adapters/ml/vgt/bros/configuration_bros.py b/src/adapters/ml/vgt/bros/configuration_bros.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8589faf91e53f02f2dc809730ff27f43394002
--- /dev/null
+++ b/src/adapters/ml/vgt/bros/configuration_bros.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+# Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" BROS model configuration """
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+BROS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/config.json",
+ "bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/config.json",
+}
+
+
+class BrosConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a
+ :class:`~transformers.TFBertModel`. It is used to instantiate a BERT model according to the specified arguments,
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
+ to that of the BERT `bert-base-uncased `__ architecture.
+
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
+
+
+ Args:
+ vocab_size (:obj:`int`, `optional`, defaults to 30522):
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
+ :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
+ :class:`~transformers.TFBertModel`.
+ hidden_size (:obj:`int`, `optional`, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (:obj:`int`, `optional`, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (:obj:`int`, `optional`, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
+ hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (:obj:`int`, `optional`, defaults to 2):
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
+ :class:`~transformers.TFBertModel`.
+ initializer_range (:obj:`float`, `optional`, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+ position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
+ `__. For more information on :obj:`"relative_key_query"`, please refer to
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
+ `__.
+ use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if ``config.is_decoder=True``.
+ classifier_dropout (:obj:`float`, `optional`):
+ The dropout ratio for the classification head.
+
+ Examples::
+
+ >>> from adapters.ml.vgt.bros import BrosModel, BrosConfig
+
+ >>> # Initializing a BROS naver-clova-ocr/bros-base-uncased style configuration
+ >>> configuration = BrosConfig()
+
+ >>> # Initializing a model from the naver-clova-ocr/bros-base-uncased style configuration
+ >>> model = BrosModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ """
+
+ model_type = "bros"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ bbox_scale=100.0,
+ pe_type="crel",
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ hidden_size=hidden_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ hidden_act=hidden_act,
+ hidden_dropout_prob=hidden_dropout_prob,
+ attention_probs_dropout_prob=attention_probs_dropout_prob,
+ max_position_embeddings=max_position_embeddings,
+ type_vocab_size=type_vocab_size,
+ initializer_range=initializer_range,
+ layer_norm_eps=layer_norm_eps,
+ pad_token_id=pad_token_id,
+ **kwargs,
+ )
+
+ self.bbox_scale = bbox_scale
+ self.pe_type = pe_type
diff --git a/src/adapters/ml/vgt/bros/modeling_bros.py b/src/adapters/ml/vgt/bros/modeling_bros.py
new file mode 100644
index 0000000000000000000000000000000000000000..57eac0c1dcd547d1b13cc911154b70506eb83b29
--- /dev/null
+++ b/src/adapters/ml/vgt/bros/modeling_bros.py
@@ -0,0 +1,1621 @@
+# coding=utf-8
+# Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BROS model. """
+
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+
+from .configuration_bros import BrosConfig
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "naver-clova-ocr/bros-base-uncased"
+_CONFIG_FOR_DOC = "BrosConfig"
+_TOKENIZER_FOR_DOC = "BrosTokenizer"
+
+BROS_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "bros-base-uncased",
+ "bros-large-uncased",
+]
+
+
+class PositionalEmbedding1D(nn.Module):
+ # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
+
+ def __init__(self, demb):
+ super(PositionalEmbedding1D, self).__init__()
+
+ self.demb = demb
+
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, pos_seq, bsz=None):
+ seq_size = pos_seq.size()
+
+ if len(seq_size) == 2:
+ b1, b2 = seq_size
+ sinusoid_inp = pos_seq.view(b1, b2, 1) * self.inv_freq.view(1, 1, self.demb // 2)
+ elif len(seq_size) == 3:
+ b1, b2, b3 = seq_size
+ sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.demb // 2)
+ else:
+ raise ValueError(f"Invalid seq_size={len(seq_size)}")
+
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
+
+ return pos_emb
+
+
+class PositionalEmbedding2D(nn.Module):
+ def __init__(self, demb, dim_bbox=8):
+ super(PositionalEmbedding2D, self).__init__()
+
+ self.demb = demb
+ self.dim_bbox = dim_bbox
+
+ self.x_pos_emb = PositionalEmbedding1D(demb // dim_bbox)
+ self.y_pos_emb = PositionalEmbedding1D(demb // dim_bbox)
+
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, bbox):
+ # bbox: [seq_length, batch_size, dim_bbox]
+ stack = []
+ for i in range(self.dim_bbox):
+ if i % 2 == 0:
+ stack.append(self.x_pos_emb(bbox[..., i]))
+ else:
+ stack.append(self.y_pos_emb(bbox[..., i]))
+ bbox_pos_emb = torch.cat(stack, dim=-1)
+ return bbox_pos_emb
+
+
+class BrosEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(
+ self.position_ids.size(),
+ dtype=torch.long,
+ device=self.position_ids.device,
+ ),
+ persistent=False,
+ )
+
+ if config.pe_type == "pdpdq_ws":
+ dim_bbox_sinusoid_emb = config.hidden_size
+ dim_bbox_projection = config.hidden_size
+ elif config.pe_type == "crel":
+ dim_bbox_sinusoid_emb = config.hidden_size // 4
+ dim_bbox_projection = config.hidden_size // config.num_attention_heads
+ else:
+ raise ValueError(f"Unknown config.pe_type={config.pe_type}")
+
+ self.bbox_sinusoid_emb = PositionalEmbedding2D(dim_bbox_sinusoid_emb, dim_bbox=8)
+ self.bbox_projection = nn.Linear(dim_bbox_sinusoid_emb, dim_bbox_projection, bias=False)
+
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ past_key_values_length=0,
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def calc_bbox_pos_emb(self, bbox, pe_type):
+ # bbox_t: [seq_length, batch_size, dim_bbox]
+ bbox_t = bbox.transpose(0, 1)
+
+ if pe_type == "pdpdq_ws":
+ bbox_pos = bbox_t
+ elif pe_type == "crel":
+ # bbox_pos: [seq_length, seq_length, batch_size, dim_bbox]
+ bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
+ else:
+ raise ValueError(f"Unknown pe_type={pe_type}")
+
+ bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
+ bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
+
+ return bbox_pos_emb
+
+
+class BrosSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ self.pe_type = config.pe_type
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ bbox_pos_emb=None,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ # bbox positional encoding
+ batch_size, n_head, seq_length, d_head = query_layer.shape
+ if self.pe_type == "pdpdq_ws":
+ head_q_pos = self.query(bbox_pos_emb)
+ head_k_pos = self.key(bbox_pos_emb)
+ head_q_pos = head_q_pos.view(seq_length, batch_size, n_head, d_head)
+ head_k_pos = head_k_pos.view(seq_length, batch_size, n_head, d_head)
+ head_q_pos = head_q_pos.permute([1, 2, 0, 3])
+ head_k_pos = head_k_pos.permute([1, 2, 0, 3])
+
+ bbox_pos_scores_1 = torch.einsum("bnid,bnjd->bnij", (torch.mul(query_layer, head_q_pos), head_k_pos))
+ bbox_pos_scores_2 = torch.einsum("bnid,bnjd->bnij", (head_q_pos, head_k_pos))
+ bbox_pos_scores = bbox_pos_scores_1 + bbox_pos_scores_2
+ elif self.pe_type == "crel":
+ bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head)
+ bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3])
+ bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb))
+ else:
+ raise ValueError(f"Unknown self.pe_type={self.pe_type}")
+
+ attention_scores = attention_scores + bbox_pos_scores
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BrosSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BrosAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = BrosSelfAttention(config)
+ self.output = BrosSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads,
+ self.self.num_attention_heads,
+ self.self.attention_head_size,
+ self.pruned_heads,
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ bbox_pos_emb=None,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ bbox_pos_emb=bbox_pos_emb,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BrosIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BrosOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BrosLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BrosAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
+ self.crossattention = BrosAttention(config)
+ self.intermediate = BrosIntermediate(config)
+ self.output = BrosOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ bbox_pos_emb=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ bbox_pos_emb=bbox_pos_emb,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ assert hasattr(
+ self, "crossattention"
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk,
+ self.chunk_size_feed_forward,
+ self.seq_len_dim,
+ attention_output,
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BrosEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ bbox_pos_emb=None,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
+ "`use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ bbox_pos_emb=bbox_pos_emb,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ bbox_pos_emb=bbox_pos_emb,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BrosPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BrosPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BrosLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BrosPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BrosOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BrosLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BrosOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class BrosPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BrosLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class BrosPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BrosConfig
+ base_model_prefix = "bros"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+class BrosForPreTrainingOutput(ModelOutput):
+ """
+ Output type of :class:`~transformers.BertForPreTraining`.
+
+ Args:
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: torch.FloatTensor = None
+ seq_relationship_logits: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+BROS_START_DOCSTRING = r"""
+
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
+ pruning heads etc.)
+
+ This model is also a PyTorch `torch.nn.Module `__
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+ general usage and behavior.
+
+ Parameters:
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+ weights.
+"""
+
+BROS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
+ details.
+
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+ 1]``:
+
+ - 0 corresponds to a `sentence A` token,
+ - 1 corresponds to a `sentence B` token.
+
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+ config.max_position_embeddings - 1]``.
+
+ `What are position IDs? <../glossary.html#position-ids>`_
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
+ vectors than the model's internal embedding lookup matrix.
+ output_attentions (:obj:`bool`, `optional`):
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+ tensors for more detail.
+ output_hidden_states (:obj:`bool`, `optional`):
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+ more detail.
+ return_dict (:obj:`bool`, `optional`):
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ BROS_START_DOCSTRING,
+)
+class BrosModel(BrosPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BrosEmbeddings(config)
+ self.encoder = BrosEncoder(config)
+
+ self.pooler = BrosPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ tokenizer_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ bbox=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ scaled_bbox = bbox * self.config.bbox_scale
+ bbox_pos_emb = self.embeddings.calc_bbox_pos_emb(scaled_bbox, self.config.pe_type)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ bbox_pos_emb=bbox_pos_emb,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ BROS_START_DOCSTRING,
+)
+class BrosForPreTraining(BrosPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bros = BrosModel(config)
+ self.cls = BrosPreTrainingHeads(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=BrosForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ next_sentence_label=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+
+ Example::
+
+ >>> from adapters.ml.vgt.bros import BrosTokenizer, BrosForPreTraining
+ >>> import torch
+
+ >>> tokenizer = BrosTokenizer.from_pretrained('naver-clova-ocr/bros-base-uncased')
+ >>> model = BrosForPreTraining.from_pretrained('naver-clova-ocr/bros-base-uncased')
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bros(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return BrosForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """,
+ BROS_START_DOCSTRING,
+)
+class BrosLMHeadModel(BrosPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `BrosLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.bros = BrosModel(config, add_pooling_layer=False)
+ self.cls = BrosOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+
+ Returns:
+
+ Example::
+
+ >>> from adapters.ml.vgt.bros import BrosTokenizer, BrosLMHeadModel, BrosConfig
+ >>> import torch
+
+ >>> tokenizer = BrosTokenizer.from_pretrained("naver-clova-ocr/bert-base-cased")
+ >>> config = BrosConfig.from_pretrained("naver-clova-ocr/bert-base-cased")
+ >>> config.is_decoder = True
+ >>> model = BrosLMHeadModel.from_pretrained("naver-clova-ocr/bert-base-cased", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bros(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
+ labels.view(-1),
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BROS_START_DOCSTRING)
+class BrosForMaskedLM(BrosPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.bros = BrosModel(config, add_pooling_layer=False)
+ self.cls = BrosOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ tokenizer_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bros(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
+ attention_mask = torch.cat(
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
+ dim=-1,
+ )
+ dummy_token = torch.full(
+ (effective_batch_size, 1),
+ self.config.pad_token_id,
+ dtype=torch.long,
+ device=input_ids.device,
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ BROS_START_DOCSTRING,
+)
+class BrosForSequenceClassification(BrosPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.bros = BrosModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ tokenizer_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bros(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ BROS_START_DOCSTRING,
+)
+class BrosForTokenClassification(BrosPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bros = BrosModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(BROS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ tokenizer_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+ 1]``.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bros(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss,
+ labels.view(-1),
+ torch.tensor(loss_fct.ignore_index).type_as(labels),
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/adapters/ml/vgt/bros/tokenization_bros.py b/src/adapters/ml/vgt/bros/tokenization_bros.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e5ab3563695c994135ec7c0f74f3620e0eb03a
--- /dev/null
+++ b/src/adapters/ml/vgt/bros/tokenization_bros.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for BROS."""
+
+
+import collections
+
+from transformers.models.bert.tokenization_bert import BertTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
+ "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
+ }
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "naver-clova-ocr/bros-base-uncased": 512,
+ "naver-clova-ocr/bros-large-uncased": 512,
+}
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
+ "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
+}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BrosTokenizer(BertTokenizer):
+ r"""
+ Construct a BERT tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
+ Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ File containing the vocabulary.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (:obj:`Iterable`, `optional`):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ :obj:`do_basic_tokenize=True`
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this `issue
+ `__).
+ strip_accents: (:obj:`bool`, `optional`):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for :obj:`lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
diff --git a/src/adapters/ml/vgt/bros/tokenization_bros_fast.py b/src/adapters/ml/vgt/bros/tokenization_bros_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..845f6794a5b2c0717bfb9f2ac2da67cc18bec1fa
--- /dev/null
+++ b/src/adapters/ml/vgt/bros/tokenization_bros_fast.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+# Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Tokenization classes for BROS."""
+
+from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
+from transformers.utils import logging
+
+from .tokenization_bros import BrosTokenizer
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
+ "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
+ },
+ "tokenizer_file": {
+ "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/tokenizer.json",
+ "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/tokenizer.json",
+ },
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "naver-clova-ocr/bros-base-uncased": 512,
+ "naver-clova-ocr/bros-large-uncased": 512,
+}
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
+ "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
+}
+
+
+class BrosTokenizerFast(BertTokenizerFast):
+ r"""
+ Construct a "fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
+ methods. Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ File containing the vocabulary.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this
+ issue `__).
+ strip_accents: (:obj:`bool`, `optional`):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for :obj:`lowercase` (as in the original BERT).
+ wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
+ slow_tokenizer_class = BrosTokenizer
diff --git a/src/adapters/ml/vgt/create_word_grid.py b/src/adapters/ml/vgt/create_word_grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea5f3188b5f85d26617b881f9f553313a164ca80
--- /dev/null
+++ b/src/adapters/ml/vgt/create_word_grid.py
@@ -0,0 +1,97 @@
+import pickle
+import shutil
+
+import numpy as np
+from os import makedirs
+from os.path import join, exists
+from pdf_features import PdfToken
+from pdf_features import Rectangle
+from pdf_features import PdfFeatures
+
+from adapters.ml.vgt.bros.tokenization_bros import BrosTokenizer
+from configuration import WORD_GRIDS_PATH
+
+tokenizer = BrosTokenizer.from_pretrained("naver-clova-ocr/bros-base-uncased")
+
+
+def rectangle_to_bbox(rectangle: Rectangle):
+ return [rectangle.left, rectangle.top, rectangle.width, rectangle.height]
+
+
+def get_words_positions(text: str, rectangle: Rectangle):
+ text = text.strip()
+ text_len = len(text)
+
+ width_per_letter = rectangle.width / text_len
+
+ words_bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, rectangle.left + 5, rectangle.bottom)]
+ words_bboxes[-1].width = 0
+ words_bboxes[-1].right = words_bboxes[-1].left
+
+ for letter in text:
+ if letter == " ":
+ left = words_bboxes[-1].right + width_per_letter
+ words_bboxes.append(Rectangle.from_coordinates(left, words_bboxes[-1].top, left + 5, words_bboxes[-1].bottom))
+ words_bboxes[-1].width = 0
+ words_bboxes[-1].right = words_bboxes[-1].left
+ else:
+ words_bboxes[-1].right = words_bboxes[-1].right + width_per_letter
+ words_bboxes[-1].width = words_bboxes[-1].width + width_per_letter
+
+ words = text.split()
+ return words, words_bboxes
+
+
+def get_subwords_positions(word: str, rectangle: Rectangle):
+ width_per_letter = rectangle.width / len(word)
+ word_tokens = [x.replace("#", "") for x in tokenizer.tokenize(word)]
+
+ if not word_tokens:
+ return [], []
+
+ ids = [x[-2] for x in tokenizer(word_tokens)["input_ids"]]
+
+ right = rectangle.left + len(word_tokens[0]) * width_per_letter
+ bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, right, rectangle.bottom)]
+
+ for subword in word_tokens[1:]:
+ right = bboxes[-1].right + len(subword) * width_per_letter
+ bboxes.append(Rectangle.from_coordinates(bboxes[-1].right, rectangle.top, right, rectangle.bottom))
+
+ return ids, bboxes
+
+
+def get_grid_words_dict(tokens: list[PdfToken]):
+ texts, bbox_texts_list, inputs_ids, bbox_subword_list = [], [], [], []
+ for token in tokens:
+ words, words_bboxes = get_words_positions(token.content, token.bounding_box)
+ texts += words
+ bbox_texts_list += [rectangle_to_bbox(r) for r in words_bboxes]
+ for word, word_box in zip(words, words_bboxes):
+ ids, subwords_bboxes = get_subwords_positions(word, word_box)
+ inputs_ids += ids
+ bbox_subword_list += [rectangle_to_bbox(r) for r in subwords_bboxes]
+
+ return {
+ "input_ids": np.array(inputs_ids),
+ "bbox_subword_list": np.array(bbox_subword_list),
+ "texts": texts,
+ "bbox_texts_list": np.array(bbox_texts_list),
+ }
+
+
+def create_word_grid(pdf_features_list: list[PdfFeatures]):
+ makedirs(WORD_GRIDS_PATH, exist_ok=True)
+
+ for pdf_features in pdf_features_list:
+ for page in pdf_features.pages:
+ image_id = f"{pdf_features.file_name}_{page.page_number - 1}"
+ if exists(join(WORD_GRIDS_PATH, image_id + ".pkl")):
+ continue
+ grid_words_dict = get_grid_words_dict(page.tokens)
+ with open(join(WORD_GRIDS_PATH, f"{image_id}.pkl"), mode="wb") as file:
+ pickle.dump(grid_words_dict, file)
+
+
+def remove_word_grids():
+ shutil.rmtree(WORD_GRIDS_PATH, ignore_errors=True)
diff --git a/src/adapters/ml/vgt/ditod/FeatureMerge.py b/src/adapters/ml/vgt/ditod/FeatureMerge.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5820046a752bc36c1b8bc82d5489b937b8bd38f
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/FeatureMerge.py
@@ -0,0 +1,129 @@
+import torch
+from torch import nn
+
+
+class FeatureMerge(nn.Module):
+ """Multimodal feature fusion used in VSR."""
+
+ def __init__(
+ self,
+ feature_names,
+ visual_dim,
+ semantic_dim,
+ merge_type="Sum",
+ dropout_ratio=0.1,
+ with_extra_fc=True,
+ shortcut=False,
+ ):
+ """Multimodal feature merge used in VSR.
+ Args:
+ visual_dim (list): the dim of visual features, e.g. [256]
+ semantic_dim (list): the dim of semantic features, e.g. [256]
+ merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted'
+ dropout_ratio (float): dropout ratio of fusion features
+ with_extra_fc (bool): whether add extra fc layers for adaptation
+ shortcut (bool): whether add shortcut connection
+ """
+ super().__init__()
+
+ # merge param
+ self.feature_names = feature_names
+ self.merge_type = merge_type
+ self.visual_dim = visual_dim
+ self.textual_dim = semantic_dim
+ self.with_extra_fc = with_extra_fc
+ self.shortcut = shortcut
+ self.relu = nn.ReLU(inplace=True)
+
+ if self.merge_type == "Sum":
+ assert len(self.visual_dim) == len(self.textual_dim)
+ elif self.merge_type == "Concat":
+ assert len(self.visual_dim) == len(self.textual_dim)
+ # self.concat_proj = nn.ModuleList()
+
+ self.vis_proj = nn.ModuleList()
+ self.text_proj = nn.ModuleList()
+ self.alpha_proj = nn.ModuleList()
+
+ for idx in range(len(self.visual_dim)):
+ # self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1))
+ if self.with_extra_fc:
+ self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
+ self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
+ self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
+
+ elif self.merge_type == "Weighted":
+ assert len(self.visual_dim) == len(self.textual_dim)
+ self.total_num = len(self.visual_dim)
+
+ # vis projection
+ self.vis_proj = nn.ModuleList()
+ self.vis_proj_relu = nn.ModuleList()
+
+ # text projection
+ self.text_proj = nn.ModuleList()
+ self.text_proj_relu = nn.ModuleList()
+
+ self.alpha_proj = nn.ModuleList()
+ for idx in range(self.total_num):
+ if self.with_extra_fc:
+ self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
+ self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
+ self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
+
+ else:
+ raise "Unknown merge type {}".format(self.merge_type)
+
+ self.dropout = nn.Dropout(dropout_ratio)
+
+ # visual context
+ # self.visual_ap = nn.AdaptiveAvgPool2d((1, 1))
+
+ def forward(self, visual_feat=None, textual_feat=None):
+ """Forward computation
+ Args:
+ visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B
+ textual_feat (Tensor): textual feature maps, in shape of B x L x C
+ Returns:
+ Tensor: fused feature maps, in shape of [B x L x C]
+ """
+ assert len(visual_feat) == len(textual_feat)
+
+ # feature merge
+ merged_feat = {}
+ if self.merge_type == "Sum":
+ for name in self.feature_names:
+ merged_feat[name] = visual_feat[name] + textual_feat[name]
+ elif self.merge_type == "Concat":
+ for idx, name in enumerate(self.feature_names):
+ # merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1))
+ per_vis = visual_feat[name].permute(0, 2, 3, 1)
+ per_text = textual_feat[name].permute(0, 2, 3, 1)
+ if self.with_extra_fc:
+ per_vis = self.relu(self.vis_proj[idx](per_vis))
+ per_text = self.relu(self.text_proj[idx](per_text))
+ x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))
+ x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
+ merged_feat[name] = x_sentence
+ else:
+ assert self.total_num == len(visual_feat) or self.total_num == 1
+ # for per_vis, per_text in zip(visual_feat, textual_feat):
+ for idx, name in enumerate(self.feature_names):
+ per_vis = visual_feat[name].permute(0, 2, 3, 1)
+ per_text = textual_feat[name].permute(0, 2, 3, 1)
+ if self.with_extra_fc:
+ per_vis = self.relu(self.vis_proj[idx](per_vis))
+ per_text = self.relu(self.text_proj[idx](per_text))
+
+ alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)))
+ if self.shortcut:
+ # shortcut
+ x_sentence = per_vis + alpha * per_text
+ else:
+ # selection
+ x_sentence = alpha * per_vis + (1 - alpha) * per_text
+
+ x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
+ merged_feat[name] = x_sentence
+
+ return merged_feat
diff --git a/src/adapters/ml/vgt/ditod/VGT.py b/src/adapters/ml/vgt/ditod/VGT.py
new file mode 100644
index 0000000000000000000000000000000000000000..50a40eddf721b78a5fba305b01be48d42bcbd9e3
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/VGT.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import logging
+import numpy as np
+from typing import Dict, List, Optional, Tuple
+import torch
+from torch import nn
+
+from detectron2.config import configurable
+from detectron2.data.detection_utils import convert_image_to_rgb
+from detectron2.structures import ImageList, Instances
+from detectron2.utils.events import get_event_storage
+from detectron2.utils.logger import log_first_n
+
+from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
+from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
+
+from .Wordnn_embedding import WordnnEmbedding
+
+__all__ = ["VGT"]
+
+
+def torch_memory(device, tag=""):
+ # Checks and prints GPU memory
+ print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED")
+ print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED")
+ print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX")
+ print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX")
+ print("")
+
+
+@META_ARCH_REGISTRY.register()
+class VGT(GeneralizedRCNN):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ vocab_size: int = 30552,
+ hidden_size: int = 768,
+ embedding_dim: int = 64,
+ bros_embedding_path: str = "",
+ use_pretrain_weight: bool = True,
+ use_UNK_text: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+ self.Wordgrid_embedding = WordnnEmbedding(
+ vocab_size, hidden_size, embedding_dim, bros_embedding_path, use_pretrain_weight, use_UNK_text
+ )
+
+ @classmethod
+ def from_config(cls, cfg):
+ ret = super().from_config(cfg)
+ ret.update(
+ {
+ "vocab_size": cfg.MODEL.WORDGRID.VOCAB_SIZE,
+ "hidden_size": cfg.MODEL.WORDGRID.HIDDEN_SIZE,
+ "embedding_dim": cfg.MODEL.WORDGRID.EMBEDDING_DIM,
+ "bros_embedding_path": cfg.MODEL.WORDGRID.MODEL_PATH,
+ "use_pretrain_weight": cfg.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT,
+ "use_UNK_text": cfg.MODEL.WORDGRID.USE_UNK_TEXT,
+ }
+ )
+ return ret
+
+ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+
+ * image: Tensor, image in (C, H, W) format.
+ * instances (optional): groundtruth :class:`Instances`
+ * proposals (optional): :class:`Instances`, precomputed proposals.
+
+ Other information that's included in the original dicts, such as:
+
+ * "height", "width" (int): the output resolution of the model, used in inference.
+ See :meth:`postprocess` for details.
+
+ Returns:
+ list[dict]:
+ Each dict is the output for one input image.
+ The dict contains one key "instances" whose value is a :class:`Instances`.
+ The :class:`Instances` object has the following keys:
+ "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
+ """
+ if not self.training:
+ return self.inference(batched_inputs)
+
+ images = self.preprocess_image(batched_inputs)
+
+ if "instances" in batched_inputs[0]:
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+ else:
+ gt_instances = None
+
+ chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs)
+ features = self.backbone(images.tensor, chargrid)
+
+ if self.proposal_generator is not None:
+ proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
+ else:
+ assert "proposals" in batched_inputs[0]
+ proposals = [x["proposals"].to(self.device) for x in batched_inputs]
+ proposal_losses = {}
+
+ _, detector_losses = self.roi_heads(images, features, proposals, gt_instances)
+ if self.vis_period > 0:
+ storage = get_event_storage()
+ if storage.iter % self.vis_period == 0:
+ self.visualize_training(batched_inputs, proposals)
+
+ losses = {}
+ losses.update(detector_losses)
+ losses.update(proposal_losses)
+
+ return losses
+
+ def inference(
+ self,
+ batched_inputs: List[Dict[str, torch.Tensor]],
+ detected_instances: Optional[List[Instances]] = None,
+ do_postprocess: bool = True,
+ ):
+ """
+ Run inference on the given inputs.
+
+ Args:
+ batched_inputs (list[dict]): same as in :meth:`forward`
+ detected_instances (None or list[Instances]): if not None, it
+ contains an `Instances` object per image. The `Instances`
+ object contains "pred_boxes" and "pred_classes" which are
+ known boxes in the image.
+ The inference will then skip the detection of bounding boxes,
+ and only predict other per-ROI outputs.
+ do_postprocess (bool): whether to apply post-processing on the outputs.
+
+ Returns:
+ When do_postprocess=True, same as in :meth:`forward`.
+ Otherwise, a list[Instances] containing raw network outputs.
+ """
+ assert not self.training
+
+ images = self.preprocess_image(batched_inputs)
+
+ chargrid = self.Wordgrid_embedding(images.tensor, batched_inputs)
+ features = self.backbone(images.tensor, chargrid)
+
+ if detected_instances is None:
+ if self.proposal_generator is not None:
+ proposals, _ = self.proposal_generator(images, features, None)
+ else:
+ assert "proposals" in batched_inputs[0]
+ proposals = [x["proposals"].to(self.device) for x in batched_inputs]
+
+ results, _ = self.roi_heads(images, features, proposals, None)
+ else:
+ detected_instances = [x.to(self.device) for x in detected_instances]
+ results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
+
+ if do_postprocess:
+ assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
+ return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
+ else:
+ return results
diff --git a/src/adapters/ml/vgt/ditod/VGTTrainer.py b/src/adapters/ml/vgt/ditod/VGTTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2544256c9e6fc053aa7d433d961419f4cb8c6de7
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/VGTTrainer.py
@@ -0,0 +1,804 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+This file contains components with some default boilerplate logic user may need
+in training / testing. They will not work for everyone, but many users may find them useful.
+
+The behavior of functions/classes in this file is subject to change,
+since they are meant to represent the "common default behavior" people need in their projects.
+"""
+
+import argparse
+import logging
+import os
+import sys
+import weakref
+from collections import OrderedDict
+from time import time
+from typing import Optional
+import torch
+from fvcore.nn.precise_bn import get_bn_modules
+from omegaconf import OmegaConf
+from torch.nn.parallel import DistributedDataParallel
+import numpy as np
+
+import detectron2.data.transforms as T
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import CfgNode, LazyConfig
+from detectron2.data import (
+ MetadataCatalog,
+ build_detection_test_loader,
+ build_detection_train_loader,
+)
+from detectron2.evaluation import (
+ DatasetEvaluator,
+ inference_on_dataset,
+ print_csv_format,
+ verify_results,
+)
+from detectron2.modeling import build_model
+from detectron2.solver import build_lr_scheduler, build_optimizer
+from detectron2.utils import comm
+from detectron2.utils.collect_env import collect_env_info
+from detectron2.utils.env import seed_all_rng
+from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
+from detectron2.utils.file_io import PathManager
+from detectron2.utils.logger import setup_logger
+
+from detectron2.engine import hooks
+from detectron2.engine.train_loop import AMPTrainer, SimpleTrainer, TrainerBase
+
+from configuration import service_logger
+from .VGTcheckpointer import MyDetectionCheckpointer
+from typing import Any, Dict, List, Set
+import itertools
+from detectron2.solver.build import maybe_add_gradient_clipping
+from .dataset_mapper import DetrDatasetMapper
+from detectron2.evaluation import COCOEvaluator
+
+import pickle
+from detectron2.data import detection_utils as utils
+from detectron2.structures import (
+ BitMasks,
+ Boxes,
+ BoxMode,
+ Instances,
+ Keypoints,
+ PolygonMasks,
+ RotatedBoxes,
+ polygons_to_bitmask,
+)
+
+__all__ = [
+ "create_ddp_model",
+ "default_argument_parser",
+ "default_setup",
+ "default_writers",
+ "DefaultPredictor",
+ "GridTextTrainer",
+]
+
+
+def torch_memory(device, tag=""):
+ # Checks and prints GPU memory
+ print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED")
+ print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED")
+ print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX")
+ print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX")
+ print("")
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """ # noqa
+ if comm.get_world_size() == 1:
+ return model
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+
+def default_argument_parser(epilog=None):
+ """
+ Create a parser with some common arguments used by detectron2 users.
+
+ Args:
+ epilog (str): epilog passed to ArgumentParser describing the usage.
+
+ Returns:
+ argparse.ArgumentParser:
+ """
+ parser = argparse.ArgumentParser(
+ epilog=epilog
+ or f"""
+Examples:
+
+Run on single machine:
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
+
+Change some config options:
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
+
+Run on multiple machines:
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
+""",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Whether to attempt to resume from the checkpoint directory. "
+ "See documentation of `MyTrainer.resume_or_load()` for what it means.",
+ )
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
+ parser.add_argument("--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)")
+
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
+ parser.add_argument(
+ "--dist-url",
+ default="tcp://127.0.0.1:{}".format(port),
+ help="initialization URL for pytorch distributed backend. See "
+ "https://pytorch.org/docs/stable/distributed.html for details.",
+ )
+ parser.add_argument(
+ "opts",
+ help="""
+Modify config options at the end of the command. For Yacs configs, use
+space-separated "PATH.KEY VALUE" pairs.
+For python-based LazyConfig, use "path.key=value".
+ """.strip(),
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def _try_get_key(cfg, *keys, default=None):
+ """
+ Try select keys from cfg until the first key that exists. Otherwise return default.
+ """
+ if isinstance(cfg, CfgNode):
+ cfg = OmegaConf.create(cfg.dump())
+ for k in keys:
+ none = object()
+ p = OmegaConf.select(cfg, k, default=none)
+ if p is not none:
+ return p
+ return default
+
+
+def _highlight(code, filename):
+ try:
+ import pygments
+ except ImportError:
+ return code
+
+ from pygments.lexers import Python3Lexer, YamlLexer
+ from pygments.formatters import Terminal256Formatter
+
+ lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
+ code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
+ return code
+
+
+def default_setup(cfg, args):
+ """
+ Perform some basic common setups at the beginning of a job, including:
+
+ 1. Set up the detectron2 logger
+ 2. Log basic information about environment, cmdline arguments, and config
+ 3. Backup the config to the output directory
+
+ Args:
+ cfg (CfgNode or omegaconf.DictConfig): the full config to be used
+ args (argparse.NameSpace): the command line arguments to be logged
+ """
+ output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
+ if comm.is_main_process() and output_dir:
+ PathManager.mkdirs(output_dir)
+
+ rank = comm.get_rank()
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
+ logger = setup_logger(output_dir, distributed_rank=rank)
+
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+ logger.info("Environment info:\n" + collect_env_info())
+
+ logger.info("Command line arguments: " + str(args))
+ if hasattr(args, "config_file") and args.config_file != "":
+ logger.info(
+ "Contents of args.config_file={}:\n{}".format(
+ args.config_file,
+ _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
+ )
+ )
+
+ if comm.is_main_process() and output_dir:
+ # Note: some of our scripts may expect the existence of
+ # config.yaml in output directory
+ path = os.path.join(output_dir, "config.yaml")
+ if isinstance(cfg, CfgNode):
+ logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
+ with PathManager.open(path, "w") as f:
+ f.write(cfg.dump())
+ else:
+ LazyConfig.save(cfg, path)
+ logger.info("Full config saved to {}".format(path))
+
+ # make sure each worker has a different, yet deterministic seed if specified
+ seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
+ seed_all_rng(None if seed < 0 else seed + rank)
+
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+ # typical validation set.
+ if not (hasattr(args, "eval_only") and args.eval_only):
+ torch.backends.cudnn.benchmark = _try_get_key(cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False)
+
+
+def default_writers(output_dir: str, max_iter: Optional[int] = None):
+ """
+ Build a list of :class:`EventWriter` to be used.
+ It now consists of a :class:`CommonMetricPrinter`,
+ :class:`TensorboardXWriter` and :class:`JSONWriter`.
+
+ Args:
+ output_dir: directory to store JSON metrics and tensorboard events
+ max_iter: the total number of iterations
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ PathManager.mkdirs(output_dir)
+ return [
+ # It may not always print what you want to see, since it prints "common" metrics only.
+ CommonMetricPrinter(max_iter),
+ JSONWriter(os.path.join(output_dir, "metrics.json")),
+ TensorboardXWriter(output_dir),
+ ]
+
+
+class DefaultPredictor:
+ """
+ Create a simple end-to-end predictor with the given config that runs on
+ single device for a single input image.
+
+ Compared to using the model directly, this class does the following additions:
+
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
+ 4. Take one input image and produce a single output, instead of a batch.
+
+ This is meant for simple demo purposes, so it does the above steps automatically.
+ This is not meant for benchmarks or running complicated inference logic.
+ If you'd like to do anything more complicated, please refer to its source code as
+ examples to build and use the model manually.
+
+ Attributes:
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
+ cfg.DATASETS.TEST.
+
+ Examples:
+ ::
+ pred = DefaultPredictor(cfg)
+ inputs = cv2.imread("input.jpg")
+ outputs = pred(inputs)
+ """
+
+ def __init__(self, cfg):
+ self.cfg = cfg.clone() # cfg can be modified by model
+ self.model = build_model(self.cfg)
+ self.model.eval()
+ if len(cfg.DATASETS.TEST):
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
+
+ checkpointer = DetectionCheckpointer(self.model)
+ checkpointer.load(cfg.MODEL.WEIGHTS)
+
+ self.aug = T.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
+
+ self.input_format = cfg.INPUT.FORMAT
+ assert self.input_format in ["RGB", "BGR"], self.input_format
+
+ def __call__(self, original_image, grid_path):
+ """
+ Args:
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+
+ Returns:
+ predictions (dict):
+ the output of the model for one image only.
+ See :doc:`/tutorials/models` for details about the format.
+ """
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
+ # Apply pre-processing to image.
+
+ # if self.input_format == "RGB":
+ # # whether the model expects BGR inputs or RGB
+ # import ipdb;ipdb.set_trace()
+ # original_image = original_image[:, :, ::-1]
+
+ height, width = original_image.shape[:2]
+ image, transforms = T.apply_transform_gens([self.aug], original_image)
+
+ # add grid
+ image_shape = image.shape[:2] # h, w
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
+
+ with open(grid_path, "rb") as f:
+ sample_inputs = pickle.load(f)
+ input_ids = sample_inputs["input_ids"]
+ bbox_subword_list = sample_inputs["bbox_subword_list"]
+
+ # word bbox
+ bbox = []
+ for bbox_per_subword in bbox_subword_list:
+ text_word = {}
+ text_word["bbox"] = bbox_per_subword.tolist()
+ text_word["bbox_mode"] = BoxMode.XYWH_ABS
+ utils.transform_instance_annotations(text_word, transforms, image_shape)
+ bbox.append(text_word["bbox"])
+
+ dataset_dict = {}
+ dataset_dict["input_ids"] = input_ids
+ dataset_dict["bbox"] = bbox
+ dataset_dict["image"] = image
+ dataset_dict["height"] = height
+ dataset_dict["width"] = width
+
+ predictions = self.model([dataset_dict])[0]
+ return predictions
+
+
+class VGTTrainer(TrainerBase):
+ """
+ A trainer with default training logic. It does the following:
+
+ 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
+ defined by the given config. Create a LR scheduler defined by the config.
+ 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
+ `resume_or_load` is called.
+ 3. Register a few common hooks defined by the config.
+
+ It is created to simplify the **standard model training workflow** and reduce code boilerplate
+ for users who only need the standard training workflow, with standard features.
+ It means this class makes *many assumptions* about your training logic that
+ may easily become invalid in a new research. In fact, any assumptions beyond those made in the
+ :class:`SimpleTrainer` are too much for research.
+
+ The code of this class has been annotated about restrictive assumptions it makes.
+ When they do not work for you, you're encouraged to:
+
+ 1. Overwrite methods of this class, OR:
+ 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
+ nothing else. You can then add your own hooks if needed. OR:
+ 3. Write your own training loop similar to `tools/plain_train_net.py`.
+
+ See the :doc:`/tutorials/training` tutorials for more details.
+
+ Note that the behavior of this class, like other functions/classes in
+ this file, is not stable, since it is meant to represent the "common default behavior".
+ It is only guaranteed to work well with the standard models and training workflow in detectron2.
+ To obtain more stable behavior, write your own training logic with other public APIs.
+
+ Examples:
+ ::
+ trainer = MyTrainer(cfg)
+ trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
+ trainer.train()
+
+ Attributes:
+ scheduler:
+ checkpointer (DetectionCheckpointer):
+ cfg (CfgNode):
+ """
+
+ def __init__(self, cfg):
+ """
+ Args:
+ cfg (CfgNode):
+ """
+ super().__init__()
+ logger = logging.getLogger("detectron2")
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
+ setup_logger()
+ cfg = VGTTrainer.auto_scale_workers(cfg, comm.get_world_size())
+
+ self.cfg = cfg
+
+ # Assume these objects must be constructed in this order.
+ model = self.build_model(cfg)
+ optimizer = self.build_optimizer(cfg, model)
+ data_loader = self.build_train_loader(cfg)
+
+ model = create_ddp_model(model, broadcast_buffers=False)
+ self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, data_loader, optimizer)
+
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+ self.checkpointer = MyDetectionCheckpointer(
+ # Assume you want to save checkpoints together with logs/statistics
+ model,
+ cfg.OUTPUT_DIR,
+ trainer=weakref.proxy(self),
+ )
+ self.start_iter = 0
+ self.max_iter = cfg.SOLVER.MAX_ITER
+ self.cfg = cfg
+
+ self.register_hooks(self.build_hooks())
+
+ def resume_or_load(self, resume=True):
+ """
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
+ available states (eg. optimizer and scheduler) and update iteration counter
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
+
+ Otherwise, this is considered as an independent training. The method will load model
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
+ from iteration 0.
+
+ Args:
+ resume (bool): whether to do resume or not
+ """
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
+ if resume and self.checkpointer.has_checkpoint():
+ # The checkpoint stores the training iteration that just finished, thus we start
+ # at the next iteration
+ self.start_iter = self.iter + 1
+
+ def build_hooks(self):
+ """
+ Build a list of default hooks, including timing, evaluation,
+ checkpointing, lr scheduling, precise BN, writing events.
+
+ Returns:
+ list[HookBase]:
+ """
+ cfg = self.cfg.clone()
+ cfg.defrost()
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
+
+ ret = [
+ hooks.IterationTimer(),
+ hooks.LRScheduler(),
+ (
+ hooks.PreciseBN(
+ # Run at the same freq as (but before) evaluation.
+ cfg.TEST.EVAL_PERIOD,
+ self.model,
+ # Build a new data loader to not affect training
+ self.build_train_loader(cfg),
+ cfg.TEST.PRECISE_BN.NUM_ITER,
+ )
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
+ else None
+ ),
+ ]
+
+ # Do PreciseBN before checkpointer, because it updates the model and need to
+ # be saved by checkpointer.
+ # This is not always the best: if checkpointing has a different frequency,
+ # some checkpoints may have more precise statistics than others.
+ if comm.is_main_process():
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
+
+ def test_and_save_results():
+ self._last_eval_results = self.test(self.cfg, self.model)
+ return self._last_eval_results
+
+ # Do evaluation after checkpointer, because then if it fails,
+ # we can use the saved checkpoint to debug.
+ ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
+
+ if comm.is_main_process():
+ # Here the default print/log frequency of each writer is used.
+ # run writers in the end, so that evaluation metrics are written
+ ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
+ return ret
+
+ def build_writers(self):
+ """
+ Build a list of writers to be used using :func:`default_writers()`.
+ If you'd like a different list of writers, you can overwrite it in
+ your trainer.
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
+
+ def train(self):
+ """
+ Run training.
+
+ Returns:
+ OrderedDict of results, if evaluation is enabled. Otherwise None.
+ """
+ super().train(self.start_iter, self.max_iter)
+ if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
+ assert hasattr(self, "_last_eval_results"), "No evaluation results obtained during training!"
+ verify_results(self.cfg, self._last_eval_results)
+ return self._last_eval_results
+
+ def run_step(self):
+ try:
+ self._trainer.iter = self.iter
+ self._trainer.run_step()
+ except RuntimeError as exception:
+ if "out of memory" in str(exception):
+ logger = logging.getLogger("detectron2")
+ logger.warn("Out of memory")
+ # import ipdb;ipdb.set_trace()
+ if hasattr(torch.cuda, "empty_cache"):
+ torch.cuda.empty_cache()
+ else:
+ raise exception
+
+ @classmethod
+ def build_model(cls, cfg):
+ """
+ Returns:
+ torch.nn.Module:
+
+ It now calls :func:`detectron2.modeling.build_model`.
+ Overwrite it if you'd like a different model.
+ """
+ model = build_model(cfg)
+
+ def compute_para(model):
+ params_num = []
+ filtered_parameters = []
+ for p in filter(lambda p: p.requires_grad, model.parameters()):
+ filtered_parameters.append(p)
+ params_num.append(np.prod(p.size()))
+ total_params = int(sum(params_num))
+ total_params = f"Trainable network params num : {total_params:,}"
+ # print(total_params)
+ return total_params
+
+ logger = logging.getLogger("detectron2")
+ logger.info("Model: {}".format(compute_para(model)))
+ return model
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+ for key, value in model.named_parameters(recurse=True):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+ lr = cfg.SOLVER.BASE_LR
+ weight_decay = cfg.SOLVER.WEIGHT_DECAY
+ if "backbone" in key:
+ lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER
+ params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
+
+ def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
+ # detectron2 doesn't have full model gradient clipping now
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
+ enable = (
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
+ and clip_norm_val > 0.0
+ )
+
+ class FullModelGradientClippingOptimizer(optim):
+ def step(self, closure=None):
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
+ super().step(closure=closure)
+
+ return FullModelGradientClippingOptimizer if enable else optim
+
+ optimizer_type = cfg.SOLVER.OPTIMIZER
+ if optimizer_type == "SGD":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
+ )
+ elif optimizer_type == "ADAMW":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(params, cfg.SOLVER.BASE_LR)
+ else:
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
+ return optimizer
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer):
+ """
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
+ Overwrite it if you'd like a different scheduler.
+ """
+ return build_lr_scheduler(cfg, optimizer)
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ if cfg.AUG.DETR:
+ mapper = DetrDatasetMapper(cfg, is_train=True)
+ else:
+ mapper = None
+ return build_detection_train_loader(cfg, mapper=mapper)
+
+ @classmethod
+ def build_test_loader(cls, cfg, dataset_name):
+ """
+ Returns:
+ iterable
+
+ It now calls :func:`detectron2.data.build_detection_test_loader`.
+ Overwrite it if you'd like a different data loader.
+ """
+ mapper = DetrDatasetMapper(cfg, is_train=False)
+ return build_detection_test_loader(cfg, dataset_name, mapper=mapper)
+
+ # return build_detection_test_loader(cfg, dataset_name)
+
+ @classmethod
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
+ if output_folder is None:
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
+ return COCOEvaluator(dataset_name, output_dir=output_folder)
+
+ # if 'icdar' not in dataset_name:
+ # return COCOEvaluator(dataset_name, output_dir=output_folder)
+ # else:
+ # return ICDAREvaluator(dataset_name, output_dir=output_folder)
+
+ @classmethod
+ def test(cls, cfg, model, evaluators=None):
+ """
+ Evaluate the given model. The given model is expected to already contain
+ weights to evaluate.
+
+ Args:
+ cfg (CfgNode):
+ model (nn.Module):
+ evaluators (list[DatasetEvaluator] or None): if None, will call
+ :meth:`build_evaluator`. Otherwise, must have the same length as
+ ``cfg.DATASETS.TEST``.
+
+ Returns:
+ dict: a dict of result metrics
+ """
+ logger = logging.getLogger(__name__)
+ if isinstance(evaluators, DatasetEvaluator):
+ evaluators = [evaluators]
+ if evaluators is not None:
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(len(cfg.DATASETS.TEST), len(evaluators))
+
+ results = OrderedDict()
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
+ data_loader = cls.build_test_loader(cfg, dataset_name)
+ # When evaluators are passed in as arguments,
+ # implicitly assume that evaluators can be created before data_loader.
+ if evaluators is not None:
+ evaluator = evaluators[idx]
+ else:
+ try:
+ evaluator = cls.build_evaluator(cfg, dataset_name)
+ except NotImplementedError:
+ logger.warn(
+ "No evaluator found. Use `MyTrainer.test(evaluators=)`, "
+ "or implement its `build_evaluator` method."
+ )
+ results[dataset_name] = {}
+ continue
+ results_i = inference_on_dataset(model, data_loader, evaluator)
+ results[dataset_name] = results_i
+ if comm.is_main_process():
+ assert isinstance(
+ results_i, dict
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(results_i)
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
+ print_csv_format(results_i)
+
+ if len(results) == 1:
+ results = list(results.values())[0]
+ return results
+
+ @staticmethod
+ def auto_scale_workers(cfg, num_workers: int):
+ """
+ When the config is defined for certain number of workers (according to
+ ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
+ workers currently in use, returns a new cfg where the total batch size
+ is scaled so that the per-GPU batch size stays the same as the
+ original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
+
+ Other config options are also scaled accordingly:
+ * training steps and warmup steps are scaled inverse proportionally.
+ * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
+
+ For example, with the original config like the following:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.1
+ REFERENCE_WORLD_SIZE: 8
+ MAX_ITER: 5000
+ STEPS: (4000,)
+ CHECKPOINT_PERIOD: 1000
+
+ When this config is used on 16 GPUs instead of the reference number 8,
+ calling this method will return a new config with:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 32
+ BASE_LR: 0.2
+ REFERENCE_WORLD_SIZE: 16
+ MAX_ITER: 2500
+ STEPS: (2000,)
+ CHECKPOINT_PERIOD: 500
+
+ Note that both the original config and this new config can be trained on 16 GPUs.
+ It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
+
+ Returns:
+ CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
+ """
+ old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
+ if old_world_size == 0 or old_world_size == num_workers:
+ return cfg
+ cfg = cfg.clone()
+ frozen = cfg.is_frozen()
+ cfg.defrost()
+
+ assert cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0, "Invalid REFERENCE_WORLD_SIZE in config!"
+ scale = num_workers / old_world_size
+ bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
+ lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
+ max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
+ warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
+ cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
+ cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
+ cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
+ cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
+ logger = logging.getLogger(__name__)
+ logger.info(
+ f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " f"max_iter={max_iter}, warmup={warmup_iter}."
+ )
+
+ if frozen:
+ cfg.freeze()
+ return cfg
+
+
+# Access basic attributes from the underlying trainer
+for _attr in ["model", "data_loader", "optimizer"]:
+ setattr(
+ VGTTrainer,
+ _attr,
+ property(
+ # getter
+ lambda self, x=_attr: getattr(self._trainer, x),
+ # setter
+ lambda self, value, x=_attr: setattr(self._trainer, x, value),
+ ),
+ )
diff --git a/src/adapters/ml/vgt/ditod/VGTbackbone.py b/src/adapters/ml/vgt/ditod/VGTbackbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5467b5df08e3a4f936e0a6634a5740ecb0d4b0
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/VGTbackbone.py
@@ -0,0 +1,219 @@
+# --------------------------------------------------------------------------------
+# VIT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+# References:
+# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# CoaT: https://github.com/mlpc-ucsd/CoaT
+# --------------------------------------------------------------------------------
+
+
+import torch
+import torch.nn.functional as F
+import logging
+
+from detectron2.layers import (
+ ShapeSpec,
+)
+from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
+from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
+
+from .VGTbeit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16, VGT_dit_base_patch16
+from .FeatureMerge import FeatureMerge
+
+__all__ = [
+ "build_VGT_fpn_backbone",
+]
+
+
+class PTM_VIT_Backbone(Backbone):
+ """
+ Implement VIT backbone.
+ """
+
+ def __init__(self, name, out_features, drop_path, img_size, pos_type, merge_type, model_kwargs):
+ super().__init__()
+ self._out_features = out_features
+ if "base" in name:
+ self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
+ else:
+ self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
+
+ if name == "beit_base_patch16":
+ model_func = beit_base_patch16
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ elif name == "dit_base_patch16":
+ model_func = dit_base_patch16
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ elif name == "deit_base_patch16":
+ model_func = deit_base_patch16
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ elif name == "VGT_dit_base_patch16":
+ model_func = VGT_dit_base_patch16
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ elif name == "mae_base_patch16":
+ model_func = mae_base_patch16
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
+ elif name == "dit_large_patch16":
+ model_func = dit_large_patch16
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+ elif name == "beit_large_patch16":
+ model_func = beit_large_patch16
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
+ else:
+ raise ValueError("Unsupported VIT name yet.")
+
+ if "beit" in name or "dit" in name:
+ if pos_type == "abs":
+ self.backbone = model_func(
+ img_size=img_size,
+ out_features=out_features,
+ drop_path_rate=drop_path,
+ use_abs_pos_emb=True,
+ **model_kwargs,
+ )
+ elif pos_type == "shared_rel":
+ self.backbone = model_func(
+ img_size=img_size,
+ out_features=out_features,
+ drop_path_rate=drop_path,
+ use_shared_rel_pos_bias=True,
+ **model_kwargs,
+ )
+ elif pos_type == "rel":
+ self.backbone = model_func(
+ img_size=img_size,
+ out_features=out_features,
+ drop_path_rate=drop_path,
+ use_rel_pos_bias=True,
+ **model_kwargs,
+ )
+ else:
+ raise ValueError()
+ else:
+ self.backbone = model_func(
+ img_size=img_size, out_features=out_features, drop_path_rate=drop_path, **model_kwargs
+ )
+
+ logger = logging.getLogger("detectron2")
+ logger.info("Merge using: {}".format(merge_type))
+ self.FeatureMerge = FeatureMerge(
+ feature_names=self._out_features,
+ visual_dim=[768, 768, 768, 768],
+ semantic_dim=[768, 768, 768, 768],
+ merge_type=merge_type,
+ )
+
+ def forward(self, x, grid):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+
+ vis_feat_out, grid_feat_out = self.backbone.forward_features(x, grid)
+ return self.FeatureMerge.forward(vis_feat_out, grid_feat_out)
+ # return self.backbone.forward_features(x)
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(channels=self._out_feature_channels[name], stride=self._out_feature_strides[name])
+ for name in self._out_features
+ }
+
+
+class GridFPN(FPN):
+ def forward(self, x, grid):
+ """
+ Args:
+ input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
+ feature map tensor for each feature level in high to low resolution order.
+ Returns:
+ dict[str->Tensor]:
+ mapping from feature map name to FPN feature map tensor
+ in high to low resolution order. Returned feature names follow the FPN
+ paper convention: "p", where stage has stride = 2 ** stage e.g.,
+ ["p2", "p3", ..., "p6"].
+ """
+ bottom_up_features = self.bottom_up(x, grid)
+ results = []
+ prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]])
+ results.append(self.output_convs[0](prev_features))
+
+ # Reverse feature maps into top-down order (from low to high resolution)
+ for idx, (lateral_conv, output_conv) in enumerate(zip(self.lateral_convs, self.output_convs)):
+ # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
+ # Therefore we loop over all modules but skip the first one
+ if idx > 0:
+ features = self.in_features[-idx - 1]
+ features = bottom_up_features[features]
+ top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
+ lateral_features = lateral_conv(features)
+ prev_features = lateral_features + top_down_features
+ if self._fuse_type == "avg":
+ prev_features /= 2
+ results.insert(0, output_conv(prev_features))
+
+ if self.top_block is not None:
+ if self.top_block.in_feature in bottom_up_features:
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
+ else:
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
+ results.extend(self.top_block(top_block_in_feature))
+ assert len(self._out_features) == len(results)
+ return {f: res for f, res in zip(self._out_features, results)}
+
+
+def build_PTM_VIT_Backbone(cfg):
+ """
+ Create a VIT instance from config.
+
+ Args:
+ cfg: a detectron2 CfgNode
+
+ Returns:
+ A VIT backbone instance.
+ """
+ # fmt: off
+ name = cfg.MODEL.VIT.NAME
+ out_features = cfg.MODEL.VIT.OUT_FEATURES
+ drop_path = cfg.MODEL.VIT.DROP_PATH
+ img_size = cfg.MODEL.VIT.IMG_SIZE
+ pos_type = cfg.MODEL.VIT.POS_TYPE
+ merge_type = cfg.MODEL.VIT.MERGE_TYPE
+
+ model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
+
+ return PTM_VIT_Backbone(name, out_features, drop_path, img_size, pos_type, merge_type, model_kwargs)
+
+
+@BACKBONE_REGISTRY.register()
+def build_VGT_fpn_backbone(cfg, input_shape: ShapeSpec):
+ """
+ Create a VIT w/ FPN backbone.
+
+ Args:
+ cfg: a detectron2 CfgNode
+
+ Returns:
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
+ """
+ bottom_up = build_PTM_VIT_Backbone(cfg)
+ in_features = cfg.MODEL.FPN.IN_FEATURES
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
+ backbone = GridFPN(
+ bottom_up=bottom_up,
+ in_features=in_features,
+ out_channels=out_channels,
+ norm=cfg.MODEL.FPN.NORM,
+ top_block=LastLevelMaxPool(),
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
+ )
+ return backbone
diff --git a/src/adapters/ml/vgt/ditod/VGTbeit.py b/src/adapters/ml/vgt/ditod/VGTbeit.py
new file mode 100644
index 0000000000000000000000000000000000000000..599afa469440396169f45540f5b1f3a606f4a91c
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/VGTbeit.py
@@ -0,0 +1,1097 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+Status/TODO:
+* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
+* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
+* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
+* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import warnings
+import math
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+
+def _cfg(url="", **kwargs):
+ return {
+ "url": url,
+ "num_classes": 1000,
+ "input_size": (3, 224, 224),
+ "pool_size": None,
+ "crop_pct": 0.9,
+ "interpolation": "bicubic",
+ "mean": (0.5, 0.5, 0.5),
+ "std": (0.5, 0.5, 0.5),
+ **kwargs,
+ }
+
+
+def torch_memory(device, tag=""):
+ # Checks and prints GPU memory
+ print(tag, f"{torch.cuda.memory_allocated(device)/1024/1024:.2f} MB USED")
+ print(tag, f"{torch.cuda.memory_reserved(device)/1024/1024:.2f} MB RESERVED")
+ print(tag, f"{torch.cuda.max_memory_allocated(device)/1024/1024:.2f} MB USED MAX")
+ print(tag, f"{torch.cuda.max_memory_reserved(device)/1024/1024:.2f} MB RESERVED MAX")
+ print("")
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ # x = self.drop(x)
+ # commit this for the orignal BERT implement
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, all_head_dim, bias=False)
+ self.kv = nn.Linear(dim, all_head_dim * 2, bias=False)
+
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, y):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ kv_bias = torch.cat((torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ kv = F.linear(input=y, weight=self.kv.weight, bias=kv_bias)
+ kv = kv.reshape(B, N, 2, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple)
+
+ q = F.linear(input=x, weight=self.q.weight, bias=self.q_bias)
+ q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4)[0]
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class CrossBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.norm_vis = norm_layer(dim)
+ self.norm_grid = norm_layer(dim)
+ self.vis_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ self.grid_attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2_vis = norm_layer(dim)
+ self.norm2_grid = norm_layer(dim)
+
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.vis_mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.grid_mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.self_block = CrossSelfBlock(
+ dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ init_values=init_values,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+
+ if init_values is not None:
+ self.gamma_vis = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_grid = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_vis, self.gamma_grid, self.gamma_1, self.gamma_2 = None, None, None, None
+
+ def cross_att(self, vis_input, grid_input):
+ # Cross Attention
+ if self.gamma_vis is None:
+ vis_att_output = vis_input + self.drop_path(self.vis_attn(self.norm_vis(vis_input), self.norm_grid(grid_input)))
+ grid_att_output = grid_input + self.drop_path(
+ self.grid_attn(self.norm_grid(grid_input), self.norm_vis(vis_input))
+ )
+ else:
+ vis_att_output = vis_input + self.drop_path(
+ self.gamma_vis * self.vis_attn(self.norm_vis(vis_input), self.norm_grid(grid_input))
+ )
+ grid_att_output = grid_input + self.drop_path(
+ self.gamma_grid * self.grid_attn(self.norm_grid(grid_input), self.norm_vis(vis_input))
+ )
+ return vis_att_output, grid_att_output
+
+ def forward(self, vis_input, grid_input):
+ vis_att_output, grid_att_output = self.cross_att(vis_input, grid_input)
+ vis_output, grid_output = self.self_block(vis_att_output, grid_att_output)
+
+ if self.gamma_1 is None:
+ vis_output = vis_output + self.drop_path(self.vis_mlp(self.norm2_vis(vis_output)))
+ grid_output = grid_output + self.drop_path(self.grid_mlp(self.norm2_grid(grid_output)))
+ else:
+ vis_output = vis_output + self.drop_path(self.gamma_1 * self.vis_mlp(self.norm2_vis(vis_output)))
+ grid_output = grid_output + self.drop_path(self.gamma_2 * self.grid_mlp(self.norm2_grid(grid_output)))
+
+ return vis_output, grid_output
+
+
+class CrossSelfBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.norm_vis = norm_layer(dim)
+ self.norm_grid = norm_layer(dim)
+ self.vis_attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ self.grid_attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ if init_values is not None:
+ self.gamma_vis = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_grid = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_vis, self.gamma_grid = None, None
+
+ def self_att(self, vis_input, grid_input):
+ # Cross Attention
+ if self.gamma_vis is None:
+ vis_att_output = vis_input + self.drop_path(self.vis_attn(self.norm_vis(vis_input)))
+ grid_att_output = grid_input + self.drop_path(self.grid_attn(self.norm_grid(grid_input)))
+ else:
+ vis_att_output = vis_input + self.drop_path(self.gamma_vis * self.vis_attn(self.norm_vis(vis_input)))
+ grid_att_output = grid_input + self.drop_path(self.gamma_grid * self.grid_attn(self.norm_grid(grid_input)))
+ return vis_att_output, grid_att_output
+
+ def forward(self, vis_input, grid_input):
+ vis_att_output, grid_att_output = self.self_att(vis_input, grid_input)
+
+ return vis_att_output, grid_att_output
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.0)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ if self.relative_position_bias_table is not None:
+ if training_window_size == self.window_size:
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+ else:
+ training_window_size = tuple(training_window_size.tolist())
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+ new_relative_position_bias_table = F.interpolate(
+ self.relative_position_bias_table[:-3, :]
+ .permute(1, 0)
+ .view(1, self.num_heads, 2 * self.window_size[0] - 1, 2 * self.window_size[1] - 1),
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1),
+ mode="bicubic",
+ align_corners=False,
+ )
+ new_relative_position_bias_table = new_relative_position_bias_table.view(
+ self.num_heads, new_num_relative_distance - 3
+ ).permute(1, 0)
+ new_relative_position_bias_table = torch.cat(
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(training_window_size[0])
+ coords_w = torch.arange(training_window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += training_window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(training_window_size[0] * training_window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
+ relative_position_index[0, 0] = new_num_relative_distance - 1
+
+ relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)].view(
+ training_window_size[0] * training_window_size[1] + 1,
+ training_window_size[0] * training_window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ init_values=None,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ window_size=None,
+ attn_head_dim=None,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values is not None:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(
+ self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+ )
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(
+ self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+ )
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768, bias=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches_w = self.patch_shape[0]
+ self.num_patches_h = self.patch_shape[1]
+ # the so-called patch_shape is the patch shape during pre-training
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+
+ def forward(self, x, position_embedding=None, **kwargs):
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1], \
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ if position_embedding is not None:
+ # interpolate the position embedding to the corresponding size
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3, 1, 2)
+ position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode="bicubic")
+ x = x + position_embedding
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+
+ def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self, training_window_size):
+ if training_window_size == self.window_size:
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ else:
+ training_window_size = tuple(training_window_size.tolist())
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
+ new_relative_position_bias_table = F.interpolate(
+ self.relative_position_bias_table[:-3, :]
+ .permute(1, 0)
+ .view(1, self.num_heads, 2 * self.window_size[0] - 1, 2 * self.window_size[1] - 1),
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1),
+ mode="bicubic",
+ align_corners=False,
+ )
+ new_relative_position_bias_table = new_relative_position_bias_table.view(
+ self.num_heads, new_num_relative_distance - 3
+ ).permute(1, 0)
+ new_relative_position_bias_table = torch.cat(
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(training_window_size[0])
+ coords_w = torch.arange(training_window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += training_window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(training_window_size[0] * training_window_size[1] + 1,) * 2, dtype=relative_coords.dtype
+ )
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
+ relative_position_index[0, 0] = new_num_relative_distance - 1
+
+ relative_position_bias = new_relative_position_bias_table[relative_position_index.view(-1)].view(
+ training_window_size[0] * training_window_size[1] + 1,
+ training_window_size[0] * training_window_size[1] + 1,
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+ return relative_position_bias
+
+
+class BEiT(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(
+ self,
+ img_size=[224, 224],
+ patch_size=16,
+ in_chans=3,
+ grid_chans=64,
+ num_classes=80,
+ embed_dim=768,
+ self_depth=7,
+ cross_depth=5,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.0,
+ hybrid_backbone=None,
+ norm_layer=None,
+ init_values=None,
+ use_abs_pos_emb=False,
+ use_rel_pos_bias=False,
+ use_shared_rel_pos_bias=False,
+ use_checkpoint=True,
+ pretrained=None,
+ out_features=None,
+ ):
+
+ super(BEiT, self).__init__()
+
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.use_checkpoint = use_checkpoint
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ self.grid_patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=grid_chans, embed_dim=embed_dim, bias=True
+ )
+ num_patches = self.patch_embed.num_patches
+ self.out_features = out_features
+ self.out_indices = [int(name[5:]) for name in out_features]
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.grid_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.grid_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.grid_pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self_depth + cross_depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
+ )
+ for i in range(self_depth)
+ ]
+ )
+
+ self.grid_blocks = nn.ModuleList(
+ [
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
+ )
+ for i in range(self_depth)
+ ]
+ )
+
+ self.cross_blocks = nn.ModuleList(
+ [
+ CrossBlock(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i + self_depth],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
+ )
+ for i in range(cross_depth)
+ ]
+ )
+
+ # trunc_normal_(self.mask_token, std=.02)
+
+ if patch_size == 16:
+ self.fpn1 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ # nn.SyncBatchNorm(embed_dim),
+ nn.BatchNorm2d(embed_dim),
+ nn.GELU(),
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.fpn2 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.fpn3 = nn.Identity()
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ self.grid_fpn1 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ # nn.SyncBatchNorm(embed_dim),
+ nn.BatchNorm2d(embed_dim),
+ nn.GELU(),
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.grid_fpn2 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.grid_fpn3 = nn.Identity()
+ self.grid_fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ elif patch_size == 8:
+ self.fpn1 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.fpn2 = nn.Identity()
+ self.fpn3 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ )
+ self.fpn4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ )
+
+ self.grid_fpn1 = nn.Sequential(
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ self.grid_fpn2 = nn.Identity()
+ self.grid_fpn3 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ )
+ self.grid_fpn4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ )
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=0.02)
+ trunc_normal_(self.grid_pos_embed, std=0.02)
+ trunc_normal_(self.cls_token, std=0.02)
+ trunc_normal_(self.grid_token, std=0.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ '''
+ def init_weights(self):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ logger = get_root_logger()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ if self.init_cfg is None:
+ logger.warn(f'No pre-trained weights for '
+ f'{self.__class__.__name__}, '
+ f'training start from scratch')
+ else:
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
+ f'specify `Pretrained` in ' \
+ f'`init_cfg` in ' \
+ f'{self.__class__.__name__} '
+ logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
+ load_checkpoint(self,
+ filename=self.init_cfg['checkpoint'],
+ strict=False,
+ logger=logger,
+ beit_spec_expand_rel_pos = self.use_rel_pos_bias,
+ )
+ '''
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"pos_embed", "cls_token"}
+
+ def forward_features(self, x, grid):
+ B, C, H, W = x.shape
+ vis_x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
+ grid_x, (grid_Hp, grid_Wp) = self.grid_patch_embed(
+ grid, self.grid_pos_embed[:, 1:, :] if self.grid_pos_embed is not None else None
+ )
+
+ # Hp, Wp are HW for patches
+ batch_size, seq_len, _ = grid_x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ grid_tokens = self.grid_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.pos_embed is not None:
+ cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
+ grid_tokens = grid_tokens + self.grid_pos_embed[:, :1, :]
+ vis_x = torch.cat((cls_tokens, vis_x), dim=1)
+ vis_x = self.pos_drop(vis_x)
+
+ grid_x = torch.cat((grid_tokens, grid_x), dim=1)
+ grid_x = self.pos_drop(grid_x)
+
+ features = []
+ grid_features = []
+ training_window_size = torch.tensor([Hp, Wp])
+ grid_training_window_size = torch.tensor([grid_Hp, grid_Wp])
+
+ rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
+
+ for i, blk in enumerate(self.blocks):
+ if self.use_checkpoint:
+ vis_x = checkpoint.checkpoint(blk, vis_x, rel_pos_bias, training_window_size)
+ else:
+ vis_x = blk(vis_x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
+ if i in self.out_indices:
+ xp = vis_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+ features.append(xp.contiguous())
+
+ for i, grid_blk in enumerate(self.grid_blocks):
+ if self.use_checkpoint:
+ grid_x = checkpoint.checkpoint(grid_blk, grid_x, rel_pos_bias, grid_training_window_size)
+ else:
+ grid_x = grid_blk(grid_x, rel_pos_bias=rel_pos_bias, training_window_size=grid_training_window_size)
+ if i in self.out_indices:
+ gp = grid_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, grid_Hp, grid_Wp)
+ grid_features.append(gp.contiguous())
+
+ # import ipdb;ipdb.set_trace()
+ for i, cross_blk in enumerate(self.cross_blocks):
+ if self.use_checkpoint:
+ vis_x, grid_x = checkpoint.checkpoint(cross_blk, vis_x, grid_x)
+ else:
+ vis_x, grid_x = cross_blk(vis_input=vis_x, grid_input=grid_x)
+
+ if 1:
+ xp = vis_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
+ features.append(xp.contiguous())
+
+ gp = grid_x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, grid_Hp, grid_Wp)
+ grid_features.append(gp.contiguous())
+
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
+ grid_ops = [self.grid_fpn1, self.grid_fpn2, self.grid_fpn3, self.grid_fpn4]
+
+ for i in range(len(features)):
+ features[i] = ops[i](features[i])
+
+ for i in range(len(grid_features)):
+ grid_features[i] = grid_ops[i](grid_features[i])
+
+ feat_out = {}
+ grid_feat_out = {}
+
+ for name, vis_value, grid_value in zip(self.out_features, features, grid_features):
+ feat_out[name] = vis_value
+ grid_feat_out[name] = grid_value
+
+ return feat_out, grid_feat_out
+
+ def forward(self, x, grid):
+ x, y = self.forward_features(x, grid)
+ return x, y
+
+
+def beit_base_patch16(pretrained=False, **kwargs):
+ model = BEiT(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=None,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def beit_large_patch16(pretrained=False, **kwargs):
+ model = BEiT(
+ patch_size=16,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=None,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def VGT_dit_base_patch16(pretrained=False, **kwargs):
+ model = BEiT(
+ patch_size=16,
+ embed_dim=768,
+ self_depth=12,
+ cross_depth=0,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=0.1,
+ in_chans=3,
+ grid_chans=64,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def dit_base_patch16(pretrained=False, **kwargs):
+ model = BEiT(
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=0.1,
+ in_chans=3,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def dit_large_patch16(pretrained=False, **kwargs):
+ model = BEiT(
+ patch_size=16,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ init_values=1e-5,
+ **kwargs,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+if __name__ == "__main__":
+ model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
+ model = model.to("cuda:0")
+ input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
+ input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
+ input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
+ output1 = model(input1)
+ output2 = model(input2)
+ output3 = model(input3)
+ print("all done")
diff --git a/src/adapters/ml/vgt/ditod/VGTcheckpointer.py b/src/adapters/ml/vgt/ditod/VGTcheckpointer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb8cc505add11837c1c72910bb61e35a3cd43490
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/VGTcheckpointer.py
@@ -0,0 +1,273 @@
+from detectron2.checkpoint import DetectionCheckpointer
+
+from typing import Any
+import torch
+import torch.nn as nn
+from fvcore.common.checkpoint import (
+ _IncompatibleKeys,
+ _strip_prefix_if_present,
+ TORCH_VERSION,
+ quantization,
+ ObserverBase,
+ FakeQuantizeBase,
+)
+from torch import distributed as dist
+from scipy import interpolate
+import numpy as np
+import torch.nn.functional as F
+
+
+def append_prefix(k):
+ prefix = "backbone."
+ if "Wordgrid_embedding" in k:
+ return k[10:]
+ elif "myFPN" in k:
+ return prefix + k[16:]
+ else:
+ return prefix + k if not k.startswith(prefix) else k
+
+
+def DiT_append_prefix(k):
+ prefix = "backbone.bottom_up.backbone."
+ return prefix + k if not k.startswith(prefix) else k
+
+
+def modify_ckpt_state(model, state_dict, logger=None):
+ # reshape absolute position embedding for Swin
+ if state_dict.get(append_prefix("absolute_pos_embed")) is not None:
+ absolute_pos_embed = state_dict[append_prefix("absolute_pos_embed")]
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.backbone.bottom_up.backbone.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H * W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict[append_prefix("absolute_pos_embed")] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+ def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+ rank, _ = get_dist_info()
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ if "relative_position_index" in key:
+ state_dict.pop(key)
+
+ if "relative_position_bias_table" in key:
+ rel_pos_bias = state_dict[key]
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+ if key not in model.state_dict():
+ continue
+ dst_num_pos, _ = model.state_dict()[key].size()
+ dst_patch_shape = model.backbone.bottom_up.backbone.patch_embed.patch_shape
+ if dst_patch_shape[0] != dst_patch_shape[1]:
+ raise NotImplementedError()
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
+ if src_size != dst_size:
+ if rank == 0:
+ print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size))
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r**n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.13492:
+ # q = 1.13492
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+ if rank == 0:
+ print("x = {}".format(x))
+ print("dx = {}".format(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(num_attn_heads):
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
+ f = interpolate.interp2d(x, y, z, kind="cubic")
+ all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
+
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
+ state_dict[key] = new_rel_pos_bias
+
+ if append_prefix("pos_embed") in state_dict:
+ pos_embed_checkpoint = state_dict[append_prefix("pos_embed")]
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.backbone.bottom_up.backbone.patch_embed.num_patches
+ num_extra_tokens = model.backbone.bottom_up.backbone.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ # new_size = int(num_patches ** 0.5)
+ new_size_w = model.backbone.bottom_up.backbone.patch_embed.num_patches_w
+ new_size_h = model.backbone.bottom_up.backbone.patch_embed.num_patches_h
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size_h or orig_size != new_size_w:
+ if rank == 0:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size_w, new_size_h))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size_w, new_size_h), mode="bicubic", align_corners=False
+ )
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ state_dict[append_prefix("pos_embed")] = new_pos_embed
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ if table_key not in model.state_dict():
+ continue
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1**0.5)
+ S2 = int(L2**0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode="bicubic"
+ )
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+ if (
+ append_prefix("rel_pos_bias.relative_position_bias_table") in state_dict
+ and model.backbone.bottom_up.backbone.use_rel_pos_bias
+ and not model.backbone.bottom_up.backbone.use_shared_rel_pos_bias
+ and append_prefix("blocks.0.attn.relative_position_bias_table") not in state_dict
+ ):
+ logger.info("[BEIT] Expand the shared relative position embedding to each transformer block. ")
+ num_layers = model.backbone.bottom_up.backbone.get_num_layers()
+ rel_pos_bias = state_dict[append_prefix("rel_pos_bias.relative_position_bias_table")]
+ for i in range(num_layers):
+ state_dict["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone()
+ state_dict.pop(append_prefix("rel_pos_bias.relative_position_bias_table"))
+
+ return state_dict
+
+
+class MyDetectionCheckpointer(DetectionCheckpointer):
+ def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
+ """
+ Load weights from a checkpoint.
+
+ Args:
+ checkpoint (Any): checkpoint contains the weights.
+
+ Returns:
+ ``NamedTuple`` with ``missing_keys``, ``unexpected_keys``,
+ and ``incorrect_shapes`` fields:
+ * **missing_keys** is a list of str containing the missing keys
+ * **unexpected_keys** is a list of str containing the unexpected keys
+ * **incorrect_shapes** is a list of (key, shape in checkpoint, shape in model)
+
+ This is just like the return value of
+ :func:`torch.nn.Module.load_state_dict`, but with extra support
+ for ``incorrect_shapes``.
+ """
+ DiT_checkpoint_state_dict = torch.load("/path/dit-base-224-p16-500k-62d53a.pth", map_location=torch.device("cpu"))[
+ "model"
+ ]
+ checkpoint_state_dict = checkpoint.pop("model")
+ # import ipdb;ipdb.set_trace()
+ self._convert_ndarray_to_tensor(checkpoint_state_dict)
+
+ # if the state_dict comes from a model that was wrapped in a
+ # DataParallel or DistributedDataParallel during serialization,
+ # remove the "module" prefix before performing the matching.
+ _strip_prefix_if_present(checkpoint_state_dict, "module.")
+
+ # workaround https://github.com/pytorch/pytorch/issues/24139
+ model_state_dict = self.model.state_dict()
+ incorrect_shapes = []
+
+ new_checkpoint_state_dict = {}
+ for k in checkpoint_state_dict.keys():
+ new_checkpoint_state_dict[append_prefix(k)] = checkpoint_state_dict[k]
+
+ for k in DiT_checkpoint_state_dict.keys():
+ new_checkpoint_state_dict[DiT_append_prefix(k)] = DiT_checkpoint_state_dict[k]
+
+ checkpoint_state_dict = new_checkpoint_state_dict
+
+ for k in list(checkpoint_state_dict.keys()):
+ if k in model_state_dict:
+ model_param = model_state_dict[k]
+ # Allow mismatch for uninitialized parameters
+ if TORCH_VERSION >= (1, 8) and isinstance(model_param, nn.parameter.UninitializedParameter):
+ continue
+ shape_model = tuple(model_param.shape)
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+ if shape_model != shape_checkpoint:
+
+ has_observer_base_classes = (
+ TORCH_VERSION >= (1, 8)
+ and hasattr(quantization, "ObserverBase")
+ and hasattr(quantization, "FakeQuantizeBase")
+ )
+ if has_observer_base_classes:
+ # Handle the special case of quantization per channel observers,
+ # where buffer shape mismatches are expected.
+ def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
+ # foo.bar.param_or_buffer_name -> [foo, bar]
+ key_parts = key.split(".")[:-1]
+ cur_module = model
+ for key_part in key_parts:
+ cur_module = getattr(cur_module, key_part)
+ return cur_module
+
+ cls_to_skip = (
+ ObserverBase,
+ FakeQuantizeBase,
+ )
+ target_module = _get_module_for_key(self.model, k)
+ if isinstance(target_module, cls_to_skip):
+ # Do not remove modules with expected shape mismatches
+ # them from the state_dict loading. They have special logic
+ # in _load_from_state_dict to handle the mismatches.
+ continue
+
+ incorrect_shapes.append((k, shape_checkpoint, shape_model))
+ checkpoint_state_dict.pop(k)
+ incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
+ return _IncompatibleKeys(
+ missing_keys=incompatible.missing_keys,
+ unexpected_keys=incompatible.unexpected_keys,
+ incorrect_shapes=incorrect_shapes,
+ )
diff --git a/src/adapters/ml/vgt/ditod/Wordnn_embedding.py b/src/adapters/ml/vgt/ditod/Wordnn_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a18cd7b9bef2e2650db1b4d0f38880fb032b7b9
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/Wordnn_embedding.py
@@ -0,0 +1,94 @@
+import numpy as np
+import torch
+from torch import nn
+from .tokenization_bros import BrosTokenizer
+
+
+def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+
+class WordnnEmbedding(nn.Module):
+ """Generate chargrid embedding feature map."""
+
+ def __init__(
+ self,
+ vocab_size=30552,
+ hidden_size=768,
+ embedding_dim=64,
+ bros_embedding_path="/bros-base-uncased/",
+ use_pretrain_weight=True,
+ use_UNK_text=False,
+ ):
+ """
+ Args:
+ vocab_size (int): size of vocabulary.
+ embedding_dim (int): dim of input features
+ """
+ super().__init__()
+
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
+ self.embedding_proj = nn.Linear(hidden_size, embedding_dim, bias=False)
+ # self.tokenizer = BrosTokenizer.from_pretrained(bros_embedding_path)
+ self.use_pretrain_weight = use_pretrain_weight
+ self.use_UNK_text = use_UNK_text
+
+ self.init_weights(bros_embedding_path)
+ self.apply(_init_weights)
+
+ def init_weights(self, bros_embedding_path):
+ if self.use_pretrain_weight:
+ state_dict = torch.load(bros_embedding_path + "pytorch_model.bin", map_location="cpu")
+ if "bert" in bros_embedding_path:
+ word_embs = state_dict["bert.embeddings.word_embeddings.weight"]
+ elif "bros" in bros_embedding_path:
+ word_embs = state_dict["embeddings.word_embeddings.weight"]
+ elif "layoutlm" in bros_embedding_path:
+ word_embs = state_dict["layoutlm.embeddings.word_embeddings.weight"]
+ else:
+ print("Wrong bros_embedding_path!")
+ self.embedding = nn.Embedding.from_pretrained(word_embs)
+ print("use_pretrain_weight: load model from:", bros_embedding_path)
+
+ def forward(self, img, batched_inputs, stride=1):
+ """Forward computation
+ Args:
+ img (Tensor): in shape of [B x 3 x H x W]
+ batched_inputs (list[dict]):
+ Returns:
+ Tensor: in shape of [B x N x L x D], where D is the embedding_dim.
+ """
+ device = img.device
+ batch_b, _, batch_h, batch_w = img.size()
+
+ chargrid_map = torch.zeros((batch_b, batch_h // stride, batch_w // stride), dtype=torch.int64).to(device)
+
+ for iter_b in range(batch_b):
+ per_input_ids = batched_inputs[iter_b]["input_ids"]
+ per_input_bbox = batched_inputs[iter_b]["bbox"]
+
+ short_length_w = min(len(per_input_ids), len(per_input_bbox))
+
+ if short_length_w > 0:
+ for word_idx in range(short_length_w):
+ per_id = per_input_ids[word_idx]
+
+ bbox = per_input_bbox[word_idx] / stride
+ w_start, h_start, w_end, h_end = bbox.round().astype(int).tolist()
+
+ if self.use_UNK_text:
+ chargrid_map[iter_b, h_start:h_end, w_start:w_end] = 100
+ else:
+ chargrid_map[iter_b, h_start:h_end, w_start:w_end] = per_id
+
+ chargrid_map = self.embedding(chargrid_map)
+ chargrid_map = self.embedding_proj(chargrid_map)
+
+ return chargrid_map.permute(0, 3, 1, 2).contiguous()
diff --git a/src/adapters/ml/vgt/ditod/__init__.py b/src/adapters/ml/vgt/ditod/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e09402088dd457a4a5876a9f2f30f06713fb20b9
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/__init__.py
@@ -0,0 +1,16 @@
+# --------------------------------------------------------------------------------
+# MPViT: Multi-Path Vision Transformer for Dense Prediction
+# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
+# All Rights Reserved.
+# Written by Youngwan Lee
+# This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------------------
+
+from .config import add_vit_config
+from .VGTbackbone import build_VGT_fpn_backbone
+from .dataset_mapper import DetrDatasetMapper
+from .VGTTrainer import VGTTrainer
+from .VGT import VGT
+
+from .utils import eval_and_show, load_gt_from_json, pub_load_gt_from_json
diff --git a/src/adapters/ml/vgt/ditod/config.py b/src/adapters/ml/vgt/ditod/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7de9783fcf225b792d63a85a977c6cd3d4a5f8dd
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/config.py
@@ -0,0 +1,48 @@
+from detectron2.config import CfgNode as CN
+
+
+def add_vit_config(cfg):
+ """
+ Add config for VIT.
+ """
+ _C = cfg
+
+ _C.MODEL.VIT = CN()
+
+ # CoaT model name.
+ _C.MODEL.VIT.NAME = ""
+
+ # Output features from CoaT backbone.
+ _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
+
+ _C.MODEL.VIT.IMG_SIZE = [224, 224]
+
+ _C.MODEL.VIT.POS_TYPE = "shared_rel"
+
+ _C.MODEL.VIT.MERGE_TYPE = "Sum"
+
+ _C.MODEL.VIT.DROP_PATH = 0.0
+
+ _C.MODEL.VIT.MODEL_KWARGS = "{}"
+
+ _C.SOLVER.OPTIMIZER = "ADAMW"
+
+ _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
+
+ _C.AUG = CN()
+
+ _C.AUG.DETR = False
+
+ _C.MODEL.WORDGRID = CN()
+
+ _C.MODEL.WORDGRID.VOCAB_SIZE = 30552
+
+ _C.MODEL.WORDGRID.EMBEDDING_DIM = 64
+
+ _C.MODEL.WORDGRID.MODEL_PATH = ""
+
+ _C.MODEL.WORDGRID.HIDDEN_SIZE = 768
+
+ _C.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT = True
+
+ _C.MODEL.WORDGRID.USE_UNK_TEXT = False
diff --git a/src/adapters/ml/vgt/ditod/dataset_mapper.py b/src/adapters/ml/vgt/ditod/dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eca0ed5e011fcbca640c1c961517ebb8df13b59
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/dataset_mapper.py
@@ -0,0 +1,202 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# from https://github.com/facebookresearch/detr/blob/main/d2/detr/dataset_mapper.py
+
+
+import copy
+import logging
+from os import path
+
+import numpy as np
+import torch
+
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+
+import json
+import pickle
+
+from detectron2.structures import (
+ BitMasks,
+ Boxes,
+ BoxMode,
+ Instances,
+ Keypoints,
+ PolygonMasks,
+ RotatedBoxes,
+ polygons_to_bitmask,
+)
+
+__all__ = ["DetrDatasetMapper"]
+
+
+def build_transform_gen(cfg, is_train):
+ """
+ Create a list of :class:`TransformGen` from config.
+ Returns:
+ list[TransformGen]
+ """
+ if is_train:
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
+ else:
+ min_size = cfg.INPUT.MIN_SIZE_TEST
+ max_size = cfg.INPUT.MAX_SIZE_TEST
+ sample_style = "choice"
+ if sample_style == "range":
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
+
+ logger = logging.getLogger(__name__)
+ tfm_gens = []
+ # if is_train:
+ # tfm_gens.append(T.RandomFlip())
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
+ if is_train:
+ logger.info("TransformGens used in training: " + str(tfm_gens))
+ return tfm_gens
+
+
+def build_transform_gen_w(cfg, is_train):
+ """
+ Create a list of :class:`TransformGen` from config.
+ Returns:
+ list[TransformGen]
+ """
+ if is_train:
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
+ max_size = 800
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
+ else:
+ min_size = cfg.INPUT.MIN_SIZE_TEST
+ max_size = cfg.INPUT.MAX_SIZE_TEST
+ sample_style = "choice"
+ if sample_style == "range":
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size))
+
+ logger = logging.getLogger(__name__)
+ tfm_gens = []
+ # if is_train:
+ # tfm_gens.append(T.RandomFlip())
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
+ if is_train:
+ logger.info("TransformGens used in training: " + str(tfm_gens))
+ return tfm_gens
+
+
+class DetrDatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by DETR.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies geometric transforms to the image and annotation
+ 3. Find and applies suitable cropping to the image and annotation
+ 4. Prepare image and annotation to Tensors
+ """
+
+ def __init__(self, cfg, is_train=True):
+ if cfg.INPUT.CROP.ENABLED and is_train:
+ self.crop_gen = [
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
+ ]
+ else:
+ self.crop_gen = None
+
+ self.mask_on = cfg.MODEL.MASK_ON
+ self.tfm_gens = build_transform_gen(cfg, is_train)
+ self.tfm_gens_w = build_transform_gen_w(cfg, is_train)
+ logging.getLogger(__name__).info(
+ "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen))
+ )
+
+ self.img_format = cfg.INPUT.FORMAT
+ self.is_train = is_train
+ self.cfg = cfg
+
+ logger = logging.getLogger("detectron2")
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+ utils.check_image_size(dataset_dict, image)
+
+ word_grid_path = dataset_dict["file_name"].replace("images", "word_grids").replace(".jpg", ".pkl")
+ if path.exists(word_grid_path):
+ with open(word_grid_path, "rb") as f:
+ sample_inputs = pickle.load(f)
+ input_ids = sample_inputs["input_ids"]
+ bbox_subword_list = sample_inputs["bbox_subword_list"]
+ else:
+ input_ids = []
+ bbox_subword_list = []
+ print(f"No word grid pkl in: {word_grid_path}")
+
+ image_shape_ori = image.shape[:2] # h, w
+
+ if self.crop_gen is None:
+ if image_shape_ori[0] > image_shape_ori[1]:
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
+ else:
+ image, transforms = T.apply_transform_gens(self.tfm_gens_w, image)
+ else:
+ if np.random.rand() > 0.5:
+ if image_shape_ori[0] > image_shape_ori[1]:
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
+ else:
+ image, transforms = T.apply_transform_gens(self.tfm_gens_w, image)
+ else:
+ image, transforms = T.apply_transform_gens(
+ self.tfm_gens_w[:-1] + self.crop_gen + self.tfm_gens_w[-1:], image
+ )
+
+ image_shape = image.shape[:2] # h, w
+
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+
+ ## 产出 text grid
+ bbox = []
+ for bbox_per_subword in bbox_subword_list:
+ text_word = {}
+ text_word["bbox"] = bbox_per_subword.tolist()
+ text_word["bbox_mode"] = BoxMode.XYWH_ABS
+ utils.transform_instance_annotations(text_word, transforms, image_shape)
+ bbox.append(text_word["bbox"])
+
+ dataset_dict["input_ids"] = input_ids
+ dataset_dict["bbox"] = bbox
+
+ if not self.is_train:
+ # USER: Modify this if you want to keep them for some reason.
+ dataset_dict.pop("annotations", None)
+ return dataset_dict
+
+ if "annotations" in dataset_dict:
+ # USER: Modify this if you want to keep them for some reason.
+ for anno in dataset_dict["annotations"]:
+ if not self.mask_on:
+ anno.pop("segmentation", None)
+ anno.pop("keypoints", None)
+
+ # USER: Implement additional transformations if you have other types of data
+ annos = [
+ utils.transform_instance_annotations(obj, transforms, image_shape)
+ for obj in dataset_dict.pop("annotations")
+ if obj.get("iscrowd", 0) == 0
+ ]
+ instances = utils.annotations_to_instances(annos, image_shape)
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
+
+ return dataset_dict
diff --git a/src/adapters/ml/vgt/ditod/tokenization_bros.py b/src/adapters/ml/vgt/ditod/tokenization_bros.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4000dafc8f27d761dbe0b1aa2da14774587c52f
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/tokenization_bros.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes."""
+
+
+import collections
+
+from transformers.models.bert.tokenization_bert import BertTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
+ "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
+ }
+}
+
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
+ "naver-clova-ocr/bros-base-uncased": 512,
+ "naver-clova-ocr/bros-large-uncased": 512,
+}
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
+ "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
+}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def convert_to_unicode(text):
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode("utf-8", "ignore")
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text.decode("utf-8", "ignore")
+ elif isinstance(text, unicode):
+ return text
+ else:
+ raise ValueError("Unsupported string type: %s" % (type(text)))
+ else:
+ raise ValueError("Not running on Python2 or Python 3?")
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class BrosTokenizer(BertTokenizer):
+ r"""
+ Construct a BERT tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
+ Users should refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (:obj:`str`):
+ File containing the vocabulary.
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (:obj:`Iterable`, `optional`):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ :obj:`do_basic_tokenize=True`
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this `issue
+ `__).
+ strip_accents: (:obj:`bool`, `optional`):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for :obj:`lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
diff --git a/src/adapters/ml/vgt/ditod/utils.py b/src/adapters/ml/vgt/ditod/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..47015c147ddd5c617dfc1b66d7c1c182f0f9099e
--- /dev/null
+++ b/src/adapters/ml/vgt/ditod/utils.py
@@ -0,0 +1,379 @@
+import json
+import os
+import sys
+
+import cv2
+import numpy as np
+from shapely.geometry import Polygon
+from tabulate import tabulate
+
+
+def get_image_path(image_dir, file_name_wo_ext):
+ ext_list = ["", ".jpg", ".JPG", ".png", ".PNG", ".jpeg"]
+ image_path = None
+ for ext in ext_list:
+ image_path_tmp = os.path.join(image_dir, file_name_wo_ext + ext)
+ if os.path.exists(image_path_tmp):
+ image_path = image_path_tmp
+ break
+ return image_path
+
+
+def visual_badcase(image_path, pred_list, label_list, output_dir="visual_badcase", info=None, prefix=""):
+ """ """
+ img = cv2.imread(image_path) if os.path.exists(image_path) is not None else None
+ if img is None:
+ print("--> Warning: skip, given iamge NOT exists: {}".format(image_path))
+ return None
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ for label in label_list:
+ points, class_id = label["poly"], label["category_id"]
+ pts = np.array(points).reshape((1, -1, 2)).astype(np.int32)
+ cv2.polylines(img, pts, isClosed=True, color=(0, 255, 0), thickness=3)
+ cv2.putText(img, "gt:" + str(class_id), tuple(pts[0][0].tolist()), font, 1, (0, 255, 0), 2)
+
+ for label in pred_list:
+ points, class_id = label["poly"], label["category_id"]
+ pts = np.array(points).reshape((1, -1, 2)).astype(np.int32)
+ cv2.polylines(img, pts, isClosed=True, color=(255, 0, 0), thickness=3)
+ cv2.putText(img, "pred:" + str(class_id), tuple(pts[0][-1].tolist()), font, 1, (255, 0, 0), 2)
+
+ if info is not None:
+ cv2.putText(img, str(info), (40, 40), font, 1, (0, 0, 255), 2)
+ output_path = os.path.join(output_dir, prefix + os.path.basename(image_path) + "_vis.jpg")
+ cv2.imwrite(output_path, img)
+ return output_path
+
+
+def pub_load_gt_from_json(json_path):
+ """ """
+ with open(json_path) as f:
+ gt_info = json.load(f)
+ gt_image_list = gt_info["images"]
+ gt_anno_list = gt_info["annotations"]
+
+ id_to_image_info = {}
+ for image_item in gt_image_list:
+ id_to_image_info[image_item["id"]] = {
+ "file_name": image_item["file_name"],
+ "group_name": image_item.get("group_name", "huntie"),
+ }
+
+ group_info = {}
+ for annotation_item in gt_anno_list:
+ image_info = id_to_image_info[annotation_item["image_id"]]
+ image_name, group_name = image_info["file_name"], image_info["group_name"]
+
+ # import ipdb;ipdb.set_trace()
+ if image_name == "15_103.tar_1705.05489.gz_main_12_ori.jpg":
+ print(image_info["file_name"], annotation_item["image_id"])
+ # import ipdb;ipdb.set_trace()
+
+ if group_name not in group_info:
+ group_info[group_name] = {}
+ if image_name not in group_info[group_name]:
+ group_info[group_name][image_name] = []
+
+ box_xywh = annotation_item["bbox"]
+ box_xyxy = [box_xywh[0], box_xywh[1], box_xywh[0] + box_xywh[2], box_xywh[1] + box_xywh[3]]
+ pts = np.round(
+ [box_xyxy[0], box_xyxy[1], box_xyxy[2], box_xyxy[1], box_xyxy[2], box_xyxy[3], box_xyxy[0], box_xyxy[3]]
+ )
+ anno_info = {
+ "category_id": annotation_item["category_id"],
+ "poly": pts,
+ "secondary_id": annotation_item.get("secondary_id", -1),
+ "direction_id": annotation_item.get("direction_id", -1),
+ }
+ group_info[group_name][image_name].append(anno_info)
+
+ group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()])
+ print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str))
+ return group_info
+
+
+def load_gt_from_json(json_path):
+ """ """
+ with open(json_path) as f:
+ gt_info = json.load(f)
+ gt_image_list = gt_info["images"]
+ gt_anno_list = gt_info["annotations"]
+
+ id_to_image_info = {}
+ for image_item in gt_image_list:
+ id_to_image_info[image_item["id"]] = {
+ "file_name": image_item["file_name"],
+ "group_name": image_item.get("group_name", "huntie"),
+ }
+
+ group_info = {}
+ for annotation_item in gt_anno_list:
+ image_info = id_to_image_info[annotation_item["image_id"]]
+ image_name, group_name = image_info["file_name"], image_info["group_name"]
+
+ if group_name not in group_info:
+ group_info[group_name] = {}
+ if image_name not in group_info[group_name]:
+ group_info[group_name][image_name] = []
+ anno_info = {
+ "category_id": annotation_item["category_id"],
+ "poly": annotation_item["poly"],
+ "secondary_id": annotation_item.get("secondary_id", -1),
+ "direction_id": annotation_item.get("direction_id", -1),
+ }
+ group_info[group_name][image_name].append(anno_info)
+
+ group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()])
+ print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str))
+ return group_info
+
+
+def calc_iou(label, detect):
+ label_box = []
+ detect_box = []
+
+ d_area = []
+ for i in range(0, len(detect)):
+ pred_poly = detect[i]["poly"]
+ box_det = []
+ for k in range(0, 4):
+ box_det.append([pred_poly[2 * k], pred_poly[2 * k + 1]])
+ detect_box.append(box_det)
+ try:
+ poly = Polygon(box_det)
+ d_area.append(poly.area)
+ except:
+ print("invalid detects", pred_poly)
+ exit(-1)
+
+ l_area = []
+ for i in range(0, len(label)):
+ gt_poly = label[i]["poly"]
+ box_gt = []
+ for k in range(4):
+ box_gt.append([gt_poly[2 * k], gt_poly[2 * k + 1]])
+ label_box.append(box_gt)
+ try:
+ poly = Polygon(box_gt)
+ l_area.append(poly.area)
+ except:
+ print("invalid detects", gt_poly)
+ exit(-1)
+
+ ol_areas = []
+ for i in range(0, len(detect_box)):
+ ol_areas.append([])
+ poly1 = Polygon(detect_box[i])
+ for j in range(0, len(label_box)):
+ poly2 = Polygon(label_box[j])
+ try:
+ ol_area = poly2.intersection(poly1).area
+ except:
+ print("invaild pair", detect_box[i], label_box[j])
+ ol_areas[i].append(0.0)
+ else:
+ ol_areas[i].append(ol_area)
+
+ d_ious = [0.0] * len(detect_box)
+ l_ious = [0.0] * len(label_box)
+ for i in range(0, len(detect_box)):
+ for j in range(0, len(label_box)):
+ if int(label[j]["category_id"]) == int(detect[i]["category_id"]):
+ iou = min(ol_areas[i][j] / (d_area[i] + 1e-10), ol_areas[i][j] / (l_area[j] + 1e-10))
+ else:
+ iou = 0
+ d_ious[i] = max(d_ious[i], iou)
+ l_ious[j] = max(l_ious[j], iou)
+ return l_ious, d_ious
+
+
+def eval(instance_info):
+ img_name, label_info = instance_info
+ label = label_info["gt"]
+ detect = label_info["det"]
+ l_ious, d_ious = calc_iou(label, detect)
+ return [img_name, d_ious, l_ious, detect, label]
+
+
+def static_with_class(rets, iou_thresh=0.7, is_verbose=True, map_info=None, src_image_dir=None, visualization_dir=None):
+ if is_verbose:
+ table_head = ["Class_id", "Class_name", "Pre_hit", "Pre_num", "GT_hit", "GT_num", "Precision", "Recall", "F-score"]
+ else:
+ table_head = ["Class_id", "Class_name", "Precision", "Recall", "F-score"]
+ table_body = []
+ class_dict = {}
+
+ for i in range(len(rets)):
+ img_name, d_ious, l_ious, detects, labels = rets[i]
+ item_lv, item_dv, item_dm, item_lm = 0, 0, 0, 0
+ for label in labels:
+ item_lv += 1
+ category_id = label["category_id"]
+ if category_id not in class_dict:
+ class_dict[category_id] = {}
+ class_dict[category_id]["dm"] = 0
+ class_dict[category_id]["dv"] = 0
+ class_dict[category_id]["lm"] = 0
+ class_dict[category_id]["lv"] = 0
+ class_dict[category_id]["lv"] += 1
+
+ for det in detects:
+ item_dv += 1
+ category_id = det["category_id"]
+ if category_id not in class_dict:
+ print("--> category_id not exists in gt: {}".format(category_id))
+ continue
+ class_dict[category_id]["dv"] += 1
+
+ for idx, iou in enumerate(d_ious):
+ if iou >= iou_thresh:
+ item_dm += 1
+ class_dict[detects[idx]["category_id"]]["dm"] += 1
+ for idx, iou in enumerate(l_ious):
+ if iou >= iou_thresh:
+ item_lm += 1
+ class_dict[labels[idx]["category_id"]]["lm"] += 1
+ item_p = item_dm / (item_dv + 1e-6)
+ item_r = item_lm / (item_lv + 1e-6)
+ item_f = 2 * item_p * item_r / (item_p + item_r + 1e-6)
+
+ if item_f < 0.97 and src_image_dir is not None:
+ image_path = get_image_path(src_image_dir, os.path.basename(img_name))
+ visualization_output = visualization_dir if visualization_dir is not None else "./visualization_badcase"
+ item_info = "IOU{}, {}, {}, {}".format(iou_thresh, item_r, item_p, item_f)
+ vis_path = visual_badcase(
+ image_path,
+ detects,
+ labels,
+ output_dir=visualization_output,
+ info=item_info,
+ prefix="{:02d}_".format(int(item_f * 100)),
+ )
+ if is_verbose:
+ print("--> info: save visualization at: {}".format(vis_path))
+
+ dm, dv, lm, lv = 0, 0, 0, 0
+ map_info = {} if map_info is None else map_info
+ for key in class_dict.keys():
+ dm += class_dict[key]["dm"]
+ dv += class_dict[key]["dv"]
+ lm += class_dict[key]["lm"]
+ lv += class_dict[key]["lv"]
+ p = class_dict[key]["dm"] / (class_dict[key]["dv"] + 1e-6)
+ r = class_dict[key]["lm"] / (class_dict[key]["lv"] + 1e-6)
+ fscore = 2 * p * r / (p + r + 1e-6)
+ if is_verbose:
+ table_body.append(
+ (
+ key,
+ map_info.get("primary_map", {}).get(str(key), str(key)),
+ class_dict[key]["dm"],
+ class_dict[key]["dv"],
+ class_dict[key]["lm"],
+ class_dict[key]["lv"],
+ p,
+ r,
+ fscore,
+ )
+ )
+ else:
+ table_body.append((key, map_info.get(str(key), str(key)), p, r, fscore))
+
+ p = dm / (dv + 1e-6)
+ r = lm / (lv + 1e-6)
+ f = 2 * p * r / (p + r + 1e-6)
+
+ table_body_sorted = sorted(table_body, key=lambda x: int((x[0])))
+ if is_verbose:
+ table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", dm, dv, lm, lv, p, r, f))
+ else:
+ table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", p, r, f))
+ print(tabulate(table_body_sorted, headers=table_head, tablefmt="pipe"))
+ return [table_head] + table_body_sorted
+
+
+def multiproc(func, task_list, proc_num=30, retv=True, progress_bar=False):
+ from multiprocessing import Pool
+
+ pool = Pool(proc_num)
+
+ rets = []
+ if progress_bar:
+ import tqdm
+
+ with tqdm.tqdm(total=len(task_list)) as t:
+ for ret in pool.imap(func, task_list):
+ rets.append(ret)
+ t.update(1)
+ else:
+ for ret in pool.imap(func, task_list):
+ rets.append(ret)
+
+ pool.close()
+ pool.join()
+
+ if retv:
+ return rets
+
+
+def eval_and_show(
+ label_dict, detect_dict, output_dir, iou_thresh=0.7, map_info=None, src_image_dir=None, visualization_dir=None
+):
+ """ """
+ evaluation_group_info = {}
+ for group_name, gt_info in label_dict.items():
+ group_pair_list = []
+ for file_name, value_list in gt_info.items():
+ if file_name not in detect_dict:
+ print("--> missing pred:", file_name)
+ continue
+ group_pair_list.append([file_name, {"gt": gt_info[file_name], "det": detect_dict[file_name]}])
+ evaluation_group_info[group_name] = group_pair_list
+
+ res_info_all = {}
+ for group_name, group_pair_list in evaluation_group_info.items():
+ print(" ------- group name: {} -----------".format(group_name))
+ rets = multiproc(eval, group_pair_list, proc_num=16)
+ group_name_map_info = map_info.get(group_name, None) if map_info is not None else None
+ res_info = static_with_class(
+ rets,
+ iou_thresh=iou_thresh,
+ map_info=group_name_map_info,
+ src_image_dir=src_image_dir,
+ visualization_dir=visualization_dir,
+ )
+ res_info_all[group_name] = res_info
+
+ evaluation_res_info_path = os.path.join(output_dir, "results_val.json")
+ with open(evaluation_res_info_path, "w") as f:
+ json.dump(res_info_all, f, ensure_ascii=False, indent=4)
+ print("--> info: evaluation result is saved at {}".format(evaluation_res_info_path))
+
+
+if __name__ == "__main__":
+
+ if len(sys.argv) != 5:
+ print("Usage: python {} gt_json_path pred_json_path output_dir iou_thresh".format(__file__))
+ exit(-1)
+ else:
+ print("--> info: {}".format(sys.argv))
+ gt_json_path, pred_json_path, output_dir, iou_thresh = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]
+
+ label_dict = load_gt_from_json(gt_json_path)
+ with open(pred_json_path, "r") as f:
+ detect_dict = json.load(f)
+
+ src_image_dir = None
+ eval_and_show(
+ label_dict,
+ detect_dict,
+ output_dir,
+ iou_thresh=iou_thresh,
+ map_info=None,
+ src_image_dir=src_image_dir,
+ visualization_dir=None,
+ )
diff --git a/src/adapters/ml/vgt/get_json_annotations.py b/src/adapters/ml/vgt/get_json_annotations.py
new file mode 100644
index 0000000000000000000000000000000000000000..85a153edb3e8065ae3cb1e23896bf4659855e572
--- /dev/null
+++ b/src/adapters/ml/vgt/get_json_annotations.py
@@ -0,0 +1,71 @@
+import json
+from os import makedirs
+from pdf_features import PdfToken
+from domain.PdfImages import PdfImages
+from configuration import DOCLAYNET_TYPE_BY_ID
+from configuration import JSONS_ROOT_PATH, JSON_TEST_FILE_PATH
+
+
+def save_annotations_json(annotations: list, width_height: list, images: list):
+ images_dict = [
+ {
+ "id": i,
+ "file_name": image_id + ".jpg",
+ "width": width_height[images.index(image_id)][0],
+ "height": width_height[images.index(image_id)][1],
+ }
+ for i, image_id in enumerate(images)
+ ]
+
+ categories_dict = [{"id": key, "name": value} for key, value in DOCLAYNET_TYPE_BY_ID.items()]
+
+ info_dict = {
+ "description": "PDF Document Layout Analysis Dataset",
+ "url": "",
+ "version": "1.0",
+ "year": 2025,
+ "contributor": "",
+ "date_created": "2025-01-01",
+ }
+
+ coco_dict = {"info": info_dict, "images": images_dict, "categories": categories_dict, "annotations": annotations}
+
+ JSON_TEST_FILE_PATH.write_text(json.dumps(coco_dict))
+
+
+def get_annotation(index: int, image_id: str, token: PdfToken):
+ return {
+ "area": 1,
+ "iscrowd": 0,
+ "score": 1,
+ "image_id": image_id,
+ "bbox": [token.bounding_box.left, token.bounding_box.top, token.bounding_box.width, token.bounding_box.height],
+ "category_id": token.token_type.get_index(),
+ "id": index,
+ }
+
+
+def get_annotations_for_document(annotations, images, index, pdf_images, width_height):
+ for page_index, page in enumerate(pdf_images.pdf_features.pages):
+ image_id = f"{pdf_images.pdf_features.file_name}_{page.page_number - 1}"
+ images.append(image_id)
+ width_height.append((pdf_images.pdf_images[page_index].width, pdf_images.pdf_images[page_index].height))
+
+ for token in page.tokens:
+ annotations.append(get_annotation(index, image_id, token))
+ index += 1
+
+
+def get_annotations(pdf_images_list: list[PdfImages]):
+ makedirs(JSONS_ROOT_PATH, exist_ok=True)
+
+ annotations = list()
+ images = list()
+ width_height = list()
+ index = 0
+
+ for pdf_images in pdf_images_list:
+ get_annotations_for_document(annotations, images, index, pdf_images, width_height)
+ index += sum([len(page.tokens) for page in pdf_images.pdf_features.pages])
+
+ save_annotations_json(annotations, width_height, images)
diff --git a/src/adapters/ml/vgt/get_model_configuration.py b/src/adapters/ml/vgt/get_model_configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..6733a4583b72da1a786d8aba7dbba7435ffd7d9e
--- /dev/null
+++ b/src/adapters/ml/vgt/get_model_configuration.py
@@ -0,0 +1,51 @@
+import torch
+from os.path import join
+from detectron2.config import get_cfg
+from detectron2.engine import default_setup, default_argument_parser
+from configuration import service_logger, SRC_PATH, ROOT_PATH
+from adapters.ml.vgt.ditod import add_vit_config
+
+
+def is_gpu_available():
+ total_free_memory_in_system: float = 0.0
+ if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**2
+ allocated_memory = torch.cuda.memory_allocated(i) / 1024**2
+ cached_memory = torch.cuda.memory_reserved(i) / 1024**2
+ service_logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
+ service_logger.info(f" Total Memory: {total_memory} MB")
+ service_logger.info(f" Allocated Memory: {allocated_memory} MB")
+ service_logger.info(f" Cached Memory: {cached_memory} MB")
+ total_free_memory_in_system += total_memory - allocated_memory - cached_memory
+ if total_free_memory_in_system < 3000:
+ service_logger.info(f"Total free GPU memory is {total_free_memory_in_system} < 3000 MB. Switching to CPU.")
+ service_logger.info("The process is probably going to be 15 times slower.")
+ else:
+ service_logger.info("No CUDA-compatible GPU detected. Switching to CPU.")
+ return total_free_memory_in_system > 3000
+
+
+def get_model_configuration():
+ parser = default_argument_parser()
+ args, unknown = parser.parse_known_args()
+ args.config_file = join(SRC_PATH, "adapters", "ml", "vgt", "model_configuration", "doclaynet_VGT_cascade_PTM.yaml")
+ args.eval_only = True
+ args.num_gpus = 1
+ args.opts = [
+ "MODEL.WEIGHTS",
+ join(ROOT_PATH, "models", "doclaynet_VGT_model.pth"),
+ "OUTPUT_DIR",
+ join(ROOT_PATH, "model_output_doclaynet"),
+ ]
+ args.debug = False
+
+ configuration = get_cfg()
+ add_vit_config(configuration)
+ configuration.merge_from_file(args.config_file)
+ configuration.merge_from_list(args.opts)
+ configuration.MODEL.DEVICE = "cuda" if is_gpu_available() else "cpu"
+ configuration.freeze()
+ default_setup(configuration, args)
+
+ return configuration
diff --git a/src/adapters/ml/vgt/get_most_probable_pdf_segments.py b/src/adapters/ml/vgt/get_most_probable_pdf_segments.py
new file mode 100644
index 0000000000000000000000000000000000000000..15eb0fc1c70ac4a153f46311a57d7bfb7768a1d3
--- /dev/null
+++ b/src/adapters/ml/vgt/get_most_probable_pdf_segments.py
@@ -0,0 +1,141 @@
+import json
+import pickle
+from os.path import join
+from pathlib import Path
+from statistics import mode
+
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfFeatures
+from pdf_features import PdfToken
+from pdf_features import Rectangle
+from pdf_token_type_labels import TokenType
+from domain.PdfImages import PdfImages
+from configuration import ROOT_PATH, DOCLAYNET_TYPE_BY_ID
+from domain.Prediction import Prediction
+
+
+def get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict):
+ pdf_name = images_names[annotation["image_id"]][:-4]
+ category_id = annotation["category_id"]
+ bounding_box = Rectangle.from_width_height(
+ left=int(annotation["bbox"][0]),
+ top=int(annotation["bbox"][1]),
+ width=int(annotation["bbox"][2]),
+ height=int(annotation["bbox"][3]),
+ )
+
+ prediction = Prediction(
+ bounding_box=bounding_box, category_id=category_id, score=round(float(annotation["score"]) * 100, 2)
+ )
+ vgt_predictions_dict.setdefault(pdf_name, list()).append(prediction)
+
+
+def get_vgt_predictions(model_name: str) -> dict[str, list[Prediction]]:
+ output_dir: str = f"model_output_{model_name}"
+ model_output_json_path = join(str(ROOT_PATH), output_dir, "inference", "coco_instances_results.json")
+ annotations = json.loads(Path(model_output_json_path).read_text())
+
+ test_json_path = join(str(ROOT_PATH), "jsons", "test.json")
+ coco_truth = json.loads(Path(test_json_path).read_text())
+
+ images_names = {value["id"]: value["file_name"] for value in coco_truth["images"]}
+
+ vgt_predictions_dict = dict()
+ for annotation in annotations:
+ get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict)
+
+ return vgt_predictions_dict
+
+
+def find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions):
+ best_score: float = 0
+ most_probable_prediction: Prediction | None = None
+ for prediction in vgt_predictions_dict[page_pdf_name]:
+ if prediction.score > best_score and prediction.bounding_box.get_intersection_percentage(token.bounding_box):
+ best_score = prediction.score
+ most_probable_prediction = prediction
+ if best_score >= 99:
+ break
+ if most_probable_prediction:
+ most_probable_tokens_by_predictions.setdefault(most_probable_prediction, list()).append(token)
+ else:
+ dummy_prediction = Prediction(bounding_box=token.bounding_box, category_id=10, score=0.0)
+ most_probable_tokens_by_predictions.setdefault(dummy_prediction, list()).append(token)
+
+
+def get_merged_prediction_type(to_merge: list[Prediction]):
+ table_exists = any([p.category_id == 9 for p in to_merge])
+ if not table_exists:
+ return mode([p.category_id for p in sorted(to_merge, key=lambda x: -x.score)])
+ return 9
+
+
+def merge_colliding_predictions(predictions: list[Prediction]):
+ predictions = [p for p in predictions if not p.score < 20]
+ while True:
+ new_predictions, merged = [], False
+ while predictions:
+ p1 = predictions.pop(0)
+ to_merge = [p for p in predictions if p1.bounding_box.get_intersection_percentage(p.bounding_box) > 0]
+ for prediction in to_merge:
+ predictions.remove(prediction)
+ if to_merge:
+ to_merge.append(p1)
+ p1.bounding_box = Rectangle.merge_rectangles([prediction.bounding_box for prediction in to_merge])
+ p1.category_id = get_merged_prediction_type(to_merge)
+ merged = True
+ new_predictions.append(p1)
+ if not merged:
+ return new_predictions
+ predictions = new_predictions
+
+
+def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dict):
+ most_probable_pdf_segments_for_page: list[PdfSegment] = []
+ most_probable_tokens_by_predictions: dict[Prediction, list[PdfToken]] = {}
+ vgt_predictions_dict[page_pdf_name] = merge_colliding_predictions(vgt_predictions_dict[page_pdf_name])
+
+ for token in page.tokens:
+ find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions)
+
+ for prediction, tokens in most_probable_tokens_by_predictions.items():
+ new_segment = PdfSegment.from_pdf_tokens(tokens, pdf_name)
+ new_segment.bounding_box = prediction.bounding_box
+ new_segment.segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id])
+ most_probable_pdf_segments_for_page.append(new_segment)
+
+ no_token_predictions = [
+ prediction
+ for prediction in vgt_predictions_dict[page_pdf_name]
+ if prediction not in most_probable_tokens_by_predictions
+ ]
+
+ for prediction in no_token_predictions:
+ segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id])
+ page_number = page.page_number
+ new_segment = PdfSegment(page_number, prediction.bounding_box, "", segment_type, pdf_name)
+ most_probable_pdf_segments_for_page.append(new_segment)
+
+ return most_probable_pdf_segments_for_page
+
+
+def prediction_exists_for_page(page_pdf_name, vgt_predictions_dict):
+ return page_pdf_name in vgt_predictions_dict
+
+
+def get_most_probable_pdf_segments(model_name: str, pdf_images_list: list[PdfImages], save_output: bool = False):
+ most_probable_pdf_segments: list[PdfSegment] = []
+ vgt_predictions_dict = get_vgt_predictions(model_name)
+ pdf_features_list: list[PdfFeatures] = [pdf_images.pdf_features for pdf_images in pdf_images_list]
+ for pdf_features in pdf_features_list:
+ for page in pdf_features.pages:
+ page_pdf_name = pdf_features.file_name + "_" + str(page.page_number - 1)
+ if not prediction_exists_for_page(page_pdf_name, vgt_predictions_dict):
+ continue
+ page_segments = get_pdf_segments_for_page(page, pdf_features.file_name, page_pdf_name, vgt_predictions_dict)
+ most_probable_pdf_segments.extend(page_segments)
+ if save_output:
+ save_path = join(ROOT_PATH, f"model_output_{model_name}", "predicted_segments.pickle")
+ with open(save_path, mode="wb") as file:
+ pickle.dump(most_probable_pdf_segments, file)
+ return most_probable_pdf_segments
diff --git a/src/adapters/ml/vgt/get_reading_orders.py b/src/adapters/ml/vgt/get_reading_orders.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8bbfdd8427cdc72b590f9198a7a684320bcd12
--- /dev/null
+++ b/src/adapters/ml/vgt/get_reading_orders.py
@@ -0,0 +1,88 @@
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfPage
+from pdf_features import PdfToken
+from pdf_token_type_labels import TokenType
+
+from domain.PdfImages import PdfImages
+
+
+def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments):
+ best_score: float = 0
+ most_probable_segment: PdfSegment | None = None
+ for segment in segments:
+ intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box)
+ if intersection_percentage > best_score:
+ best_score = intersection_percentage
+ most_probable_segment = segment
+ if best_score >= 99:
+ break
+ if most_probable_segment:
+ tokens_by_segments.setdefault(most_probable_segment, list()).append(token)
+
+
+def get_average_reading_order_for_segment(page: PdfPage, tokens_for_segment: list[PdfToken]):
+ reading_order_sum: int = sum(page.tokens.index(token) for token in tokens_for_segment)
+ return reading_order_sum / len(tokens_for_segment)
+
+
+def get_distance_between_segments(segment1: PdfSegment, segment2: PdfSegment):
+ center_1_x = (segment1.bounding_box.left + segment1.bounding_box.right) / 2
+ center_1_y = (segment1.bounding_box.top + segment1.bounding_box.bottom) / 2
+ center_2_x = (segment2.bounding_box.left + segment2.bounding_box.right) / 2
+ center_2_y = (segment2.bounding_box.top + segment2.bounding_box.bottom) / 2
+ return ((center_1_x - center_2_x) ** 2 + (center_1_y - center_2_y) ** 2) ** 0.5
+
+
+def add_no_token_segments(segments, no_token_segments):
+ if segments:
+ for no_token_segment in no_token_segments:
+ closest_segment = sorted(segments, key=lambda seg: get_distance_between_segments(no_token_segment, seg))[0]
+ closest_index = segments.index(closest_segment)
+ if closest_segment.bounding_box.top < no_token_segment.bounding_box.top:
+ segments.insert(closest_index + 1, no_token_segment)
+ else:
+ segments.insert(closest_index, no_token_segment)
+ else:
+ for segment in sorted(no_token_segments, key=lambda r: (r.bounding_box.left, r.bounding_box.top)):
+ segments.append(segment)
+
+
+def filter_and_sort_segments(page, tokens_by_segments, types):
+ filtered_segments = [seg for seg in tokens_by_segments.keys() if seg.segment_type in types]
+ order = {seg: get_average_reading_order_for_segment(page, tokens_by_segments[seg]) for seg in filtered_segments}
+ return sorted(filtered_segments, key=lambda seg: order[seg])
+
+
+def get_ordered_segments_for_page(segments_for_page: list[PdfSegment], page: PdfPage):
+ tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {}
+ for token in page.tokens:
+ find_segment_for_token(token, segments_for_page, tokens_by_segments)
+
+ page_number_segment: None | PdfSegment = None
+ if tokens_by_segments:
+ last_segment = max(tokens_by_segments.keys(), key=lambda seg: seg.bounding_box.top)
+ if last_segment.text_content and len(last_segment.text_content) < 5:
+ page_number_segment = last_segment
+ del tokens_by_segments[last_segment]
+
+ header_segments: list[PdfSegment] = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_HEADER})
+ paragraph_types = {t for t in TokenType if t.name not in {"PAGE_HEADER", "PAGE_FOOTER", "FOOTNOTE"}}
+ paragraph_segments = filter_and_sort_segments(page, tokens_by_segments, paragraph_types)
+ footer_segments = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_FOOTER, TokenType.FOOTNOTE})
+ if page_number_segment:
+ footer_segments.append(page_number_segment)
+ ordered_segments = header_segments + paragraph_segments + footer_segments
+ no_token_segments = [segment for segment in segments_for_page if segment not in ordered_segments]
+ add_no_token_segments(ordered_segments, no_token_segments)
+ return ordered_segments
+
+
+def get_reading_orders(pdf_images_list: list[PdfImages], predicted_segments: list[PdfSegment]):
+ ordered_segments: list[PdfSegment] = []
+ for pdf_images in pdf_images_list:
+ pdf_name = pdf_images.pdf_features.file_name
+ segments_for_file = [segment for segment in predicted_segments if segment.pdf_name == pdf_name]
+ for page in pdf_images.pdf_features.pages:
+ segments_for_page = [segment for segment in segments_for_file if segment.page_number == page.page_number]
+ ordered_segments.extend(get_ordered_segments_for_page(segments_for_page, page))
+ return ordered_segments
diff --git a/src/adapters/ml/vgt/model_configuration/Base-RCNN-FPN.yaml b/src/adapters/ml/vgt/model_configuration/Base-RCNN-FPN.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9f8cd0b54430ebcdbd5a8e03133432f4dbf38b3c
--- /dev/null
+++ b/src/adapters/ml/vgt/model_configuration/Base-RCNN-FPN.yaml
@@ -0,0 +1,70 @@
+MODEL:
+ MASK_ON: True
+ META_ARCHITECTURE: "GeneralizedRCNN"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ BACKBONE:
+ NAME: "build_vit_fpn_backbone"
+ VIT:
+ OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+ DROP_PATH: 0.1
+ IMG_SIZE: [224,224]
+ POS_TYPE: "abs"
+ FPN:
+ IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
+ ANCHOR_GENERATOR:
+ SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
+ RPN:
+ IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
+ PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
+ PRE_NMS_TOPK_TEST: 1000 # Per FPN level
+ # Detectron1 uses 2000 proposals per-batch,
+ # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
+ # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
+ POST_NMS_TOPK_TRAIN: 1000
+ POST_NMS_TOPK_TEST: 1000
+ ROI_HEADS:
+ NAME: "StandardROIHeads"
+ IN_FEATURES: ["p2", "p3", "p4", "p5"]
+ NUM_CLASSES: 5
+ ROI_BOX_HEAD:
+ NAME: "FastRCNNConvFCHead"
+ NUM_FC: 2
+ POOLER_RESOLUTION: 7
+ ROI_MASK_HEAD:
+ NAME: "MaskRCNNConvUpsampleHead"
+ NUM_CONV: 4
+ POOLER_RESOLUTION: 14
+DATASETS:
+ TRAIN: ("docbank_train",)
+ TEST: ("docbank_val",)
+SOLVER:
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
+ AMP:
+ ENABLED: True
+ OPTIMIZER: "ADAMW"
+ BACKBONE_MULTIPLIER: 1.0
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 1.0
+ NORM_TYPE: 2.0
+ WARMUP_FACTOR: 0.01
+ BASE_LR: 0.0004
+ WEIGHT_DECAY: 0.05
+ IMS_PER_BATCH: 32
+INPUT:
+ CROP:
+ ENABLED: True
+ TYPE: "absolute_range"
+ SIZE: (384, 600)
+ MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
+ FORMAT: "RGB"
+DATALOADER:
+ NUM_WORKERS: 6
+ FILTER_EMPTY_ANNOTATIONS: False
+VERSION: 2
+AUG:
+ DETR: True
+SEED: 42
\ No newline at end of file
diff --git a/src/adapters/ml/vgt/model_configuration/doclaynet_VGT_cascade_PTM.yaml b/src/adapters/ml/vgt/model_configuration/doclaynet_VGT_cascade_PTM.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..57fbfdd918333f872dddbb6858291d443e61ffba
--- /dev/null
+++ b/src/adapters/ml/vgt/model_configuration/doclaynet_VGT_cascade_PTM.yaml
@@ -0,0 +1,41 @@
+DATASETS:
+ TEST: ("predict_data",)
+ TRAIN: ("train_data",)
+MODEL:
+ BACKBONE:
+ NAME: build_VGT_fpn_backbone
+ MASK_ON: false
+ META_ARCHITECTURE: VGT
+ PIXEL_MEAN:
+ - 127.5
+ - 127.5
+ - 127.5
+ PIXEL_STD:
+ - 127.5
+ - 127.5
+ - 127.5
+ ROI_BOX_HEAD:
+ CLS_AGNOSTIC_BBOX_REG: true
+ ROI_HEADS:
+ NAME: CascadeROIHeads
+ NUM_CLASSES: 11
+ RPN:
+ POST_NMS_TOPK_TRAIN: 2000
+ VIT:
+ MERGE_TYPE: Sum
+ NAME: VGT_dit_base_patch16
+ WEIGHTS: https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth
+ WORDGRID:
+ EMBEDDING_DIM: 64
+ MODEL_PATH: ../models/layoutlm-base-uncased/
+ USE_PRETRAIN_WEIGHT: true
+ VOCAB_SIZE: 30552
+SOLVER:
+ BASE_LR: 0.0002
+ IMS_PER_BATCH: 12
+ MAX_ITER: 10000
+ STEPS: (6000, 8000)
+ WARMUP_ITERS: 100
+TEST:
+ EVAL_PERIOD: 2000
+_BASE_: ./Base-RCNN-FPN.yaml
diff --git a/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle b/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle
new file mode 100644
index 0000000000000000000000000000000000000000..7c9d3b094d887f32252a829aea270c97d0876985
--- /dev/null
+++ b/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e26d79332727c133aca47b8c79c11158679889e3c0c76ac498767a73dbd083d1
+size 5668
diff --git a/src/adapters/ml/vgt_model_adapter.py b/src/adapters/ml/vgt_model_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..67d466269d30bd18ed014244882f313846b4fddf
--- /dev/null
+++ b/src/adapters/ml/vgt_model_adapter.py
@@ -0,0 +1,45 @@
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+from ports.services.ml_model_service import MLModelService
+from adapters.ml.vgt.ditod import VGTTrainer
+from adapters.ml.vgt.get_model_configuration import get_model_configuration
+from adapters.ml.vgt.get_most_probable_pdf_segments import get_most_probable_pdf_segments
+from adapters.ml.vgt.get_reading_orders import get_reading_orders
+from adapters.ml.vgt.get_json_annotations import get_annotations
+from adapters.ml.vgt.create_word_grid import create_word_grid, remove_word_grids
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.data.datasets import register_coco_instances
+from detectron2.data import DatasetCatalog
+from configuration import JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH
+
+configuration = get_model_configuration()
+model = VGTTrainer.build_model(configuration)
+DetectionCheckpointer(model, save_dir=configuration.OUTPUT_DIR).resume_or_load(configuration.MODEL.WEIGHTS, resume=True)
+
+
+class VGTModelAdapter(MLModelService):
+
+ def _register_data(self) -> None:
+ try:
+ DatasetCatalog.remove("predict_data")
+ except KeyError:
+ pass
+
+ register_coco_instances("predict_data", {}, JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH)
+
+ def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ create_word_grid([pdf_images_obj.pdf_features for pdf_images_obj in pdf_images])
+ get_annotations(pdf_images)
+
+ self._register_data()
+ VGTTrainer.test(configuration, model)
+
+ predicted_segments = get_most_probable_pdf_segments("doclaynet", pdf_images, False)
+
+ PdfImages.remove_images()
+ remove_word_grids()
+
+ return get_reading_orders(pdf_images, predicted_segments)
+
+ def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ raise NotImplementedError("Fast prediction should be handled by FastTrainerAdapter")
diff --git a/src/adapters/storage/__init__.py b/src/adapters/storage/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/storage/file_system_repository.py b/src/adapters/storage/file_system_repository.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1051aebe01debdd5b0d1987f253bcb0a87eb15
--- /dev/null
+++ b/src/adapters/storage/file_system_repository.py
@@ -0,0 +1,56 @@
+import tempfile
+import uuid
+from pathlib import Path
+from typing import AnyStr
+from ports.repositories.file_repository import FileRepository
+from configuration import XMLS_PATH
+
+
+class FileSystemRepository(FileRepository):
+ def save_pdf(self, content: AnyStr, filename: str = "") -> Path:
+ if not filename:
+ filename = str(uuid.uuid1())
+
+ pdf_path = Path(tempfile.gettempdir(), f"{filename}.pdf")
+ pdf_path.write_bytes(content)
+ return pdf_path
+
+ def save_xml(self, content: str, filename: str) -> Path:
+ if not filename.endswith(".xml"):
+ filename = f"{filename}.xml"
+
+ xml_path = Path(XMLS_PATH, filename)
+ xml_path.parent.mkdir(parents=True, exist_ok=True)
+ xml_path.write_text(content)
+ return xml_path
+
+ def get_xml(self, filename: str) -> str:
+ if not filename.endswith(".xml"):
+ filename = f"{filename}.xml"
+
+ xml_path = Path(XMLS_PATH, filename)
+ if not xml_path.exists():
+ raise FileNotFoundError(f"XML file {filename} not found")
+
+ return xml_path.read_text()
+
+ def delete_file(self, filepath: Path) -> None:
+ filepath.unlink(missing_ok=True)
+
+ def cleanup_temp_files(self) -> None:
+ pass
+
+ def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path:
+ if namespace:
+ target_path = Path(directory, namespace, filename)
+ else:
+ target_path = Path(directory, filename)
+
+ target_path.parent.mkdir(parents=True, exist_ok=True)
+ target_path.write_bytes(content)
+ return target_path
+
+ def save_markdown(self, content: str, filepath: Path) -> Path:
+ filepath.parent.mkdir(parents=True, exist_ok=True)
+ filepath.write_text(content, encoding="utf-8")
+ return filepath
diff --git a/src/adapters/web/__init__.py b/src/adapters/web/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/adapters/web/fastapi_controllers.py b/src/adapters/web/fastapi_controllers.py
new file mode 100644
index 0000000000000000000000000000000000000000..faaf8f896b90f7bde532fb22258d15619174a949
--- /dev/null
+++ b/src/adapters/web/fastapi_controllers.py
@@ -0,0 +1,120 @@
+import sys
+import subprocess
+from fastapi import UploadFile, File, Form
+from typing import Optional, Union
+from starlette.responses import Response
+from starlette.concurrency import run_in_threadpool
+from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase
+from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase
+from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase
+from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase
+from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase
+from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase
+from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase
+from adapters.storage.file_system_repository import FileSystemRepository
+
+
+class FastAPIControllers:
+ def __init__(
+ self,
+ analyze_pdf_use_case: AnalyzePDFUseCase,
+ extract_text_use_case: ExtractTextUseCase,
+ extract_toc_use_case: ExtractTOCUseCase,
+ create_visualization_use_case: CreateVisualizationUseCase,
+ process_ocr_use_case: ProcessOCRUseCase,
+ convert_to_markdown_use_case: ConvertToMarkdownUseCase,
+ convert_to_html_use_case: ConvertToHtmlUseCase,
+ file_repository: FileSystemRepository,
+ ):
+ self.analyze_pdf_use_case = analyze_pdf_use_case
+ self.extract_text_use_case = extract_text_use_case
+ self.extract_toc_use_case = extract_toc_use_case
+ self.create_visualization_use_case = create_visualization_use_case
+ self.process_ocr_use_case = process_ocr_use_case
+ self.convert_to_markdown_use_case = convert_to_markdown_use_case
+ self.convert_to_html_use_case = convert_to_html_use_case
+ self.file_repository = file_repository
+
+ async def root(self):
+ import torch
+
+ return sys.version + " Using GPU: " + str(torch.cuda.is_available())
+
+ async def info(self):
+ return {
+ "sys": sys.version,
+ "tesseract_version": subprocess.run("tesseract --version", shell=True, text=True, capture_output=True).stdout,
+ "ocrmypdf_version": subprocess.run("ocrmypdf --version", shell=True, text=True, capture_output=True).stdout,
+ "supported_languages": self.process_ocr_use_case.get_supported_languages(),
+ }
+
+ async def error(self):
+ raise FileNotFoundError("This is a test error from the error endpoint")
+
+ async def analyze_pdf(
+ self, file: UploadFile = File(...), fast: bool = Form(False), parse_tables_and_math: bool = Form(False)
+ ):
+ return await run_in_threadpool(
+ self.analyze_pdf_use_case.execute, file.file.read(), "", parse_tables_and_math, fast, False
+ )
+
+ async def analyze_and_save_xml(
+ self, file: UploadFile = File(...), xml_file_name: str | None = None, fast: bool = Form(False)
+ ):
+ if not xml_file_name.endswith(".xml"):
+ xml_file_name = f"{xml_file_name}.xml"
+ return await run_in_threadpool(self.analyze_pdf_use_case.execute_and_save_xml, file.file.read(), xml_file_name, fast)
+
+ async def get_xml_by_name(self, xml_file_name: str):
+ if not xml_file_name.endswith(".xml"):
+ xml_file_name = f"{xml_file_name}.xml"
+ return await run_in_threadpool(self.file_repository.get_xml, xml_file_name)
+
+ async def get_toc_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)):
+ return await run_in_threadpool(self.extract_toc_use_case.execute, file, fast)
+
+ async def toc_legacy_uwazi_compatible(self, file: UploadFile = File(...)):
+ return await run_in_threadpool(self.extract_toc_use_case.execute_uwazi_compatible, file)
+
+ async def get_text_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False), types: str = Form("all")):
+ return await run_in_threadpool(self.extract_text_use_case.execute, file, fast, types)
+
+ async def get_visualization_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)):
+ return await run_in_threadpool(self.create_visualization_use_case.execute, file, fast)
+
+ async def ocr_pdf_sync(self, file: UploadFile = File(...), language: str = Form("en")):
+ return await run_in_threadpool(self.process_ocr_use_case.execute, file, language)
+
+ async def convert_to_markdown_endpoint(
+ self,
+ file: UploadFile = File(...),
+ fast: bool = Form(False),
+ extract_toc: bool = Form(False),
+ dpi: int = Form(120),
+ output_file: Optional[str] = Form(None),
+ ) -> Union[str, Response]:
+ return await run_in_threadpool(
+ self.convert_to_markdown_use_case.execute,
+ file.file.read(),
+ fast,
+ extract_toc,
+ dpi,
+ output_file,
+ )
+
+ async def convert_to_html_endpoint(
+ self,
+ file: UploadFile = File(...),
+ fast: bool = Form(False),
+ extract_toc: bool = Form(False),
+ dpi: int = Form(120),
+ output_file: Optional[str] = Form(None),
+ ) -> Union[str, Response]:
+ return await run_in_threadpool(
+ self.convert_to_html_use_case.execute,
+ file.file.read(),
+ fast,
+ extract_toc,
+ dpi,
+ output_file,
+ )
diff --git a/src/app.py b/src/app.py
new file mode 100755
index 0000000000000000000000000000000000000000..bf186036cf50e6d10c82a7b4996512d280714772
--- /dev/null
+++ b/src/app.py
@@ -0,0 +1,12 @@
+from configuration import RESTART_IF_NO_GPU
+from drivers.web.fastapi_app import create_app
+from drivers.web.dependency_injection import setup_dependencies
+import torch
+
+if RESTART_IF_NO_GPU:
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU available. Restarting the service is required.")
+
+controllers = setup_dependencies()
+
+app = create_app(controllers)
diff --git a/src/catch_exceptions.py b/src/catch_exceptions.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a1524d9b8ceac7fb68106be6b8050fffc1c15ab
--- /dev/null
+++ b/src/catch_exceptions.py
@@ -0,0 +1,23 @@
+from functools import wraps
+from fastapi import HTTPException
+
+from configuration import service_logger
+
+
+def catch_exceptions(func):
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ try:
+ service_logger.info(f"Calling endpoint: {func.__name__}")
+ if kwargs and "file" in kwargs:
+ service_logger.info(f"Processing file: {kwargs['file'].filename}")
+ if kwargs and "xml_file_name" in kwargs:
+ service_logger.info(f"Asking for file: {kwargs['xml_file_name']}")
+ return await func(*args, **kwargs)
+ except FileNotFoundError:
+ raise HTTPException(status_code=404, detail="No xml file")
+ except Exception:
+ service_logger.error("Error see traceback", exc_info=1)
+ raise HTTPException(status_code=422, detail="Error see traceback")
+
+ return wrapper
diff --git a/src/configuration.py b/src/configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d988b16b6d218ec218a15ce9780ebd09d66d0f
--- /dev/null
+++ b/src/configuration.py
@@ -0,0 +1,37 @@
+import logging
+import os
+from pathlib import Path
+
+
+SRC_PATH = Path(__file__).parent.absolute()
+ROOT_PATH = Path(__file__).parent.parent.absolute()
+
+handlers = [logging.StreamHandler()]
+logging.root.handlers = []
+logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=handlers)
+service_logger = logging.getLogger(__name__)
+
+RESTART_IF_NO_GPU = os.environ.get("RESTART_IF_NO_GPU", "false").lower().strip() == "true"
+IMAGES_ROOT_PATH = Path(ROOT_PATH, "images")
+WORD_GRIDS_PATH = Path(ROOT_PATH, "word_grids")
+JSONS_ROOT_PATH = Path(ROOT_PATH, "jsons")
+OCR_SOURCE = Path(ROOT_PATH, "ocr", "source")
+OCR_OUTPUT = Path(ROOT_PATH, "ocr", "output")
+OCR_FAILED = Path(ROOT_PATH, "ocr", "failed")
+JSON_TEST_FILE_PATH = Path(JSONS_ROOT_PATH, "test.json")
+MODELS_PATH = Path(ROOT_PATH, "models")
+XMLS_PATH = Path(ROOT_PATH, "xmls")
+
+DOCLAYNET_TYPE_BY_ID = {
+ 1: "Caption",
+ 2: "Footnote",
+ 3: "Formula",
+ 4: "List_Item",
+ 5: "Page_Footer",
+ 6: "Page_Header",
+ 7: "Picture",
+ 8: "Section_Header",
+ 9: "Table",
+ 10: "Text",
+ 11: "Title",
+}
diff --git a/src/domain/PdfImages.py b/src/domain/PdfImages.py
new file mode 100644
index 0000000000000000000000000000000000000000..83e22c78eb692b5f27c48e619fdf04b20f208e5f
--- /dev/null
+++ b/src/domain/PdfImages.py
@@ -0,0 +1,55 @@
+import os
+import shutil
+
+import cv2
+import numpy as np
+from os import makedirs
+from os.path import join
+from pathlib import Path
+from PIL import Image
+from pdf2image import convert_from_path
+from pdf_features import PdfFeatures
+
+from src.configuration import IMAGES_ROOT_PATH, XMLS_PATH
+
+
+class PdfImages:
+ def __init__(self, pdf_features: PdfFeatures, pdf_images: list[Image], dpi: int = 72):
+ self.pdf_features: PdfFeatures = pdf_features
+ self.pdf_images: list[Image] = pdf_images
+ self.dpi: int = dpi
+ self.save_images()
+
+ def show_images(self, next_image_delay: int = 2):
+ for image_index, image in enumerate(self.pdf_images):
+ image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
+ cv2.imshow(f"Page: {image_index + 1}", image_np)
+ cv2.waitKey(next_image_delay * 1000)
+ cv2.destroyAllWindows()
+
+ def save_images(self):
+ makedirs(IMAGES_ROOT_PATH, exist_ok=True)
+ for image_index, image in enumerate(self.pdf_images):
+ image_name = f"{self.pdf_features.file_name}_{image_index}.jpg"
+ image.save(join(IMAGES_ROOT_PATH, image_name))
+
+ @staticmethod
+ def remove_images():
+ shutil.rmtree(IMAGES_ROOT_PATH)
+
+ @staticmethod
+ def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_file_name: str = "", dpi: int = 72):
+ xml_path = None if not xml_file_name else Path(XMLS_PATH, xml_file_name)
+
+ if xml_path and not xml_path.parent.exists():
+ os.makedirs(xml_path.parent, exist_ok=True)
+
+ pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, xml_path)
+
+ if pdf_name:
+ pdf_features.file_name = pdf_name
+ else:
+ pdf_name = Path(pdf_path).parent.name if Path(pdf_path).name == "document.pdf" else Path(pdf_path).stem
+ pdf_features.file_name = pdf_name
+ pdf_images = convert_from_path(pdf_path, dpi=dpi)
+ return PdfImages(pdf_features, pdf_images, dpi)
diff --git a/src/domain/PdfSegment.py b/src/domain/PdfSegment.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3eb34aa7eed5d6be756dd29e276d42a265686d1
--- /dev/null
+++ b/src/domain/PdfSegment.py
@@ -0,0 +1,24 @@
+from statistics import mode
+from pdf_features import PdfToken
+from pdf_features import Rectangle
+from pdf_token_type_labels import TokenType
+
+
+class PdfSegment:
+ def __init__(
+ self, page_number: int, bounding_box: Rectangle, text_content: str, segment_type: TokenType, pdf_name: str = ""
+ ):
+ self.page_number = page_number
+ self.bounding_box = bounding_box
+ self.text_content = text_content
+ self.segment_type = segment_type
+ self.pdf_name = pdf_name
+
+ @staticmethod
+ def from_pdf_tokens(pdf_tokens: list[PdfToken], pdf_name: str = ""):
+ text: str = " ".join([pdf_token.content for pdf_token in pdf_tokens])
+ bounding_boxes = [pdf_token.bounding_box for pdf_token in pdf_tokens]
+ segment_type = mode([token.token_type for token in pdf_tokens])
+ return PdfSegment(
+ pdf_tokens[0].page_number, Rectangle.merge_rectangles(bounding_boxes), text, segment_type, pdf_name
+ )
diff --git a/src/domain/Prediction.py b/src/domain/Prediction.py
new file mode 100644
index 0000000000000000000000000000000000000000..59cd061bf8b6acccdc5a313a2fa2d0e90c90af8f
--- /dev/null
+++ b/src/domain/Prediction.py
@@ -0,0 +1,8 @@
+from pdf_features import Rectangle
+
+
+class Prediction:
+ def __init__(self, bounding_box: Rectangle, category_id: int, score: float):
+ self.bounding_box: Rectangle = bounding_box
+ self.category_id: int = category_id
+ self.score: float = score
diff --git a/src/domain/SegmentBox.py b/src/domain/SegmentBox.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddc3beb59f328a2278e813b1fdf5bcda999dd4bf
--- /dev/null
+++ b/src/domain/SegmentBox.py
@@ -0,0 +1,65 @@
+from domain.PdfSegment import PdfSegment
+from pdf_features import PdfPage
+from pdf_token_type_labels import TokenType
+from pydantic import BaseModel
+
+
+class SegmentBox(BaseModel):
+ left: float
+ top: float
+ width: float
+ height: float
+ page_number: int
+ page_width: int
+ page_height: int
+ text: str = ""
+ type: TokenType = TokenType.TEXT
+ id: str = ""
+
+ def __hash__(self):
+ return hash(
+ (
+ self.left,
+ self.top,
+ self.width,
+ self.height,
+ self.page_number,
+ self.page_width,
+ self.page_height,
+ self.text,
+ self.type,
+ self.id,
+ )
+ )
+
+ def to_dict(self):
+ return {
+ "left": self.left,
+ "top": self.top,
+ "width": self.width,
+ "height": self.height,
+ "page_number": self.page_number,
+ "page_width": self.page_width,
+ "page_height": self.page_height,
+ "text": self.text,
+ "type": self.type.value,
+ }
+
+ @staticmethod
+ def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage]):
+ return SegmentBox(
+ left=pdf_segment.bounding_box.left,
+ top=pdf_segment.bounding_box.top,
+ width=pdf_segment.bounding_box.width,
+ height=pdf_segment.bounding_box.height,
+ page_number=pdf_segment.page_number,
+ page_width=pdf_pages[pdf_segment.page_number - 1].page_width,
+ page_height=pdf_pages[pdf_segment.page_number - 1].page_height,
+ text=pdf_segment.text_content,
+ type=pdf_segment.segment_type,
+ )
+
+
+if __name__ == "__main__":
+ a = TokenType.TEXT
+ print(a.value)
diff --git a/src/download_models.py b/src/download_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..db86462d0640b7624dc1bb315984772b79fe1a60
--- /dev/null
+++ b/src/download_models.py
@@ -0,0 +1,64 @@
+import math
+from os import makedirs
+from os.path import join, exists
+from pathlib import Path
+from urllib.request import urlretrieve
+from huggingface_hub import snapshot_download, hf_hub_download
+
+from configuration import service_logger, MODELS_PATH
+
+
+def download_progress(count, block_size, total_size):
+ total_counts = total_size // block_size
+ show_counts_percentages = total_counts // 5
+ percent = count * block_size * 100 / total_size
+ if count % show_counts_percentages == 0:
+ service_logger.info(f"Downloaded {math.ceil(percent)}%")
+
+
+def download_vgt_model(model_name: str):
+ service_logger.info(f"Downloading {model_name} model")
+ model_path = join(MODELS_PATH, f"{model_name}_VGT_model.pth")
+ if exists(model_path):
+ return
+ download_link = f"https://github.com/AlibabaResearch/AdvancedLiterateMachinery/releases/download/v1.3.0-VGT-release/{model_name}_VGT_model.pth"
+ urlretrieve(download_link, model_path, reporthook=download_progress)
+
+
+def download_embedding_model():
+ model_path = join(MODELS_PATH, "layoutlm-base-uncased")
+ if exists(model_path):
+ return
+ makedirs(model_path, exist_ok=True)
+ service_logger.info("Embedding model is being downloaded")
+ snapshot_download(repo_id="microsoft/layoutlm-base-uncased", local_dir=model_path, local_dir_use_symlinks=False)
+
+
+def download_from_hf_hub(path: Path):
+ if path.exists():
+ return
+
+ file_name = path.name
+ makedirs(path.parent, exist_ok=True)
+ repo_id = "HURIDOCS/pdf-document-layout-analysis"
+ hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=path.parent, local_dir_use_symlinks=False)
+
+
+def download_lightgbm_models():
+ download_from_hf_hub(Path(MODELS_PATH, "token_type_lightgbm.model"))
+ download_from_hf_hub(Path(MODELS_PATH, "paragraph_extraction_lightgbm.model"))
+ download_from_hf_hub(Path(MODELS_PATH, "config.json"))
+
+
+def download_models(model_name: str):
+ makedirs(MODELS_PATH, exist_ok=True)
+ if model_name == "fast":
+ download_lightgbm_models()
+ return
+ download_vgt_model(model_name)
+ download_embedding_model()
+
+
+if __name__ == "__main__":
+ download_models("doclaynet")
+ download_models("fast")
diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/drivers/web/__init__.py b/src/drivers/web/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/drivers/web/dependency_injection.py b/src/drivers/web/dependency_injection.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06b87a8c58f60ca0a2e679d5f22a9956763f06c
--- /dev/null
+++ b/src/drivers/web/dependency_injection.py
@@ -0,0 +1,76 @@
+from adapters.storage.file_system_repository import FileSystemRepository
+from adapters.ml.vgt_model_adapter import VGTModelAdapter
+from adapters.ml.fast_trainer_adapter import FastTrainerAdapter
+from adapters.infrastructure.pdf_analysis_service_adapter import PDFAnalysisServiceAdapter
+from adapters.infrastructure.text_extraction_adapter import TextExtractionAdapter
+from adapters.infrastructure.toc_service_adapter import TOCServiceAdapter
+from adapters.infrastructure.visualization_service_adapter import VisualizationServiceAdapter
+from adapters.infrastructure.ocr_service_adapter import OCRServiceAdapter
+from adapters.infrastructure.format_conversion_service_adapter import FormatConversionServiceAdapter
+from adapters.infrastructure.markdown_conversion_service_adapter import MarkdownConversionServiceAdapter
+from adapters.infrastructure.html_conversion_service_adapter import HtmlConversionServiceAdapter
+from adapters.web.fastapi_controllers import FastAPIControllers
+from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase
+from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase
+from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase
+from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase
+from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase
+from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase
+from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase
+
+
+def setup_dependencies():
+ file_repository = FileSystemRepository()
+
+ vgt_model_service = VGTModelAdapter()
+ fast_model_service = FastTrainerAdapter()
+
+ format_conversion_service = FormatConversionServiceAdapter()
+ markdown_conversion_service = MarkdownConversionServiceAdapter()
+ html_conversion_service = HtmlConversionServiceAdapter()
+ text_extraction_service = TextExtractionAdapter()
+ toc_service = TOCServiceAdapter()
+ visualization_service = VisualizationServiceAdapter()
+ ocr_service = OCRServiceAdapter()
+
+ pdf_analysis_service = PDFAnalysisServiceAdapter(
+ vgt_model_service=vgt_model_service,
+ fast_model_service=fast_model_service,
+ format_conversion_service=format_conversion_service,
+ file_repository=file_repository,
+ )
+
+ analyze_pdf_use_case = AnalyzePDFUseCase(pdf_analysis_service=pdf_analysis_service, ml_model_service=vgt_model_service)
+
+ extract_text_use_case = ExtractTextUseCase(
+ pdf_analysis_service=pdf_analysis_service, text_extraction_service=text_extraction_service
+ )
+
+ extract_toc_use_case = ExtractTOCUseCase(pdf_analysis_service=pdf_analysis_service, toc_service=toc_service)
+
+ create_visualization_use_case = CreateVisualizationUseCase(
+ pdf_analysis_service=pdf_analysis_service, visualization_service=visualization_service
+ )
+
+ process_ocr_use_case = ProcessOCRUseCase(ocr_service=ocr_service, file_repository=file_repository)
+
+ convert_to_markdown_use_case = ConvertToMarkdownUseCase(
+ pdf_analysis_service=pdf_analysis_service, markdown_conversion_service=markdown_conversion_service
+ )
+
+ convert_to_html_use_case = ConvertToHtmlUseCase(
+ pdf_analysis_service=pdf_analysis_service, html_conversion_service=html_conversion_service
+ )
+
+ controllers = FastAPIControllers(
+ analyze_pdf_use_case=analyze_pdf_use_case,
+ extract_text_use_case=extract_text_use_case,
+ extract_toc_use_case=extract_toc_use_case,
+ create_visualization_use_case=create_visualization_use_case,
+ process_ocr_use_case=process_ocr_use_case,
+ convert_to_markdown_use_case=convert_to_markdown_use_case,
+ convert_to_html_use_case=convert_to_html_use_case,
+ file_repository=file_repository,
+ )
+
+ return controllers
diff --git a/src/drivers/web/fastapi_app.py b/src/drivers/web/fastapi_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..54f125797fa8ba9b1878a940f3120faf49baf661
--- /dev/null
+++ b/src/drivers/web/fastapi_app.py
@@ -0,0 +1,31 @@
+import torch
+from fastapi import FastAPI
+from fastapi.responses import PlainTextResponse
+from adapters.web.fastapi_controllers import FastAPIControllers
+from catch_exceptions import catch_exceptions
+from configuration import service_logger
+
+
+def create_app(controllers: FastAPIControllers) -> FastAPI:
+ service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}")
+
+ app = FastAPI()
+
+ app.get("/")(controllers.root)
+ app.get("/info")(controllers.info)
+ app.get("/error")(controllers.error)
+
+ app.post("/")(catch_exceptions(controllers.analyze_pdf))
+ app.post("/save_xml/{xml_file_name}")(catch_exceptions(controllers.analyze_and_save_xml))
+ app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)(catch_exceptions(controllers.get_xml_by_name))
+
+ app.post("/toc")(catch_exceptions(controllers.get_toc_endpoint))
+ app.post("/toc_legacy_uwazi_compatible")(catch_exceptions(controllers.toc_legacy_uwazi_compatible))
+
+ app.post("/text")(catch_exceptions(controllers.get_text_endpoint))
+ app.post("/visualize")(catch_exceptions(controllers.get_visualization_endpoint))
+ app.post("/markdown", response_model=None)(catch_exceptions(controllers.convert_to_markdown_endpoint))
+ app.post("/html", response_model=None)(catch_exceptions(controllers.convert_to_html_endpoint))
+ app.post("/ocr")(catch_exceptions(controllers.ocr_pdf_sync))
+
+ return app
diff --git a/src/ports/__init__.py b/src/ports/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/ports/repositories/__init__.py b/src/ports/repositories/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/ports/repositories/file_repository.py b/src/ports/repositories/file_repository.py
new file mode 100644
index 0000000000000000000000000000000000000000..31140359f339e26716857fcaec1f00df380b6420
--- /dev/null
+++ b/src/ports/repositories/file_repository.py
@@ -0,0 +1,33 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import AnyStr
+
+
+class FileRepository(ABC):
+ @abstractmethod
+ def save_pdf(self, content: AnyStr, filename: str = "") -> Path:
+ pass
+
+ @abstractmethod
+ def save_xml(self, content: str, filename: str) -> Path:
+ pass
+
+ @abstractmethod
+ def get_xml(self, filename: str) -> str:
+ pass
+
+ @abstractmethod
+ def delete_file(self, filepath: Path) -> None:
+ pass
+
+ @abstractmethod
+ def cleanup_temp_files(self) -> None:
+ pass
+
+ @abstractmethod
+ def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path:
+ pass
+
+ @abstractmethod
+ def save_markdown(self, content: str, filepath: Path) -> Path:
+ pass
diff --git a/src/ports/services/__init__.py b/src/ports/services/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/ports/services/format_conversion_service.py b/src/ports/services/format_conversion_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..d901ce32b1ba23818280cf4f5fbcc29244a40e87
--- /dev/null
+++ b/src/ports/services/format_conversion_service.py
@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+
+
+class FormatConversionService(ABC):
+
+ @abstractmethod
+ def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
+ pass
+
+ @abstractmethod
+ def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
+ pass
diff --git a/src/ports/services/html_conversion_service.py b/src/ports/services/html_conversion_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b23d3b6505e58bcc508e719a44abe0b546a0ea4
--- /dev/null
+++ b/src/ports/services/html_conversion_service.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+from starlette.responses import Response
+from domain.SegmentBox import SegmentBox
+
+
+class HtmlConversionService(ABC):
+
+ @abstractmethod
+ def convert_to_html(
+ self,
+ pdf_content: bytes,
+ segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ pass
diff --git a/src/ports/services/markdown_conversion_service.py b/src/ports/services/markdown_conversion_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b48aea45d51f0ca69e36336acd20e121853223e
--- /dev/null
+++ b/src/ports/services/markdown_conversion_service.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+from starlette.responses import Response
+from domain.SegmentBox import SegmentBox
+
+
+class MarkdownConversionService(ABC):
+
+ @abstractmethod
+ def convert_to_markdown(
+ self,
+ pdf_content: bytes,
+ segments: list[SegmentBox],
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ pass
diff --git a/src/ports/services/ml_model_service.py b/src/ports/services/ml_model_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..c119eaf898d3b095262a6a63f83e553ef56d9d18
--- /dev/null
+++ b/src/ports/services/ml_model_service.py
@@ -0,0 +1,13 @@
+from abc import ABC, abstractmethod
+from domain.PdfImages import PdfImages
+from domain.PdfSegment import PdfSegment
+
+
+class MLModelService(ABC):
+ @abstractmethod
+ def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ pass
+
+ @abstractmethod
+ def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
+ pass
diff --git a/src/ports/services/ocr_service.py b/src/ports/services/ocr_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c96471ad01e5cdf23f303fd82c5bd3d8050e5a6
--- /dev/null
+++ b/src/ports/services/ocr_service.py
@@ -0,0 +1,12 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+
+
+class OCRService(ABC):
+ @abstractmethod
+ def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path:
+ pass
+
+ @abstractmethod
+ def get_supported_languages(self) -> list[str]:
+ pass
diff --git a/src/ports/services/pdf_analysis_service.py b/src/ports/services/pdf_analysis_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..fda31a30a85711479a7b77c3392567f21e991dc1
--- /dev/null
+++ b/src/ports/services/pdf_analysis_service.py
@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+from typing import AnyStr
+
+
+class PDFAnalysisService(ABC):
+ @abstractmethod
+ def analyze_pdf_layout(
+ self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
+ ) -> list[dict]:
+ pass
+
+ @abstractmethod
+ def analyze_pdf_layout_fast(
+ self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
+ ) -> list[dict]:
+ pass
diff --git a/src/ports/services/text_extraction_service.py b/src/ports/services/text_extraction_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69c45acee2fbbc8a8f558269a99d30355c960ac
--- /dev/null
+++ b/src/ports/services/text_extraction_service.py
@@ -0,0 +1,12 @@
+from abc import ABC, abstractmethod
+from pdf_token_type_labels import TokenType
+
+
+class TextExtractionService(ABC):
+ @abstractmethod
+ def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict:
+ pass
+
+ @abstractmethod
+ def extract_all_text(self, segment_boxes: list[dict]) -> dict:
+ pass
diff --git a/src/ports/services/toc_service.py b/src/ports/services/toc_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1139e5cc29220e91113873efc42ae8a0520286f5
--- /dev/null
+++ b/src/ports/services/toc_service.py
@@ -0,0 +1,12 @@
+from abc import ABC, abstractmethod
+from typing import AnyStr
+
+
+class TOCService(ABC):
+ @abstractmethod
+ def extract_table_of_contents(self, pdf_content: AnyStr, segment_boxes: list[dict]) -> list[dict]:
+ pass
+
+ @abstractmethod
+ def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]:
+ pass
diff --git a/src/ports/services/visualization_service.py b/src/ports/services/visualization_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..492e0a67d0332782e747941e44035be48424b3f1
--- /dev/null
+++ b/src/ports/services/visualization_service.py
@@ -0,0 +1,13 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+from starlette.responses import FileResponse
+
+
+class VisualizationService(ABC):
+ @abstractmethod
+ def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path:
+ pass
+
+ @abstractmethod
+ def get_visualization_response(self, pdf_path: Path) -> FileResponse:
+ pass
diff --git a/src/tests/__init__.py b/src/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/tests/test_end_to_end.py b/src/tests/test_end_to_end.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1164b24976861bd0d5a0cc1682f3364fceb749
--- /dev/null
+++ b/src/tests/test_end_to_end.py
@@ -0,0 +1,405 @@
+from pathlib import Path
+import requests
+from unittest import TestCase
+from configuration import ROOT_PATH
+
+SRC_PATH = Path(__file__).parent.parent.parent
+
+
+class TestEndToEnd(TestCase):
+ service_url = "http://localhost:5060"
+
+ def test_info(self):
+ results = requests.get(f"{self.service_url}/info")
+
+ self.assertEqual(200, results.status_code)
+ self.assertIn("ko", results.json()["supported_languages"])
+ self.assertIn("kor-vert", results.json()["supported_languages"])
+ self.assertIn("ru", results.json()["supported_languages"])
+ self.assertIn("el", results.json()["supported_languages"])
+
+ def test_error_file(self):
+ with open(f"{ROOT_PATH}/test_pdfs/error.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(422, results.status_code)
+
+ def test_blank_pdf(self):
+ with open(f"{ROOT_PATH}/test_pdfs/blank.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(0, len(results.json()))
+
+ def test_segmentation_some_empty_pages(self):
+ with open(f"{ROOT_PATH}/test_pdfs/some_empty_pages.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(2, len(results.json()))
+
+ def test_image_pdfs(self):
+ with open(f"{ROOT_PATH}/test_pdfs/image.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(200, results.status_code)
+
+ def test_regular_pdf(self):
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ results_list = results.json()
+ expected_content = "RESOLUCIÓN DE LA CORTE INTERAMERICANA DE DERECHOS HUMANOS DEL 29 DE JULIO DE 1991"
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(expected_content, results_list[0]["text"])
+ self.assertEqual(157, results_list[0]["left"])
+ self.assertEqual(105, results_list[0]["top"])
+ self.assertEqual(282, results_list[0]["width"])
+ self.assertEqual(36, results_list[0]["height"])
+ self.assertEqual(1, results_list[0]["page_number"])
+ self.assertEqual(595, results_list[0]["page_width"])
+ self.assertEqual(842, results_list[0]["page_height"])
+ self.assertEqual("Section header", results_list[0]["type"])
+
+ def test_error_file_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/error.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(422, results.status_code)
+
+ def test_blank_pdf_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/blank.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(0, len(results.json()))
+
+ def test_segmentation_some_empty_pages_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/some_empty_pages.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(2, len(results.json()))
+
+ def test_image_pdfs_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/image.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(0, len(results.json()))
+
+ def test_regular_pdf_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ results_list = results.json()
+ expected_content = "RESOLUCIÓN DE LA CORTE INTERAMERICANA DE DERECHOS HUMANOS"
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(expected_content, results_list[0]["text"])
+ self.assertEqual(157, results_list[0]["left"])
+ self.assertEqual(106, results_list[0]["top"])
+ self.assertEqual(278, results_list[0]["width"])
+ self.assertEqual(24, results_list[0]["height"])
+ self.assertEqual(1, results_list[0]["page_number"])
+ self.assertEqual(595, results_list[0]["page_width"])
+ self.assertEqual(842, results_list[0]["page_height"])
+ self.assertEqual("Section header", results_list[0]["type"])
+
+ def test_save_xml_fast(self):
+ xml_name = "test_fast.xml"
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+ requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data)
+
+ result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}")
+ self.assertEqual(200, result_xml.status_code)
+ self.assertIsNotNone(result_xml.text)
+
+ def test_save_xml(self):
+ xml_name = "test.xml"
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "False"}
+ requests.post(f"{self.service_url}/save_xml/{xml_name}", files=files, data=data)
+
+ result_xml = requests.get(f"{self.service_url}/get_xml/{xml_name}")
+ self.assertEqual(200, result_xml.status_code)
+ self.assertIsNotNone(result_xml.text)
+
+ def test_korean(self):
+ with open(f"{ROOT_PATH}/test_pdfs/korean.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(200, results.status_code)
+
+ def test_chinese(self):
+ with open(f"{ROOT_PATH}/test_pdfs/chinese.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}", files=files)
+
+ self.assertEqual(200, results.status_code)
+
+ def test_korean_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/korean.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(200, results.status_code)
+
+ def test_chinese_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/chinese.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ results = requests.post(f"{self.service_url}", files=files, data=data)
+
+ self.assertEqual(200, results.status_code)
+
+ def test_toc(self):
+ with open(f"{ROOT_PATH}/test_pdfs/toc-test.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ response = requests.post(f"{self.service_url}/toc", files=files)
+
+ response_json = response.json()
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(len(response_json), 5)
+ self.assertEqual(response_json[0]["label"], "TEST")
+ self.assertEqual(response_json[0]["indentation"], 0)
+ self.assertEqual(response_json[-1]["label"], "C. TITLE LONGER")
+ self.assertEqual(response_json[-1]["indentation"], 2)
+
+ def test_toc_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/toc-test.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ response = requests.post(f"{self.service_url}/toc", files=files, data=data)
+
+ response_json = response.json()
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(len(response_json), 5)
+ self.assertEqual(response_json[0]["label"], "TEST")
+ self.assertEqual(response_json[0]["indentation"], 0)
+ self.assertEqual(response_json[-1]["label"], "C. TITLE LONGER")
+ self.assertEqual(response_json[-1]["indentation"], 2)
+
+ def test_text_extraction(self):
+ with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ response = requests.post(f"{self.service_url}/text", files=files)
+
+ response_json = response.json()
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response_json.split()[0], "Document")
+ self.assertEqual(response_json.split()[1], "Big")
+ self.assertEqual(response_json.split()[-1], "TEXT")
+
+ def test_text_extraction_fast(self):
+ with open(f"{ROOT_PATH}/test_pdfs/test.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"fast": "True"}
+
+ response = requests.post(f"{self.service_url}/text", files=files, data=data)
+
+ response_json = response.json()
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response_json.split()[0], "Document")
+ self.assertEqual(response_json.split()[1], "Big")
+ self.assertEqual(response_json.split()[-1], "TEXT")
+
+ def test_table_extraction(self):
+ with open(f"{ROOT_PATH}/test_pdfs/table.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"parse_tables_and_math": "true"}
+
+ response = requests.post(f"{self.service_url}", files=files, data=data)
+
+ response_json = response.json()
+ table_html = response_json[0]["text"]
+
+ parts = table_html.split("")
+ values = []
+ for part in parts[1:]:
+ value = part.split(" | ")[0]
+ values.append(value)
+
+ col1, col2, data_1a, data_1b, data_2a, data_2b = values
+
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("Column 1", col1)
+ self.assertIn("Column 2", col2)
+ self.assertIn("Data 1A", data_1a)
+ self.assertIn("Data 1B", data_1b)
+ self.assertIn("Data 2A", data_2a)
+ self.assertIn("Data 2B", data_2b)
+
+ def test_formula_extraction(self):
+ with open(f"{ROOT_PATH}/test_pdfs/formula.pdf", "rb") as stream:
+ files = {"file": stream}
+ data = {"parse_tables_and_math": "true"}
+
+ response = requests.post(f"{self.service_url}", files=files, data=data)
+
+ response_json = response.json()
+ formula_text = response_json[1]["text"]
+ self.assertEqual(response.status_code, 200)
+ self.assertIn("E_{p r i o r}", formula_text)
+ self.assertIn("(\\Theta)\\", formula_text)
+
+ def test_ocr_english(self):
+ with open(Path(ROOT_PATH, "test_pdfs", "ocr-sample-english.pdf"), "rb") as stream:
+ files = {"file": stream}
+ result_ocr = requests.post(f"{self.service_url}/ocr", files=files)
+ files = {"file": result_ocr.content}
+ results = requests.post(f"{self.service_url}", files=files)
+
+ results_list = results.json()
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(1, len(results_list))
+ self.assertEqual("Test text OCR", results_list[0]["text"])
+ self.assertEqual(248, results_list[0]["left"])
+ self.assertEqual(264, results_list[0]["top"])
+ self.assertEqual(313, results_list[0]["width"])
+ self.assertEqual(50, results_list[0]["height"])
+ self.assertEqual(1, results_list[0]["page_number"])
+ self.assertEqual(842, results_list[0]["page_width"])
+ self.assertEqual(595, results_list[0]["page_height"])
+ self.assertEqual("Section header", results_list[0]["type"])
+
+ def test_ocr_pdf_with_text(self):
+ with open(Path(ROOT_PATH, "test_pdfs", "ocr-sample-already-ocred.pdf"), "rb") as stream:
+ files = {"file": stream}
+ result_ocr = requests.post(f"{self.service_url}/ocr", files=files)
+ files = {"file": result_ocr.content}
+ results = requests.post(f"{self.service_url}", files=files)
+
+ results_list = results.json()
+ self.assertEqual(200, results.status_code)
+ self.assertEqual(2, len(results_list))
+ self.assertEqual("This is some real text", results_list[0]["text"])
+ self.assertEqual("Text", results_list[0]["type"])
+ self.assertEqual("This is some text in an image", results_list[1]["text"])
+ self.assertEqual("Text", results_list[1]["type"])
+
+ def test_ocr_failing(self):
+ with open(Path(ROOT_PATH, "test_pdfs", "not_a_pdf.pdf"), "rb") as stream:
+ files = {"file": stream}
+ result_ocr = requests.post(f"{self.service_url}/ocr", files=files)
+
+ self.assertEqual(500, result_ocr.status_code)
+
+ def test_html_extraction(self):
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}/html", files=files)
+
+ result = results.json()
+
+ span_elements = [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ]
+
+ heading_elements = [
+ "RESOLUCIÓN DE LA CORTE INTERAMERICANA DE DERECHOS HUMANOS DEL 29 DE JULIO DE 1991
",
+ "MEDIDAS PROVISIONALES SOLICITADAS POR LA COMISIÓN INTERAMERICANA DE DERECHOS HUMANOS RESPECTO DE GUATEMALA
",
+ "CASO CHUNIMA
",
+ "LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,
",
+ "VISTOS:
",
+ "CONSIDERANDO:
",
+ "POR TANTO:
",
+ "LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,
",
+ "RESUELVE:
",
+ ]
+
+ bold_elements = [
+ "RESOLUCIÓN DE LA",
+ "CORTE INTERAMERICANA DE DERECHOS HUMANOS",
+ "DEL 29 DE JULIO DE 1991",
+ "MEDIDAS PROVISIONALES SOLICITADAS POR LA COMISIÓN",
+ "INTERAMERICANA DE DERECHOS HUMANOS",
+ "RESPECTO DE GUATEMALA",
+ "CASO CHUNIMA",
+ "LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,",
+ "VISTOS:",
+ "CONSIDERANDO:",
+ "POR TANTO:",
+ "LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,",
+ "RESUELVE:",
+ ]
+
+ self.assertEqual(200, results.status_code)
+
+ for span_element in span_elements:
+ self.assertIn(span_element, result)
+
+ for heading_element in heading_elements:
+ self.assertIn(heading_element, result)
+
+ for bold_element in bold_elements:
+ self.assertIn(bold_element, result)
+
+ def test_markdown_extraction(self):
+ with open(f"{ROOT_PATH}/test_pdfs/regular.pdf", "rb") as stream:
+ files = {"file": stream}
+
+ results = requests.post(f"{self.service_url}/markdown", files=files)
+
+ heading_elements = [
+ "#### **RESOLUCIÓN DE LA** **CORTE INTERAMERICANA DE DERECHOS HUMANOS** **DEL 29 DE JULIO DE 1991**\n\n",
+ "#### **MEDIDAS PROVISIONALES SOLICITADAS POR LA COMISIÓN** **INTERAMERICANA DE DERECHOS HUMANOS** **RESPECTO DE GUATEMALA**\n\n",
+ "#### **CASO CHUNIMA**\n\n",
+ "#### **LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,**\n\n",
+ "#### **VISTOS:**\n\n",
+ "#### **CONSIDERANDO:**\n\n",
+ "#### **POR TANTO:**\n\n",
+ "#### **LA CORTE INTERAMERICANA DE DERECHOS HUMANOS,**\n\n",
+ "#### **RESUELVE:**\n\n",
+ ]
+
+ result = results.json()
+
+ self.assertEqual(200, results.status_code)
+
+ for heading_element in heading_elements:
+ self.assertIn(heading_element, result)
diff --git a/src/use_cases/__init__.py b/src/use_cases/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/html_conversion/__init__.py b/src/use_cases/html_conversion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/html_conversion/convert_to_html_use_case.py b/src/use_cases/html_conversion/convert_to_html_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..f620d7e411c4ecc1c40daec920edfd27c6a0229e
--- /dev/null
+++ b/src/use_cases/html_conversion/convert_to_html_use_case.py
@@ -0,0 +1,48 @@
+from typing import Optional, Union
+from starlette.responses import Response
+from ports.services.html_conversion_service import HtmlConversionService
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from domain.SegmentBox import SegmentBox
+
+
+class ConvertToHtmlUseCase:
+ def __init__(
+ self,
+ pdf_analysis_service: PDFAnalysisService,
+ html_conversion_service: HtmlConversionService,
+ ):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.html_conversion_service = html_conversion_service
+
+ def execute(
+ self,
+ pdf_content: bytes,
+ use_fast_mode: bool = False,
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ if use_fast_mode:
+ analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False)
+ else:
+ analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False)
+
+ segments: list[SegmentBox] = []
+ for item in analysis_result:
+ if isinstance(item, dict):
+ segment = SegmentBox(
+ left=item.get("left", 0),
+ top=item.get("top", 0),
+ width=item.get("width", 0),
+ height=item.get("height", 0),
+ page_number=item.get("page_number", 1),
+ page_width=item.get("page_width", 0),
+ page_height=item.get("page_height", 0),
+ text=item.get("text", ""),
+ type=item.get("type", "TEXT"),
+ )
+ segments.append(segment)
+ elif isinstance(item, SegmentBox):
+ segments.append(item)
+
+ return self.html_conversion_service.convert_to_html(pdf_content, segments, extract_toc, dpi, output_file)
diff --git a/src/use_cases/markdown_conversion/__init__.py b/src/use_cases/markdown_conversion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/markdown_conversion/convert_to_markdown_use_case.py b/src/use_cases/markdown_conversion/convert_to_markdown_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..71329bc1378874df04d385be8be2f13bf746995b
--- /dev/null
+++ b/src/use_cases/markdown_conversion/convert_to_markdown_use_case.py
@@ -0,0 +1,48 @@
+from typing import Optional, Union
+from starlette.responses import Response
+from ports.services.markdown_conversion_service import MarkdownConversionService
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from domain.SegmentBox import SegmentBox
+
+
+class ConvertToMarkdownUseCase:
+ def __init__(
+ self,
+ pdf_analysis_service: PDFAnalysisService,
+ markdown_conversion_service: MarkdownConversionService,
+ ):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.markdown_conversion_service = markdown_conversion_service
+
+ def execute(
+ self,
+ pdf_content: bytes,
+ use_fast_mode: bool = False,
+ extract_toc: bool = False,
+ dpi: int = 120,
+ output_file: Optional[str] = None,
+ ) -> Union[str, Response]:
+ if use_fast_mode:
+ analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False)
+ else:
+ analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False)
+
+ segments: list[SegmentBox] = []
+ for item in analysis_result:
+ if isinstance(item, dict):
+ segment = SegmentBox(
+ left=item.get("left", 0),
+ top=item.get("top", 0),
+ width=item.get("width", 0),
+ height=item.get("height", 0),
+ page_number=item.get("page_number", 1),
+ page_width=item.get("page_width", 0),
+ page_height=item.get("page_height", 0),
+ text=item.get("text", ""),
+ type=item.get("type", "TEXT"),
+ )
+ segments.append(segment)
+ elif isinstance(item, SegmentBox):
+ segments.append(item)
+
+ return self.markdown_conversion_service.convert_to_markdown(pdf_content, segments, extract_toc, dpi, output_file)
diff --git a/src/use_cases/ocr/__init__.py b/src/use_cases/ocr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/ocr/process_ocr_use_case.py b/src/use_cases/ocr/process_ocr_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1fe7147cab35ce42d044f8500b88c53f893132
--- /dev/null
+++ b/src/use_cases/ocr/process_ocr_use_case.py
@@ -0,0 +1,26 @@
+from pathlib import Path
+from fastapi import UploadFile
+from starlette.responses import FileResponse
+from ports.services.ocr_service import OCRService
+from ports.repositories.file_repository import FileRepository
+from configuration import OCR_SOURCE
+
+
+class ProcessOCRUseCase:
+ def __init__(self, ocr_service: OCRService, file_repository: FileRepository):
+ self.ocr_service = ocr_service
+ self.file_repository = file_repository
+
+ def execute(self, file: UploadFile, language: str = "en") -> FileResponse:
+ namespace = "sync_pdfs"
+
+ self.file_repository.save_pdf_to_directory(
+ content=file.file.read(), filename=file.filename, directory=Path(OCR_SOURCE), namespace=namespace
+ )
+
+ processed_pdf_filepath = self.ocr_service.process_pdf_ocr(file.filename, namespace, language)
+
+ return FileResponse(path=processed_pdf_filepath, media_type="application/pdf")
+
+ def get_supported_languages(self) -> list:
+ return self.ocr_service.get_supported_languages()
diff --git a/src/use_cases/pdf_analysis/__init__.py b/src/use_cases/pdf_analysis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/pdf_analysis/analyze_pdf_use_case.py b/src/use_cases/pdf_analysis/analyze_pdf_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2033e8c1e09e7a3e2a067636ea32298c1159a5d
--- /dev/null
+++ b/src/use_cases/pdf_analysis/analyze_pdf_use_case.py
@@ -0,0 +1,32 @@
+from typing import AnyStr
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from ports.services.ml_model_service import MLModelService
+
+
+class AnalyzePDFUseCase:
+ def __init__(
+ self,
+ pdf_analysis_service: PDFAnalysisService,
+ ml_model_service: MLModelService,
+ ):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.ml_model_service = ml_model_service
+
+ def execute(
+ self,
+ pdf_content: AnyStr,
+ xml_filename: str = "",
+ parse_tables_and_math: bool = False,
+ use_fast_mode: bool = False,
+ keep_pdf: bool = False,
+ ) -> list[dict]:
+ if use_fast_mode:
+ return self.pdf_analysis_service.analyze_pdf_layout_fast(
+ pdf_content, xml_filename, parse_tables_and_math, keep_pdf
+ )
+ else:
+ return self.pdf_analysis_service.analyze_pdf_layout(pdf_content, xml_filename, parse_tables_and_math, keep_pdf)
+
+ def execute_and_save_xml(self, pdf_content: AnyStr, xml_filename: str, use_fast_mode: bool = False) -> list[dict]:
+ result = self.execute(pdf_content, xml_filename, False, use_fast_mode, keep_pdf=False)
+ return result
diff --git a/src/use_cases/text_extraction/__init__.py b/src/use_cases/text_extraction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/text_extraction/extract_text_use_case.py b/src/use_cases/text_extraction/extract_text_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..06166c8c09012e71784815da8c3b81136a805245
--- /dev/null
+++ b/src/use_cases/text_extraction/extract_text_use_case.py
@@ -0,0 +1,25 @@
+from fastapi import UploadFile
+from pdf_token_type_labels import TokenType
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from ports.services.text_extraction_service import TextExtractionService
+
+
+class ExtractTextUseCase:
+ def __init__(self, pdf_analysis_service: PDFAnalysisService, text_extraction_service: TextExtractionService):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.text_extraction_service = text_extraction_service
+
+ def execute(self, file: UploadFile, use_fast_mode: bool = False, types: str = "all") -> dict:
+ file_content = file.file.read()
+
+ if types == "all":
+ token_types: list[TokenType] = [t for t in TokenType]
+ else:
+ token_types = list(set([TokenType.from_text(t.strip().replace(" ", "_")) for t in types.split(",")]))
+
+ if use_fast_mode:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content)
+ else:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "")
+
+ return self.text_extraction_service.extract_text_by_types(segment_boxes, token_types)
diff --git a/src/use_cases/toc_extraction/__init__.py b/src/use_cases/toc_extraction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/toc_extraction/extract_toc_use_case.py b/src/use_cases/toc_extraction/extract_toc_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..437cfd8145b0913cb21e85b6bf1a14082bd91d44
--- /dev/null
+++ b/src/use_cases/toc_extraction/extract_toc_use_case.py
@@ -0,0 +1,23 @@
+from fastapi import UploadFile
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from ports.services.toc_service import TOCService
+
+
+class ExtractTOCUseCase:
+ def __init__(self, pdf_analysis_service: PDFAnalysisService, toc_service: TOCService):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.toc_service = toc_service
+
+ def execute(self, file: UploadFile, use_fast_mode: bool = False) -> list[dict]:
+ file_content = file.file.read()
+
+ if use_fast_mode:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content)
+ else:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "")
+
+ return self.toc_service.extract_table_of_contents(file_content, segment_boxes)
+
+ def execute_uwazi_compatible(self, file: UploadFile) -> list[dict]:
+ toc_items = self.execute(file, use_fast_mode=True)
+ return self.toc_service.format_toc_for_uwazi(toc_items)
diff --git a/src/use_cases/visualization/__init__.py b/src/use_cases/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/use_cases/visualization/create_visualization_use_case.py b/src/use_cases/visualization/create_visualization_use_case.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f7318d3ff3e078edf3ecfb7796e8c6646cc691
--- /dev/null
+++ b/src/use_cases/visualization/create_visualization_use_case.py
@@ -0,0 +1,27 @@
+from fastapi import UploadFile
+from starlette.responses import FileResponse
+from ports.services.pdf_analysis_service import PDFAnalysisService
+from ports.services.visualization_service import VisualizationService
+from glob import glob
+from os.path import getctime, join
+from tempfile import gettempdir
+from pathlib import Path
+
+
+class CreateVisualizationUseCase:
+ def __init__(self, pdf_analysis_service: PDFAnalysisService, visualization_service: VisualizationService):
+ self.pdf_analysis_service = pdf_analysis_service
+ self.visualization_service = visualization_service
+
+ def execute(self, file: UploadFile, use_fast_mode: bool = False) -> FileResponse:
+ file_content = file.file.read()
+
+ if use_fast_mode:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content, "", "", True)
+ else:
+ segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "", "", True)
+
+ pdf_path = Path(max(glob(join(gettempdir(), "*.pdf")), key=getctime))
+ visualization_path = self.visualization_service.create_pdf_visualization(pdf_path, segment_boxes)
+
+ return self.visualization_service.get_visualization_response(visualization_path)
diff --git a/start.sh b/start.sh
new file mode 100755
index 0000000000000000000000000000000000000000..ea1f2d34a9f6f64510010521179abf9af24cfa9a
--- /dev/null
+++ b/start.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+gunicorn -k uvicorn.workers.UvicornWorker --chdir ./src app:app --bind 0.0.0.0:5060 --timeout 10000
\ No newline at end of file
diff --git a/test_pdfs/blank.pdf b/test_pdfs/blank.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..80e44a4e4f0d027b44d3951383f74c96b462eab2
--- /dev/null
+++ b/test_pdfs/blank.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:42851873b3d0687137259e415e114ab0649445159d96a8fe4dfeb3eb8c923072
+size 798
diff --git a/test_pdfs/chinese.pdf b/test_pdfs/chinese.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..a7bb956e45b2f2edf3d29203f4c06a1b7f75f7b6
--- /dev/null
+++ b/test_pdfs/chinese.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c18524785cb14c73299151c4b7c92741241310ab35b9465f526a70b101cb1048
+size 104119
diff --git a/test_pdfs/error.pdf b/test_pdfs/error.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..5c5f405a9d222ca9dc73ce55f19da7d5a0e9df62
--- /dev/null
+++ b/test_pdfs/error.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca00fccfb408989eddc401062c4d1219a6aceb6b9b55412357f1790862e8f178
+size 5
diff --git a/test_pdfs/formula.pdf b/test_pdfs/formula.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..05e058eca6f3c7f2432fd65c02221be9f283cec4
--- /dev/null
+++ b/test_pdfs/formula.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:959e2c673d48a9e36948281a71c2713e77882e31303201747f82db43d9d613ff
+size 38726
diff --git a/test_pdfs/image.pdf b/test_pdfs/image.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..7f398525f969e933bb304842a436a85b6a2a444c
--- /dev/null
+++ b/test_pdfs/image.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d5b449b2c2f91b81c306222352019d0bc6299fbb9a68041440cbde3ea2221e1
+size 353829
diff --git a/test_pdfs/korean.pdf b/test_pdfs/korean.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..634e83d4fa0fe3b5c41522979abccc901d7136b3
--- /dev/null
+++ b/test_pdfs/korean.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:24e41f246a060065aaff2e5ccbc7be8f4591f7cb75e9dfe6a1f138ea22b823f5
+size 198455
diff --git a/test_pdfs/not_a_pdf.pdf b/test_pdfs/not_a_pdf.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test_pdfs/ocr-sample-already-ocred.pdf b/test_pdfs/ocr-sample-already-ocred.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..3b18e550b17b53b543b673b302917d9aaeb61850
--- /dev/null
+++ b/test_pdfs/ocr-sample-already-ocred.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:97b92540609317cc4699ada6f40e06569ba424bb368d2f8c5efbe5e192a907b6
+size 15939
diff --git a/test_pdfs/ocr-sample-english.pdf b/test_pdfs/ocr-sample-english.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..c7262f81020d89391fc622639c8f096c86524a2e
--- /dev/null
+++ b/test_pdfs/ocr-sample-english.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:522a6295db6b2055f1506d1c94dbc00285fda1d00b3ce005f537dac2a9ad6dc0
+size 16960
diff --git a/test_pdfs/ocr-sample-french.pdf b/test_pdfs/ocr-sample-french.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..3bb97362f4c8955fb68496cd84ebdbe7917ac719
--- /dev/null
+++ b/test_pdfs/ocr-sample-french.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:477be3bb9203c17b0c90806b2949b15312c2badc910b74a27ea8144e01a136a2
+size 30560
diff --git a/test_pdfs/ocr_pdf.pdf b/test_pdfs/ocr_pdf.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..8098414f8feb915b63b86b7762ae146dca6c761a
--- /dev/null
+++ b/test_pdfs/ocr_pdf.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:16fcebc32f714ef1cf3d134ab47ca36f61c7af782b1808fb426346e2a6e3fa75
+size 125292
diff --git a/test_pdfs/regular.pdf b/test_pdfs/regular.pdf
new file mode 100755
index 0000000000000000000000000000000000000000..fd10d5e74a357939aeaafe0602f8cbd18562c9be
--- /dev/null
+++ b/test_pdfs/regular.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a672e07971dd45ce7770324ab314e29b3f106ad246379282ab661e3f8fc8b4f4
+size 38697
diff --git a/test_pdfs/some_empty_pages.pdf b/test_pdfs/some_empty_pages.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..0d29c669f76ccf72cfe4b173b2572e1705cd503f
--- /dev/null
+++ b/test_pdfs/some_empty_pages.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:77d5bfabdb40d3d6fc3c2b56fec4fd54cce427c4162550bf2b5054c330bfaa42
+size 14798
diff --git a/test_pdfs/table.pdf b/test_pdfs/table.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..56c72ff98245f565c6c934f73d12cda5d8512b52
--- /dev/null
+++ b/test_pdfs/table.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d05b3acdd25a2c18854ca3f3ac0b6307443cd34c4040a335cd24dbd520c2768d
+size 5224
diff --git a/test_pdfs/test.pdf b/test_pdfs/test.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..f2a4c6246b4f589002b0fe268805f8035ab52e01
--- /dev/null
+++ b/test_pdfs/test.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:095937acc2bd0f790c11b8a3e80a7e8d2c99e555dfb49edf3599a1533dcbc19c
+size 48752
diff --git a/test_pdfs/toc-test.pdf b/test_pdfs/toc-test.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..34f968f7023399ee3e33b6fe59f8de86671de466
--- /dev/null
+++ b/test_pdfs/toc-test.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e71fd8c9b27d9dd0050cb54893003b93aff327cd7cecfc5fa809057f322e711b
+size 44462
diff --git a/update_docx_with_pdf.py b/update_docx_with_pdf.py
index ffd7c7b923dbc13b4694675f612db54ce567df73..0b534c2c2628dd417b9e750438c5fd8ef7691565 100644
--- a/update_docx_with_pdf.py
+++ b/update_docx_with_pdf.py
@@ -1,177 +1,1470 @@
-from openai import OpenAI
+#!/usr/bin/env python3
+"""
+Enhanced NHVAS PDF to DOCX JSON Merger
+Comprehensive extraction and mapping from PDF to DOCX structure
+(keep pipeline intact; fix spacing, operator info mapping, vehicle-reg header mapping, date fallback)
+"""
import json
-import os
import re
+import sys
+from pathlib import Path
+from typing import Dict, List, Any, Optional
+from collections import OrderedDict # <-- add this
+
+
+def _nz(x):
+ return x if isinstance(x, str) and x.strip() else ""
+
+SUMMARY_SECTIONS = {
+ "MAINTENANCE MANAGEMENT": "Maintenance Management Summary",
+ "MASS MANAGEMENT": "Mass Management Summary",
+ "FATIGUE MANAGEMENT": "Fatigue Management Summary",
+}
+
+# ───────────────────────────── helpers: text cleanup & label matching ─────────────────────────────
+def _canon_header(s: str) -> str:
+ if not s: return ""
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ s = re.sub(r"[/]+", " / ", s)
+ s = re.sub(r"[^a-z0-9#/ ]+", " ", s)
+ s = re.sub(r"\s+", " ", s).strip()
+ return s
+
+
+# Header aliases -> internal keys we already use later during mapping
+_VEH_HEADER_ALIASES = {
+ # common
+ "registration number": "registration",
+ "reg no": "registration",
+ "reg.#": "registration",
+ "no.": "no",
+ "no": "no",
+
+ # maintenance table
+ "roadworthiness certificates": "roadworthiness",
+ "maintenance records": "maintenance_records",
+ "daily checks": "daily_checks",
+ "fault recording reporting": "fault_recording",
+ "fault recording / reporting": "fault_recording",
+ "fault repair": "fault_repair",
+
+ # mass table
+ "sub contractor": "sub_contractor",
+ "sub-contractor": "sub_contractor",
+ "sub contracted vehicles statement of compliance": "sub_comp",
+ "sub-contracted vehicles statement of compliance": "sub_comp",
+ "weight verification records": "weight_verification",
+ "rfs suspension certification #": "rfs_certification",
+ "rfs suspension certification number": "rfs_certification",
+ "suspension system maintenance": "suspension_maintenance",
+ "trip records": "trip_records",
+ "fault recording reporting on suspension system": "fault_reporting_suspension",
+ "fault recording / reporting on suspension system": "fault_reporting_suspension",
+}
+
+# --- helpers ---
+def build_vehicle_sections(extracted: dict) -> dict:
+ """Build arrays for Maintenance and Mass tables. Maintenance uses recorded rows to include ALL entries."""
+ maint = {
+ "Registration Number": [],
+ "Roadworthiness Certificates": [],
+ "Maintenance Records": [],
+ "Daily Checks": [],
+ "Fault Recording/ Reporting": [],
+ "Fault Repair": [],
+ }
+ mass = {
+ "Registration Number": [],
+ "Weight Verification Records": [],
+ "RFS Suspension Certification #": [],
+ "Suspension System Maintenance": [],
+ "Trip Records": [],
+ "Fault Recording/ Reporting on Suspension System": [],
+ }
+
+ # Prefer authoritative maintenance rows captured during parsing (spans all pages)
+ if extracted.get("_maint_rows"):
+ for row in extracted["_maint_rows"]:
+ maint["Registration Number"].append(_smart_space(row.get("registration", "")))
+ maint["Roadworthiness Certificates"].append(_nz(row.get("roadworthiness", "")))
+ maint["Maintenance Records"].append(_nz(row.get("maintenance_records", "")))
+ maint["Daily Checks"].append(_nz(row.get("daily_checks", "")))
+ maint["Fault Recording/ Reporting"].append(_nz(row.get("fault_recording", "")))
+ maint["Fault Repair"].append(_nz(row.get("fault_repair", "")))
+ else:
+ # Fallback to vehicles map (older behavior)
+ for v in extracted.get("vehicles", []) or []:
+ if not v.get("registration"): continue
+ if v.get("seen_in_maintenance") or any(v.get(k) for k in ["roadworthiness","maintenance_records","daily_checks","fault_recording","fault_repair"]):
+ rw = _nz(v.get("roadworthiness", "")); mr = _nz(v.get("maintenance_records", "")); dc = _nz(v.get("daily_checks", ""))
+ fr = _nz(v.get("fault_recording", "")); rp = _nz(v.get("fault_repair", ""))
+ if not mr and dc: mr = dc
+ if not rp and fr: rp = fr
+ if not fr and rp: fr = rp
+ maint["Registration Number"].append(_smart_space(v["registration"]))
+ maint["Roadworthiness Certificates"].append(rw)
+ maint["Maintenance Records"].append(mr)
+ maint["Daily Checks"].append(dc)
+ maint["Fault Recording/ Reporting"].append(fr)
+ maint["Fault Repair"].append(rp)
+
+ # Mass stays as-is (from vehicles)
+ for v in extracted.get("vehicles", []) or []:
+ if not v.get("registration"): continue
+ if v.get("seen_in_mass") or any(v.get(k) for k in ["weight_verification","rfs_certification","suspension_maintenance","trip_records","fault_reporting_suspension"]):
+ mass["Registration Number"].append(_smart_space(v["registration"]))
+ mass["Weight Verification Records"].append(_nz(v.get("weight_verification", "")))
+ mass["RFS Suspension Certification #"].append(_nz(v.get("rfs_certification", "")))
+ mass["Suspension System Maintenance"].append(_nz(v.get("suspension_maintenance", "")))
+ mass["Trip Records"].append(_nz(v.get("trip_records", "")))
+ mass["Fault Recording/ Reporting on Suspension System"].append(_nz(v.get("fault_reporting_suspension", "")))
+
+ return {
+ "Vehicle Registration Numbers Maintenance": maint,
+ "Vehicle Registration Numbers Mass": mass,
+ }
+
+
+def _map_header_indices(headers: list[str]) -> dict:
+ """Return {internal_key: column_index} by matching/aliasing header text."""
+ idx = {}
+ for i, h in enumerate(headers or []):
+ ch = _canon_header(h)
+ # try direct alias
+ if ch in _VEH_HEADER_ALIASES:
+ idx[_VEH_HEADER_ALIASES[ch]] = i
+ continue
+ # relax a little for 'registration number' variants
+ if "registration" in ch and "number" in ch:
+ idx["registration"] = i
+ continue
+ if "roadworthiness" in ch:
+ idx["roadworthiness"] = i
+ continue
+ if "maintenance" in ch and "records" in ch:
+ idx["maintenance_records"] = i
+ continue
+ if "daily" in ch and "check" in ch:
+ idx["daily_checks"] = i
+ continue
+ if "fault" in ch and "record" in ch and "suspension" not in ch:
+ # maintenance fault-recording column
+ if "repair" in ch:
+ idx["fault_repair"] = i
+ else:
+ idx["fault_recording"] = i
+ continue
+ if "weight" in ch and "verification" in ch:
+ idx["weight_verification"] = i
+ continue
+ if "rfs" in ch and "certification" in ch:
+ idx["rfs_certification"] = i
+ continue
+ if "suspension" in ch and "maintenance" in ch:
+ idx["suspension_maintenance"] = i
+ continue
+ if "trip" in ch and "record" in ch:
+ idx["trip_records"] = i
+ continue
+ if "fault" in ch and "report" in ch and "suspension" in ch:
+ idx["fault_reporting_suspension"] = i
+ continue
+ return idx
+
+def _canon(s: str) -> str:
+ if not s: return ""
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = re.sub(r"[^a-z0-9#]+", " ", s)
+ return re.sub(r"\s+", " ", s).strip()
+
+def _smart_space(s: str) -> str:
+ if not s: return s
+ s = str(s)
+
+ # Insert spaces at typical OCR glue points
+ s = re.sub(r'([a-z])([A-Z])', r'\1 \2', s)
+ s = re.sub(r'([A-Za-z])(\d)', r'\1 \2', s)
+ s = re.sub(r'(\d)([A-Za-z])', r'\1 \2', s)
+ s = re.sub(r'([A-Z]{2,})(\d)', r'\1 \2', s)
+
+ # Fix common glued tokens
+ s = s.replace("POBox", "PO Box")
+
+ # Compact ordinals back together: "9 th" -> "9th", but preserve a space after the ordinal if followed by a word
+ s = re.sub(r'\b(\d+)\s*(st|nd|rd|th)\b', r'\1\2', s)
+
+ s = re.sub(r"\s+", " ", s).strip()
+ return s
+
+def looks_like_plate(s: str) -> bool:
+ if not s: return False
+ t = re.sub(r"[\s-]", "", str(s).upper())
+ if not (5 <= len(t) <= 8): return False
+ if not re.fullmatch(r"[A-Z0-9]+", t): return False
+ if sum(c.isalpha() for c in t) < 2: return False
+ if sum(c.isdigit() for c in t) < 2: return False
+ if t in {"ENTRY","YES","NO","N/A","NA"}: return False
+ return True
+
+def is_dateish(s: str) -> bool:
+ if not s: return False
+ s = _smart_space(s)
+ # tokens like 03/22, 20/02/2023, 01.02.21, 2023-02-20
+ return bool(re.search(r"\b\d{1,4}(?:[./-]\d{1,2}){1,2}\b", s))
+
+def extract_date_tokens(s: str) -> list[str]:
+ if not s: return []
+ s = _smart_space(s)
+ return re.findall(r"\b\d{1,4}(?:[./-]\d{1,2}){1,2}\b", s)
+
+
+def _clean_list(vals: List[str]) -> List[str]:
+ out = []
+ for v in vals:
+ v = _smart_space(v)
+ if v:
+ out.append(v)
+ return out
+
+def _looks_like_manual_value(s: str) -> bool:
+ if not s: return False
+ s = s.strip()
+ # reject pure digits (e.g., "51902") and very short tokens
+ if re.fullmatch(r"\d{3,}", s):
+ return False
+ # accept if it has any letters or typical version hints
+ return bool(re.search(r"[A-Za-z]", s))
+
+def _looks_like_company(s: str) -> bool:
+ """Very light validation to avoid capturing labels as values."""
+ if not s: return False
+ s = _smart_space(s)
+ # at least two words containing letters (e.g., "Kangaroo Transport")
+ return bool(re.search(r"[A-Za-z]{2,}\s+[A-Za-z&]{2,}", s))
+
+# ───────────────────────────── label index (non-summary only; no values) ─────────────────────────────
+LABEL_INDEX: Dict[str, Dict[str, Dict[str, Any]]] = {
+ "Audit Information": {
+ "Date of Audit": {"alts": ["Date of Audit"]},
+ "Location of audit": {"alts": ["Location of audit", "Location"]},
+ "Auditor name": {"alts": ["Auditor name", "Auditor"]},
+ "Audit Matrix Identifier (Name or Number)": {"alts": ["Audit Matrix Identifier (Name or Number)", "Audit Matrix Identifier"]},
+ "Auditor Exemplar Global Reg No.": {"alts": ["Auditor Exemplar Global Reg No."]},
+ "NHVR Auditor Registration Number": {"alts": ["NHVR Auditor Registration Number"]},
+ "expiry Date:": {"alts": ["expiry Date:", "Expiry Date:"]},
+ },
+ "Operator Information": {
+ "Operator name (Legal entity)": {"alts": ["Operator name (Legal entity)", "Operator's Name (legal entity)"]},
+ "NHVAS Accreditation No. (If applicable)": {"alts": ["NHVAS Accreditation No. (If applicable)", "NHVAS Accreditation No."]},
+ "Registered trading name/s": {"alts": ["Registered trading name/s", "Trading name/s"]},
+ "Australian Company Number": {"alts": ["Australian Company Number", "ACN"]},
+ "NHVAS Manual (Policies and Procedures) developed by": {"alts": [
+ "NHVAS Manual (Policies and Procedures) developed by",
+ "NHVAS Manual developed by",
+ "Manual developed by"
+ ]},
+ },
+ "Operator contact details": {
+ "Operator business address": {"alts": ["Operator business address", "Business address"]},
+ "Operator Postal address": {"alts": ["Operator Postal address", "Postal address"]},
+ "Email address": {"alts": ["Email address", "Email"]},
+ "Operator Telephone Number": {"alts": ["Operator Telephone Number", "Telephone", "Phone"]},
+ },
+ "Attendance List (Names and Position Titles)": {
+ "Attendance List (Names and Position Titles)": {"alts": ["Attendance List (Names and Position Titles)", "Attendance List"]},
+ },
+ "Nature of the Operators Business (Summary)": {
+ "Nature of the Operators Business (Summary):": {"alts": ["Nature of the Operators Business (Summary):"]},
+ },
+ "Accreditation Vehicle Summary": {
+ "Number of powered vehicles": {"alts": ["Number of powered vehicles"]},
+ "Number of trailing vehicles": {"alts": ["Number of trailing vehicles"]},
+ },
+ "Accreditation Driver Summary": {
+ "Number of drivers in BFM": {"alts": ["Number of drivers in BFM"]},
+ "Number of drivers in AFM": {"alts": ["Number of drivers in AFM"]},
+ },
+ "Vehicle Registration Numbers Maintenance": {
+ "No.": {"alts": ["No.", "No"]},
+ "Registration Number": {"alts": ["Registration Number", "Registration"]},
+ "Roadworthiness Certificates": {"alts": ["Roadworthiness Certificates", "Roadworthiness"]},
+ "Maintenance Records": {"alts": ["Maintenance Records"]},
+ "Daily Checks": {"alts": ["Daily Checks", "Daily Check"]},
+ "Fault Recording/ Reporting": {"alts": ["Fault Recording/ Reporting", "Fault Recording / Reporting"]},
+ "Fault Repair": {"alts": ["Fault Repair"]},
+ },
+ "Vehicle Registration Numbers Mass": {
+ "No.": {"alts": ["No.", "No"]},
+ "Registration Number": {"alts": ["Registration Number", "Registration"]},
+ "Sub contractor": {"alts": ["Sub contractor", "Sub-contractor"]},
+ "Sub-contracted Vehicles Statement of Compliance": {"alts": ["Sub-contracted Vehicles Statement of Compliance"]},
+ "Weight Verification Records": {"alts": ["Weight Verification Records"]},
+ "RFS Suspension Certification #": {"alts": ["RFS Suspension Certification #", "RFS Suspension Certification Number"]},
+ "Suspension System Maintenance": {"alts": ["Suspension System Maintenance"]},
+ "Trip Records": {"alts": ["Trip Records"]},
+ "Fault Recording/ Reporting on Suspension System": {"alts": ["Fault Recording/ Reporting on Suspension System"]},
+ },
+ "Driver / Scheduler Records Examined": {
+ "No.": {"alts": ["No.", "No"]},
+ "Driver / Scheduler Name": {"alts": ["Driver / Scheduler Name"]},
+ "Driver TLIF Course # Completed": {"alts": ["Driver TLIF Course # Completed"]},
+ "Scheduler TLIF Course # Completed": {"alts": ["Scheduler TLIF Course # Completed"]},
+ "Medical Certificates (Current Yes/No) Date of expiry": {"alts": ["Medical Certificates (Current Yes/No) Date of expiry"]},
+ "Roster / Schedule / Safe Driving Plan (Date Range)": {"alts": ["Roster / Schedule / Safe Driving Plan (Date Range)"]},
+ "Fit for Duty Statement Completed (Yes/No)": {"alts": ["Fit for Duty Statement Completed (Yes/No)"]},
+ "Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)": {"alts": ["Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)"]},
+ },
+ "NHVAS Approved Auditor Declaration": {
+ "Print Name": {"alts": ["Print Name"]},
+ "NHVR or Exemplar Global Auditor Registration Number": {"alts": ["NHVR or Exemplar Global Auditor Registration Number"]},
+ },
+ "Audit Declaration dates": {
+ "Audit was conducted on": {"alts": ["Audit was conducted on"]},
+ "Unconditional CARs closed out on:": {"alts": ["Unconditional CARs closed out on:"]},
+ "Conditional CARs to be closed out by:": {"alts": ["Conditional CARs to be closed out by:"]},
+ },
+ "Print accreditation name": {
+ "(print accreditation name)": {"alts": ["(print accreditation name)"]},
+ },
+ "Operator Declaration": {
+ "Print Name": {"alts": ["Print Name"]},
+ "Position Title": {"alts": ["Position Title"]},
+ },
+}
+
+class NHVASMerger:
+ def __init__(self):
+ self.debug_mode = True
+ self._vehicle_by_reg = OrderedDict()
+
+ def log_debug(self, msg: str):
+ if self.debug_mode:
+ print(f"🔍 {msg}")
+
+ def normalize_std_label(self, label: str) -> str:
+ if not label: return ""
+ base = re.sub(r"\([^)]*\)", "", label)
+ base = re.sub(r"\s+", " ", base).strip()
+ m = re.match(r"^(Std\s*\d+\.\s*[^:]+?)\s*$", base, flags=re.IGNORECASE)
+ return m.group(1).strip() if m else base
+
+ def _pick_nearby(self, row, anchor_idx: int | None, want: str = "plate", window: int = 3) -> str:
+ """Return the best cell for a field by looking at the anchor index and nearby columns.
+ want ∈ {"plate","date","rf","yn"}"""
+ def cell(i):
+ if i is None or i < 0 or i >= len(row): return ""
+ v = row[i]
+ return v.strip() if isinstance(v, str) else str(v).strip()
+
+ # 1) try the anchor cell
+ cand = cell(anchor_idx)
+ if want == "plate" and looks_like_plate(cand): return _smart_space(cand)
+ if want == "date" and is_dateish(cand): return _smart_space(cand)
+ if want == "rf" and re.search(r"\bRF\s*\d+\b", cand, re.I): return _smart_space(re.search(r"\bRF\s*\d+\b", cand, re.I).group(0))
+ if want == "yn" and cand.strip().lower() in {"yes","no"}: return cand.strip().title()
+
+ # 2) scan a window around the anchor
+ if anchor_idx is not None:
+ for offset in range(1, window+1):
+ for i in (anchor_idx - offset, anchor_idx + offset):
+ c = cell(i)
+ if not c: continue
+ if want == "plate" and looks_like_plate(c): return _smart_space(c)
+ if want == "date" and is_dateish(c): return _smart_space(c)
+ if want == "rf":
+ m = re.search(r"\bRF\s*\d+\b", c, re.I)
+ if m: return _smart_space(m.group(0))
+ if want == "yn" and c.strip().lower() in {"yes","no"}: return c.strip().title()
+
+ # 3) last resort: scan whole row
+ joined = " ".join(str(c or "") for c in row)
+ if want == "plate":
+ for tok in joined.split():
+ if looks_like_plate(tok): return _smart_space(tok)
+ if want == "date":
+ tok = extract_date_tokens(joined)
+ return tok[0] if tok else ""
+ if want == "rf":
+ m = re.search(r"\bRF\s*\d+\b", joined, re.I)
+ return _smart_space(m.group(0)) if m else ""
+ if want == "yn":
+ j = f" {joined.lower()} "
+ if " yes " in j: return "Yes"
+ if " no " in j: return "No"
+ return ""
+
+
+ def _force_fill_maintenance_from_tables(self, pdf_data: Dict, merged: Dict) -> None:
+ """Overwrite Maintenance arrays by scanning ALL maintenance tables across pages."""
+ maint = merged.get("Vehicle Registration Numbers Maintenance")
+ if not isinstance(maint, dict):
+ return
+
+ tables = (pdf_data.get("extracted_data") or {}).get("all_tables") or []
+ regs, rw, mr, dc, fr, rp = [], [], [], [], [], []
+
+ for t in tables:
+ hdrs = [_canon_header(h or "") for h in t.get("headers") or []]
+ if not hdrs:
+ continue
+ # detect a maintenance table
+ txt = " ".join(hdrs)
+ if ("registration" not in txt) or not any(
+ k in txt for k in ["maintenance records", "daily", "fault recording", "fault repair", "roadworthiness"]
+ ):
+ continue
+
+ def fidx(pred):
+ for i, h in enumerate(hdrs):
+ if pred(h):
+ return i
+ return None
+
+ reg_i = fidx(lambda h: "registration" in h)
+ rw_i = fidx(lambda h: "roadworthiness" in h)
+ mr_i = fidx(lambda h: "maintenance" in h and "record" in h)
+ dc_i = fidx(lambda h: "daily" in h and "check" in h)
+ fr_i = fidx(lambda h: "fault" in h and "record" in h and "suspension" not in h)
+ rp_i = fidx(lambda h: "fault" in h and "repair" in h)
+
+ for r in t.get("data") or []:
+ def cell(i):
+ if i is None or i >= len(r): return ""
+ v = r[i]
+ return v.strip() if isinstance(v, str) else str(v).strip()
+
+ plate = _smart_space(cell(reg_i))
+ if not plate or not looks_like_plate(plate):
+ continue
+
+ v_rw = _nz(cell(rw_i))
+ v_mr = _nz(cell(mr_i))
+ v_dc = _nz(cell(dc_i))
+ v_fr = _nz(cell(fr_i))
+ v_rp = _nz(cell(rp_i))
+
+ # sensible fallbacks
+ if not v_mr and v_dc: v_mr = v_dc
+ if not v_rp and v_fr: v_rp = v_fr
+ if not v_fr and v_rp: v_fr = v_rp
+
+ regs.append(plate); rw.append(v_rw); mr.append(v_mr)
+ dc.append(v_dc); fr.append(v_fr); rp.append(v_rp)
+
+ if regs: # overwrite arrays only if we found rows
+ maint["Registration Number"] = regs
+ maint["Roadworthiness Certificates"] = rw
+ maint["Maintenance Records"] = mr
+ maint["Daily Checks"] = dc
+ maint["Fault Recording/ Reporting"] = fr
+ maint["Fault Repair"] = rp
+
+ def _collapse_multiline_headers(self, headers: List[str], data_rows: List[List[str]]):
+ """
+ Merge header continuation rows (when first data rows are not numeric '1.', '2.', …)
+ into the main headers, then return (merged_headers, remaining_data_rows).
+ """
+ merged = [_smart_space(h or "") for h in (headers or [])]
+ consumed = 0
+ header_frags: List[List[str]] = []
+
+ # Collect up to 5 leading rows that look like header fragments
+ for r in data_rows[:5]:
+ first = (str(r[0]).strip() if r else "")
+ if re.match(r"^\d+\.?$", first):
+ break # real data starts
+ consumed += 1
+ header_frags.append(r)
+
+ # Merge every collected fragment row into merged
+ for frag in header_frags:
+ for i, cell in enumerate(frag):
+ cell_txt = _smart_space(str(cell or "").strip())
+ if not cell_txt:
+ continue
+ if i >= len(merged):
+ merged.append(cell_txt)
+ else:
+ merged[i] = (merged[i] + " " + cell_txt).strip()
+
+ return merged, data_rows[consumed:]
+
+ def _first_attendance_name_title(self, att_list: List[str]) -> Optional[tuple[str, str]]:
+ """Return (print_name, position_title) from the first 'Name - Title' in attendance."""
+ if not att_list:
+ return None
+ # First "Name - Title", stop before next "Name -"
+ pat = re.compile(
+ r'([A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3})\s*-\s*(.*?)(?=(?:\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+){0,3}\s*-\s*)|$)'
+ )
+ for item in att_list:
+ s = _smart_space(str(item))
+ m = pat.search(s)
+ if m:
+ name = _smart_space(m.group(1))
+ title = _smart_space(m.group(2))
+ return name, title
+ return None
-def update_json_with_pdf(word_json_file, pdf_txt_file, output_file):
- """
- word_json_file: file-like object or file path (docx extraction JSON)
- pdf_txt_file: file-like object or file path (PDF plain text)
- output_file: file-like object (opened for writing) or file path
- """
- # --- Load files ---
- def read_any(f):
- if hasattr(f, "read"):
- f.seek(0)
- content = f.read()
- if isinstance(content, bytes):
- content = content.decode("utf-8")
- return content
- else:
- with open(f, "r", encoding="utf-8") as fh:
- return fh.read()
-
- word_json = read_any(word_json_file)
- pdf_txt = read_any(pdf_txt_file)
-
- # --- Build prompt ---
- user_prompt = f"""Here is a JSON template with fields that need updating with data from the PDF:
-{word_json}
-
-Here is the extracted text from a PDF document:
-{pdf_txt}
-
-EXTRACTION INSTRUCTIONS:
-1. COMPREHENSIVE EXTRACTION: Extract data for EVERY field present in the JSON template. Do not skip any field.
-
-2. FIELD-SPECIFIC EXTRACTION RULES:
- - Dates: Look for patterns like "5th July 2023", "28th February 2024"
- - Company Names: Extract the EXACT company name from the current PDF document
- - Registration Numbers: Look for vehicle registrations (format: XX ## XX)
- - Contact Details: Extract addresses, phone numbers, emails exactly as written
- - ACN Numbers: Extract 9-digit Australian Company Numbers
- - Audit Numbers: Look for audit matrix identifiers, CAR numbers
-
-3. TABLE DATA EXTRACTION:
- - For Vehicle Registration tables: Extract ALL columns including maintenance records, weight verification, suspension data
- - For attendance lists: Extract ALL names with their positions/roles
- - For management summaries: Extract specific dates, numbers, and compliance details
-
-4. MISSING DATA HANDLING:
- - If data is not found in PDF, use "Not Available" instead of "Entry"
- - For empty date ranges, use "Date range not specified"
- - For missing numbers, use "Not provided"
- - Only use actual data found in the PDF text
-
-5. OPERATOR DECLARATION CRITICAL RULES:
- - "Print Name": Must be the COMPANY REPRESENTATIVE signing the operator declaration (NOT the auditor)
- - Look for "OPERATOR DECLARATION" section - the person signing this is from the company
- - "Position Title": Their job role within the company (Director, Compliance Officer, Manager, etc.)
- - NEVER use the auditor's name (Greg Dyer) for operator declaration
-
-6. DATA CONSISTENCY:
- - Ensure the same company name appears throughout all sections
- - Ensure the same people appear consistently with correct roles
- - Cross-reference data between sections for accuracy
-
-7. QUALITY VALIDATION:
- - Verify extracted company name matches throughout the document
- - Check that dates are logical and properly formatted
- - Ensure vehicle registrations follow proper format
-
-CRITICAL: Extract data ONLY from the provided PDF text. Do not use any external knowledge or previous document data.
-
-Output ONLY the updated JSON with all fields filled using the extracted data. No markdown, no explanations, just valid JSON."""
-
- # --- Call OpenAI API ---
- api_key = os.environ.get("OPENAI_API_KEY")
- if not api_key:
- raise RuntimeError("OPENAI_API_KEY not found in environment variables!")
-
- client = OpenAI(api_key=api_key)
- response = client.chat.completions.create(
- model="gpt-4o",
- messages=[
- {"role": "system", "content": "You are a precise data extraction assistant specializing in audit documents. Extract data EXACTLY as it appears in the source document. Only reply with valid JSON - no markdown, no explanations, no extra formatting. Be thorough and extract ALL requested fields from the provided text."},
- {"role": "user", "content": user_prompt}
- ],
- max_tokens=6000, # Increased for more comprehensive extraction
- temperature=0.1 # Slightly increased for better handling of variations in text
- )
-
- updated_json_str = response.choices[0].message.content.strip()
-
- # Clean up common formatting issues
- if updated_json_str.startswith("```json"):
- updated_json_str = updated_json_str[7:]
- if updated_json_str.endswith("```"):
- updated_json_str = updated_json_str[:-3]
- updated_json_str = updated_json_str.strip()
-
- # --- Try to parse as JSON ---
- try:
- parsed = json.loads(updated_json_str)
-
- # Basic validation
- print("🔍 Validating extracted data...")
- original_data = json.loads(word_json)
-
- # Check if we have the same structure
- original_keys = set(original_data.keys()) if isinstance(original_data, dict) else set()
- parsed_keys = set(parsed.keys()) if isinstance(parsed, dict) else set()
-
- if original_keys and parsed_keys:
- missing_keys = original_keys - parsed_keys
- if missing_keys:
- print(f"⚠️ Warning: Missing keys in extraction: {missing_keys}")
-
- added_keys = parsed_keys - original_keys
- if added_keys:
- print(f"⚠️ Warning: Unexpected keys added: {added_keys}")
-
- # Save the parsed JSON
- if hasattr(output_file, "write"):
- json.dump(parsed, output_file, indent=2, ensure_ascii=False)
- output_file.flush()
- else:
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(parsed, f, indent=2, ensure_ascii=False)
- print("✅ JSON updated and saved to", getattr(output_file, "name", output_file))
-
- # Print extraction summary
- print(f"📊 Extraction Summary:")
- if isinstance(parsed, dict):
- total_fields = sum(len(v) if isinstance(v, list) else 1 for v in parsed.values())
- print(f" - Total sections: {len(parsed)}")
- print(f" - Total data points extracted: {total_fields}")
-
- # Debug: Print the updated JSON content
- print("\n🔍 UPDATED JSON CONTENT:")
- print("=" * 80)
- print(json.dumps(parsed, indent=2, ensure_ascii=False)[:3000] + ("..." if len(json.dumps(parsed, indent=2)) > 3000 else ""))
- print("=" * 80)
-
- except json.JSONDecodeError as e:
- print("⚠️ Model did not return valid JSON. Raw output below:\n")
- print(updated_json_str[:1000] + "..." if len(updated_json_str) > 1000 else updated_json_str)
- print(f"\n❌ JSON Parse Error: {e}")
- print("🔧 Attempting to fix common JSON issues...")
-
- # Try to fix common issues
+ # ───────────────────────────── summary tables (unchanged logic) ─────────────────────────────
+ def build_summary_maps(self, pdf_json: dict) -> dict:
+ out = {v: {} for v in SUMMARY_SECTIONS.values()}
try:
- # Remove trailing commas
- fixed_json = re.sub(r',(\s*[}\]])', r'\1', updated_json_str)
- parsed = json.loads(fixed_json)
- print("✅ Fixed JSON formatting issues")
-
- if hasattr(output_file, "write"):
- json.dump(parsed, output_file, indent=2, ensure_ascii=False)
- output_file.flush()
+ tables = pdf_json["extracted_data"]["all_tables"]
+ except Exception:
+ return out
+
+ for t in tables:
+ headers = [re.sub(r"\s+", " ", (h or "")).strip().upper() for h in t.get("headers", [])]
+ if "DETAILS" not in headers:
+ continue
+ section_key_raw = next((h for h in headers if h in SUMMARY_SECTIONS), None)
+ if not section_key_raw:
+ continue
+ section_name = SUMMARY_SECTIONS[section_key_raw]
+ for row in t.get("data", []):
+ if not row: continue
+ left = str(row[0]) if len(row) >= 1 else ""
+ right = str(row[1]) if len(row) >= 2 else ""
+ left_norm = self.normalize_std_label(left)
+ if left_norm and right:
+ prev = out[section_name].get(left_norm, "")
+ merged_text = (prev + " " + right).strip() if prev else right.strip()
+ out[section_name][left_norm] = merged_text
+
+ for sec in out:
+ out[sec] = {k: [_smart_space(v)] for k, v in out[sec].items() if v}
+ return out
+
+ # ───────────────────────────── NEW: find cell by label in tables ─────────────────────────────
+ def _find_table_value(self, tables: List[Dict], label_variants: List[str]) -> Optional[str]:
+ targets = {_canon(v) for v in label_variants}
+ for t in tables:
+ data = t.get("data", [])
+ if not data: continue
+ for row in data:
+ if not row: continue
+ key = _canon(str(row[0]))
+ if key in targets:
+ vals = [str(c).strip() for c in row[1:] if str(c).strip()]
+ if vals:
+ return _smart_space(" ".join(vals))
+ return None
+
+ # ───────────────────────────── comprehensive extraction (minimal changes) ─────────────────────────────
+ def extract_from_pdf_comprehensive(self, pdf_data: Dict) -> Dict[str, Any]:
+ self._vehicle_by_reg.clear()
+ extracted = {}
+ extracted_data = pdf_data.get("extracted_data", {})
+ tables = extracted_data.get("all_tables", [])
+
+ # Capture "Audit was conducted on" from tables; ignore placeholder "Date"
+ awd = self._find_table_value(
+ tables,
+ LABEL_INDEX["Audit Declaration dates"]["Audit was conducted on"]["alts"]
+ )
+ if awd:
+ awd = _smart_space(awd)
+ if re.search(r"\d", awd) and not re.fullmatch(r"date", awd, re.I):
+ extracted["audit_conducted_date"] = awd
+
+
+
+ # 1) Audit Information (table first)
+ audit_info = extracted_data.get("audit_information", {})
+ if audit_info:
+ extracted["audit_info"] = {
+ "date_of_audit": _smart_space(audit_info.get("DateofAudit", "")),
+ "location": _smart_space(audit_info.get("Locationofaudit", "")),
+ "auditor_name": _smart_space(audit_info.get("Auditorname", "")),
+ "matrix_id": _smart_space(audit_info.get("AuditMatrixIdentifier (Name or Number)", "")),
+ }
+ # If missing, try generic table lookup
+ for label, meta in LABEL_INDEX.get("Audit Information", {}).items():
+ if label == "expiry Date:": # not used in your DOCX example
+ continue
+ val = self._find_table_value(tables, meta.get("alts", [label]))
+ if val:
+ extracted.setdefault("audit_info", {})
+ if _canon(label) == _canon("Date of Audit"): extracted["audit_info"]["date_of_audit"] = val
+ elif _canon(label) == _canon("Location of audit"): extracted["audit_info"]["location"] = val
+ elif _canon(label) == _canon("Auditor name"): extracted["audit_info"]["auditor_name"] = val
+ elif _canon(label) == _canon("Audit Matrix Identifier (Name or Number)"): extracted["audit_info"]["matrix_id"] = val
+
+ # 2) Operator Information (prefer table rows)
+ operator_info = extracted_data.get("operator_information", {})
+ if operator_info:
+ extracted["operator_info"] = {
+ "name": "",
+ "trading_name": _smart_space(operator_info.get("trading_name", "")),
+ "acn": _smart_space(operator_info.get("company_number", "")),
+ "manual": _smart_space(operator_info.get("nhvas_accreditation", "")),
+ "business_address": _smart_space(operator_info.get("business_address", "")),
+ "postal_address": _smart_space(operator_info.get("postal_address", "")),
+ "email": operator_info.get("email", ""),
+ "phone": _smart_space(operator_info.get("phone", "")),
+ }
+
+ # Fill operator info via table lookup
+ for label, meta in LABEL_INDEX.get("Operator Information", {}).items():
+ val = self._find_table_value(tables, meta.get("alts", [label]))
+ if not val: continue
+ if _canon(label) == _canon("Operator name (Legal entity)") and _looks_like_company(val):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["name"] = val
+ elif _canon(label) == _canon("Registered trading name/s"):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["trading_name"] = val
+ elif _canon(label) == _canon("Australian Company Number"):
+ extracted.setdefault("operator_info", {})
+ extracted["operator_info"]["acn"] = val
+ elif _canon(label) == _canon("NHVAS Manual (Policies and Procedures) developed by"):
+ extracted.setdefault("operator_info", {})
+ if _looks_like_manual_value(val):
+ extracted["operator_info"]["manual"] = val
+
+ # 3) Generic table parsing (unchanged logic for other sections)
+ self._extract_table_data(tables, extracted)
+
+ # 4) Text parsing (kept, but spacing applied)
+ self._extract_text_content(extracted_data.get("all_text_content", []), extracted)
+ # Vehicle tables sometimes fail to land in all_tables; parse from text as a fallback
+ self._extract_vehicle_tables_from_text(extracted_data.get("all_text_content", []), extracted)
+
+ # 5) Vehicle/Driver data (kept)
+ self._extract_vehicle_driver_data(extracted_data, extracted)
+
+ # 6) Detailed mgmt (kept)
+ self._extract_detailed_management_data(extracted_data, extracted)
+
+ return extracted
+
+ # ───────────────────────────── table classifiers ─────────────────────────────
+ # replace your _extract_table_data with this version
+ def _extract_table_data(self, tables: List[Dict], extracted: Dict):
+ for table in tables:
+ headers = table.get("headers", []) or []
+ data_rows = table.get("data", []) or []
+ if not data_rows:
+ continue
+
+ page_num = table.get("page", 0)
+ self.log_debug(f"Processing table on page {page_num} with headers: {headers[:3]}...")
+
+ # 🔧 NEW: collapse possible multi-line headers once up front
+ collapsed_headers, collapsed_rows = self._collapse_multiline_headers(headers, data_rows)
+
+ # 🔧 Try vehicle tables FIRST using either raw or collapsed headers
+ if self._is_vehicle_registration_table(headers) or self._is_vehicle_registration_table(collapsed_headers):
+ # always extract with the collapsed header/rows so we see "Registration Number", etc.
+ self._extract_vehicle_registration_table(collapsed_headers, collapsed_rows, extracted, page_num)
+ continue
+
+ # the rest keep their existing order/logic (use the original headers/rows)
+ if self._is_audit_info_table(headers):
+ self._extract_audit_info_table(data_rows, extracted)
+ elif self._is_operator_info_table(headers):
+ self._extract_operator_info_table(data_rows, extracted)
+ elif self._is_attendance_table(headers):
+ self._extract_attendance_table(data_rows, extracted)
+ elif self._is_vehicle_summary_table(headers):
+ self._extract_vehicle_summary_table(data_rows, extracted)
+ elif self._is_driver_table(headers):
+ self._extract_driver_table(headers, data_rows, extracted)
+ elif self._is_management_compliance_table(headers):
+ self._extract_management_table(data_rows, extracted, headers)
+
+
+ def _is_audit_info_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["audit", "date", "location", "auditor"])
+
+ def _is_operator_info_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["operator", "company", "trading", "address"])
+
+ def _is_attendance_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return "attendance" in txt
+
+ def _is_vehicle_summary_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["powered vehicles", "trailing vehicles", "drivers in bfm"])
+
+ def _is_vehicle_registration_table(self, headers: List[str]) -> bool:
+ if not headers: return False
+ ch = [_canon_header(h) for h in headers]
+ has_reg = any(
+ ("registration" in h) or re.search(r"\breg(?:istration)?\b", h) or ("reg" in h and "no" in h)
+ for h in ch
+ )
+ others = ["roadworthiness","maintenance records","daily checks","fault recording","fault repair",
+ "sub contractor","sub-contractor","weight verification","rfs suspension","suspension system maintenance",
+ "trip records","fault recording reporting on suspension system","fault reporting suspension"]
+ has_signal = any(any(tok in h for tok in others) for h in ch)
+ return has_reg and has_signal
+
+ def _is_driver_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["driver", "scheduler", "tlif", "medical"])
+
+ def _is_management_compliance_table(self, headers: List[str]) -> bool:
+ txt = " ".join(str(h) for h in headers).lower()
+ return any(t in txt for t in ["maintenance management", "mass management", "fatigue management"])
+
+ def _extract_vehicle_tables_from_text(self, text_pages: List[Dict], extracted: Dict):
+ # flatten text
+ lines = []
+ for p in text_pages or []:
+ for ln in re.split(r"\s*\n\s*", p.get("text", "")):
+ ln = _smart_space(ln)
+ if ln: lines.append(ln)
+
+ maint_rows, mass_rows = [], []
+ rf_pat = re.compile(r"\bRF\s*\d+\b", re.IGNORECASE)
+
+ for ln in lines:
+ # find first token that looks like a rego
+ tokens = ln.split()
+ reg = next((t for t in tokens if looks_like_plate(t)), None)
+ if not reg:
+ continue
+
+ # everything after the reg on that line
+ tail = _smart_space(ln.split(reg, 1)[1]) if reg in ln else ""
+ dates = extract_date_tokens(tail)
+ has_rf = bool(rf_pat.search(ln)) or "suspension" in ln.lower()
+
+ if has_rf:
+ rfs = (rf_pat.search(ln).group(0).upper().replace(" ", "") if rf_pat.search(ln) else "")
+ wv = dates[0] if len(dates) > 0 else ""
+ rest = dates[1:]
+ mass_rows.append({
+ "registration": reg,
+ "sub_contractor": "Yes" if " yes " in f" {ln.lower()} " else ("No" if " no " in f" {ln.lower()} " else ""),
+ "sub_comp": "Yes" if " yes " in f" {ln.lower()} " else ("No" if " no " in f" {ln.lower()} " else ""),
+ "weight_verification": wv,
+ "rfs_certification": rfs or ("N/A" if "n/a" in ln.lower() else ""),
+ "suspension_maintenance": rest[0] if len(rest) > 0 else "",
+ "trip_records": rest[1] if len(rest) > 1 else "",
+ "fault_reporting_suspension": rest[2] if len(rest) > 2 else "",
+ })
else:
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(parsed, f, indent=2, ensure_ascii=False)
- print("✅ JSON saved after fixes")
- except Exception as fix_error:
- print(f"❌ Could not fix JSON: {fix_error}")
- raise e
- except Exception as e:
- print(f"❌ Unexpected error: {e}")
- raise e
+ # map first 5 date-like tokens in sensible order; fallbacks keep table consistent
+ rw = dates[0] if len(dates) > 0 else ""
+ mr = dates[1] if len(dates) > 1 else ""
+ dc = dates[2] if len(dates) > 2 else ""
+ fr = dates[3] if len(dates) > 3 else ""
+ rp = dates[4] if len(dates) > 4 else ""
+ maint_rows.append({
+ "registration": reg,
+ "roadworthiness": rw,
+ "maintenance_records": mr or dc,
+ "daily_checks": dc,
+ "fault_recording": fr or rp,
+ "fault_repair": rp or fr,
+ })
-if __name__ == "__main__":
- import sys
+ # ... after building maint_rows and mass_rows ...
+ vlist = extracted.setdefault("vehicles", []) # ensure it always exists
+
+ if maint_rows or mass_rows:
+ for r in maint_rows:
+ r["section"] = "maintenance"
+ vlist.append(r)
+ for r in mass_rows:
+ r["section"] = "mass"
+ vlist.append(r)
+ self.log_debug(f"Vehicle rows (text fallback): maint={len(maint_rows)} mass={len(mass_rows)} total={len(vlist)}")
+ else:
+ self.log_debug("Vehicle rows (text fallback): none detected.")
+
+
+ # ───────────────────────────── simple extractors (spacing applied) ─────────────────────────────
+ def _extract_audit_info_table(self, data_rows: List[List], extracted: Dict):
+ ai = extracted.setdefault("audit_info", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ val = _smart_space(" ".join(str(c).strip() for c in row[1:] if str(c).strip()))
+ if not val: continue
+ if "date" in key and "audit" in key: ai["date_of_audit"] = val
+ elif "location" in key: ai["location"] = val
+ elif "auditor" in key and "name" in key: ai["auditor_name"] = val
+ elif "matrix" in key: ai["matrix_id"] = val
+
+ def _extract_operator_info_table(self, data_rows: List[List], extracted: Dict):
+ oi = extracted.setdefault("operator_info", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ val = _smart_space(" ".join(str(c).strip() for c in row[1:] if str(c).strip()))
+ if not val: continue
+ if "operator" in key and "name" in key and _looks_like_company(val): oi["name"] = val
+ elif "trading" in key: oi["trading_name"] = val
+ elif "australian" in key and "company" in key: oi["acn"] = val
+ elif "business" in key and "address" in key: oi["business_address"] = val
+ elif "postal" in key and "address" in key: oi["postal_address"] = val
+ elif "email" in key: oi["email"] = val
+ elif "telephone" in key or "phone" in key: oi["phone"] = val
+ elif "manual" in key or ("nhvas" in key and "manual" in key) or "developed" in key:
+ if _looks_like_manual_value(val):
+ oi["manual"] = val
+
+ def _extract_attendance_table(self, data_rows: List[List], extracted: Dict):
+ lst = []
+ for row in data_rows:
+ if not row: continue
+ cells = [str(c).strip() for c in row if str(c).strip()]
+ if not cells: continue
+ lst.append(_smart_space(" ".join(cells)))
+ if lst:
+ extracted["attendance"] = lst
+
+ def _extract_vehicle_summary_table(self, data_rows: List[List], extracted: Dict):
+ vs = extracted.setdefault("vehicle_summary", {})
+ for row in data_rows:
+ if len(row) < 2: continue
+ key = _canon(row[0])
+ value = ""
+ for c in row[1:]:
+ if str(c).strip():
+ value = _smart_space(str(c).strip()); break
+ if not value: continue
+ if "powered" in key and "vehicle" in key: vs["powered_vehicles"] = value
+ elif "trailing" in key and "vehicle" in key: vs["trailing_vehicles"] = value
+ elif "drivers" in key and "bfm" in key: vs["drivers_bfm"] = value
+ elif "drivers" in key and "afm" in key: vs["drivers_afm"] = value
+
+ # ▶▶ REPLACED: column mapping by headers
+ def _extract_vehicle_registration_table(self, headers, rows, extracted, page_num):
+ ch = [_canon_header(h) for h in (headers or [])]
+ alias = _map_header_indices(headers or [])
+
+ # header indices (may be misaligned vs data; that's OK, we’ll search near them)
+ def idx_of(*needles):
+ for i, h in enumerate(ch):
+ if all(n in h for n in needles): return i
+ return None
+
+ reg_i = alias.get("registration") or idx_of("registration number") or idx_of("registration") or idx_of("reg","no")
+ rw_i = alias.get("roadworthiness") or idx_of("roadworthiness")
+ maint_i = alias.get("maintenance_records") or idx_of("maintenance","records")
+ daily_i = alias.get("daily_checks") or idx_of("daily","check")
+ fr_i = alias.get("fault_recording") or idx_of("fault","recording")
+ rep_i = alias.get("fault_repair") or idx_of("fault","repair")
+
+ weight_i = alias.get("weight_verification") or idx_of("weight","verification")
+ rfs_i = alias.get("rfs_certification") or idx_of("rfs","certification")
+ susp_i = alias.get("suspension_maintenance") or idx_of("suspension","maintenance")
+ trip_i = alias.get("trip_records") or idx_of("trip","records")
+ frs_i = alias.get("fault_reporting_suspension") or idx_of("fault","reporting","suspension")
+
+ # classify table type by header signals
+ is_maint = any("roadworthiness" in h or "maintenance records" in h or ("daily" in h and "check" in h) or "fault repair" in h for h in ch)
+ is_mass = any("weight verification" in h or "rfs" in h or "suspension system" in h or "trip records" in h or "reporting on suspension" in h for h in ch)
+
+ maint_rows = extracted.setdefault("_maint_rows", []) if is_maint else None
+ added = 0
+
+ for r in rows or []:
+ # tolerant plate pick (handles misaligned columns)
+ reg = self._pick_nearby(r, reg_i, "plate", window=4)
+ if not reg or not looks_like_plate(reg):
+ continue
+
+ # collect values using tolerant picks
+ if is_maint:
+ rw = self._pick_nearby(r, rw_i, "date", window=4)
+ mr = self._pick_nearby(r, maint_i, "date", window=4)
+ dc = self._pick_nearby(r, daily_i, "date", window=4)
+ fr = self._pick_nearby(r, fr_i, "date", window=4)
+ rep = self._pick_nearby(r, rep_i, "date", window=4)
+
+ # sensible fallbacks
+ if not mr and dc: mr = dc
+ if not rep and fr: rep = fr
+ if not fr and rep: fr = rep
+
+ else: # mass or mixed
+ wv = self._pick_nearby(r, weight_i, "date", window=4)
+ rfs = self._pick_nearby(r, rfs_i, "rf", window=5)
+ sm = self._pick_nearby(r, susp_i, "date", window=4)
+ tr = self._pick_nearby(r, trip_i, "date", window=4)
+ frs = self._pick_nearby(r, frs_i, "date", window=4)
+ yn1 = self._pick_nearby(r, idx_of("sub","contractor"), "yn", window=3) or ""
+ yn2 = self._pick_nearby(r, idx_of("sub contracted vehicles statement of compliance"), "yn", window=3) or yn1
+
+ # merge into vehicle map
+ v = self._vehicle_by_reg.get(reg)
+ if v is None:
+ v = {"registration": reg}
+ self._vehicle_by_reg[reg] = v
+ added += 1
+
+ if is_maint:
+ v["seen_in_maintenance"] = True
+ if rw: v.setdefault("roadworthiness", rw)
+ if mr: v.setdefault("maintenance_records", mr)
+ if dc: v.setdefault("daily_checks", dc)
+ if fr: v.setdefault("fault_recording", fr)
+ if rep: v.setdefault("fault_repair", rep)
+
+ if maint_rows is not None:
+ maint_rows.append({
+ "registration": reg,
+ "roadworthiness": rw,
+ "maintenance_records": mr or dc,
+ "daily_checks": dc,
+ "fault_recording": fr or rep,
+ "fault_repair": rep or fr,
+ })
+ else:
+ v["seen_in_mass"] = True
+ if yn1: v.setdefault("sub_contractor", yn1)
+ if yn2: v.setdefault("sub_comp", yn2)
+ if wv: v.setdefault("weight_verification", wv)
+ if rfs: v.setdefault("rfs_certification", _smart_space(rfs).upper().replace(" ", ""))
+ if sm: v.setdefault("suspension_maintenance", sm)
+ if tr: v.setdefault("trip_records", tr)
+ if frs: v.setdefault("fault_reporting_suspension", frs)
+
+ extracted["vehicles"] = list(self._vehicle_by_reg.values())
+ return added
+
+ def _extract_driver_table(self, headers: List[str], data_rows: List[List], extracted: Dict):
+ """Header-driven extraction for Driver / Scheduler Records."""
+ drivers = []
+ ch = [_canon_header(h) for h in headers or []]
+
+ # helpers
+ def find_col(needles: list[str]) -> Optional[int]:
+ for i, h in enumerate(ch):
+ if any(n in h for n in needles):
+ return i
+ return None
+
+ def find_col_rx(patterns: list[str]) -> Optional[int]:
+ for i, h in enumerate(ch):
+ if any(re.search(p, h) for p in patterns):
+ return i
+ return None
+
+ name_idx = find_col_rx([r"\bdriver\s*/\s*scheduler\s*name\b",
+ r"\bdriver\s+name\b", r"\bscheduler\s+name\b", r"\bname\b"])
+ tlif_d_idx = find_col(["driver tlif"])
+ tlif_s_idx = find_col(["scheduler tlif"])
+ medical_idx= find_col(["medical", "expiry"])
+ roster_idx = find_col_rx([r"\broster\b", r"\bsafe\s+driving\s+plan\b", r"\bschedule\b(?!r\b)"])
+ fit_idx = find_col(["fit for duty"])
+ diary_idx = find_col(["work diary", "electronic work diary", "page numbers"])
+
+ for row in data_rows:
+ if not row:
+ continue
+
+ name = None
+ if name_idx is not None and name_idx < len(row):
+ name = _smart_space(str(row[name_idx]).strip())
+ if not name:
+ continue
+
+ d = {"name": name}
+
+ if tlif_d_idx is not None and tlif_d_idx < len(row):
+ d["driver_tlif"] = _smart_space(str(row[tlif_d_idx]).strip())
+ if tlif_s_idx is not None and tlif_s_idx < len(row):
+ d["scheduler_tlif"] = _smart_space(str(row[tlif_s_idx]).strip())
+ if medical_idx is not None and medical_idx < len(row):
+ d["medical_expiry"] = _smart_space(str(row[medical_idx]).strip())
+
+ # Roster/Schedule/SDP: prefer the detected column; accept only date/range-like, not the name
+ if roster_idx is not None and roster_idx < len(row):
+ raw_roster = _smart_space(str(row[roster_idx]).strip())
+ if raw_roster and re.search(r"[0-9/–-]", raw_roster) and raw_roster.lower() != name.lower():
+ d["roster_schedule"] = raw_roster
+
+ # Fallback: scan the row for the first date/range-like cell that's not the name cell
+ if "roster_schedule" not in d:
+ for j, cell in enumerate(row):
+ if j == name_idx:
+ continue
+ s = _smart_space(str(cell).strip())
+ if s and re.search(r"[0-9/–-]", s) and s.lower() != name.lower():
+ d["roster_schedule"] = s
+ break
+
+ if fit_idx is not None and fit_idx < len(row):
+ d["fit_for_duty"] = _smart_space(str(row[fit_idx]).strip())
+ if diary_idx is not None and diary_idx < len(row):
+ d["work_diary"] = _smart_space(str(row[diary_idx]).strip())
+
+ drivers.append(d)
+
+ if drivers:
+ extracted["drivers_detailed"] = drivers
+ self.log_debug(f"Driver rows extracted (header-based): {len(drivers)}")
+
+
+ def _extract_management_table(self, data_rows: List[List], extracted: Dict, headers: List[str]):
+ txt = " ".join(str(h) for h in headers).lower()
+ comp = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ std = str(row[0]).strip()
+ val = _smart_space(str(row[1]).strip())
+ if std.startswith("Std") and val:
+ comp[std] = val
+ if comp:
+ if "maintenance" in txt: extracted["maintenance_compliance"] = comp
+ elif "mass" in txt: extracted["mass_compliance"] = comp
+ elif "fatigue" in txt: extracted["fatigue_compliance"] = comp
+
+ def _extract_text_content(self, text_pages: List[Dict], extracted: Dict):
+ all_text = " ".join(page.get("text", "") for page in text_pages)
+ all_text = _smart_space(all_text)
+
+ # business summary
+ patt = [
+ r"Nature of the Operators? Business.*?:\s*(.*?)(?:Accreditation Number|Expiry Date|$)",
+ r"Nature of.*?Business.*?Summary.*?:\s*(.*?)(?:Accreditation|$)"
+ ]
+ for p in patt:
+ m = re.search(p, all_text, re.IGNORECASE | re.DOTALL)
+ if m:
+ txt = re.sub(r'\s+', ' ', m.group(1).strip())
+ txt = re.sub(r'\s*(Accreditation Number.*|Expiry Date.*)', '', txt, flags=re.IGNORECASE)
+ if len(txt) > 50:
+ extracted["business_summary"] = txt
+ break
+
+ # audit conducted date
+ for p in [
+ r"Audit was conducted on\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})",
+ r"DATE\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})",
+ r"AUDITOR SIGNATURE\s+DATE\s+([0-9]+(?:st|nd|rd|th)?\s+[A-Za-z]+\s+\d{4})"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ extracted["audit_conducted_date"] = _smart_space(m.group(1).strip())
+ break
+
+ # print accreditation name
+ for p in [
+ r"\(print accreditation name\)\s*([A-Za-z0-9\s&().,'/\-]+?)(?:\s+DOES|\s+does|\n|$)",
+ r"print accreditation name.*?\n\s*([A-Za-z0-9\s&().,'/\-]+?)(?:\s+DOES|\s+does|\n|$)"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ extracted["print_accreditation_name"] = _smart_space(m.group(1).strip())
+ break
+
+ # numbers in text (optional)
+ for p in [
+ r"Number of powered vehicles\s+(\d+)",
+ r"powered vehicles\s+(\d+)",
+ r"Number of trailing vehicles\s+(\d+)",
+ r"trailing vehicles\s+(\d+)",
+ r"Number of drivers in BFM\s+(\d+)",
+ r"drivers in BFM\s+(\d+)"
+ ]:
+ m = re.search(p, all_text, re.IGNORECASE)
+ if m:
+ val = m.group(1)
+ if "powered" in p: extracted.setdefault("vehicle_summary", {})["powered_vehicles"] = val
+ elif "trailing" in p: extracted.setdefault("vehicle_summary", {})["trailing_vehicles"] = val
+ elif "bfm" in p.lower(): extracted.setdefault("vehicle_summary", {})["drivers_bfm"] = val
+
+ def _extract_detailed_management_data(self, extracted_data: Dict, extracted: Dict):
+ all_tables = extracted_data.get("all_tables", [])
+ for table in all_tables:
+ headers = table.get("headers", [])
+ data_rows = table.get("data", [])
+ page_num = table.get("page", 0)
+ if self._has_details_column(headers):
+ section = self._identify_management_section(headers)
+ if section:
+ self._extract_management_details(data_rows, extracted, section)
+ elif 6 <= page_num <= 15:
+ self._extract_summary_by_content(data_rows, headers, extracted, page_num)
+
+ def _extract_summary_by_content(self, data_rows: List[List], headers: List[str], extracted: Dict, page_num: int):
+ section_type = "maintenance" if 6 <= page_num <= 9 else "mass" if 10 <= page_num <= 12 else "fatigue" if 13 <= page_num <= 15 else None
+ if not section_type: return
+ details_key = f"{section_type}_summary_details"
+ extracted[details_key] = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ standard = str(row[0]).strip()
+ details = _smart_space(str(row[1]).strip())
+ if standard.startswith("Std") and details and len(details) > 10:
+ m = re.search(r"Std\s+(\d+)\.\s*([^(]+)", standard)
+ if m:
+ key = f"Std {m.group(1)}. {m.group(2).strip()}"
+ extracted[details_key][key] = details
+
+ def _has_details_column(self, headers: List[str]) -> bool:
+ return "details" in " ".join(str(h) for h in headers).lower()
+
+ def _identify_management_section(self, headers: List[str]) -> Optional[str]:
+ txt = " ".join(str(h) for h in headers).lower()
+ if "maintenance" in txt: return "maintenance"
+ if "mass" in txt: return "mass"
+ if "fatigue" in txt: return "fatigue"
+ return None
+
+ def _extract_management_details(self, data_rows: List[List], extracted: Dict, section: str):
+ details_key = f"{section}_details"
+ extracted[details_key] = {}
+ for row in data_rows:
+ if len(row) < 2: continue
+ standard = str(row[0]).strip()
+ details = _smart_space(str(row[1]).strip())
+ if standard.startswith("Std") and details and details != "V" and len(details) > 10:
+ m = re.search(r"Std\s+\d+\.\s*([^(]+)", standard)
+ if m:
+ extracted[details_key][m.group(1).strip()] = details
+
+ def _extract_vehicle_driver_data(self, extracted_data: Dict, extracted: Dict):
+ vehicle_regs = extracted_data.get("vehicle_registrations", [])
+ if vehicle_regs:
+ extracted["vehicle_registrations"] = vehicle_regs
+ driver_records = extracted_data.get("driver_records", [])
+ if driver_records:
+ extracted["driver_records"] = driver_records
+
+ # Add this method inside your NHVASMerger class, with proper indentation
+ # Place it after the _extract_vehicle_driver_data method
+
+ def map_vehicle_registration_arrays(self, pdf_extracted: Dict, merged: Dict):
+ """Extract and map vehicle registration data (Maintenance + Mass) to DOCX arrays."""
+ vehicles_src = []
+
+ # Prefer rows we parsed ourselves (header-based). Fall back to curated list if present.
+ if "vehicles" in pdf_extracted and isinstance(pdf_extracted["vehicles"], list):
+ vehicles_src = pdf_extracted["vehicles"]
+ elif "vehicle_registrations" in pdf_extracted and isinstance(pdf_extracted["vehicle_registrations"], list):
+ # Normalize curated structure (list of dicts with keys like 'registration_number', etc.)
+ for row in pdf_extracted["vehicle_registrations"]:
+ if not isinstance(row, dict):
+ continue
+ v = {
+ "registration": _smart_space(row.get("registration_number") or row.get("registration") or ""),
+ # Maintenance table columns (names as seen in curated JSON)
+ "roadworthiness": _smart_space(row.get("roadworthiness_certificates", "")),
+ "maintenance_records": _smart_space(row.get("maintenance_records", "")),
+ "daily_checks": _smart_space(row.get("daily_checks", "")),
+ "fault_recording": _smart_space(row.get("fault_recording_reporting", "")),
+ "fault_repair": _smart_space(row.get("fault_repair", "")),
+ # Mass table columns (in case the curated list ever includes them)
+ "sub_contractor": _smart_space(row.get("sub_contractor", "")),
+ "sub_comp": _smart_space(row.get("sub_contracted_vehicles_statement_of_compliance", "")),
+ "weight_verification": _smart_space(row.get("weight_verification_records", "")),
+ "rfs_certification": _smart_space(row.get("rfs_suspension_certification", row.get("rfs_suspension_certification_#", ""))),
+ "suspension_maintenance": _smart_space(row.get("suspension_system_maintenance", "")),
+ "trip_records": _smart_space(row.get("trip_records", "")),
+ "fault_reporting_suspension": _smart_space(row.get("fault_recording_reporting_on_suspension_system", "")),
+ }
+ if v["registration"]:
+ vehicles_src.append(v)
+
+ if not vehicles_src:
+ return # nothing to map
+
+ # Build column arrays
+ regs = []
+ roadworthiness = []
+ maint_records = []
+ daily_checks = []
+ fault_recording = []
+ fault_repair = []
+
+ sub_contractors = []
+ weight_verification = []
+ rfs_certification = []
+ suspension_maintenance = []
+ trip_records = []
+ fault_reporting_suspension = []
+
+ for v in vehicles_src:
+ reg = _smart_space(v.get("registration", "")).strip()
+ if not reg:
+ continue
+ regs.append(reg)
+
+ roadworthiness.append(_smart_space(v.get("roadworthiness", "")).strip())
+ maint_records.append(_smart_space(v.get("maintenance_records", "")).strip())
+ daily_checks.append(_smart_space(v.get("daily_checks", "")).strip())
+ fault_recording.append(_smart_space(v.get("fault_recording", "")).strip())
+ fault_repair.append(_smart_space(v.get("fault_repair", "")).strip())
+
+ sub_contractors.append(_smart_space(v.get("sub_contractor", "")).strip())
+ weight_verification.append(_smart_space(v.get("weight_verification", "")).strip())
+ rfs_certification.append(_smart_space(v.get("rfs_certification", "")).strip())
+ suspension_maintenance.append(_smart_space(v.get("suspension_maintenance", "")).strip())
+ trip_records.append(_smart_space(v.get("trip_records", "")).strip())
+ fault_reporting_suspension.append(_smart_space(v.get("fault_reporting_suspension", "")).strip())
+
+ # Update Maintenance table arrays (if present in template)
+ if "Vehicle Registration Numbers Maintenance" in merged and regs:
+ m = merged["Vehicle Registration Numbers Maintenance"]
+ m["Registration Number"] = regs
+ m["Roadworthiness Certificates"] = roadworthiness
+ m["Maintenance Records"] = maint_records
+ m["Daily Checks"] = daily_checks
+ m["Fault Recording/ Reporting"] = fault_recording
+ m["Fault Repair"] = fault_repair
+
+ # Update Mass table arrays (if present in template)
+ if "Vehicle Registration Numbers Mass" in merged and regs:
+ ms = merged["Vehicle Registration Numbers Mass"]
+ ms["Registration Number"] = regs
+ ms["Sub contractor"] = sub_contractors
+ ms["Weight Verification Records"] = weight_verification
+ ms["RFS Suspension Certification #"] = rfs_certification
+ ms["Suspension System Maintenance"] = suspension_maintenance
+ ms["Trip Records"] = trip_records
+ ms["Fault Recording/ Reporting on Suspension System"] = fault_reporting_suspension
+
+ self.log_debug(f"Updated vehicle registration arrays for {len(regs)} vehicles")
+ # ───────────────────────────── map to DOCX (apply spacing + safe fallbacks) ─────────────────────────────
+ def map_to_docx_structure(self, pdf_extracted: Dict, docx_data: Dict, pdf_data: Dict) -> Dict:
+ merged = json.loads(json.dumps(docx_data))
+
+ # Audit Information
+ if "audit_info" in pdf_extracted and "Audit Information" in merged:
+ ai = pdf_extracted["audit_info"]
+ if ai.get("date_of_audit"):
+ merged["Audit Information"]["Date of Audit"] = [_smart_space(ai["date_of_audit"])]
+ if ai.get("location"):
+ merged["Audit Information"]["Location of audit"] = [_smart_space(ai["location"])]
+ if ai.get("auditor_name"):
+ merged["Audit Information"]["Auditor name"] = [_smart_space(ai["auditor_name"])]
+ if ai.get("matrix_id"):
+ merged["Audit Information"]["Audit Matrix Identifier (Name or Number)"] = [_smart_space(ai["matrix_id"])]
+
+ # Operator Information
+ if "operator_info" in pdf_extracted and "Operator Information" in merged:
+ op = pdf_extracted["operator_info"]
+ if op.get("name") and _looks_like_company(op["name"]):
+ merged["Operator Information"]["Operator name (Legal entity)"] = [_smart_space(op["name"])]
+ if op.get("trading_name"):
+ merged["Operator Information"]["Registered trading name/s"] = [_smart_space(op["trading_name"])]
+ if op.get("acn"):
+ merged["Operator Information"]["Australian Company Number"] = [_smart_space(op["acn"])]
+ if op.get("manual"):
+ merged["Operator Information"]["NHVAS Manual (Policies and Procedures) developed by"] = [_smart_space(op["manual"])]
+
+ # Contact details
+ if "operator_info" in pdf_extracted and "Operator contact details" in merged:
+ op = pdf_extracted["operator_info"]
+ if op.get("business_address"):
+ merged["Operator contact details"]["Operator business address"] = [_smart_space(op["business_address"])]
+ if op.get("postal_address"):
+ merged["Operator contact details"]["Operator Postal address"] = [_smart_space(op["postal_address"])]
+ if op.get("email"):
+ merged["Operator contact details"]["Email address"] = [op["email"]]
+ if op.get("phone"):
+ merged["Operator contact details"]["Operator Telephone Number"] = [_smart_space(op["phone"])]
+
+ # Attendance
+ if "attendance" in pdf_extracted and "Attendance List (Names and Position Titles)" in merged:
+ merged["Attendance List (Names and Position Titles)"]["Attendance List (Names and Position Titles)"] = _clean_list(pdf_extracted["attendance"])
+
+ # Business summary
+ if "business_summary" in pdf_extracted and "Nature of the Operators Business (Summary)" in merged:
+ merged["Nature of the Operators Business (Summary)"]["Nature of the Operators Business (Summary):"] = [_smart_space(pdf_extracted["business_summary"])]
+
+ # Vehicle summary
+ if "vehicle_summary" in pdf_extracted:
+ vs = pdf_extracted["vehicle_summary"]
+ if "Accreditation Vehicle Summary" in merged:
+ if vs.get("powered_vehicles"):
+ merged["Accreditation Vehicle Summary"]["Number of powered vehicles"] = [vs["powered_vehicles"]]
+ if vs.get("trailing_vehicles"):
+ merged["Accreditation Vehicle Summary"]["Number of trailing vehicles"] = [vs["trailing_vehicles"]]
+ if "Accreditation Driver Summary" in merged:
+ if vs.get("drivers_bfm"):
+ merged["Accreditation Driver Summary"]["Number of drivers in BFM"] = [vs["drivers_bfm"]]
+ if vs.get("drivers_afm"):
+ merged["Accreditation Driver Summary"]["Number of drivers in AFM"] = [vs["drivers_afm"]]
+
+ # Summary sections (unchanged behavior)
+ summary_maps = self.build_summary_maps(pdf_data)
+ for section_name, std_map in summary_maps.items():
+ if section_name in merged and std_map:
+ for detail_key, details_list in std_map.items():
+ if detail_key in merged[section_name]:
+ merged[section_name][detail_key] = details_list
+ continue
+ for docx_key in list(merged[section_name].keys()):
+ m1 = re.search(r"Std\s+(\d+)", detail_key)
+ m2 = re.search(r"Std\s+(\d+)", docx_key)
+ if m1 and m2 and m1.group(1) == m2.group(1):
+ merged[section_name][docx_key] = details_list
+ break
+
+ # Vehicle registration arrays via consolidated builder
+ sections = build_vehicle_sections(pdf_extracted)
+ if "Vehicle Registration Numbers Maintenance" in merged:
+ merged["Vehicle Registration Numbers Maintenance"].update(
+ sections["Vehicle Registration Numbers Maintenance"]
+ )
+ if "Vehicle Registration Numbers Mass" in merged:
+ merged["Vehicle Registration Numbers Mass"].update(
+ sections["Vehicle Registration Numbers Mass"]
+ )
+
+
+ # replace the whole Drivers/Scheduler block with:
+ if "drivers_detailed" in pdf_extracted and "Driver / Scheduler Records Examined" in merged:
+ drivers = pdf_extracted["drivers_detailed"]
+
+ def _looks_like_range(s):
+ return bool(re.search(r"[0-9]{1,2}[/-]", s or ""))
+
+ merged["Driver / Scheduler Records Examined"]["Roster / Schedule / Safe Driving Plan (Date Range)"] = [d.get("roster_schedule","") for d in drivers]
+ merged["Driver / Scheduler Records Examined"]["Fit for Duty Statement Completed (Yes/No)"] = [d.get("fit_for_duty","") for d in drivers]
+ merged["Driver / Scheduler Records Examined"]["Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)"] = [d.get("work_diary","") for d in drivers]
+
+
+ # --- Print accreditation name (robust, no UnboundLocalError) ---
+ if "Print accreditation name" in merged:
+ acc_name = "" # init
+ acc_name = _smart_space(pdf_extracted.get("print_accreditation_name") or "")
+ if not acc_name:
+ oi = pdf_extracted.get("operator_info") or {}
+ acc_name = _smart_space(oi.get("name") or "") or _smart_space(oi.get("trading_name") or "")
+ if acc_name:
+ merged["Print accreditation name"]["(print accreditation name)"] = [acc_name]
+
+ # Audit Declaration dates: prefer explicit extracted date; fallback to audit_info; ignore literal "Date"
+ if "Audit Declaration dates" in merged:
+ def _real_date(s: Optional[str]) -> bool:
+ return bool(s and re.search(r"\d", s) and not re.fullmatch(r"date", s.strip(), re.I))
+
+ val = pdf_extracted.get("audit_conducted_date")
+ if not _real_date(val):
+ val = (pdf_extracted.get("audit_info", {}) or {}).get("date_of_audit")
+
+ if _real_date(val):
+ merged["Audit Declaration dates"]["Audit was conducted on"] = [_smart_space(val)]
+
+
+ # Operator Declaration: page 22 image missing → derive from first Attendance "Name - Title"
+ if "Operator Declaration" in merged:
+ # If an explicit operator declaration exists, use it
+ if "operator_declaration" in pdf_extracted:
+ od = pdf_extracted["operator_declaration"]
+ pn = _smart_space(od.get("print_name", ""))
+ pt = _smart_space(od.get("position_title", ""))
+ if pn:
+ merged["Operator Declaration"]["Print Name"] = [pn]
+ if pt:
+ merged["Operator Declaration"]["Position Title"] = [pt]
+ else:
+ # Fallback: first "Name - Title" from Attendance
+ nt = self._first_attendance_name_title(pdf_extracted.get("attendance", []))
+ if nt:
+ merged["Operator Declaration"]["Print Name"] = [nt[0]]
+ merged["Operator Declaration"]["Position Title"] = [nt[1]]
+
+
+ # Paragraphs: fill company name for the 3 management headings; set the 2 dates
+ if "paragraphs" in merged:
+ paras = merged["paragraphs"]
+
+ audit_date = (
+ pdf_extracted.get("audit_conducted_date")
+ or pdf_extracted.get("audit_info", {}).get("date_of_audit")
+ )
+
+ # Prefer accreditation name, else operator legal name, else trading name
+ company_name = (
+ _smart_space(pdf_extracted.get("print_accreditation_name") or "")
+ or _smart_space(pdf_extracted.get("operator_info", {}).get("name") or "")
+ or _smart_space(pdf_extracted.get("operator_info", {}).get("trading_name") or "")
+ )
+
+ # Update the three layered headings
+ for key in ("MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"):
+ if key in paras and company_name:
+ paras[key] = [company_name]
+
+ # Second-last page: date under page heading
+ if "NHVAS APPROVED AUDITOR DECLARATION" in paras and audit_date:
+ paras["NHVAS APPROVED AUDITOR DECLARATION"] = [_smart_space(audit_date)]
+
+ # Last page: date under long acknowledgement paragraph
+ ack_key = ("I hereby acknowledge and agree with the findings detailed in this NHVAS Audit Summary Report. "
+ "I have read and understand the conditions applicable to the Scheme, including the NHVAS Business Rules and Standards.")
+ if ack_key in paras and audit_date:
+ paras[ack_key] = [_smart_space(audit_date)]
+
+ self._force_fill_maintenance_from_tables(pdf_data, merged)
+ return merged
+
+ # ───────────────────────────── merge & CLI (unchanged) ─────────────────────────────
+ def merge_pdf_to_docx(self, docx_data: Dict, pdf_data: Dict) -> Dict:
+ self.log_debug("Starting comprehensive PDF extraction...")
+ pdf_extracted = self.extract_from_pdf_comprehensive(pdf_data)
+ self.log_debug(f"Extracted PDF data keys: {list(pdf_extracted.keys())}")
+
+ self.log_debug("Mapping to DOCX structure...")
+ merged_data = self.map_to_docx_structure(pdf_extracted, docx_data, pdf_data)
+
+ for section_name, section_data in docx_data.items():
+ if isinstance(section_data, dict):
+ for label in section_data:
+ if (section_name in merged_data and
+ label in merged_data[section_name] and
+ merged_data[section_name][label] != docx_data[section_name][label]):
+ print(f"✓ Updated {section_name}.{label}: {merged_data[section_name][label]}")
+ return merged_data
+
+ def process_files(self, docx_file: str, pdf_file: str, output_file: str):
+ try:
+ print(f"Loading DOCX JSON from: {docx_file}")
+ with open(docx_file, 'r', encoding='utf-8') as f:
+ docx_data = json.load(f)
+ print(f"Loading PDF JSON from: {pdf_file}")
+ with open(pdf_file, 'r', encoding='utf-8') as f:
+ pdf_data = json.load(f)
+
+ print("Merging PDF data into DOCX structure...")
+ merged_data = self.merge_pdf_to_docx(docx_data, pdf_data)
+
+ print(f"Saving merged data to: {output_file}")
+ with open(output_file, 'w', encoding='utf-8') as f:
+ json.dump(merged_data, f, indent=2, ensure_ascii=False)
+
+ print("✅ Merge completed successfully!")
+ return merged_data
+ except Exception as e:
+ print(f"❌ Error processing files: {str(e)}")
+ import traceback
+ traceback.print_exc()
+ raise
+
+def main():
if len(sys.argv) != 4:
- print("Usage: python update_docx_with_pdf.py ")
- exit(1)
- update_json_with_pdf(sys.argv[1], sys.argv[2], sys.argv[3])
\ No newline at end of file
+ print("Usage: python nhvas_merger.py ")
+ print("Example: python nhvas_merger.py docx_template.json pdf_extracted.json merged_output.json")
+ sys.exit(1)
+
+ docx_file = sys.argv[1]
+ pdf_file = sys.argv[2]
+ output_file = sys.argv[3]
+
+ for file_path in [docx_file, pdf_file]:
+ if not Path(file_path).exists():
+ print(f"❌ File not found: {file_path}")
+ sys.exit(1)
+
+ merger = NHVASMerger()
+ merger.process_files(docx_file, pdf_file, output_file)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/updated_word.py b/updated_word.py
index 58f58b85035a723c2ba1eacc5d02a5cf8b1fea77..fc13cf6770ac2717695632cfd12464a8a2206df3 100644
--- a/updated_word.py
+++ b/updated_word.py
@@ -1,1801 +1,1189 @@
#!/usr/bin/env python3
-"""
-pipeline.py — safer matching and operator-declaration protections
-
-Key improvements:
- - find_matching_json_key_and_value() returns (key, value) so callers can accept/reject by key.
- - Higher fuzzy thresholds for risky substitutions.
- - Operator Declaration: avoid using attendance lists / unrelated keys for Position Title.
- - Vehicle header mapping: stronger normalized substring/ token matching for long headers.
- - Preserves existing logging and all previous handlers/logic.
-"""
-
-import json
+# update_docx_from_json.py
+import sys, json, re
+from pathlib import Path
+from typing import Dict, List, Tuple, Optional
from docx import Document
-from docx.shared import RGBColor
-import re
-from typing import Any, Tuple, Optional
-
-# ============================================================================
-# Heading patterns for document structure detection (unchanged)
-# ============================================================================
-HEADING_PATTERNS = {
- "main": [
- r"NHVAS\s+Audit\s+Summary\s+Report",
- r"NATIONAL\s+HEAVY\s+VEHICLE\s+ACCREDITATION\s+AUDIT\s+SUMMARY\s+REPORT",
- r"NHVAS\s+AUDIT\s+SUMMARY\s+REPORT"
- ],
- "sub": [
- r"AUDIT\s+OBSERVATIONS\s+AND\s+COMMENTS",
- r"MAINTENANCE\s+MANAGEMENT",
- r"MASS\s+MANAGEMENT",
- r"FATIGUE\s+MANAGEMENT",
- r"Fatigue\s+Management\s+Summary\s+of\s+Audit\s+findings",
- r"MAINTENANCE\s+MANAGEMENT\s+SUMMARY\s+OF\s+AUDIT\s+FINDINGS",
- r"MASS\s+MANAGEMENT\s+SUMMARY\s+OF\s+AUDIT\s+FINDINGS",
- r"Vehicle\s+Registration\s+Numbers\s+of\s+Records\s+Examined",
- r"CORRECTIVE\s+ACTION\s+REQUEST\s+\(CAR\)",
- r"NHVAS\s+APPROVED\s+AUDITOR\s+DECLARATION",
- r"Operator\s+Declaration",
- r"Operator\s+Information",
- r"Driver\s*/\s*Scheduler\s+Records\s+Examined"
- ]
-}
-
-# ============================================================================
-# Utility helpers
-# ============================================================================
-_unmatched_headers = {}
-def record_unmatched_header(header: str):
- if not header:
- return
- _unmatched_headers[header] = _unmatched_headers.get(header, 0) + 1
+from docx.shared import RGBColor, Pt # add Pt
+from docx.table import _Cell, Table
+from docx.text.paragraph import Paragraph
+from copy import deepcopy
+from docx.oxml.ns import qn
+from docx.oxml.table import CT_Tbl
+from docx.oxml.text.paragraph import CT_P
+
+BLACK = RGBColor(0, 0, 0)
+RED = RGBColor(0xFF, 0x00, 0x00)
+
+# ----------------------------- text helpers -----------------------------
+def _find_table_with_headers(doc: Document, must_have: list[str]) -> Optional[Table]:
+ for t in doc.tables:
+ if not t.rows:
+ continue
+ head = canon(" ".join(cell_text(c) for c in t.rows[0].cells))
+ if all(canon_label(x) in head for x in must_have):
+ return t
+ return None
-def load_json(filepath):
- with open(filepath, 'r', encoding='utf-8') as file:
- return json.load(file)
+def ensure_auditor_decl_headers(doc: Document) -> bool:
+ """
+ Second-last page table under 'NHVAS APPROVED AUDITOR DECLARATION'.
+ Force the HEADER row to read exactly:
+ [ Print Name | NHVR or Exemplar Global Auditor Registration Number ]
+ Never touch the bottom (values) row.
+ """
+ changed = False
+ expected_left = "Print Name"
+ expected_right = "NHVR or Exemplar Global Auditor Registration Number"
-def flatten_json(y, prefix=''):
- out = {}
- for key, val in y.items():
- new_key = f"{prefix}.{key}" if prefix else key
- if isinstance(val, dict):
- out.update(flatten_json(val, new_key))
- else:
- out[new_key] = val
- out[key] = val
- return out
+ for t in doc.tables:
+ if not t.rows or not t.rows[0].cells:
+ continue
+ # must look like the auditor table: header left says "Print Name", 2+ cols, 2+ rows
+ head_left = canon_label(cell_text(t.rows[0].cells[0]))
+ if head_left == "print name" and len(t.rows[0].cells) >= 2 and len(t.rows) >= 2:
+ # fix left header if needed
+ if canon_label(cell_text(t.rows[0].cells[0])) != canon_label(expected_left) or \
+ any(is_red_run(r) for p in t.rows[0].cells[0].paragraphs for r in p.runs):
+ _set_cell_text_black(t.rows[0].cells[0], expected_left)
+ changed = True
+ # unconditionally set the RIGHT header text (this is where "Peter Sheppard" was sitting)
+ if canon_label(cell_text(t.rows[0].cells[1])) != canon_label(expected_right) or \
+ any(is_red_run(r) for p in t.rows[0].cells[1].paragraphs for r in p.runs):
+ _set_cell_text_black(t.rows[0].cells[1], expected_right)
+ changed = True
+ # found and fixed the table; no need to continue
+ break
-def is_red(run):
- color = run.font.color
- try:
- return color and ((getattr(color, "rgb", None) and color.rgb == RGBColor(255, 0, 0)) or getattr(color, "theme_color", None) == 1)
- except Exception:
+ return changed
+
+
+def fill_operator_declaration(doc: Document, print_name: str, position_title: str) -> bool:
+ """Last page table: write values ONLY into the bottom row (red placeholders)."""
+ t = _find_table_with_headers(doc, ["Print Name", "Position Title"])
+ if not t or len(t.rows) < 2 or len(t.rows[0].cells) < 2:
return False
+ bot_left = t.rows[1].cells[0]
+ bot_right = t.rows[1].cells[1]
+
+ # only replace if that cell has a red placeholder
+ if any(is_red_run(r) for p in bot_left.paragraphs for r in p.runs):
+ _set_cell_text_black(bot_left, print_name)
+ if any(is_red_run(r) for p in bot_right.paragraphs for r in p.runs):
+ _set_cell_text_black(bot_right, position_title)
+ return True
-def get_value_as_string(value, field_name=""):
- if isinstance(value, list):
- if len(value) == 0:
- return ""
- elif len(value) == 1:
- return str(value[0])
- else:
- # Keep lists intact for special patterns (e.g., ACN digits) but default to join
- if "australian company number" in field_name.lower() or "company number" in field_name.lower():
- return value
- return " ".join(str(v) for v in value)
- else:
- return str(value)
-
-def get_clean_text(cell):
- text = ""
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- text += run.text
- return text.strip()
-
-def has_red_text(cell):
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run) and run.text.strip():
- return True
+def find_heading_index_from_end(doc: Document, heading: str) -> Optional[int]:
+ key = canon(heading)
+ allp = iter_paragraphs(doc)
+ for i in range(len(allp) - 1, -1, -1):
+ if key in canon(para_text(allp[i])):
+ return i
+ return None
+
+def set_date_by_heading_from_end(doc: Document, heading: str, date_text: str, max_scan: int = 60) -> bool:
+ """Find the LAST occurrence of `heading`, then replace the FIRST red run in the next paragraphs."""
+ if not date_text:
+ return False
+ allp = iter_paragraphs(doc)
+ idx = find_heading_index_from_end(doc, heading)
+ if idx is None:
+ return False
+ for p in allp[idx + 1 : min(idx + 1 + max_scan, len(allp))]:
+ if replace_red_in_paragraph(p, date_text): # writes in black
+ return True
return False
-def has_red_text_in_paragraph(paragraph):
- for run in paragraph.runs:
- if is_red(run) and run.text.strip():
+def set_date_by_paragraph_from_end(doc: Document, paragraph_text: str, date_text: str, max_scan: int = 60) -> bool:
+ """Find the LAST paragraph matching `paragraph_text`, then set the FIRST red run after it."""
+ if not date_text:
+ return False
+ key = canon(paragraph_text)
+ allp = iter_paragraphs(doc)
+ hit = None
+ for i in range(len(allp) - 1, -1, -1):
+ if key in canon(para_text(allp[i])):
+ hit = i
+ break
+ if hit is None:
+ return False
+ # date placeholder is on the LAST page, right after this long paragraph
+ for p in allp[hit + 1 : min(hit + 1 + max_scan, len(allp))]:
+ if replace_red_in_paragraph(p, date_text): # writes in black
return True
return False
-def normalize_header_text(s: str) -> str:
- if not s:
- return ""
- s = re.sub(r'\([^)]*\)', ' ', s) # remove parenthetical content
- s = s.replace("/", " ")
- s = re.sub(r'[^\w\s\#\%]', ' ', s)
- s = re.sub(r'\s+', ' ', s).strip().lower()
- # canonical tweaks
- s = s.replace('registrationno', 'registration number')
- s = s.replace('registrationnumber', 'registration number')
- s = s.replace('sub-contractor', 'sub contractor')
- s = s.replace('sub contracted', 'sub contractor')
- return s.strip()
-
-# ============================================================================
-# JSON matching functions
-# - find_matching_json_value: (keeps behavior used elsewhere)
-# - find_matching_json_key_and_value: returns (key, value) so callers can
-# decide whether to use an entry based on the matched key.
-# ============================================================================
-def find_matching_json_value(field_name, flat_json):
- """Legacy API: return value only (preserves existing callers)."""
- result = find_matching_json_key_and_value(field_name, flat_json)
- return result[1] if result else None
-
-def find_matching_json_key_and_value(field_name, flat_json) -> Optional[Tuple[str, Any]]:
+def set_layer3_name_after_management_heading(doc: Document, mid_heading: str, allowed_prev_titles: List[str], name: str) -> bool:
+ if not name:
+ return False
+
+ allp = iter_paragraphs(doc)
+ wrote = False
+ mid = canon(mid_heading)
+ allowed_prev = {canon(t) for t in allowed_prev_titles}
+
+ for i, p in enumerate(allp):
+ if canon(para_text(p)) != mid:
+ continue
+
+ # previous non-empty must be one of the allowed titles
+ j = i - 1
+ while j >= 0 and not nz(para_text(allp[j])):
+ j -= 1
+ if j < 0 or canon(para_text(allp[j])) not in allowed_prev:
+ continue
+
+ # next non-empty is the 3rd line we overwrite
+ k = i + 1
+ while k < len(allp) and not nz(para_text(allp[k])):
+ k += 1
+ if k >= len(allp):
+ continue
+
+ # compute target size from the middle heading; fall back to a sensible bump
+ target_size = _para_effective_font_size(allp[i]) or Pt(16)
+
+ _clear_para_and_write_black(allp[k], name)
+
+ # apply size to all runs explicitly (overrides style)
+ for r in allp[k].runs:
+ r.font.size = target_size
+
+ wrote = True
+
+ return wrote
+
+def _para_effective_font_size(p: Paragraph):
+ # try explicit run sizes first
+ for r in p.runs:
+ if r.font.size:
+ return r.font.size
+ # then the paragraph style
+ if p.style and p.style.font and p.style.font.size:
+ return p.style.font.size
+ return None
+
+# --- helpers for summary tables ---
+# --- helpers for summary overwrite ---
+def _std_key(s: str) -> str:
+ """
+ Normalize a label to match a 'Std N' key.
+ e.g. 'Std 7. Internal Review' -> 'std 7'
+ """
+ t = canon_label(s)
+ m = re.match(r"(std\s+\d+)", t)
+ return m.group(1) if m else t
+
+def _looks_like_summary_table(table: Table) -> Optional[Tuple[int, int]]:
"""
- Return (matched_key, matched_value) or None.
- Safer thresholds: fuzzy matches require >=0.35 by default.
+ Return (label_col_idx, details_col_idx) if this is a Summary table
+ with a DETAILS column; otherwise None.
"""
- field_name = (field_name or "").strip()
- if not field_name:
+ if not table.rows:
+ return None
+ first = table.rows[0]
+ cols = len(first.cells)
+ if cols < 2:
+ return None
+
+ # header texts for first row
+ head = [canon(cell_text(c)) for c in first.cells]
+
+ # find DETAILS column
+ details_col = None
+ for j, t in enumerate(head):
+ if "detail" in t:
+ details_col = j
+ break
+ if details_col is None:
return None
- # Exact match
- if field_name in flat_json:
- print(f" ✅ Direct match found for key '{field_name}'")
- return field_name, flat_json[field_name]
-
- # Case-insensitive exact
- for key, value in flat_json.items():
- if key.lower() == field_name.lower():
- print(f" ✅ Case-insensitive match found for key '{field_name}' -> '{key}'")
- return key, value
-
- # Special-case 'print name' preference for operator vs auditor (prefer fully-qualified)
- if field_name.lower().strip() == "print name":
- operator_keys = [k for k in flat_json.keys() if "operator" in k.lower() and "print name" in k.lower()]
- auditor_keys = [k for k in flat_json.keys() if "auditor" in k.lower() and ("print name" in k.lower() or "name" in k.lower())]
- if operator_keys:
- print(f" ✅ Operator Print Name match: '{field_name}' -> '{operator_keys[0]}'")
- return operator_keys[0], flat_json[operator_keys[0]]
- elif auditor_keys:
- print(f" ✅ Auditor Name match: '{field_name}' -> '{auditor_keys[0]}'")
- return auditor_keys[0], flat_json[auditor_keys[0]]
-
- # Suffix match for nested keys (e.g., 'section.field')
- for key, value in flat_json.items():
- if '.' in key and key.split('.')[-1].lower() == field_name.lower():
- print(f" ✅ Suffix match found for key '{field_name}' -> '{key}'")
- return key, value
-
- # Clean and exact
- clean_field = re.sub(r'[^\w\s]', ' ', field_name.lower()).strip()
- clean_field = re.sub(r'\s+', ' ', clean_field)
- for key, value in flat_json.items():
- clean_key = re.sub(r'[^\w\s]', ' ', key.lower()).strip()
- clean_key = re.sub(r'\s+', ' ', clean_key)
- if clean_field == clean_key:
- print(f" ✅ Clean match found for key '{field_name}' -> '{key}'")
- return key, value
-
- # Fuzzy matching with token scoring
- field_words = set(word.lower() for word in re.findall(r'\b\w+\b', field_name) if len(word) > 2)
- if not field_words:
+ # find the label column (left-hand standards column)
+ label_col = None
+ for j, t in enumerate(head):
+ if any(k in t for k in ["maintenance management", "mass management", "fatigue management"]):
+ label_col = j
+ break
+ if label_col is None:
+ # fallback: assume the first non-DETAILS column is the label column
+ label_col = 0 if details_col != 0 else 1
+
+ return (label_col, details_col)
+def count_header_rows(table: Table, scan_up_to: int = 6) -> int:
+ """Heuristically count header rows (stop when first data row like '1.' appears)."""
+ for i, row in enumerate(table.rows[:scan_up_to]):
+ first = cell_text(row.cells[0]).strip()
+ if re.match(r"^\d+\.?$", first):
+ return i
+ return 1
+def _header_col_texts(table: Table, scan_rows: int = 5) -> List[str]:
+ scan_rows = min(scan_rows, len(table.rows))
+ if scan_rows == 0:
+ return []
+ # pick the row with the most cells as base
+ base_row = max(range(scan_rows), key=lambda i: len(table.rows[i].cells))
+ base_cols = len(table.rows[base_row].cells)
+ cols = []
+ for j in range(base_cols):
+ parts = []
+ for i in range(scan_rows):
+ row = table.rows[i]
+ if j < len(row.cells):
+ parts.append(cell_text(row.cells[j]))
+ cols.append(canon(" ".join(parts)))
+ return cols
+
+def count_header_rows(table: Table, scan_up_to: int = 6) -> int:
+ """Header ends right before the first row whose 1st cell looks like '1.'"""
+ limit = min(scan_up_to, len(table.rows))
+ for i in range(limit):
+ first = cell_text(table.rows[i].cells[0]).strip()
+ if re.match(r"^\d+\.?$", first):
+ return i
+ # fallback to 1 header row
+ return 1
+
+def map_cols_mass_strict(table: Table) -> Dict[str, int]:
+ cols = _header_col_texts(table, 5)
+ def first_col(*needles):
+ for j, t in enumerate(cols):
+ if all(n in t for n in needles):
+ return j
return None
+ idx = {
+ "no": first_col("no"),
+ "reg": first_col("registration", "number") or first_col("registration"),
+ "wv": first_col("weight", "verification"),
+ "rfs": first_col("rfs", "cert") or first_col("rfs", "certification"),
+ "susp": first_col("suspension", "maintenance"),
+ "trip": first_col("trip", "record"),
+ "frs": first_col("fault", "suspension") or first_col("fault", "reporting", "suspension"),
+ }
+ return {k: v for k, v in idx.items() if v is not None}
+
+def find_mass_vehicle_numbers_table(doc: Document) -> Optional[Table]:
+ """Pick the Mass vehicle-number table by matching its column set (not the Summary table)."""
+ best = None
+ best_score = -1
+ for t in iter_tables(doc):
+ cols = _header_col_texts(t, 5)
+ allhdr = " ".join(cols)
+ # must look like the vehicle numbers table
+ hits = 0
+ hits += int(any("registration" in c and "number" in c for c in cols))
+ hits += int(any("weight" in c and "verification" in c for c in cols))
+ hits += int(any("rfs" in c and ("cert" in c or "certification" in c) for c in cols))
+ hits += int(any("suspension" in c and "maintenance" in c for c in cols))
+ hits += int(any("trip" in c and "record" in c for c in cols))
+ hits += int(any("fault" in c and "suspension" in c for c in cols))
+ # reject obvious Summary tables
+ if "details" in allhdr:
+ continue
+ # prefer tables with numbering column and many rows
+ score = hits + (0.5 if any("no" == c or c.startswith("no ") for c in cols) else 0) + (len(t.rows) / 100.0)
+ if hits >= 4 and score > best_score:
+ best, best_score = t, score
+ return best
- best_key = None
- best_value = None
- best_score = 0.0
+def update_operator_declaration(doc: Document, print_name: str, position_title: str) -> bool:
+ """
+ First try strict table label mapping for 'Print Name' and 'Position Title'.
+ If not found, fallback to the first two red placeholders under the 'Operator Declaration' heading.
+ """
+ changed = False
+ # 1) Table label approach
+ for lbl, val in (("Print Name", print_name), ("Position Title", position_title)):
+ if not val:
+ continue
+ loc = find_label_cell(doc, lbl)
+ if not loc:
+ # tolerate odd spacing/colon/camelcase
+ for alt in ("PrintName", "Print Name", "Print Name:", "PositionTitle", "Position Title", "Position Title:"):
+ loc = find_label_cell(doc, alt)
+ if loc:
+ break
+ if loc:
+ t, r, c = loc
+ cell = get_adjacent_value_cell(t, r, c)
+ if not replace_red_in_cell(cell, val):
+ _set_cell_text_black(cell, val)
+ changed = True
+
+ if changed:
+ return True
+
+ # 2) Fallback: heading-scoped red placeholders
+ head = "OPERATOR DECLARATION"
+ p = find_heading_paragraph(doc, head) or find_heading_paragraph(doc, head.title())
+ if not p:
+ return False
+ allp = iter_paragraphs(doc)
+ try:
+ i = allp.index(p)
+ except ValueError:
+ i = 0
+ red_targets = []
+ for q in allp[i+1:i+1+20]:
+ reds = [r for r in q.runs if is_red_run(r)]
+ if reds:
+ red_targets.extend(reds)
+ if len(red_targets) >= 2:
+ break
+ wrote = False
+ if print_name and red_targets:
+ _set_text_and_black(red_targets[0], print_name); wrote = True
+ if position_title and len(red_targets) >= 2:
+ _set_text_and_black(red_targets[1], position_title); wrote = True
+ return wrote
+
+
+def fill_mass_vehicle_table_preserve_headers(table: Table, arrays: Dict[str, List[str]]):
+ colmap = map_cols_mass_strict(table)
+ if "reg" not in colmap:
+ return
+ hdr_rows = count_header_rows(table, 6)
+ regs = arrays.get("Registration Number", [])
+ n = len(regs)
+
+ # clear data rows only
+ while len(table.rows) > hdr_rows:
+ table._tbl.remove(table.rows[-1]._tr)
+ # ensure enough rows
+ while len(table.rows) < hdr_rows + n:
+ table.add_row()
+
+ def put(row, key, arr_key, i):
+ if key in colmap:
+ vals = arrays.get(arr_key, [])
+ val = nz(vals[i]) if i < len(vals) else ""
+ replace_red_in_cell(row.cells[colmap[key]], val)
+
+ for i in range(n):
+ row = table.rows[hdr_rows + i]
+ replace_red_in_cell(row.cells[colmap["reg"]], nz(regs[i]))
+ put(row, "wv", "Weight Verification Records", i)
+ put(row, "rfs", "RFS Suspension Certification #", i)
+ put(row, "susp", "Suspension System Maintenance", i)
+ put(row, "trip", "Trip Records", i)
+ put(row, "frs", "Fault Recording/ Reporting on Suspension System", i)
+
+def overwrite_summary_details_cells(doc: Document, section_name: str, section_dict: Dict[str, List[str]]) -> int:
+ """For a Summary table (Maintenance/Mass/Fatigue), replace the entire DETAILS cell
+ for each Std N row with the JSON text (written in black)."""
+ # build desired texts
+ desired: Dict[str, str] = { _std_key(k): join_value(v) for k, v in section_dict.items() }
+
+ # pick which tables belong to this section by header sniff
+ wanted_prefix = canon_label(section_name.split()[0]) # "maintenance" | "mass" | "fatigue"
+
+ updated = 0
+ for t in doc.tables:
+ cols = _looks_like_summary_table(t)
+ if not cols:
+ continue
+ label_col, details_col = cols
- for key, value in flat_json.items():
- key_words = set(word.lower() for word in re.findall(r'\b\w+\b', key) if len(word) > 2)
- if not key_words:
+ head_txt = table_header_text(t, up_to_rows=2)
+ if wanted_prefix not in head_txt: # keep to the correct section
continue
- common = field_words.intersection(key_words)
- if not common:
- # allow substring in normalized forms as a weaker fallback
- norm_field = normalize_header_text(field_name)
- norm_key = normalize_header_text(key)
- if norm_field and norm_key and (norm_field in norm_key or norm_key in norm_field):
- # substring score based on length ratio
- substring_score = min(len(norm_field), len(norm_key)) / max(len(norm_field), len(norm_key))
- final_score = 0.4 * substring_score
- else:
- final_score = 0.0
- else:
- similarity = len(common) / len(field_words.union(key_words))
- coverage = len(common) / len(field_words)
- final_score = (similarity * 0.6) + (coverage * 0.4)
-
- if final_score > best_score:
- best_score = final_score
- best_key = key
- best_value = value
-
- # Accept only reasonable fuzzy matches (threshold 0.35)
- if best_key and best_score >= 0.35:
- print(f" ✅ Fuzzy match found for key '{field_name}' with JSON key '{best_key}' (score: {best_score:.2f})")
- return best_key, best_value
-
- print(f" ❌ No match found for '{field_name}'")
- return None
+ # walk body rows
+ for i in range(1, len(t.rows)):
+ row = t.rows[i]
+ key = _std_key(cell_text(row.cells[label_col]))
+
+ # exact match or "std N" prefix match
+ cand = desired.get(key)
+ if not cand:
+ m = re.match(r"(std\s+\d+)", key)
+ if m:
+ for k2, v2 in desired.items():
+ if k2.startswith(m.group(1)):
+ cand = v2
+ break
+ if not cand:
+ continue
+
+ _set_cell_text_black(row.cells[details_col], cand) # full overwrite, black
+ updated += 1
+ return updated
+
+SPLIT_SENT_PAT = re.compile(r"(?<=\.|\?|!)\s+")
+ORDINAL_DATE_PAT = re.compile(r"\b(\d{1,2}(?:st|nd|rd|th)\s+[A-Za-z]+\s+\d{4})\b", re.I)
+
+def split_sentences_keep(text: str) -> List[str]:
+ s = " ".join(str(text or "").split())
+ if not s:
+ return []
+ out = []
+ start = 0
+ for m in SPLIT_SENT_PAT.finditer(s):
+ out.append(s[start:m.start()].strip())
+ start = m.end()
+ last = s[start:].strip()
+ if last:
+ out.append(last)
+ return out
+
+_sent_split = re.compile(r'(?<=[.!?])\s+|\n+')
+_date_pat = re.compile(r'\b(?:\d{1,2}(?:st|nd|rd|th)\s+[A-Za-z]+\s+\d{4}|\d{1,2}/\d{1,2}/\d{2,4}|[A-Za-z]+\s+\d{1,2},\s*\d{4})\b')
+
+def extract_summary_snippets(desired_text: str):
+ sents = _sentences(desired_text)
+ dates = [m.group(0) for m in _date_pat.finditer(desired_text)]
+ pick = lambda rx: next((s for s in sents if re.search(rx, s, re.I)), None)
+ return {
+ "sheet_sent": pick(r'\b(daily\s+check|sheet)\b'),
+ "sheet_phrase": _extract_sheet_phrase_from_desired(desired_text),
+ "review": pick(r'\binternal\s+review\b'),
+ "qcs": pick(r'\bquarterly\b.*\bcompliance\b') or pick(r'\bquarterly\b'),
+ "dates": dates,
+ "sents": sents,
+ }
+
+def fill_management_summary_tables(doc: Document, section_key: str, section_data: Dict[str, List[str]]):
+ """
+ Fill ALL summary tables for the given section_key ('maintenance'|'mass'|'fatigue')
+ by matching each row label (left column) against keys in section_data and
+ patching only the red text inside the DETAILS cell.
+ """
+ targets = [x for x in find_all_summary_tables(doc) if x[0] == section_key]
+ if not targets:
+ return
+
+ # build list of (normalized label, original label, desired_text)
+ desired = []
+ for label, vals in section_data.items():
+ want = canon_label(label)
+ if not want:
+ continue
+ desired.append((want, label, join_value(vals)))
+
+ for _, table, lcol, dcol in targets:
+ # iterate data rows (skip header)
+ for i in range(1, len(table.rows)):
+ left_txt_norm = canon_label(cell_text(table.rows[i].cells[lcol]))
+ if not left_txt_norm:
+ continue
+ for want_norm, _orig_lbl, value in desired:
+ # loose contains match handles minor punctuation differences
+ if want_norm and want_norm in left_txt_norm:
+ patch_details_cell_from_json(table.rows[i].cells[dcol], value)
+
+def _set_text_and_black(run, new_text: str):
+ """Replace a run's text and force color to black (clears theme color too)."""
+ if new_text is None:
+ new_text = ""
+ run.text = str(new_text)
+ run.font.color.rgb = BLACK
+ try:
+ # clear any theme color so rgb sticks
+ run.font.color.theme_color = None
+ except Exception:
+ pass
+
+def update_business_summary_once(doc: Document, value) -> bool:
+ """Replace only the red summary paragraph; keep 'Accreditation Number' and 'Expiry Date' lines."""
+ loc = (find_label_cell(doc, "Nature of the Operators Business (Summary)")
+ or find_label_cell(doc, "Nature of the Operators Business (Summary):"))
+ if not loc:
+ return False
-# ============================================================================
-# Red text helpers (unchanged except kept robust)
-# ============================================================================
-def extract_red_text_segments(cell):
- red_segments = []
- for para_idx, paragraph in enumerate(cell.paragraphs):
- current_segment = ""
- segment_runs = []
- for run_idx, run in enumerate(paragraph.runs):
- if is_red(run):
- if run.text:
- current_segment += run.text
- segment_runs.append((para_idx, run_idx, run))
- else:
- if segment_runs:
- red_segments.append({'text': current_segment, 'runs': segment_runs.copy(), 'paragraph_idx': para_idx})
- current_segment = ""
- segment_runs = []
- if segment_runs:
- red_segments.append({'text': current_segment, 'runs': segment_runs.copy(), 'paragraph_idx': para_idx})
- return red_segments
-
-def replace_all_red_segments(red_segments, replacement_text):
- if not red_segments:
- return 0
- if '\n' in replacement_text:
- replacement_lines = replacement_text.split('\n')
+ t, r, c = loc
+ cell = get_adjacent_value_cell(t, r, c)
+ if not cell.paragraphs:
+ cell.add_paragraph("")
+
+ txt = join_value(value)
+
+ # find paragraphs with any red runs (the placeholders for the summary)
+ red_paras = [p for p in cell.paragraphs if any(is_red_run(run) for run in p.runs)]
+
+ if red_paras:
+ # write the summary into the first red paragraph (in black)
+ _clear_para_and_write_black(red_paras[0], txt)
+ # clear any extra red placeholders
+ for p in red_paras[1:]:
+ _clear_para_and_write_black(p, "")
else:
- replacement_lines = [replacement_text]
- replacements_made = 0
- first_segment = red_segments[0]
- if first_segment['runs']:
- first_run = first_segment['runs'][0][2]
- first_run.text = replacement_lines[0]
- first_run.font.color.rgb = RGBColor(0, 0, 0)
- replacements_made = 1
- for _, _, run in first_segment['runs'][1:]:
- run.text = ''
- for segment in red_segments[1:]:
- for _, _, run in segment['runs']:
- run.text = ''
- if len(replacement_lines) > 1 and red_segments:
- try:
- first_run = red_segments[0]['runs'][0][2]
- paragraph = first_run.element.getparent()
- from docx.oxml import OxmlElement
- for line in replacement_lines[1:]:
- if line.strip():
- br = OxmlElement('w:br')
- first_run.element.append(br)
- new_run = paragraph.add_run(line.strip())
- new_run.font.color.rgb = RGBColor(0, 0, 0)
- except Exception:
- if red_segments and red_segments[0]['runs']:
- first_run = red_segments[0]['runs'][0][2]
- first_run.text = ' '.join(replacement_lines)
- first_run.font.color.rgb = RGBColor(0, 0, 0)
- return replacements_made
-
-def replace_single_segment(segment, replacement_text):
- if not segment['runs']:
+ # no red placeholder found: just put the summary into the first paragraph, leave others
+ _clear_para_and_write_black(cell.paragraphs[0], txt)
+
+ return True
+
+
+def _nuke_cell_paragraphs(cell: _Cell):
+ """Remove ALL paragraphs from a cell (true delete, not just emptying runs)."""
+ for p in list(cell.paragraphs):
+ p._element.getparent().remove(p._element)
+
+def _clear_para_and_write_black(paragraph, text: str):
+ """Clear a whole paragraph and write fresh black text."""
+ # wipe existing runs
+ for r in list(paragraph.runs):
+ r.text = ""
+ r = paragraph.add_run(str(text or ""))
+ r.font.color.rgb = BLACK
+ try:
+ r.font.color.theme_color = None
+ except Exception:
+ pass
+
+def _set_cell_text_black(cell, text: str):
+ """Clear a table cell and insert black text."""
+ # remove text from all runs in all paragraphs
+ for p in cell.paragraphs:
+ for r in p.runs:
+ r.text = ""
+ p = cell.paragraphs[0] if cell.paragraphs else cell.add_paragraph()
+ r = p.add_run(str(text or ""))
+ r.font.color.rgb = BLACK
+ try:
+ r.font.color.theme_color = None
+ except Exception:
+ pass
+
+def nz(x: Optional[str]) -> str:
+ return (x or "").strip()
+
+def canon(s: str) -> str:
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ return re.sub(r"[^a-z0-9/#()+,.\- ]+", "", s)
+
+def canon_label(s: str) -> str:
+ # labels often vary by punctuation/casing; keep digits/letters
+ s = re.sub(r"\s+", " ", str(s)).strip().lower()
+ s = s.replace("–", "-").replace("—", "-")
+ s = re.sub(r"[^a-z0-9 ]+", " ", s)
+ return re.sub(r"\s+", " ", s).strip()
+
+def join_value(value) -> str:
+ if isinstance(value, list):
+ # Keep multi-line when list provided
+ return "\n".join([str(v) for v in value if nz(v)])
+ return str(value)
+
+def split_digits(s: str) -> List[str]:
+ return re.findall(r"\d", s)
+
+def para_text(p: Paragraph) -> str:
+ return "".join(run.text for run in p.runs)
+
+def cell_text(c: _Cell) -> str:
+ return "\n".join(para_text(p) for p in c.paragraphs)
+
+def is_red_run(run) -> bool:
+ col = run.font.color
+ if not col:
return False
- first_run = segment['runs'][0][2]
- first_run.text = replacement_text
- first_run.font.color.rgb = RGBColor(0, 0, 0)
- for _, _, run in segment['runs'][1:]:
- run.text = ''
+ if col.rgb is not None:
+ return col.rgb == RED
+ # Some templates use theme colors; treat explicit red text snippets only
+ return False
+
+def replace_red_in_paragraph(p: Paragraph, new_text: str) -> bool:
+ replaced = False
+ red_runs = [r for r in p.runs if is_red_run(r)]
+ if not red_runs:
+ return False
+ # collapse all red runs into one and write value (in black)
+ first = red_runs[0]
+ _set_text_and_black(first, new_text)
+ for r in red_runs[1:]:
+ r.text = ""
+ replaced = True
+ return replaced
+
+def replace_red_in_cell(cell: _Cell, new_text: str) -> bool:
+ # replace only red runs; if none, replace whole cell with a single run (fallback)
+ any_red = False
+ for p in cell.paragraphs:
+ if replace_red_in_paragraph(p, new_text):
+ any_red = True
+ if any_red:
+ return True
+ # fallback: clear cell, set single paragraph text in black
+ _set_cell_text_black(cell, new_text)
return True
-def replace_red_text_in_cell(cell, replacement_text):
- red_segments = extract_red_text_segments(cell)
- if not red_segments:
- return 0
- return replace_all_red_segments(red_segments, replacement_text)
-
-# ============================================================================
-# Specialized handlers (vehicle, attendance, management, operator) with fixes
-# ============================================================================
-
-def handle_australian_company_number(row, company_numbers):
- replacements_made = 0
- for i, digit in enumerate(company_numbers):
- cell_idx = i + 1
- if cell_idx < len(row.cells):
- cell = row.cells[cell_idx]
- if has_red_text(cell):
- cell_replacements = replace_red_text_in_cell(cell, str(digit))
- replacements_made += cell_replacements
- print(f" -> Placed digit '{digit}' in cell {cell_idx + 1}")
- return replacements_made
-
-def handle_vehicle_registration_table(table, flat_json):
+def parse_attendance_lines(value) -> List[str]:
"""
- Stronger header normalization + substring matching for long headers.
- Keeps existing behavior but reduces 'No mapping found' by using normalized substring matching.
+ Parse strings like:
+ "Peter Sheppard - Compliance Greg Dyer - Auditor"
+ into:
+ ["Peter Sheppard - Compliance", "Greg Dyer - Auditor"]
+ Handles lists, newlines, semicolons, and pipes too.
"""
- replacements_made = 0
-
- # Build candidate vehicle_section similar to prior logic
- vehicle_section = None
- # Prefer keys explicitly mentioning 'registration' or 'vehicle'
- candidates = [(k, v) for k, v in flat_json.items() if 'registration' in k.lower() or 'vehicle' in k.lower()]
- if candidates:
- # prefer the one with longest key match (likely most specific)
- candidates.sort(key=lambda kv: -len(kv[0]))
- vehicle_section = candidates[0][1]
-
- # fallback: collect flattened keys that look like vehicle columns
- if vehicle_section is None:
- potential_columns = {}
- for key, value in flat_json.items():
- lk = key.lower()
- if any(col_name in lk for col_name in ["registration number", "sub-contractor", "weight verification", "rfs suspension", "trip records", "fault recording", "fault repair", "daily checks", "roadworthiness"]):
- if "." in key:
- column_name = key.split(".")[-1]
- else:
- column_name = key
- potential_columns[column_name] = value
- if potential_columns:
- vehicle_section = potential_columns
- print(f" ✅ Found vehicle data from flattened keys: {list(vehicle_section.keys())}")
-
- if not vehicle_section:
- print(f" ❌ Vehicle registration data not found in JSON")
- return 0
-
- # Normalize vehicle_section into dict of column_label -> list/value
- if isinstance(vehicle_section, list):
- # if list of dicts, pivot
- if vehicle_section and isinstance(vehicle_section[0], dict):
- flattened = {}
- for entry in vehicle_section:
- for k, v in entry.items():
- flattened.setdefault(k, []).append(v)
- vehicle_section = flattened
- else:
- # can't interpret, bail
- vehicle_section = {}
-
- if not isinstance(vehicle_section, dict):
- try:
- vehicle_section = dict(vehicle_section)
- except Exception:
- vehicle_section = {}
-
- print(f" ✅ Found vehicle registration data with {len(vehicle_section)} columns")
-
- # Find header row (look for registration + number or reg no)
- header_row_idx = -1
- header_row = None
- for row_idx, row in enumerate(table.rows):
- row_text = " ".join(get_clean_text(cell).lower() for cell in row.cells)
- if ("registration" in row_text and "number" in row_text) or "reg no" in row_text or "registration no" in row_text:
- header_row_idx = row_idx
- header_row = row
- break
+ if isinstance(value, list):
+ s = " ".join(str(v) for v in value if v)
+ else:
+ s = str(value or "")
+ s = re.sub(r"\s+", " ", s).strip()
+ if not s:
+ return []
- if header_row_idx == -1:
- print(f" ❌ Could not find header row in vehicle table")
- return 0
-
- print(f" ✅ Found header row at index {header_row_idx}")
-
- # Build master labels from vehicle_section keys
- master_labels = {}
- for orig_key in vehicle_section.keys():
- norm = normalize_header_text(str(orig_key))
- if norm:
- # if there is collision, prefer longer orig_key (more specific)
- if norm in master_labels:
- if len(orig_key) > len(master_labels[norm]):
- master_labels[norm] = orig_key
- else:
- master_labels[norm] = orig_key
-
- # Map header cells using normalized token overlap + substring fallback
- column_mapping = {}
- for col_idx, cell in enumerate(header_row.cells):
- header_text = get_clean_text(cell).strip()
- if not header_text:
- continue
- header_key = header_text.strip().lower()
- if header_key in {"no", "no.", "#"}:
+ # First split on explicit separators; then within each chunk, extract Name - Title pairs.
+ chunks = re.split(r"\s*[\n;|]\s*", s)
+ items: List[str] = []
+
+ pair_pat = re.compile(
+ r"([A-Z][A-Za-z.'-]+(?:\s+[A-Z][A-Za-z.'-]+){0,3})\s*-\s*"
+ r"([^-\n]+?)(?=\s+[A-Z][A-Za-z.'-]+(?:\s+[A-Z][A-Za-z.'-]+){0,3}\s*-\s*|$)"
+ )
+
+ for chunk in chunks:
+ chunk = chunk.strip()
+ if not chunk:
continue
+ found = False
+ for m in pair_pat.finditer(chunk):
+ name = m.group(1).strip()
+ title = m.group(2).strip()
+ items.append(f"{name} - {title}")
+ found = True
+ if not found:
+ # Fallback: single "Name - Title"
+ if " - " in chunk:
+ a, b = chunk.split(" - ", 1)
+ items.append(f"{a.strip()} - {b.strip()}")
+ elif chunk:
+ items.append(chunk)
+
+ return items
+
+def fill_attendance_block(doc: Document, value) -> bool:
+ items = parse_attendance_lines(value)
+ if not items:
+ return False
- norm_header = normalize_header_text(header_text)
- best_match = None
- best_score = 0.0
-
- # exact normalized match
- if norm_header in master_labels:
- best_match = master_labels[norm_header]
- best_score = 1.0
- else:
- # token overlap
- header_tokens = set(t for t in norm_header.split() if len(t) > 2)
- for norm_key, orig_label in master_labels.items():
- key_tokens = set(t for t in norm_key.split() if len(t) > 2)
- if not key_tokens:
- continue
- common = header_tokens.intersection(key_tokens)
- if common:
- score = len(common) / max(1, len(header_tokens.union(key_tokens)))
- else:
- # substring fallback on normalized strings
- if norm_header in norm_key or norm_key in norm_header:
- score = min(len(norm_header), len(norm_key)) / max(len(norm_header), len(norm_key))
- else:
- score = 0.0
- if score > best_score:
- best_score = score
- best_match = orig_label
-
- # additional heuristic: if header contains 'roadworthiness' and any master_labels key contains that token, accept
- if not best_match:
- for norm_key, orig_label in master_labels.items():
- if 'roadworthiness' in norm_header and 'roadworthiness' in norm_key:
- best_match = orig_label
- best_score = 0.65
- break
-
- if best_match and best_score >= 0.30:
- column_mapping[col_idx] = best_match
- print(f" 📌 Column {col_idx}: '{header_text}' -> '{best_match}' (norm:'{norm_header}' score:{best_score:.2f})")
- else:
- print(f" ⚠️ No mapping found for '{header_text}' (norm:'{norm_header}')")
- record_unmatched_header(header_text)
-
- if not column_mapping:
- print(f" ❌ No column mappings found")
- return 0
-
- # Determine how many rows of data to populate
- max_data_rows = 0
- for json_key, data in vehicle_section.items():
- if isinstance(data, list):
- max_data_rows = max(max_data_rows, len(data))
-
- print(f" 📌 Need to populate {max_data_rows} data rows")
-
- # Populate or add rows
- for data_row_index in range(max_data_rows):
- table_row_idx = header_row_idx + 1 + data_row_index
- if table_row_idx >= len(table.rows):
- print(f" ⚠️ Row {table_row_idx + 1} doesn't exist, adding one")
- table.add_row()
-
- row = table.rows[table_row_idx]
- print(f" 📌 Processing data row {table_row_idx + 1} (vehicle {data_row_index + 1})")
- for col_idx, json_key in column_mapping.items():
- if col_idx < len(row.cells):
- cell = row.cells[col_idx]
- column_data = vehicle_section.get(json_key, [])
- if isinstance(column_data, list) and data_row_index < len(column_data):
- replacement_value = str(column_data[data_row_index])
- cell_text = get_clean_text(cell)
- if has_red_text(cell) or not cell_text.strip():
- if not cell_text.strip():
- cell.text = replacement_value
- replacements_made += 1
- print(f" -> Added '{replacement_value}' to empty cell (col '{json_key}')")
- else:
- cell_replacements = replace_red_text_in_cell(cell, replacement_value)
- replacements_made += cell_replacements
- if cell_replacements > 0:
- print(f" -> Replaced red text with '{replacement_value}' (col '{json_key}')")
-
- return replacements_made
-
-def handle_attendance_list_table_enhanced(table, flat_json):
- """Same as before — preserved behavior."""
- replacements_made = 0
- attendance_patterns = ["attendance list", "names and position titles", "attendees"]
- found_attendance_row = None
- for row_idx, row in enumerate(table.rows[:3]):
- for cell_idx, cell in enumerate(row.cells):
- cell_text = get_clean_text(cell).lower()
- if any(pattern in cell_text for pattern in attendance_patterns):
- found_attendance_row = row_idx
- print(f" 🎯 ENHANCED: Found Attendance List in row {row_idx + 1}, cell {cell_idx + 1}")
- break
- if found_attendance_row is not None:
- break
- if found_attendance_row is None:
- return 0
+ loc = find_label_cell(doc, "Attendance List (Names and Position Titles)")
+ if not loc:
+ return False
- attendance_value = None
- attendance_search_keys = [
- "Attendance List (Names and Position Titles).Attendance List (Names and Position Titles)",
- "Attendance List (Names and Position Titles)",
- "attendance list",
- "attendees"
- ]
- print(f" 🔍 Searching for attendance data in JSON...")
- for search_key in attendance_search_keys:
- kv = find_matching_json_key_and_value(search_key, flat_json)
- if kv:
- attendance_value = kv[1]
- print(f" ✅ Found attendance data with key: '{kv[0]}'")
- print(f" 📊 Raw value: {attendance_value}")
- break
- if attendance_value is None:
- print(f" ❌ No attendance data found in JSON")
- return 0
-
- # Find red text candidate cell
- target_cell = None
- print(f" 🔍 Scanning ALL cells in attendance table for red text...")
- for row_idx, row in enumerate(table.rows):
- for cell_idx, cell in enumerate(row.cells):
- if has_red_text(cell):
- red_text = ""
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run):
- red_text += run.text
- if red_text.strip():
- print(f" 🎯 Found red text in row {row_idx + 1}, cell {cell_idx + 1}")
- print(f" 📋 Red text content: '{red_text[:60]}...'")
- red_lower = red_text.lower()
- if any(ind in red_lower for ind in ['manager', 'director', 'auditor', '–', '-']):
- target_cell = cell
- print(f" ✅ This looks like attendance data - using this cell")
- break
- if target_cell:
- break
+ t, r, c = loc
+ # value cell: usually directly under the heading cell
+ target = (
+ t.rows[r + 1].cells[c]
+ if r + 1 < len(t.rows) and c < len(t.rows[r + 1].cells)
+ else get_adjacent_value_cell(t, r, c)
+ )
+
+ # ---- read ONLY the target cell (don’t touch the row)
+ def is_red_para(p): return any(is_red_run(run) for run in p.runs)
+ def looks_like_pair(s: str) -> bool:
+ if " - " not in s: return False
+ a, b = s.split(" - ", 1)
+ return bool(a.strip()) and bool(b.strip())
+
+ paras = list(target.paragraphs)
+ red_count = sum(1 for p in paras if is_red_para(p))
+ existing_black = [para_text(p).strip() for p in paras
+ if (not is_red_para(p)) and looks_like_pair(para_text(p))]
+
+ # compose final lines
+ out_lines: List[str] = []
+ out_lines.extend(items[:red_count]) # replace red placeholders
+ out_lines.extend(existing_black) # keep black lines
+ norm = lambda s: re.sub(r"\s+", " ", s.strip().lower())
+ seen = {norm(x) for x in out_lines}
+ for extra in items[red_count:]:
+ k = norm(extra)
+ if k not in seen:
+ out_lines.append(extra); seen.add(k)
+
+ # ---- hard clear target cell and write fresh (all black)
+ _nuke_cell_paragraphs(target)
+ # first line
+ p = target.add_paragraph()
+ _clear_para_and_write_black(p, out_lines[0] if out_lines else "")
+ # remaining lines
+ for line in out_lines[1:]:
+ p = target.add_paragraph()
+ _clear_para_and_write_black(p, line)
- if target_cell is None:
- print(f" ⚠️ No red text found that looks like attendance data")
- return 0
-
- if has_red_text(target_cell):
- print(f" 🔧 Replacing red text with properly formatted attendance list...")
- if isinstance(attendance_value, list):
- attendance_list = [str(item).strip() for item in attendance_value if str(item).strip()]
- else:
- attendance_list = [str(attendance_value).strip()]
- print(f" 📝 Attendance items to add:")
- for i, item in enumerate(attendance_list):
- print(f" {i+1}. {item}")
- replacement_text = "\n".join(attendance_list)
- cell_replacements = replace_red_text_in_cell(target_cell, replacement_text)
- replacements_made += cell_replacements
- print(f" ✅ Added {len(attendance_list)} attendance items")
- print(f" 📊 Replacements made: {cell_replacements}")
- return replacements_made
-
-def fix_management_summary_details_column(table, flat_json):
- """CORRECTED VERSION: Replace red text with UPDATED values from JSON (not old extracted values)"""
- replacements_made = 0
- print(f" 🎯 FIX: Management Summary DETAILS column processing")
- print(f" 📋 NOTE: JSON contains UPDATED values to replace red text with")
-
- # Determine which type of management summary this is
- table_text = ""
- for row in table.rows[:3]:
- for cell in row.cells:
- table_text += get_clean_text(cell).lower() + " "
-
- mgmt_types = []
- if "mass management" in table_text or "mass" in table_text:
- mgmt_types.append("Mass Management Summary")
- if "maintenance management" in table_text or "maintenance" in table_text:
- mgmt_types.append("Maintenance Management Summary")
- if "fatigue management" in table_text or "fatigue" in table_text:
- mgmt_types.append("Fatigue Management Summary")
-
- # Fallback detection
- if not mgmt_types:
- if any("std 5" in get_clean_text(c).lower() for r in table.rows for c in r.cells):
- mgmt_types.append("Mass Management Summary")
-
- if not mgmt_types:
- print(f" ⚠️ Could not determine management summary type")
- return 0
-
- for mgmt_type in mgmt_types:
- print(f" ✅ Confirmed {mgmt_type} table processing")
-
- # Build management data dict from flattened keys - these contain UPDATED values
- mgmt_data = {}
-
- # Look for flattened keys like "Mass Management Summary.Std 5. Verification"
- # IMPORTANT: Prioritize longer, more detailed values over shorter ones
- for key, value in flat_json.items():
- if key.startswith(mgmt_type + "."):
- # Extract the standard part (after the management type)
- std_key = key[len(mgmt_type) + 1:] # Remove "Mass Management Summary." prefix
-
- # Check if this is a longer, more detailed version than what we already have
- if std_key in mgmt_data:
- # Compare value lengths - prefer longer, more detailed content
- existing_value = mgmt_data[std_key]
- existing_length = len(str(existing_value)) if not isinstance(existing_value, list) else len(str(existing_value[0]) if existing_value else "")
- new_length = len(str(value)) if not isinstance(value, list) else len(str(value[0]) if value else "")
-
- if new_length > existing_length:
- mgmt_data[std_key] = value
- print(f" ✅ UPDATED to longer standard: '{std_key}' = {value}")
- else:
- print(f" ⏭️ Keeping existing longer standard: '{std_key}'")
- else:
- mgmt_data[std_key] = value
- print(f" ✅ Found UPDATED standard: '{std_key}' = {value}")
-
- if not mgmt_data:
- print(f" ⚠️ No UPDATED JSON data found for {mgmt_type}")
- continue
-
- print(f" 📋 Processing {mgmt_type} with {len(mgmt_data)} updated standards: {list(mgmt_data.keys())}")
-
- # Process each row looking for red text in details column
- print(f" 🔍 Analyzing all {len(table.rows)} rows in table:")
-
- for row_idx, row in enumerate(table.rows):
- if len(row.cells) >= 2:
- standard_cell = row.cells[0]
- details_cell = row.cells[1]
- standard_text = get_clean_text(standard_cell).strip()
- details_text = get_clean_text(details_cell).strip()
- standard_text_lower = standard_text.lower()
-
- print(f" 📋 Row {row_idx + 1}:")
- print(f" 📄 Standard: '{standard_text}'")
- print(f" 📄 Current Details: '{details_text[:50]}...' (length: {len(details_text)})")
- print(f" 🔴 Has red text (OLD data): {has_red_text(details_cell)}")
-
- # Skip header rows - be more specific about what constitutes a header
- header_indicators = ["standard", "requirement", "details", mgmt_type.lower().split()[0]]
- if any(header in standard_text_lower for header in header_indicators) and len(standard_text) < 50:
- print(f" ⏭️ Skipping header row")
- continue
-
- # IMPORTANT: We want to replace red text (old data) with updated data from JSON
- # Check if this row has red text in details cell - this is what we need to replace
- if not has_red_text(details_cell):
- print(f" ⏭️ No red text found in details cell (already updated?), skipping")
- continue
-
- print(f" 🎯 PROCESSING row {row_idx + 1} - REPLACING OLD red text with NEW data")
-
- # Extract current red text (this is the OLD data we're replacing)
- red_segments = extract_red_text_segments(details_cell)
- current_red_text = ""
- for segment in red_segments:
- current_red_text += segment['text']
-
- print(f" 🔴 Current red text (OLD): '{current_red_text[:100]}...'")
-
- # Find the UPDATED replacement value from JSON
- replacement_value = None
- matched_std = None
-
- # Strategy 1: Extract standard number and match
- std_match = re.search(r'std\s*(\d+)', standard_text_lower)
- if std_match:
- std_num = std_match.group(1)
- print(f" 🎯 Looking for UPDATED Standard {std_num} data")
-
- # Look for matching standard in mgmt_data (contains UPDATED values)
- for std_key, std_value in mgmt_data.items():
- if f"std {std_num}" in std_key.lower():
- replacement_value = std_value
- matched_std = std_key
- print(f" ✅ Found UPDATED data for std {std_num}: '{std_key}'")
- break
-
- # Strategy 2: Keyword-based matching if std number doesn't work
- if not replacement_value:
- print(f" 🔍 No std number match, trying keyword matching for UPDATED data...")
-
- # More comprehensive keyword matching
- keyword_mappings = {
- "daily check": ["Std 1. Daily Check", "Daily Check"],
- "verification": ["Std 5. Verification", "Verification"],
- "internal review": ["Std 6. Internal Review", "Std 7. Internal Review", "Std 5. Internal Review", "Internal Review"],
- "fault recording": ["Std 2. Fault Recording", "Fault Recording/ Reporting"],
- "fault repair": ["Std 3. Fault Repair", "Fault Repair"],
- "maintenance schedules": ["Std 4. Maintenance Schedules", "Maintenance Schedules"],
- "responsibilities": ["Std 1. Responsibilities", "Std 6. Responsibilities"],
- "vehicle control": ["Std 2. Vehicle Control", "Vehicle Control"],
- "vehicle use": ["Std 3. Vehicle Use", "Vehicle Use"],
- "records and documentation": ["Std 4. Records", "Std 5. Records", "Records and Documentation"],
- "training": ["Std 8. Training", "Std 3. Training", "Training"],
- "suspension": ["Std 8. Maintenance of Suspension", "Suspension"],
- "scheduling": ["Std 1. Scheduling", "Scheduling"],
- "health and wellbeing": ["Std 2. Health", "Health and wellbeing"],
- "workplace conditions": ["Std 7. Workplace", "Workplace conditions"]
- }
-
- for keyword, candidates in keyword_mappings.items():
- if keyword in standard_text_lower:
- replacement_value = find_best_standard_value(mgmt_data, candidates)
- if replacement_value:
- matched_std = f"{keyword} related"
- print(f" ✅ Found UPDATED data for keyword '{keyword}'")
- break
-
- # Strategy 3: Try exact standard name matching
- if not replacement_value:
- print(f" 🔍 Trying exact standard name matching for UPDATED data...")
- # Clean the standard text for better matching
- clean_standard = re.sub(r'\([^)]*\)', '', standard_text).strip()
-
- for std_key, std_value in mgmt_data.items():
- # Try partial matching
- if (clean_standard.lower() in std_key.lower() or
- std_key.lower() in clean_standard.lower()):
- replacement_value = std_value
- matched_std = std_key
- print(f" ✅ Found UPDATED data via partial match: '{std_key}'")
- break
-
- # Apply replacement if found
- if replacement_value:
- # Handle list values properly
- if isinstance(replacement_value, list):
- if len(replacement_value) == 1:
- replacement_text = str(replacement_value[0])
- else:
- replacement_text = "\n".join(str(item) for item in replacement_value)
- else:
- replacement_text = str(replacement_value)
-
- print(f" 🎯 REPLACING old red text with UPDATED data: '{replacement_text[:100]}...'")
-
- # Use robust red text replacement
- cell_replacements = replace_red_text_in_cell(details_cell, replacement_text)
-
- # FALLBACK: If replace_red_text_in_cell fails, try manual replacement
- if cell_replacements == 0:
- print(f" ⚠️ Standard replacement failed, trying manual approach...")
-
- # Try to replace red text manually
- for paragraph in details_cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run) and run.text.strip():
- print(f" 🔧 Manually replacing red run: '{run.text[:50]}...'")
- run.text = replacement_text
- run.font.color.rgb = RGBColor(0, 0, 0)
- cell_replacements = 1
- break
- if cell_replacements > 0:
- break
-
- replacements_made += cell_replacements
-
- if cell_replacements > 0:
- print(f" ✅ SUCCESSFULLY UPDATED '{standard_text}' with NEW data in {mgmt_type}")
- print(f" 📋 Used UPDATED data from: '{matched_std}'")
-
- # Verify the replacement worked
- new_details_text = get_clean_text(details_cell).strip()
- print(f" 🔍 NEW details text: '{new_details_text[:100]}...'")
- print(f" 🎉 OLD red text replaced with UPDATED data!")
- else:
- print(f" ❌ Failed to replace red text in cell")
- print(f" 🔍 Cell still contains OLD data: '{get_clean_text(details_cell)[:100]}...'")
- else:
- print(f" ⚠️ No UPDATED replacement found for '{standard_text}' in {mgmt_type}")
- print(f" 📋 Available UPDATED standards: {list(mgmt_data.keys())}")
-
- # FALLBACK: Try to find ANY available standard that might fit
- if mgmt_data and current_red_text:
- print(f" 🔄 Trying fallback - any available UPDATED standard...")
- # Use the first available standard as a fallback
- first_std_key = list(mgmt_data.keys())[0]
- fallback_value = mgmt_data[first_std_key]
-
- if isinstance(fallback_value, list):
- fallback_text = "\n".join(str(item) for item in fallback_value)
- else:
- fallback_text = str(fallback_value)
-
- print(f" 🔄 Using fallback UPDATED data: '{fallback_text[:100]}...'")
-
- cell_replacements = replace_red_text_in_cell(details_cell, fallback_text)
- if cell_replacements > 0:
- replacements_made += cell_replacements
- print(f" ✅ Applied fallback UPDATED data successfully")
-
- else:
- print(f" ⚠️ Row {row_idx + 1} has insufficient columns ({len(row.cells)})")
-
- print(f" 📊 Total management summary UPDATES: {replacements_made}")
- return replacements_made
-
-
-def find_best_standard_value(mgmt_data, candidate_keys):
- """ENHANCED: Find the best matching value for a standard from management data"""
- print(f" 🔍 Searching for candidates: {candidate_keys}")
- print(f" 📋 In available keys: {list(mgmt_data.keys())}")
-
- # Direct match
- for candidate in candidate_keys:
- if candidate in mgmt_data:
- print(f" ✅ Direct match found: '{candidate}'")
- return mgmt_data[candidate]
-
- # Case insensitive match
- for candidate in candidate_keys:
- for key, value in mgmt_data.items():
- if candidate.lower() == key.lower():
- print(f" ✅ Case-insensitive match found: '{key}' for '{candidate}'")
- return value
-
- # Partial match (contains)
- for candidate in candidate_keys:
- for key, value in mgmt_data.items():
- if candidate.lower() in key.lower() or key.lower() in candidate.lower():
- print(f" ✅ Partial match found: '{key}' for '{candidate}'")
- return value
-
- # Extract number and match by number
- for candidate in candidate_keys:
- candidate_num = re.search(r'(\d+)', candidate)
- if candidate_num:
- for key, value in mgmt_data.items():
- key_num = re.search(r'(\d+)', key)
- if key_num and candidate_num.group(1) == key_num.group(1):
- print(f" ✅ Number match found: '{key}' for '{candidate}'")
- return value
-
- print(f" ❌ No match found for any candidate")
+ return True
+
+# ----------------------------- document search -----------------------------
+def iter_tables(doc: Document) -> List[Table]:
+ return list(doc.tables)
+
+def iter_paragraphs(doc: Document) -> List[Paragraph]:
+ # paragraphs at doc level + inside tables
+ out = list(doc.paragraphs)
+ for t in doc.tables:
+ for row in t.rows:
+ for cell in row.cells:
+ out.extend(cell.paragraphs)
+ return out
+
+def find_heading_paragraph(doc: Document, heading_text: str, window: int = 60) -> Optional[Paragraph]:
+ key = canon(heading_text)
+ for p in iter_paragraphs(doc):
+ if canon(para_text(p)).startswith(key):
+ return p
+ # fuzzy contains
+ for p in iter_paragraphs(doc):
+ if key in canon(para_text(p)):
+ return p
return None
-# ============================================================================
-# Canonical operator declaration fixer — SAFER
-# ============================================================================
-def fix_operator_declaration_empty_values(table, flat_json):
- """
- FIXED: Properly distinguish between auditor and operator data for Operator Declaration table
- """
- replacements_made = 0
- print(f" 🎯 FIX: Operator Declaration empty values processing")
-
- # Verify this is actually an operator declaration table
- table_context = ""
- for row in table.rows:
- for cell in row.cells:
- table_context += get_clean_text(cell).lower() + " "
-
- if not ("print name" in table_context and "position title" in table_context):
- return 0
-
- print(f" ✅ Confirmed Operator Declaration table")
-
- def parse_name_and_position(value):
- """Enhanced parsing for name/position combinations"""
- if value is None:
- return None, None
-
- if isinstance(value, list):
- if len(value) == 0:
- return None, None
- if len(value) == 1:
- # Check if single item looks like "Name - Position" format
- single_item = str(value[0]).strip()
- if ' - ' in single_item:
- parts = single_item.split(' - ', 1)
- if len(parts) == 2:
- return parts[0].strip(), parts[1].strip()
- return single_item, None
-
- # Handle [name, position] pattern or multiple attendance entries
- if len(value) == 2:
- first = str(value[0]).strip()
- second = str(value[1]).strip()
-
- # Check if both look like names (attendance list pattern)
- if (' ' in first and ' ' in second and
- not any(role in first.lower() for role in ['manager', 'director', 'auditor', 'officer']) and
- not any(role in second.lower() for role in ['manager', 'director', 'auditor', 'officer'])):
- # This is likely attendance list data, return first name only
- return first, None
-
- return first, second
-
- # Multiple items - check if it's attendance list format
- attendance_like = any(' - ' in str(item) for item in value)
- if attendance_like:
- # Extract first person's name from attendance format
- first_entry = str(value[0]).strip()
- if ' - ' in first_entry:
- return first_entry.split(' - ')[0].strip(), first_entry.split(' - ')[1].strip()
- return first_entry, None
-
- # Join list elements as fallback
- value = " ".join(str(v).strip() for v in value if str(v).strip())
-
- s = str(value).strip()
- if not s:
- return None, None
-
- # Split on common separators
- separators = [r'\s+[-–—]\s+', r'\s*,\s*', r'\s*\|\s*', r'\s*;\s*']
- parts = None
-
- for sep_pattern in separators:
- parts = re.split(sep_pattern, s)
- if len(parts) >= 2:
- break
-
- if parts and len(parts) >= 2:
- left = parts[0].strip()
- right = parts[1].strip()
-
- # Check which part is more likely to be a position
- role_indicators = ['manager', 'auditor', 'owner', 'director', 'supervisor',
- 'coordinator', 'driver', 'operator', 'representative', 'chief',
- 'president', 'ceo', 'cfo', 'secretary', 'treasurer', 'officer',
- 'compliance']
-
- right_has_role = any(ind in right.lower() for ind in role_indicators)
- left_has_role = any(ind in left.lower() for ind in role_indicators)
-
- if right_has_role and not left_has_role:
- return left, right # Standard: name, position
- elif left_has_role and not right_has_role:
- return right, left # Reversed: position, name
- else:
- # Default to left=name, right=position
- return left, right
-
- # Look for single word position at end
- tokens = s.split()
- if len(tokens) >= 2:
- last_token = tokens[-1].lower()
- role_indicators = ['manager', 'auditor', 'owner', 'director', 'supervisor',
- 'coordinator', 'driver', 'operator', 'representative', 'chief', 'officer']
- if any(ind == last_token for ind in role_indicators):
- return " ".join(tokens[:-1]), tokens[-1]
-
- return s, None
-
- def looks_like_role(s: str) -> bool:
- """Check if string looks like a job role/position"""
- if not s:
- return False
-
- s = s.lower().strip()
-
- # Common role words
- roles = ['manager', 'auditor', 'owner', 'director', 'supervisor',
- 'coordinator', 'driver', 'operator', 'representative', 'chief',
- 'president', 'ceo', 'cfo', 'secretary', 'treasurer', 'officer']
-
- # Direct role match
- if any(role in s for role in roles):
- return True
-
- # Short descriptive terms (likely roles)
- if len(s.split()) <= 3 and any(c.isalpha() for c in s) and len(s) > 1:
+def find_label_cell_in_table(table: Table, label: str) -> Optional[Tuple[int, int]]:
+ target = canon_label(label)
+ for r_i, row in enumerate(table.rows):
+ for c_i, cell in enumerate(row.cells):
+ if canon_label(cell_text(cell)) == target:
+ return (r_i, c_i)
+ # allow contains (safe-ish)
+ for r_i, row in enumerate(table.rows):
+ for c_i, cell in enumerate(row.cells):
+ if target and target in canon_label(cell_text(cell)):
+ return (r_i, c_i)
+ return None
+
+def find_label_cell(doc: Document, label: str) -> Optional[Tuple[Table, int, int]]:
+ for t in iter_tables(doc):
+ pos = find_label_cell_in_table(t, label)
+ if pos:
+ return (t, pos[0], pos[1])
+ return None
+
+def get_adjacent_value_cell(table: Table, r: int, c: int) -> _Cell:
+ # Prefer right cell, otherwise next row same col, otherwise this cell
+ cols = len(table.rows[0].cells)
+ if c + 1 < cols:
+ return table.rows[r].cells[c+1]
+ if r + 1 < len(table.rows):
+ return table.rows[r+1].cells[c]
+ return table.rows[r].cells[c]
+
+# ----------------------------- label/value updates -----------------------------
+def update_label_value_in_tables(doc: Document, label: str, value) -> bool:
+ tup = find_label_cell(doc, label)
+ val = join_value(value)
+ if not tup:
+ return False
+ t, r, c = tup
+ target_cell = get_adjacent_value_cell(t, r, c)
+ return replace_red_in_cell(target_cell, val)
+
+def update_heading_followed_red(doc: Document, heading: str, value, max_scan: int = 12) -> bool:
+ """Find heading paragraph, then replace the first red run found within next N paragraphs (including inside tables)"""
+ start = find_heading_paragraph(doc, heading)
+ if not start:
+ return False
+ # Build a linear list of paragraphs across whole doc to get an index
+ allp = iter_paragraphs(doc)
+ try:
+ idx = allp.index(start)
+ except ValueError:
+ idx = 0
+ new_text = join_value(value)
+ # Scan forward
+ for p in allp[idx+1: idx+1+max_scan]:
+ if replace_red_in_paragraph(p, new_text):
return True
-
+ # Also check any red in table cells inside this paragraph's parent (already covered via iter_paragraphs)
+ return False
+
+# ----------------------------- ACN per-digit fill -----------------------------
+def fill_acn_digits(doc: Document, acn_value: str) -> bool:
+ digits = split_digits(acn_value)
+ if not digits:
+ return False
+ loc = find_label_cell(doc, "Australian Company Number")
+ if not loc:
return False
- def looks_like_person_name(s: str) -> bool:
- """Check if string looks like a person's name"""
- if not s:
- return False
-
- s = s.strip()
-
- # Exclude company-like terms
- company_terms = ['pty ltd', 'ltd', 'inc', 'corp', 'company', 'llc', 'plc']
- s_lower = s.lower()
- if any(term in s_lower for term in company_terms):
- return False
-
- # Should have letters and reasonable length
- if len(s) > 1 and any(c.isalpha() for c in s):
- return True
-
+ t, r, c = loc
+
+ # Collect cells to the RIGHT in the same row first
+ targets: List[_Cell] = [t.rows[r].cells[j] for j in range(c + 1, len(t.rows[r].cells))]
+
+ # If not enough, continue row-by-row below (left→right)
+ rr = r + 1
+ while len(targets) < len(digits) and rr < len(t.rows):
+ targets.extend(list(t.rows[rr].cells))
+ rr += 1
+
+ targets = targets[:len(digits)]
+ if not targets:
return False
- # Process the table
- for row_idx, row in enumerate(table.rows):
- if len(row.cells) >= 2:
- cell1_text = get_clean_text(row.cells[0]).strip().lower()
- cell2_text = get_clean_text(row.cells[1]).strip().lower()
-
- # Detect header row
- if "print name" in cell1_text and "position" in cell2_text:
- print(f" 📌 Found header row at {row_idx + 1}")
-
- # Process data row (next row after header)
- if row_idx + 1 < len(table.rows):
- data_row = table.rows[row_idx + 1]
- if len(data_row.cells) >= 2:
- name_cell = data_row.cells[0]
- position_cell = data_row.cells[1]
-
- current_name = get_clean_text(name_cell).strip()
- current_position = get_clean_text(position_cell).strip()
-
- print(f" 📋 Current values: Name='{current_name}', Position='{current_position}'")
-
- # IMPROVED: More comprehensive search for operator declaration data
- final_name = None
- final_position = None
-
- # IMPROVED: Better strategy to find OPERATOR (not auditor) data
- final_name = None
- final_position = None
-
- # Strategy 1: Look specifically in Attendance List for operator names
- attendance_kv = find_matching_json_key_and_value("Attendance List (Names and Position Titles)", flat_json)
- if attendance_kv and attendance_kv[1]:
- attendance_data = attendance_kv[1]
- print(f" 📋 Found attendance data: {attendance_data}")
-
- # Parse attendance list to find non-auditor names
- if isinstance(attendance_data, list):
- for entry in attendance_data:
- entry_str = str(entry).strip()
- if 'auditor' not in entry_str.lower() and entry_str:
- # Parse this entry for name and position
- parsed_name, parsed_pos = parse_name_and_position(entry_str)
- if parsed_name and looks_like_person_name(parsed_name):
- final_name = parsed_name
- if parsed_pos and looks_like_role(parsed_pos):
- final_position = parsed_pos
- break
-
- # Strategy 2: If no good name from attendance, try nested attendance keys
- if not final_name:
- nested_attendance_kv = find_matching_json_key_and_value("Attendance List (Names and Position Titles).Attendance List (Names and Position Titles)", flat_json)
- if nested_attendance_kv and nested_attendance_kv[1]:
- nested_data = nested_attendance_kv[1]
- print(f" 📋 Found nested attendance data: {nested_data}")
-
- if isinstance(nested_data, list):
- for entry in nested_data:
- entry_str = str(entry).strip()
- if 'auditor' not in entry_str.lower() and entry_str:
- parsed_name, parsed_pos = parse_name_and_position(entry_str)
- if parsed_name and looks_like_person_name(parsed_name):
- final_name = parsed_name
- if parsed_pos and looks_like_role(parsed_pos):
- final_position = parsed_pos
- break
-
- # Strategy 3: Direct operator declaration keys (with filtering)
- if not final_name:
- search_strategies = [
- ("Operator Declaration.Print Name", "Operator Declaration.Position Title"),
- ("Print Name", "Position Title"),
- ]
-
- for name_key_pattern, pos_key_pattern in search_strategies:
- name_kv = find_matching_json_key_and_value(name_key_pattern, flat_json)
- pos_kv = find_matching_json_key_and_value(pos_key_pattern, flat_json)
-
- if name_kv and name_kv[1]:
- # Filter out auditor names
- potential_name = str(name_kv[1]).strip()
-
- # Skip if this is clearly auditor data
- if name_kv[0] and 'auditor' in name_kv[0].lower():
- continue
-
- # Skip common auditor names that appear in our data
- auditor_names = ['greg dyer', 'greg', 'dyer']
- if any(aud_name in potential_name.lower() for aud_name in auditor_names):
- continue
-
- name_from_val, pos_from_val = parse_name_and_position(name_kv[1])
- if name_from_val and looks_like_person_name(name_from_val):
- # Additional check - avoid auditor names
- if not any(aud_name in name_from_val.lower() for aud_name in auditor_names):
- final_name = name_from_val
- if pos_from_val and looks_like_role(pos_from_val):
- final_position = pos_from_val
-
- if pos_kv and pos_kv[1] and not final_position:
- # Only use if key doesn't indicate auditor data
- if not (pos_kv[0] and 'auditor' in pos_kv[0].lower()):
- pos_val = str(pos_kv[1]).strip()
- if looks_like_role(pos_val) and 'auditor' not in pos_val.lower():
- final_position = pos_val
-
- if final_name:
- break
-
- # Strategy 4: Last resort - search all keys but with strict filtering
- if not final_name:
- print(f" 🔍 Searching all keys with strict operator filtering...")
- for key, value in flat_json.items():
- key_lower = key.lower()
-
- # Skip keys that clearly relate to auditor
- if 'auditor' in key_lower:
- continue
-
- # Look for operator-related keys
- if (("operator" in key_lower and "name" in key_lower) or
- ("print name" in key_lower and "operator" in key_lower)):
-
- if value and looks_like_person_name(str(value)):
- potential_name = str(value).strip()
- # Skip auditor names
- auditor_names = ['greg dyer', 'greg', 'dyer']
- if not any(aud_name in potential_name.lower() for aud_name in auditor_names):
- name_from_val, pos_from_val = parse_name_and_position(value)
- if name_from_val and looks_like_person_name(name_from_val):
- final_name = name_from_val
- if pos_from_val and looks_like_role(pos_from_val):
- final_position = pos_from_val
- break
-
- # Clean up final values
- if isinstance(final_name, (list, tuple)):
- final_name = " ".join(str(x) for x in final_name).strip()
- if isinstance(final_position, (list, tuple)):
- final_position = " ".join(str(x) for x in final_position).strip()
-
- final_name = str(final_name).strip() if final_name else None
- final_position = str(final_position).strip() if final_position else None
-
- print(f" 🎯 Final extracted values: Name='{final_name}', Position='{final_position}'")
-
- # Update name cell if needed
- if (not current_name or has_red_text(name_cell)) and final_name and looks_like_person_name(final_name):
- if has_red_text(name_cell):
- replace_red_text_in_cell(name_cell, final_name)
- else:
- name_cell.text = final_name
- replacements_made += 1
- print(f" ✅ Updated Print Name -> '{final_name}'")
-
- # Update position cell if needed
- if (not current_position or has_red_text(position_cell)) and final_position and looks_like_role(final_position):
- if has_red_text(position_cell):
- replace_red_text_in_cell(position_cell, final_position)
- else:
- position_cell.text = final_position
- replacements_made += 1
- print(f" ✅ Updated Position Title -> '{final_position}'")
-
- break # Found and processed the header row
-
- # Mark table as processed
- if replacements_made > 0:
- try:
- setattr(table, "_processed_operator_declaration", True)
- print(" 🔖 Marked table as processed by Operator Declaration handler")
- except Exception:
- pass
-
- return replacements_made
-
-def handle_multiple_red_segments_in_cell(cell, flat_json):
- replacements_made = 0
- red_segments = extract_red_text_segments(cell)
- if not red_segments:
- return 0
- for i, segment in enumerate(red_segments):
- segment_text = segment['text'].strip()
- if segment_text:
- kv = find_matching_json_key_and_value(segment_text, flat_json)
- if kv:
- replacement_text = get_value_as_string(kv[1], segment_text)
- if replace_single_segment(segment, replacement_text):
- replacements_made += 1
- print(f" ✅ Replaced segment {i+1}: '{segment_text}' -> '{replacement_text}'")
- return replacements_made
-
-def handle_nature_business_multiline_fix(cell, flat_json):
- replacements_made = 0
- red_text = ""
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run):
- red_text += run.text
- red_text = red_text.strip()
- if not red_text:
- return 0
- nature_indicators = ["transport", "logistics", "freight", "delivery", "trucking", "haulage"]
- if any(indicator in red_text.lower() for indicator in nature_indicators):
- kv = find_matching_json_key_and_value("Nature of Business", flat_json) or find_matching_json_key_and_value("Nature of the Operators Business (Summary)", flat_json)
- if kv:
- replacement_text = get_value_as_string(kv[1], "Nature of Business")
- cell_replacements = replace_red_text_in_cell(cell, replacement_text)
- replacements_made += cell_replacements
- print(f" ✅ Fixed Nature of Business multiline content")
- return replacements_made
-
-def handle_management_summary_fix(cell, flat_json):
- replacements_made = 0
- red_text = ""
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run):
- red_text += run.text
- red_text = red_text.strip()
- if not red_text:
- return 0
- management_types = ["Mass Management Summary", "Maintenance Management Summary", "Fatigue Management Summary"]
- for mgmt_type in management_types:
- if mgmt_type in flat_json and isinstance(flat_json[mgmt_type], dict):
- mgmt_data = flat_json[mgmt_type]
- for std_key, std_value in mgmt_data.items():
- if isinstance(std_value, list) and std_value:
- if len(red_text) > 10:
- for item in std_value:
- if red_text.lower() in str(item).lower() or str(item).lower() in red_text.lower():
- replacement_text = "\n".join(str(i) for i in std_value)
- cell_replacements = replace_red_text_in_cell(cell, replacement_text)
- replacements_made += cell_replacements
- print(f" ✅ Fixed {mgmt_type} - {std_key}")
- return replacements_made
- return replacements_made
-
-def handle_print_accreditation_section(table, flat_json):
- replacements_made = 0
- if getattr(table, "_processed_operator_declaration", False):
- print(f" ⏭️ Skipping Print Accreditation - this is an Operator Declaration table")
- return 0
- table_context = ""
- for row in table.rows:
- for cell in row.cells:
- table_context += get_clean_text(cell).lower() + " "
- if "operator declaration" in table_context or ("print name" in table_context and "position title" in table_context):
- print(f" ⏭️ Skipping Print Accreditation - this is an Operator Declaration table")
- return 0
- print(f" 📋 Processing Print Accreditation section")
- for row_idx, row in enumerate(table.rows):
- for cell_idx, cell in enumerate(row.cells):
- if has_red_text(cell):
- accreditation_fields = [
- "(print accreditation name)",
- "Operator name (Legal entity)",
- "Print accreditation name"
- ]
- for field in accreditation_fields:
- kv = find_matching_json_key_and_value(field, flat_json)
- if kv:
- replacement_text = get_value_as_string(kv[1], field)
- if replacement_text.strip():
- cell_replacements = replace_red_text_in_cell(cell, replacement_text)
- replacements_made += cell_replacements
- if cell_replacements > 0:
- print(f" ✅ Fixed accreditation: {kv[0]}")
- break
- return replacements_made
-
-def process_single_column_sections(cell, key_text, flat_json):
- replacements_made = 0
- if has_red_text(cell):
- red_text = ""
- for paragraph in cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run):
- red_text += run.text
- if red_text.strip():
- kv = find_matching_json_key_and_value(red_text.strip(), flat_json)
- if not kv:
- kv = find_matching_json_key_and_value(key_text, flat_json)
- if kv:
- section_replacement = get_value_as_string(kv[1], red_text.strip())
- cell_replacements = replace_red_text_in_cell(cell, section_replacement)
- replacements_made += cell_replacements
- if cell_replacements > 0:
- print(f" ✅ Fixed single column section: '{key_text}'")
- return replacements_made
-
-# ============================================================================
-# Main table/paragraph/heading processing (preserve logic + use new helpers)
-# ============================================================================
-def process_tables(document, flat_json):
- replacements_made = 0
- for table_idx, table in enumerate(document.tables):
- print(f"\n🔍 Processing table {table_idx + 1}:")
- table_text = ""
- for row in table.rows[:3]:
- for cell in row.cells:
- table_text += get_clean_text(cell).lower() + " "
-
- management_summary_indicators = ["mass management", "maintenance management", "fatigue management"]
- has_management = any(indicator in table_text for indicator in management_summary_indicators)
- has_details = "details" in table_text
-
- if has_management and has_details:
- print(f" 📋 Detected Management Summary table")
- summary_fixes = fix_management_summary_details_column(table, flat_json)
- replacements_made += summary_fixes
-
- summary_replacements = 0
- for row_idx, row in enumerate(table.rows):
- for cell_idx, cell in enumerate(row.cells):
- if has_red_text(cell):
- for mgmt_type in ["Mass Management Summary", "Maintenance Management Summary", "Fatigue Management Summary"]:
- if mgmt_type.lower().replace(" summary", "") in table_text:
- if mgmt_type in flat_json:
- mgmt_data = flat_json[mgmt_type]
- if isinstance(mgmt_data, dict):
- for std_key, std_value in mgmt_data.items():
- if isinstance(std_value, list) and len(std_value) > 0:
- red_text = "".join(run.text for p in cell.paragraphs for run in p.runs if is_red(run)).strip()
- for item in std_value:
- if len(red_text) > 15 and red_text.lower() in str(item).lower():
- replacement_text = "\n".join(str(i) for i in std_value)
- cell_replacements = replace_red_text_in_cell(cell, replacement_text)
- summary_replacements += cell_replacements
- print(f" ✅ Updated {std_key} with summary data")
- break
- break
-
- if summary_replacements == 0:
- cell_replacements = handle_management_summary_fix(cell, flat_json)
- summary_replacements += cell_replacements
-
- replacements_made += summary_replacements
- continue
+ # Clear each target cell and write ONE digit in black
+ for d, cell in zip(digits, targets):
+ _set_cell_text_black(cell, d)
- # Vehicle tables detection
- vehicle_indicators = ["registration number", "sub-contractor", "weight verification", "rfs suspension", "registration"]
- indicator_count = sum(1 for indicator in vehicle_indicators if indicator in table_text)
- if indicator_count >= 2:
- print(f" 🚗 Detected Vehicle Registration table")
- vehicle_replacements = handle_vehicle_registration_table(table, flat_json)
- replacements_made += vehicle_replacements
- continue
+ return True
- # Attendance
- if "attendance list" in table_text and "names and position titles" in table_text:
- print(f" 👥 Detected Attendance List table")
- attendance_replacements = handle_attendance_list_table_enhanced(table, flat_json)
- replacements_made += attendance_replacements
- continue
- # Print Accreditation / Operator Declaration
- print_accreditation_indicators = ["print name", "position title"]
- indicator_count = sum(1 for indicator in print_accreditation_indicators if indicator in table_text)
- if indicator_count >= 2 or ("print name" in table_text and "position title" in table_text):
- print(f" 📋 Detected Print Accreditation/Operator Declaration table")
- declaration_fixes = fix_operator_declaration_empty_values(table, flat_json)
- replacements_made += declaration_fixes
- if not getattr(table, "_processed_operator_declaration", False):
- print_accreditation_replacements = handle_print_accreditation_section(table, flat_json)
- replacements_made += print_accreditation_replacements
- continue
+# ----------------------------- vehicle tables -----------------------------
+def table_header_text(table: Table, up_to_rows: int = 3) -> str:
+ heads = []
+ for i, row in enumerate(table.rows[:up_to_rows]):
+ for cell in row.cells:
+ heads.append(cell_text(cell))
+ return canon(" ".join(heads))
- # Regular table rows handling (preserved)
- for row_idx, row in enumerate(table.rows):
- if len(row.cells) < 1:
- continue
- key_cell = row.cells[0]
- key_text = get_clean_text(key_cell)
- if not key_text:
- continue
- print(f" 📌 Row {row_idx + 1}: Key = '{key_text}'")
- kv = find_matching_json_key_and_value(key_text, flat_json)
- json_value = kv[1] if kv else None
-
- if json_value is not None:
- replacement_text = get_value_as_string(json_value, key_text)
-
- # ACN handling
- if ("australian company number" in key_text.lower() or "company number" in key_text.lower()) and isinstance(json_value, list):
- cell_replacements = handle_australian_company_number(row, json_value)
- replacements_made += cell_replacements
-
- # section headers
- elif ("attendance list" in key_text.lower() or "nature of" in key_text.lower()) and row_idx + 1 < len(table.rows):
- print(f" ✅ Section header detected, checking next row...")
- next_row = table.rows[row_idx + 1]
- for cell_idx, cell in enumerate(next_row.cells):
- if has_red_text(cell):
- print(f" ✅ Found red text in next row, cell {cell_idx + 1}")
- if isinstance(json_value, list):
- section_text = "\n".join(str(item) for item in json_value)
- else:
- section_text = replacement_text
- cell_replacements = replace_red_text_in_cell(cell, section_text)
- replacements_made += cell_replacements
- if cell_replacements > 0:
- print(f" -> Replaced section content")
-
- # single column
- elif len(row.cells) == 1 or (len(row.cells) > 1 and not any(has_red_text(row.cells[i]) for i in range(1, len(row.cells)))):
- if has_red_text(key_cell):
- cell_replacements = process_single_column_sections(key_cell, key_text, flat_json)
- replacements_made += cell_replacements
-
- # key-value pairs
- else:
- for cell_idx in range(1, len(row.cells)):
- value_cell = row.cells[cell_idx]
- if has_red_text(value_cell):
- print(f" ✅ Found red text in column {cell_idx + 1}")
- cell_replacements = replace_red_text_in_cell(value_cell, replacement_text)
- replacements_made += cell_replacements
-
- else:
- # fallback single cell red-text key
- if len(row.cells) == 1 and has_red_text(key_cell):
- red_text = ""
- for paragraph in key_cell.paragraphs:
- for run in paragraph.runs:
- if is_red(run):
- red_text += run.text
- if red_text.strip():
- kv2 = find_matching_json_key_and_value(red_text.strip(), flat_json)
- if kv2:
- section_replacement = get_value_as_string(kv2[1], red_text.strip())
- cell_replacements = replace_red_text_in_cell(key_cell, section_replacement)
- replacements_made += cell_replacements
-
- # attempt multiple red-segments or surgical fixes
- for cell_idx in range(len(row.cells)):
- cell = row.cells[cell_idx]
- if has_red_text(cell):
- cell_replacements = handle_multiple_red_segments_in_cell(cell, flat_json)
- replacements_made += cell_replacements
- if cell_replacements == 0:
- surgical_fix = handle_nature_business_multiline_fix(cell, flat_json)
- replacements_made += surgical_fix
- if cell_replacements == 0:
- management_summary_fix = handle_management_summary_fix(cell, flat_json)
- replacements_made += management_summary_fix
-
- # Final operator/auditor declaration check on last few tables
- print(f"\n🎯 Final check for Declaration tables...")
- for table in document.tables[-3:]:
- if len(table.rows) <= 4:
- if getattr(table, "_processed_operator_declaration", False):
- print(f" ⏭️ Skipping - already processed by operator declaration handler")
- continue
- declaration_fix = fix_operator_declaration_empty_values(table, flat_json)
- replacements_made += declaration_fix
-
- return replacements_made
-
-def process_paragraphs(document, flat_json):
- replacements_made = 0
- print(f"\n🔍 Processing paragraphs:")
- for para_idx, paragraph in enumerate(document.paragraphs):
- red_runs = [run for run in paragraph.runs if is_red(run) and run.text.strip()]
- if red_runs:
- red_text_only = "".join(run.text for run in red_runs).strip()
- print(f" 📌 Paragraph {para_idx + 1}: Found red text: '{red_text_only}'")
-
- kv = find_matching_json_key_and_value(red_text_only, flat_json)
- json_value = kv[1] if kv else None
-
- if json_value is None:
- if "AUDITOR SIGNATURE" in red_text_only.upper() or "DATE" in red_text_only.upper():
- kv = find_matching_json_key_and_value("auditor signature", flat_json)
- elif "OPERATOR SIGNATURE" in red_text_only.upper():
- kv = find_matching_json_key_and_value("operator signature", flat_json)
- json_value = kv[1] if kv else None
-
- if json_value is not None:
- replacement_text = get_value_as_string(json_value)
- print(f" ✅ Replacing red text with: '{replacement_text}'")
- red_runs[0].text = replacement_text
- red_runs[0].font.color.rgb = RGBColor(0, 0, 0)
- for run in red_runs[1:]:
- run.text = ''
- replacements_made += 1
- return replacements_made
-
-def process_headings(document, flat_json):
+def find_vehicle_table(doc: Document, want: str) -> Optional[Table]:
"""
- FIXED: Better heading processing with proper red text replacement
+ want = "maintenance" or "mass"
"""
- replacements_made = 0
- print(f"\n🔍 Processing headings:")
- paragraphs = document.paragraphs
-
- # Extract the correct operator name from the JSON data
- operator_name = None
- for key, value in flat_json.items():
- if "operator name" in key.lower() and "legal entity" in key.lower():
- if isinstance(value, list) and value:
- operator_name = str(value[0]).strip()
- else:
- operator_name = str(value).strip()
- break
-
- if not operator_name:
- # Fallback - try other operator name keys
- for key, value in flat_json.items():
- if ("operator" in key.lower() and "name" in key.lower()) or key.lower() == "operator name":
- if isinstance(value, list) and value:
- operator_name = str(value[0]).strip()
- elif value:
- operator_name = str(value).strip()
- break
-
- print(f" 📋 Using operator name: '{operator_name}'")
-
- for para_idx, paragraph in enumerate(paragraphs):
- paragraph_text = paragraph.text.strip()
- if not paragraph_text:
- continue
-
- matched_heading = None
- for category, patterns in HEADING_PATTERNS.items():
- for pattern in patterns:
- if re.search(pattern, paragraph_text, re.IGNORECASE):
- matched_heading = pattern
- break
- if matched_heading:
- break
-
- if matched_heading:
- print(f" 📌 Found heading at paragraph {para_idx + 1}: '{paragraph_text}'")
-
- # Check if the heading itself has red text
- if has_red_text_in_paragraph(paragraph):
- print(f" 🔴 Found red text in heading itself")
- heading_replacements = process_red_text_in_heading_paragraph(paragraph, paragraph_text, flat_json, operator_name)
- replacements_made += heading_replacements
-
- # Look for red text in paragraphs immediately following this heading
- for next_para_offset in range(1, 6):
- next_para_idx = para_idx + next_para_offset
- if next_para_idx >= len(paragraphs):
- break
-
- next_paragraph = paragraphs[next_para_idx]
- next_text = next_paragraph.text.strip()
-
- if not next_text:
- continue
-
- # Stop if we hit another heading
- is_another_heading = False
- for category, patterns in HEADING_PATTERNS.items():
- for pattern in patterns:
- if re.search(pattern, next_text, re.IGNORECASE):
- is_another_heading = True
- break
- if is_another_heading:
- break
-
- if is_another_heading:
- break
-
- if has_red_text_in_paragraph(next_paragraph):
- print(f" 🔴 Found red text in paragraph {next_para_idx + 1} after heading")
- context_replacements = process_red_text_in_context_paragraph(
- next_paragraph,
- paragraph_text,
- flat_json,
- operator_name
- )
- replacements_made += context_replacements
-
- return replacements_made
-
-def process_red_text_in_heading_paragraph(paragraph, paragraph_text, flat_json, operator_name):
- """Process red text found in heading paragraphs - FIXED"""
- replacements_made = 0
- red_text_segments = []
-
- for run in paragraph.runs:
- if is_red(run) and run.text.strip():
- red_text_segments.append(run.text.strip())
-
- if not red_text_segments:
- return 0
-
- combined_red_text = " ".join(red_text_segments).strip()
- print(f" 🔍 Red text found in heading: '{combined_red_text}'")
-
- replacement_value = None
-
- # Determine what to replace based on heading context
- if any(mgmt_type in paragraph_text.upper() for mgmt_type in ["MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"]):
- # For management section headings, replace with operator name
- if operator_name:
- replacement_value = operator_name
- print(f" ✅ Using operator name for management section: '{operator_name}'")
-
- elif "NHVAS APPROVED AUDITOR DECLARATION" in paragraph_text.upper():
- # For auditor declarations, look for auditor name
- auditor_name = None
- for key, value in flat_json.items():
- if "auditor" in key.lower() and "name" in key.lower():
- if isinstance(value, list) and value:
- auditor_name = str(value[0]).strip()
- elif value:
- auditor_name = str(value).strip()
- break
-
- if auditor_name:
- replacement_value = auditor_name
- print(f" ✅ Using auditor name: '{auditor_name}'")
-
- elif "OPERATOR DECLARATION" in paragraph_text.upper():
- # For operator declarations, use operator name
- if operator_name:
- replacement_value = operator_name
- print(f" ✅ Using operator name for operator declaration: '{operator_name}'")
-
- else:
- # For other headings, try to find a relevant match
- # First try direct match
- kv = find_matching_json_key_and_value(combined_red_text, flat_json)
- if kv:
- replacement_value = get_value_as_string(kv[1], combined_red_text)
- else:
- # Try contextual search with heading
- context_queries = [f"{paragraph_text} {combined_red_text}", combined_red_text, paragraph_text]
- for query in context_queries:
- kv = find_matching_json_key_and_value(query, flat_json)
- if kv:
- replacement_value = get_value_as_string(kv[1], combined_red_text)
- print(f" ✅ Found match with combined query: {kv[0]}")
- break
-
- # FIXED: Apply the replacement if we found a suitable value
- if replacement_value:
- red_runs = [run for run in paragraph.runs if is_red(run) and run.text.strip()]
- if red_runs:
- # Replace the first red run with the new text
- red_runs[0].text = replacement_value
- red_runs[0].font.color.rgb = RGBColor(0, 0, 0)
- # Clear subsequent red runs
- for run in red_runs[1:]:
- run.text = ''
- replacements_made = 1
- print(f" ✅ Replaced heading red text with: '{replacement_value}'")
- else:
- print(f" ❌ No suitable replacement found for: '{combined_red_text}'")
-
- return replacements_made
-
-def process_red_text_in_context_paragraph(paragraph, heading_text, flat_json, operator_name):
- """Process red text found in paragraphs following headings - FIXED"""
- replacements_made = 0
- red_text_segments = []
-
- for run in paragraph.runs:
- if is_red(run) and run.text.strip():
- red_text_segments.append(run.text.strip())
-
- if not red_text_segments:
- return 0
-
- combined_red_text = " ".join(red_text_segments).strip()
- print(f" 🔍 Red text found: '{combined_red_text}'")
-
- replacement_value = None
-
- # Determine what to replace based on heading context
- if any(mgmt_type in heading_text.upper() for mgmt_type in ["MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"]):
- # For management section headings, replace with operator name
- if operator_name:
- replacement_value = operator_name
- print(f" ✅ Using operator name for management section: '{operator_name}'")
-
- elif "NHVAS APPROVED AUDITOR DECLARATION" in heading_text.upper():
- # For auditor declarations, look for auditor name
- auditor_name = None
- for key, value in flat_json.items():
- if "auditor" in key.lower() and "name" in key.lower():
- if isinstance(value, list) and value:
- auditor_name = str(value[0]).strip()
- elif value:
- auditor_name = str(value).strip()
- break
-
- if auditor_name:
- replacement_value = auditor_name
- print(f" ✅ Using auditor name: '{auditor_name}'")
-
- elif "OPERATOR DECLARATION" in heading_text.upper():
- # For operator declarations, use operator name
- if operator_name:
- replacement_value = operator_name
- print(f" ✅ Using operator name for operator declaration: '{operator_name}'")
-
+ MAINT_KEYS = ["registration number", "maintenance records", "daily checks", "fault recording", "fault repair"]
+ MASS_KEYS = ["registration number", "weight verification", "rfs suspension", "suspension system maintenance", "trip records", "reporting on suspension"]
+ candidates = []
+ for t in iter_tables(doc):
+ htxt = table_header_text(t)
+ if want == "maintenance":
+ if all(k in htxt for k in ["registration", "maintenance", "fault"]) and "suspension" not in htxt:
+ candidates.append(t)
+ elif want == "mass":
+ if "suspension" in htxt and "weight" in htxt:
+ candidates.append(t)
+ # Prefer the one with most rows
+ if not candidates:
+ return None
+ return max(candidates, key=lambda tb: len(tb.rows))
+
+def map_cols(table: Table, want: str) -> Dict[str, int]:
+ # map header columns by keywords from the first 2 rows that contain headers
+ header_rows = table.rows[:2]
+ col_texts = []
+ cols = len(table.rows[0].cells)
+ for j in range(cols):
+ txt = " ".join(cell_text(r.cells[j]) for r in header_rows if j < len(r.cells))
+ col_texts.append(canon(txt))
+ idx = {}
+ def first_col(*needles) -> Optional[int]:
+ for j, t in enumerate(col_texts):
+ if all(n in t for n in needles):
+ return j
+ return None
+ if want == "maintenance":
+ idx["reg"] = first_col("registration")
+ idx["rw"] = first_col("roadworthiness")
+ idx["mr"] = first_col("maintenance", "records")
+ idx["daily"] = first_col("daily", "check")
+ idx["fr"] = first_col("fault", "recording")
+ idx["rep"] = first_col("fault", "repair")
else:
- # For other headings, try to find a relevant match
- # First try direct match
- kv = find_matching_json_key_and_value(combined_red_text, flat_json)
- if kv:
- replacement_value = get_value_as_string(kv[1], combined_red_text)
- else:
- # Try contextual search with heading
- context_queries = [f"{heading_text} {combined_red_text}", combined_red_text, heading_text]
- for query in context_queries:
- kv = find_matching_json_key_and_value(query, flat_json)
- if kv:
- replacement_value = get_value_as_string(kv[1], combined_red_text)
- print(f" ✅ Found match with combined query: {kv[0]}")
- break
-
- # FIXED: Apply the replacement if we found a suitable value
- if replacement_value:
- red_runs = [run for run in paragraph.runs if is_red(run) and run.text.strip()]
- if red_runs:
- # Replace the first red run with the new text
- red_runs[0].text = replacement_value
- red_runs[0].font.color.rgb = RGBColor(0, 0, 0)
- # Clear subsequent red runs
- for run in red_runs[1:]:
- run.text = ''
- replacements_made = 1
- print(f" ✅ Replaced context red text with: '{replacement_value}'")
+ idx["reg"] = first_col("registration")
+ idx["wv"] = first_col("weight", "verification")
+ idx["rfs"] = first_col("rfs", "cert")
+ idx["susp"] = first_col("suspension", "maintenance")
+ idx["trip"] = first_col("trip", "record")
+ idx["frs"] = first_col("fault", "suspension")
+ return {k:v for k,v in idx.items() if v is not None}
+
+def clear_data_rows_keep_headers(table: Table, header_rows: int = 1):
+ # Keep first header_rows, drop everything else
+ while len(table.rows) > header_rows:
+ table._tbl.remove(table.rows[-1]._tr)
+
+def ensure_rows(table: Table, need_rows: int):
+ # assumes 1 header row; add rows to reach need_rows + 1 total
+ while len(table.rows) < need_rows + 1:
+ table.add_row()
+
+def fill_vehicle_table(table: Table, want: str, arrays: Dict[str, List[str]]):
+ colmap = map_cols(table, want)
+ if "reg" not in colmap:
+ return
+ if want == "maintenance":
+ regs = arrays.get("Registration Number", [])
+ rw = arrays.get("Roadworthiness Certificates", [])
+ mr = arrays.get("Maintenance Records", [])
+ daily= arrays.get("Daily Checks", [])
+ fr = arrays.get("Fault Recording/ Reporting", [])
+ rep = arrays.get("Fault Repair", [])
+ n = len(regs)
+ # keep header row(s), then fill N rows
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+ for i in range(n):
+ row = table.rows[i+1]
+ def put(col_key, vals):
+ if col_key not in colmap or i >= len(vals): return
+ c = row.cells[colmap[col_key]]
+ replace_red_in_cell(c, nz(vals[i]))
+ # write each col
+ c_reg = row.cells[colmap["reg"]]; replace_red_in_cell(c_reg, nz(regs[i]))
+ put("rw", rw)
+ put("mr", mr)
+ put("daily",daily)
+ put("fr", fr)
+ put("rep", rep)
else:
- print(f" ❌ No suitable replacement found for: '{combined_red_text}'")
-
- return replacements_made
-
-# ============================================================================
-# Orchestrator
-# ============================================================================
-def process_hf(json_file, docx_file, output_file):
- try:
- if hasattr(json_file, "read"):
- json_data = json.load(json_file)
- else:
- with open(json_file, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
-
- flat_json = flatten_json(json_data)
- print("📄 Available JSON keys (sample):")
- for i, (key, value) in enumerate(sorted(flat_json.items())):
- if i < 10:
- print(f" - {key}: {value}")
- print(f" ... and {len(flat_json) - 10} more keys\n")
-
- if hasattr(docx_file, "read"):
- doc = Document(docx_file)
- else:
- doc = Document(docx_file)
-
- print("🚀 Starting comprehensive document processing...")
- table_replacements = process_tables(doc, flat_json)
- paragraph_replacements = process_paragraphs(doc, flat_json)
- heading_replacements = process_headings(doc, flat_json)
- total_replacements = table_replacements + paragraph_replacements + heading_replacements
-
- # Save unmatched headers for iterative improvement
- if _unmatched_headers:
- try:
- tmp_path = "/tmp/unmatched_headers.json"
- with open(tmp_path, 'w', encoding='utf-8') as f:
- json.dump(_unmatched_headers, f, indent=2, ensure_ascii=False)
- print(f"✅ Unmatched headers saved to {tmp_path}")
- except Exception as e:
- print(f"⚠️ Could not save unmatched headers: {e}")
-
- if hasattr(output_file, "write"):
- doc.save(output_file)
- else:
- doc.save(output_file)
-
- print(f"\n✅ Document saved as: {output_file}")
- print(f"✅ Total replacements: {total_replacements}")
- print(f" 📊 Tables: {table_replacements}")
- print(f" 📝 Paragraphs: {paragraph_replacements}")
- print(f" 📋 Headings: {heading_replacements}")
- print(f"🎉 Processing complete!")
-
- except FileNotFoundError as e:
- print(f"❌ File not found: {e}")
- except Exception as e:
- print(f"❌ Error: {e}")
- import traceback
- traceback.print_exc()
+ regs = arrays.get("Registration Number", [])
+ wv = arrays.get("Weight Verification Records", [])
+ rfs = arrays.get("RFS Suspension Certification #", [])
+ susp = arrays.get("Suspension System Maintenance", [])
+ trip = arrays.get("Trip Records", [])
+ frs = arrays.get("Fault Recording/ Reporting on Suspension System", [])
+ n = len(regs)
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+ for i in range(n):
+ row = table.rows[i+1]
+ def put(col_key, vals):
+ if col_key not in colmap or i >= len(vals): return
+ c = row.cells[colmap[col_key]]
+ replace_red_in_cell(c, nz(vals[i]))
+ c_reg = row.cells[colmap["reg"]]; replace_red_in_cell(c_reg, nz(regs[i]))
+ put("wv", wv)
+ put("rfs", rfs)
+ put("susp", susp)
+ put("trip", trip)
+ put("frs", frs)
+
+# ----------------------------- driver table -----------------------------
+def find_driver_table(doc: Document) -> Optional[Table]:
+ for t in iter_tables(doc):
+ h = table_header_text(t)
+ if "driver / scheduler" in h and ("fit for duty" in h or "work diary" in h):
+ return t
+ return None
+def map_driver_cols(table: Table) -> Dict[str,int]:
+ header_rows = table.rows[:2]
+ cols = len(table.rows[0].cells)
+ col_texts = []
+ for j in range(cols):
+ txt = " ".join(cell_text(r.cells[j]) for r in header_rows if j < len(r.cells))
+ col_texts.append(canon(txt))
+ idx = {}
+ def first_col(*needles):
+ for j, t in enumerate(col_texts):
+ if all(n in t for n in needles):
+ return j
+ return None
+ idx["name"] = first_col("driver", "name")
+ idx["roster"]= first_col("roster", "safe")
+ idx["fit"] = first_col("fit for duty")
+ # Work diary might be split across two headers; match "work diary" OR "electronic work diary"
+ wd = first_col("work diary") or first_col("electronic work diary")
+ if wd is not None: idx["wd"] = wd
+ return {k:v for k,v in idx.items() if v is not None}
+
+def fill_driver_table(table: Table, arrays: Dict[str, List[str]]):
+ colmap = map_driver_cols(table)
+ if not colmap:
+ return
+
+ names = arrays.get("Driver / Scheduler Name", [])
+ rosters = arrays.get("Roster / Schedule / Safe Driving Plan (Date Range)", [])
+ fit = arrays.get("Fit for Duty Statement Completed (Yes/No)", [])
+ wd = arrays.get("Work Diary Pages (Page Numbers) Electronic Work Diary Records (Date Range)", [])
+
+ n = max(len(rosters), len(fit), len(wd), len(names))
+ clear_data_rows_keep_headers(table, header_rows=1)
+ ensure_rows(table, n)
+
+ has_any_name = any(str(x).strip() for x in names)
+
+ for i in range(n):
+ row = table.rows[i+1]
+ if "name" in colmap and has_any_name:
+ replace_red_in_cell(row.cells[colmap["name"]], names[i] if i < len(names) else "")
+ if "roster" in colmap:
+ replace_red_in_cell(row.cells[colmap["roster"]], rosters[i] if i < len(rosters) else "")
+ if "fit" in colmap:
+ replace_red_in_cell(row.cells[colmap["fit"]], fit[i] if i < len(fit) else "")
+ if "wd" in colmap:
+ replace_red_in_cell(row.cells[colmap["wd"]], wd[i] if i < len(wd) else "")
+
+
+
+# ----------------------------- main mapping -----------------------------
+def flatten_simple_sections(data: Dict) -> Dict[str, str]:
+ """Collect simple label->single value mappings from top-level sections other than tables."""
+ out = {}
+ skip_sections = {
+ "Vehicle Registration Numbers Maintenance",
+ "Vehicle Registration Numbers Mass",
+ "Driver / Scheduler Records Examined",
+ "paragraphs",
+ "Attendance List (Names and Position Titles)",
+ "Nature of the Operators Business (Summary)",
+ "Maintenance Management Summary",
+ "Mass Management Summary",
+ "Fatigue Management Summary",
+ }
+ for sec, kv in data.items():
+ if sec in skip_sections: continue
+ if not isinstance(kv, dict): continue
+ for label, val in kv.items():
+ out[f"{sec}::{label}"] = join_value(val)
+ return out
+
+def run(input_json: Path, template_docx: Path, output_docx: Path):
+ with open(input_json, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ doc = Document(str(template_docx))
+
+ # 1) simple label/value tables
+ simple = flatten_simple_sections(data)
+
+ # Map by (section::label). We try: (a) find exact label cell somewhere and write in the adjacent cell;
+ # (b) if not found, search by heading then the next red run below the heading.
+ for k, v in simple.items():
+ # use the part after '::' as the label
+ label = k.split("::", 1)[1] if "::" in k else k
+
+ # SPECIAL: skip ACN here; we'll fill per-digit later
+ if canon_label(label) == "australian company number":
+ continue
+
+ ok = update_label_value_in_tables(doc, label, v)
+ if not ok:
+ sec = k.split("::", 1)[0] if "::" in k else k
+ update_heading_followed_red(doc, sec, v)
+
+
+ # 2) paragraphs block
+ paras = data.get("paragraphs", {})
+
+ # 2a) generic headings → replace next red (skip the 3 management headings here)
+ # third-line headings above the three tables
+ for head in ("MAINTENANCE MANAGEMENT", "MASS MANAGEMENT", "FATIGUE MANAGEMENT"):
+ name_val = join_value(paras.get(head, ""))
+ if name_val:
+ update_heading_followed_red(doc, head, name_val, max_scan=6)
+
+ # 2b) the 3-layer headings → overwrite the 3rd line only
+ # second-last page: date under page heading
+ aud_head = "NHVAS APPROVED AUDITOR DECLARATION"
+ aud_date = join_value(paras.get(aud_head, ""))
+ if aud_date:
+ set_date_by_heading_from_end(doc, aud_head, aud_date, max_scan=40)
+
+ # last page: date under the long acknowledgement paragraph
+ ack_head = ("I hereby acknowledge and agree with the findings detailed in this NHVAS Audit Summary Report. "
+ "I have read and understand the conditions applicable to the Scheme, including the NHVAS Business Rules and Standards.")
+ ack_date = join_value(paras.get(ack_head, ""))
+ if ack_date:
+ set_date_by_paragraph_from_end(doc, ack_head, ack_date, max_scan=40)
+
+ maint_name = join_value(paras.get("MAINTENANCE MANAGEMENT", ""))
+ if maint_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "MAINTENANCE MANAGEMENT",
+ ["Vehicle Registration Numbers of Records Examined"],
+ maint_name,
+ )
+
+ mass_name = join_value(paras.get("MASS MANAGEMENT", ""))
+ if mass_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "MASS MANAGEMENT",
+ ["Vehicle Registration Numbers of Records Examined"],
+ mass_name,
+ )
+
+ fat_name = join_value(paras.get("FATIGUE MANAGEMENT", ""))
+ if fat_name:
+ set_layer3_name_after_management_heading(
+ doc,
+ "FATIGUE MANAGEMENT",
+ ["Driver / Scheduler Records Examined"],
+ fat_name,
+ )
+
+
+ # 3) ACN digits
+ op_info = data.get("Operator Information", {})
+ acn_val = join_value(op_info.get("Australian Company Number", ""))
+ if acn_val:
+ fill_acn_digits(doc, acn_val)
+
+ # 4) Vehicle tables
+ maint = data.get("Vehicle Registration Numbers Maintenance", {})
+ mass = data.get("Vehicle Registration Numbers Mass", {})
+ t_m = find_vehicle_table(doc, "maintenance")
+ if t_m and maint:
+ fill_vehicle_table(t_m, "maintenance", maint)
+ t_ms = find_mass_vehicle_numbers_table(doc)
+ if t_ms and mass:
+ fill_mass_vehicle_table_preserve_headers(t_ms, mass)
+
+ # 5) Driver table
+ drivers = data.get("Driver / Scheduler Records Examined", {})
+ t_d = find_driver_table(doc)
+ if t_d and drivers:
+ fill_driver_table(t_d, drivers)
+
+ # 6) Special: Audit Declaration dates via heading
+ decl = data.get("Audit Declaration dates", {})
+ if decl.get("Audit was conducted on"):
+ update_heading_followed_red(doc, "Audit was conducted on", decl["Audit was conducted on"])
+
+ # 7) Operator Declaration (last page, bottom row only), and fix Auditor table header
+ op_decl = data.get("Operator Declaration", {})
+ if op_decl:
+ fill_operator_declaration(
+ doc,
+ join_value(op_decl.get("Print Name", "")),
+ join_value(op_decl.get("Position Title", "")),
+ )
+
+ # make sure the second-last page “NHVAS APPROVED AUDITOR DECLARATION” header row is labels
+ ensure_auditor_decl_headers(doc)
+
+
+ # 8) Attendance List
+ # Attendance: replace red lines only
+ atts = data.get("Attendance List (Names and Position Titles)", {})
+ att_val = atts.get("Attendance List (Names and Position Titles)")
+ if att_val:
+ fill_attendance_block(doc, att_val)
+
+ # 9) Nature of the Operators Business (Summary): write once (no duplicates)
+ biz = data.get("Nature of the Operators Business (Summary)", {})
+ if biz:
+ val = biz.get("Nature of the Operators Business (Summary):") or next(iter(biz.values()), "")
+ if val:
+ update_business_summary_once(doc, val)
+
+ # 10) Summary tables: FULL OVERWRITE of DETAILS from JSON
+ mm_sum = data.get("Maintenance Management Summary", {})
+ if mm_sum:
+ overwrite_summary_details_cells(doc, "Maintenance Management Summary", mm_sum)
+
+ mass_sum = data.get("Mass Management Summary", {})
+ if mass_sum:
+ overwrite_summary_details_cells(doc, "Mass Management Summary", mass_sum)
+
+ fat_sum = data.get("Fatigue Management Summary", {})
+ if fat_sum:
+ overwrite_summary_details_cells(doc, "Fatigue Management Summary", fat_sum)
+
+
+ doc.save(str(output_docx))
+
+# ----------------------------- CLI -----------------------------
if __name__ == "__main__":
import sys
+ from pathlib import Path
+
if len(sys.argv) != 4:
- print("Usage: python pipeline.py ")
- exit(1)
- docx_path = sys.argv[1]
- json_path = sys.argv[2]
- output_path = sys.argv[3]
- process_hf(json_path, docx_path, output_path)
\ No newline at end of file
+ print("Usage: python updated_word.py ")
+ sys.exit(1)
+
+ a, b, c = map(Path, sys.argv[1:4])
+ files = [a, b, c]
+
+ json_path = next((p for p in files if p.suffix.lower() == ".json"), None)
+ docx_paths = [p for p in files if p.suffix.lower() == ".docx"]
+
+ if not json_path or len(docx_paths) < 2:
+ print("Error: provide one .json and two .docx (template + output).")
+ sys.exit(1)
+
+ # Template = the .docx that already exists; Output = the other .docx
+ template_docx = next((p for p in docx_paths if p.exists()), docx_paths[0])
+ output_docx = docx_paths[1] if docx_paths[0] == template_docx else docx_paths[0]
+
+ run(json_path, template_docx, output_docx)
\ No newline at end of file