Spaces:
Paused
Paused
feat(api): fast FastAPI app + model loader refactor; add mock mode for tests\n\n- Add pyproject + setuptools config and console entrypoint\n- Implement enhanced field extraction + MRZ heuristics\n- Add response builder with compatibility for legacy MRZ fields\n- New preprocessing pipeline for PDFs/images\n- HF Spaces GPU: cache ENV, optional flash-attn, configurable base image\n- Add Make targets for Spaces GPU and local CPU\n- Add httpx for TestClient; tests pass in mock mode\n- Remove embedded model files and legacy app/modules
Browse files- .gitignore +37 -177
- Dockerfile +19 -4
- Makefile +104 -0
- README.md +34 -10
- main.py +27 -0
- pyproject.toml +83 -0
- requirements.txt +7 -1
- scripts/README_TESTING.md +215 -0
- scripts/quick_test.py +75 -0
- scripts/run_tests.sh +163 -0
- scripts/test_api_endpoint.py +407 -0
- scripts/test_config.json +54 -0
- scripts/test_production.py +79 -0
- scripts/test_production_curl.sh +89 -0
- setup_dev.py +57 -0
- src/kybtech_dots_ocr/__init__.py +32 -0
- app.py → src/kybtech_dots_ocr/api_models.py +34 -164
- src/kybtech_dots_ocr/app.py +217 -0
- src/kybtech_dots_ocr/enhanced_field_extraction.py +403 -0
- field_extraction.py → src/kybtech_dots_ocr/field_extraction.py +1 -1
- src/kybtech_dots_ocr/model_loader.py +313 -0
- models.py → src/kybtech_dots_ocr/models.py +0 -0
- src/kybtech_dots_ocr/preprocessing.py +333 -0
- src/kybtech_dots_ocr/response_builder.py +321 -0
- tests/__init__.py +1 -0
- tests/test_app.py +38 -0
- tests/test_field_extraction.py +88 -0
.gitignore
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
-
#
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
| 4 |
*$py.class
|
| 5 |
-
|
| 6 |
-
# C extensions
|
| 7 |
*.so
|
| 8 |
-
|
| 9 |
-
# Distribution / packaging
|
| 10 |
.Python
|
| 11 |
build/
|
| 12 |
develop-eggs/
|
|
@@ -20,156 +16,78 @@ parts/
|
|
| 20 |
sdist/
|
| 21 |
var/
|
| 22 |
wheels/
|
| 23 |
-
pip-wheel-metadata/
|
| 24 |
-
share/python-wheels/
|
| 25 |
*.egg-info/
|
| 26 |
.installed.cfg
|
| 27 |
*.egg
|
| 28 |
MANIFEST
|
| 29 |
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
#
|
|
|
|
|
|
|
| 41 |
htmlcov/
|
| 42 |
.tox/
|
| 43 |
.nox/
|
| 44 |
-
.coverage
|
| 45 |
-
.coverage.*
|
| 46 |
-
.cache
|
| 47 |
-
nosetests.xml
|
| 48 |
coverage.xml
|
| 49 |
*.cover
|
| 50 |
-
*.py,cover
|
| 51 |
.hypothesis/
|
| 52 |
-
.pytest_cache/
|
| 53 |
-
|
| 54 |
-
# Translations
|
| 55 |
-
*.mo
|
| 56 |
-
*.pot
|
| 57 |
-
|
| 58 |
-
# Django stuff:
|
| 59 |
-
*.log
|
| 60 |
-
local_settings.py
|
| 61 |
-
db.sqlite3
|
| 62 |
-
db.sqlite3-journal
|
| 63 |
-
|
| 64 |
-
# Flask stuff:
|
| 65 |
-
instance/
|
| 66 |
-
.webassets-cache
|
| 67 |
-
|
| 68 |
-
# Scrapy stuff:
|
| 69 |
-
.scrapy
|
| 70 |
-
|
| 71 |
-
# Sphinx documentation
|
| 72 |
-
docs/_build/
|
| 73 |
-
|
| 74 |
-
# PyBuilder
|
| 75 |
-
target/
|
| 76 |
|
| 77 |
# Jupyter Notebook
|
| 78 |
.ipynb_checkpoints
|
| 79 |
|
| 80 |
-
# IPython
|
| 81 |
-
profile_default/
|
| 82 |
-
ipython_config.py
|
| 83 |
-
|
| 84 |
# pyenv
|
| 85 |
.python-version
|
| 86 |
|
| 87 |
-
# pipenv
|
| 88 |
-
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
-
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
-
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
-
# install all needed dependencies.
|
| 92 |
-
#Pipfile.lock
|
| 93 |
-
|
| 94 |
-
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
-
__pypackages__/
|
| 96 |
-
|
| 97 |
-
# Celery stuff
|
| 98 |
-
celerybeat-schedule
|
| 99 |
-
celerybeat.pid
|
| 100 |
-
|
| 101 |
-
# SageMath parsed files
|
| 102 |
-
*.sage.py
|
| 103 |
-
|
| 104 |
# Environments
|
| 105 |
.env
|
| 106 |
-
.
|
| 107 |
-
env
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
env.bak/
|
| 111 |
-
venv.bak/
|
| 112 |
-
|
| 113 |
-
# Spyder project settings
|
| 114 |
-
.spyderproject
|
| 115 |
-
.spyproject
|
| 116 |
-
|
| 117 |
-
# Rope project settings
|
| 118 |
-
.ropeproject
|
| 119 |
-
|
| 120 |
-
# mkdocs documentation
|
| 121 |
-
/site
|
| 122 |
|
| 123 |
# mypy
|
| 124 |
.mypy_cache/
|
| 125 |
.dmypy.json
|
| 126 |
dmypy.json
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
# Data files
|
| 132 |
-
*.csv
|
| 133 |
-
*.json
|
| 134 |
-
*.jsonl
|
| 135 |
-
*.parquet
|
| 136 |
-
*.feather
|
| 137 |
-
*.arrow
|
| 138 |
data/
|
| 139 |
-
datasets/
|
| 140 |
-
raw_data/
|
| 141 |
-
processed_data/
|
| 142 |
-
|
| 143 |
-
# Hugging Face specific
|
| 144 |
-
.cache/
|
| 145 |
-
huggingface_hub/
|
| 146 |
-
transformers_cache/
|
| 147 |
-
|
| 148 |
-
# OpenCV and image processing
|
| 149 |
*.jpg
|
| 150 |
*.jpeg
|
| 151 |
*.png
|
| 152 |
-
*.
|
| 153 |
-
*.
|
| 154 |
-
*.tiff
|
| 155 |
-
*.tif
|
| 156 |
-
*.webp
|
| 157 |
-
*.svg
|
| 158 |
-
test_images/
|
| 159 |
-
sample_images/
|
| 160 |
-
uploads/
|
| 161 |
-
temp_images/
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
*.swp
|
| 167 |
-
*.swo
|
| 168 |
-
*~
|
| 169 |
-
.DS_Store
|
| 170 |
-
Thumbs.db
|
| 171 |
|
| 172 |
-
# OS
|
| 173 |
.DS_Store
|
| 174 |
.DS_Store?
|
| 175 |
._*
|
|
@@ -178,66 +96,8 @@ Thumbs.db
|
|
| 178 |
ehthumbs.db
|
| 179 |
Thumbs.db
|
| 180 |
|
| 181 |
-
# Logs
|
| 182 |
-
*.log
|
| 183 |
-
logs/
|
| 184 |
-
log/
|
| 185 |
-
|
| 186 |
-
# Temporary files
|
| 187 |
-
tmp/
|
| 188 |
-
temp/
|
| 189 |
-
.tmp/
|
| 190 |
-
|
| 191 |
# Docker
|
| 192 |
.dockerignore
|
| 193 |
|
| 194 |
-
# Local configuration
|
| 195 |
-
config.local.py
|
| 196 |
-
settings.local.py
|
| 197 |
-
.env.local
|
| 198 |
-
.env.development
|
| 199 |
-
.env.test
|
| 200 |
-
.env.production
|
| 201 |
-
|
| 202 |
-
# Backup files
|
| 203 |
-
*.bak
|
| 204 |
-
*.backup
|
| 205 |
-
*.old
|
| 206 |
-
|
| 207 |
-
# Runtime files
|
| 208 |
-
*.pid
|
| 209 |
-
*.sock
|
| 210 |
-
|
| 211 |
-
# Coverage reports
|
| 212 |
-
htmlcov/
|
| 213 |
-
.coverage
|
| 214 |
-
coverage.xml
|
| 215 |
-
|
| 216 |
-
# Profiling
|
| 217 |
-
*.prof
|
| 218 |
-
|
| 219 |
-
# Jupyter notebook checkpoints
|
| 220 |
-
.ipynb_checkpoints/
|
| 221 |
-
|
| 222 |
-
# pytest
|
| 223 |
-
.pytest_cache/
|
| 224 |
-
|
| 225 |
-
# Ruff
|
| 226 |
-
.ruff_cache/
|
| 227 |
-
|
| 228 |
-
# Black
|
| 229 |
-
.black/
|
| 230 |
-
|
| 231 |
-
# isort
|
| 232 |
-
.isort.cfg
|
| 233 |
-
|
| 234 |
-
# Pre-commit
|
| 235 |
-
.pre-commit-config.yaml
|
| 236 |
-
|
| 237 |
-
# Local development
|
| 238 |
-
local/
|
| 239 |
-
dev/
|
| 240 |
-
development/
|
| 241 |
-
|
| 242 |
.cursor/
|
| 243 |
docs/
|
|
|
|
| 1 |
+
# Python
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
| 4 |
*$py.class
|
|
|
|
|
|
|
| 5 |
*.so
|
|
|
|
|
|
|
| 6 |
.Python
|
| 7 |
build/
|
| 8 |
develop-eggs/
|
|
|
|
| 16 |
sdist/
|
| 17 |
var/
|
| 18 |
wheels/
|
|
|
|
|
|
|
| 19 |
*.egg-info/
|
| 20 |
.installed.cfg
|
| 21 |
*.egg
|
| 22 |
MANIFEST
|
| 23 |
|
| 24 |
+
# Virtual environments
|
| 25 |
+
.env
|
| 26 |
+
.venv
|
| 27 |
+
env/
|
| 28 |
+
venv/
|
| 29 |
+
ENV/
|
| 30 |
+
env.bak/
|
| 31 |
+
venv.bak/
|
| 32 |
|
| 33 |
+
# IDE
|
| 34 |
+
.vscode/
|
| 35 |
+
.idea/
|
| 36 |
+
*.swp
|
| 37 |
+
*.swo
|
| 38 |
+
*~
|
| 39 |
|
| 40 |
+
# Testing
|
| 41 |
+
.pytest_cache/
|
| 42 |
+
.coverage
|
| 43 |
htmlcov/
|
| 44 |
.tox/
|
| 45 |
.nox/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
coverage.xml
|
| 47 |
*.cover
|
|
|
|
| 48 |
.hypothesis/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
# Jupyter Notebook
|
| 51 |
.ipynb_checkpoints
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# pyenv
|
| 54 |
.python-version
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Environments
|
| 57 |
.env
|
| 58 |
+
.env.local
|
| 59 |
+
.env.development.local
|
| 60 |
+
.env.test.local
|
| 61 |
+
.env.production.local
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# mypy
|
| 64 |
.mypy_cache/
|
| 65 |
.dmypy.json
|
| 66 |
dmypy.json
|
| 67 |
|
| 68 |
+
# Ruff
|
| 69 |
+
.ruff_cache/
|
| 70 |
+
|
| 71 |
+
# Model files (if stored locally)
|
| 72 |
+
models/
|
| 73 |
+
*.bin
|
| 74 |
+
*.safetensors
|
| 75 |
+
*.pt
|
| 76 |
+
*.pth
|
| 77 |
|
| 78 |
# Data files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
data/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
*.jpg
|
| 81 |
*.jpeg
|
| 82 |
*.png
|
| 83 |
+
*.pdf
|
| 84 |
+
*.mp4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
# Logs
|
| 87 |
+
*.log
|
| 88 |
+
logs/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# OS
|
| 91 |
.DS_Store
|
| 92 |
.DS_Store?
|
| 93 |
._*
|
|
|
|
| 96 |
ehthumbs.db
|
| 97 |
Thumbs.db
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Docker
|
| 100 |
.dockerignore
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
.cursor/
|
| 103 |
docs/
|
Dockerfile
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
# Build args to optionally enable flash-attn installation and override wheel URL
|
| 4 |
# Enable by default for Hugging Face Spaces GPU builds; override locally with
|
|
@@ -6,6 +7,13 @@ FROM pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime
|
|
| 6 |
ARG INSTALL_FLASH_ATTN=true
|
| 7 |
ARG FLASH_ATTN_WHEEL_URL=https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Install system dependencies as root
|
| 10 |
RUN apt-get update && apt-get install -y \
|
| 11 |
libgl1-mesa-dri \
|
|
@@ -48,6 +56,9 @@ RUN pip install --no-cache-dir --upgrade pip
|
|
| 48 |
COPY --chown=user requirements.txt .
|
| 49 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
# Optionally install flash-attn wheel (requires Python/torch/CUDA compatibility)
|
| 52 |
# Will auto-skip if the wheel's Python tag does not match this image's Python.
|
| 53 |
RUN if [ "$INSTALL_FLASH_ATTN" = "true" ]; then \
|
|
@@ -63,11 +74,15 @@ RUN if [ "$INSTALL_FLASH_ATTN" = "true" ]; then \
|
|
| 63 |
echo "Skipping flash-attn installation"; \
|
| 64 |
fi
|
| 65 |
|
| 66 |
-
# Copy
|
| 67 |
-
COPY --chown=user
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
# Expose port
|
| 70 |
EXPOSE 7860
|
| 71 |
|
| 72 |
# Run the application
|
| 73 |
-
CMD ["python", "
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime
|
| 2 |
+
FROM ${BASE_IMAGE}
|
| 3 |
|
| 4 |
# Build args to optionally enable flash-attn installation and override wheel URL
|
| 5 |
# Enable by default for Hugging Face Spaces GPU builds; override locally with
|
|
|
|
| 7 |
ARG INSTALL_FLASH_ATTN=true
|
| 8 |
ARG FLASH_ATTN_WHEEL_URL=https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
|
| 9 |
|
| 10 |
+
# Persist caches and model storage in Spaces, and enable fast transfers
|
| 11 |
+
ENV HF_HUB_ENABLE_HF_TRANSFER=1 \
|
| 12 |
+
HUGGINGFACE_HUB_CACHE=/data/.cache/huggingface \
|
| 13 |
+
HF_HOME=/data/.cache/huggingface \
|
| 14 |
+
DOTS_OCR_LOCAL_DIR=/data/models/dots-ocr \
|
| 15 |
+
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb=512
|
| 16 |
+
|
| 17 |
# Install system dependencies as root
|
| 18 |
RUN apt-get update && apt-get install -y \
|
| 19 |
libgl1-mesa-dri \
|
|
|
|
| 56 |
COPY --chown=user requirements.txt .
|
| 57 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 58 |
|
| 59 |
+
# Copy pyproject.toml for package installation
|
| 60 |
+
COPY --chown=user pyproject.toml .
|
| 61 |
+
|
| 62 |
# Optionally install flash-attn wheel (requires Python/torch/CUDA compatibility)
|
| 63 |
# Will auto-skip if the wheel's Python tag does not match this image's Python.
|
| 64 |
RUN if [ "$INSTALL_FLASH_ATTN" = "true" ]; then \
|
|
|
|
| 74 |
echo "Skipping flash-attn installation"; \
|
| 75 |
fi
|
| 76 |
|
| 77 |
+
# Copy source code
|
| 78 |
+
COPY --chown=user src/ ./src/
|
| 79 |
+
COPY --chown=user main.py .
|
| 80 |
+
|
| 81 |
+
# Install the package in development mode
|
| 82 |
+
RUN pip install --no-cache-dir -e .
|
| 83 |
|
| 84 |
# Expose port
|
| 85 |
EXPOSE 7860
|
| 86 |
|
| 87 |
# Run the application
|
| 88 |
+
CMD ["python", "main.py"]
|
Makefile
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: help install dev test lint format clean run
|
| 2 |
+
|
| 3 |
+
help: ## Show this help message
|
| 4 |
+
@echo "KYB Tech Dots.OCR - Development Commands"
|
| 5 |
+
@echo "=========================================="
|
| 6 |
+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
| 7 |
+
|
| 8 |
+
install: ## Install the package
|
| 9 |
+
uv pip install -e .
|
| 10 |
+
|
| 11 |
+
dev: ## Install development dependencies
|
| 12 |
+
uv pip install -e .[dev]
|
| 13 |
+
|
| 14 |
+
test: ## Run tests
|
| 15 |
+
pytest
|
| 16 |
+
|
| 17 |
+
test-verbose: ## Run tests with verbose output
|
| 18 |
+
pytest -v
|
| 19 |
+
|
| 20 |
+
lint: ## Run linting
|
| 21 |
+
ruff check .
|
| 22 |
+
mypy src/
|
| 23 |
+
|
| 24 |
+
format: ## Format code
|
| 25 |
+
black .
|
| 26 |
+
ruff check --fix .
|
| 27 |
+
|
| 28 |
+
clean: ## Clean up build artifacts
|
| 29 |
+
rm -rf build/
|
| 30 |
+
rm -rf dist/
|
| 31 |
+
rm -rf *.egg-info/
|
| 32 |
+
rm -rf .pytest_cache/
|
| 33 |
+
rm -rf .mypy_cache/
|
| 34 |
+
rm -rf .ruff_cache/
|
| 35 |
+
find . -type d -name __pycache__ -exec rm -rf {} +
|
| 36 |
+
find . -type f -name "*.pyc" -delete
|
| 37 |
+
|
| 38 |
+
run: ## Run the application
|
| 39 |
+
python main.py
|
| 40 |
+
|
| 41 |
+
run-dev: ## Run the application in development mode
|
| 42 |
+
uvicorn src.kybtech_dots_ocr.app:app --host 0.0.0.0 --port 7860 --reload
|
| 43 |
+
|
| 44 |
+
setup: ## Set up development environment
|
| 45 |
+
python setup_dev.py
|
| 46 |
+
|
| 47 |
+
check: ## Run all checks (lint, format, test)
|
| 48 |
+
$(MAKE) lint
|
| 49 |
+
$(MAKE) test
|
| 50 |
+
|
| 51 |
+
build: ## Build the Docker image
|
| 52 |
+
docker build -t kybtech-dots-ocr .
|
| 53 |
+
|
| 54 |
+
# Build for Hugging Face Spaces GPU (CUDA, optional flash-attn)
|
| 55 |
+
build-spaces-gpu: ## Build Docker image for HF Spaces GPU
|
| 56 |
+
# Use CUDA runtime base; leave flash-attn on by default (can disable with ARGS)
|
| 57 |
+
docker build \
|
| 58 |
+
--build-arg BASE_IMAGE=pytorch/pytorch:2.7.0-cuda12.6-cudnn9-runtime \
|
| 59 |
+
--build-arg INSTALL_FLASH_ATTN=true \
|
| 60 |
+
-t kybtech-dots-ocr:spaces-gpu .
|
| 61 |
+
|
| 62 |
+
# Build for local Apple Silicon CPU (no CUDA, no flash-attn)
|
| 63 |
+
build-local-cpu: ## Build Docker image for local CPU (arm64)
|
| 64 |
+
docker build \
|
| 65 |
+
--platform=linux/arm64 \
|
| 66 |
+
--build-arg BASE_IMAGE=python:3.12-slim \
|
| 67 |
+
--build-arg INSTALL_FLASH_ATTN=false \
|
| 68 |
+
-t kybtech-dots-ocr:cpu .
|
| 69 |
+
|
| 70 |
+
run-docker: ## Run the Docker container locally
|
| 71 |
+
docker run -p 7860:7860 kybtech-dots-ocr
|
| 72 |
+
|
| 73 |
+
deploy-staging: ## Deploy to staging (requires HF CLI)
|
| 74 |
+
@echo "Deploying to staging..."
|
| 75 |
+
@echo "Make sure you have HF CLI installed and are logged in"
|
| 76 |
+
@echo "Then push to your staging space repository"
|
| 77 |
+
|
| 78 |
+
deploy-production: ## Deploy to production (requires HF CLI)
|
| 79 |
+
@echo "Deploying to production..."
|
| 80 |
+
@echo "Make sure you have HF CLI installed and are logged in"
|
| 81 |
+
@echo "Then push to your production space repository"
|
| 82 |
+
|
| 83 |
+
test-local: ## Test the local API endpoint
|
| 84 |
+
cd scripts && ./run_tests.sh -e local
|
| 85 |
+
|
| 86 |
+
test-production: ## Test the production API endpoint
|
| 87 |
+
cd scripts && ./run_tests.sh -e production
|
| 88 |
+
|
| 89 |
+
test-staging: ## Test the staging API endpoint
|
| 90 |
+
cd scripts && ./run_tests.sh -e staging
|
| 91 |
+
|
| 92 |
+
test-quick: ## Quick test with curl (no Python dependencies)
|
| 93 |
+
cd scripts && ./test_production_curl.sh
|
| 94 |
+
|
| 95 |
+
logs: ## Show application logs (if running in Docker)
|
| 96 |
+
docker logs -f kybtech-dots-ocr
|
| 97 |
+
|
| 98 |
+
stop: ## Stop the Docker container
|
| 99 |
+
docker stop kybtech-dots-ocr || true
|
| 100 |
+
|
| 101 |
+
clean-docker: ## Clean up Docker images and containers
|
| 102 |
+
docker stop kybtech-dots-ocr || true
|
| 103 |
+
docker rm kybtech-dots-ocr || true
|
| 104 |
+
docker rmi kybtech-dots-ocr || true
|
README.md
CHANGED
|
@@ -178,24 +178,48 @@ The Space will be available at `https://algoryn-dots-ocr-idcard.hf.space` after
|
|
| 178 |
|
| 179 |
## 🐳 Local Development
|
| 180 |
|
| 181 |
-
###
|
| 182 |
```bash
|
| 183 |
-
#
|
| 184 |
-
|
| 185 |
|
| 186 |
-
#
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
```
|
| 189 |
|
| 190 |
-
###
|
| 191 |
```bash
|
| 192 |
-
#
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
#
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
```
|
| 198 |
|
|
|
|
|
|
|
| 199 |
## 📚 Documentation
|
| 200 |
|
| 201 |
- [Hugging Face Spaces Documentation](https://huggingface.co/docs/hub/spaces)
|
|
|
|
| 178 |
|
| 179 |
## 🐳 Local Development
|
| 180 |
|
| 181 |
+
### Quick Start with uv
|
| 182 |
```bash
|
| 183 |
+
# Set up development environment
|
| 184 |
+
make setup
|
| 185 |
|
| 186 |
+
# Activate virtual environment
|
| 187 |
+
source .venv/bin/activate # On Unix/macOS
|
| 188 |
+
# or
|
| 189 |
+
.venv\Scripts\activate # On Windows
|
| 190 |
+
|
| 191 |
+
# Run the application
|
| 192 |
+
make run-dev
|
| 193 |
```
|
| 194 |
|
| 195 |
+
### Docker Development
|
| 196 |
```bash
|
| 197 |
+
# Build and run with Docker
|
| 198 |
+
make build
|
| 199 |
+
make run-docker
|
| 200 |
|
| 201 |
+
# View logs
|
| 202 |
+
make logs
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### Development Commands
|
| 206 |
+
```bash
|
| 207 |
+
# Run tests
|
| 208 |
+
make test
|
| 209 |
+
|
| 210 |
+
# Format code
|
| 211 |
+
make format
|
| 212 |
+
|
| 213 |
+
# Run linting
|
| 214 |
+
make lint
|
| 215 |
+
|
| 216 |
+
# Test API endpoints
|
| 217 |
+
make test-local
|
| 218 |
+
make test-production
|
| 219 |
```
|
| 220 |
|
| 221 |
+
For detailed development instructions, see the documentation in `docs/`.
|
| 222 |
+
|
| 223 |
## 📚 Documentation
|
| 224 |
|
| 225 |
- [Hugging Face Spaces Documentation](https://huggingface.co/docs/hub/spaces)
|
main.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Main entry point for the KYB Tech Dots.OCR application.
|
| 3 |
+
|
| 4 |
+
Provides a callable ``main()`` for console_scripts and direct execution.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import uvicorn
|
| 9 |
+
from src.kybtech_dots_ocr.app import app
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main() -> None:
|
| 13 |
+
"""Start the FastAPI server with sensible defaults.
|
| 14 |
+
|
| 15 |
+
This function is exposed as a console script via pyproject.toml.
|
| 16 |
+
Set DOTS_OCR_SKIP_MODEL_LOAD=1 to skip heavy model download for local testing.
|
| 17 |
+
"""
|
| 18 |
+
# Respect environment overrides for host/port
|
| 19 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 20 |
+
port = int(os.getenv("PORT", "7860"))
|
| 21 |
+
log_level = os.getenv("LOG_LEVEL", "info")
|
| 22 |
+
|
| 23 |
+
uvicorn.run(app, host=host, port=port, log_level=log_level)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "kybtech-dots-ocr"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Dots.OCR model integration for KYB text extraction"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "Algoryn", email = "info@algoryn.com"}
|
| 7 |
+
]
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
requires-python = ">=3.9"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"fastapi>=0.112.1",
|
| 12 |
+
"uvicorn[standard]>=0.30.6",
|
| 13 |
+
"python-multipart>=0.0.9",
|
| 14 |
+
"pydantic>=2.2.0,<3.0.0",
|
| 15 |
+
"opencv-python>=4.9.0.80",
|
| 16 |
+
"numpy>=1.26.0",
|
| 17 |
+
"pillow>=10.3.0",
|
| 18 |
+
"huggingface_hub",
|
| 19 |
+
"PyMuPDF>=1.23.0",
|
| 20 |
+
"torch>=2.0.0",
|
| 21 |
+
"transformers>=4.40.0",
|
| 22 |
+
"accelerate>=0.20.0",
|
| 23 |
+
"qwen-vl-utils",
|
| 24 |
+
"requests>=2.31.0",
|
| 25 |
+
"httpx>=0.27.0",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
dev = [
|
| 30 |
+
"pytest>=7.0.0",
|
| 31 |
+
"pytest-asyncio>=0.21.0",
|
| 32 |
+
"black>=23.0.0",
|
| 33 |
+
"ruff>=0.1.0",
|
| 34 |
+
"mypy>=1.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
kyb-ocr = "main:main"
|
| 39 |
+
|
| 40 |
+
[build-system]
|
| 41 |
+
requires = ["setuptools>=68", "wheel"]
|
| 42 |
+
build-backend = "setuptools.build_meta"
|
| 43 |
+
|
| 44 |
+
[tool.setuptools]
|
| 45 |
+
package-dir = {"" = "src"}
|
| 46 |
+
|
| 47 |
+
[tool.setuptools.packages.find]
|
| 48 |
+
where = ["src"]
|
| 49 |
+
|
| 50 |
+
[tool.black]
|
| 51 |
+
line-length = 88
|
| 52 |
+
target-version = ['py39']
|
| 53 |
+
|
| 54 |
+
[tool.ruff]
|
| 55 |
+
line-length = 88
|
| 56 |
+
target-version = "py39"
|
| 57 |
+
select = ["E", "F", "W", "C90", "I", "N", "UP", "YTT", "S", "BLE", "FBT", "B", "A", "COM", "C4", "DTZ", "T10", "EM", "EXE", "FA", "ISC", "ICN", "G", "INP", "PIE", "T20", "PYI", "PT", "Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TID", "TCH", "INT", "ARG", "PTH", "TD", "FIX", "ERA", "PD", "PGH", "PL", "TRY", "FLY", "NPY", "AIR", "PERF", "FURB", "LOG", "RUF"]
|
| 58 |
+
ignore = ["S101", "PLR0913", "PLR0912", "PLR0915"]
|
| 59 |
+
|
| 60 |
+
[tool.ruff.per-file-ignores]
|
| 61 |
+
"__init__.py" = ["F401"]
|
| 62 |
+
|
| 63 |
+
[tool.mypy]
|
| 64 |
+
python_version = "3.9"
|
| 65 |
+
warn_return_any = true
|
| 66 |
+
warn_unused_configs = true
|
| 67 |
+
disallow_untyped_defs = true
|
| 68 |
+
disallow_incomplete_defs = true
|
| 69 |
+
check_untyped_defs = true
|
| 70 |
+
disallow_untyped_decorators = true
|
| 71 |
+
no_implicit_optional = true
|
| 72 |
+
warn_redundant_casts = true
|
| 73 |
+
warn_unused_ignores = true
|
| 74 |
+
warn_no_return = true
|
| 75 |
+
warn_unreachable = true
|
| 76 |
+
strict_equality = true
|
| 77 |
+
|
| 78 |
+
[tool.pytest.ini_options]
|
| 79 |
+
testpaths = ["tests"]
|
| 80 |
+
python_files = ["test_*.py"]
|
| 81 |
+
python_classes = ["Test*"]
|
| 82 |
+
python_functions = ["test_*"]
|
| 83 |
+
addopts = "-v --tb=short"
|
requirements.txt
CHANGED
|
@@ -6,4 +6,10 @@ opencv-python>=4.9.0.80
|
|
| 6 |
numpy>=1.26.0
|
| 7 |
pillow>=10.3.0
|
| 8 |
huggingface_hub
|
| 9 |
-
PyMuPDF
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
numpy>=1.26.0
|
| 7 |
pillow>=10.3.0
|
| 8 |
huggingface_hub
|
| 9 |
+
PyMuPDF>=1.23.0
|
| 10 |
+
torch>=2.0.0
|
| 11 |
+
transformers>=4.40.0
|
| 12 |
+
accelerate>=0.20.0
|
| 13 |
+
qwen-vl-utils
|
| 14 |
+
requests>=2.31.0
|
| 15 |
+
httpx>=0.27.0
|
scripts/README_TESTING.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dots.OCR API Testing
|
| 2 |
+
|
| 3 |
+
This directory contains comprehensive testing scripts for the Dots.OCR API endpoint.
|
| 4 |
+
|
| 5 |
+
## Test Scripts
|
| 6 |
+
|
| 7 |
+
### 1. `test_api_endpoint.py` - Comprehensive API Testing
|
| 8 |
+
|
| 9 |
+
The main testing script that provides full API validation capabilities.
|
| 10 |
+
|
| 11 |
+
**Features:**
|
| 12 |
+
- Health check validation
|
| 13 |
+
- Single and multiple image testing
|
| 14 |
+
- ROI (Region of Interest) testing
|
| 15 |
+
- Field extraction validation
|
| 16 |
+
- Response structure validation
|
| 17 |
+
- Performance metrics
|
| 18 |
+
- Detailed error reporting
|
| 19 |
+
|
| 20 |
+
**Usage:**
|
| 21 |
+
```bash
|
| 22 |
+
# Basic test with default settings
|
| 23 |
+
python test_api_endpoint.py
|
| 24 |
+
|
| 25 |
+
# Test with custom API URL
|
| 26 |
+
python test_api_endpoint.py --url https://your-api.example.com
|
| 27 |
+
|
| 28 |
+
# Test with ROI
|
| 29 |
+
python test_api_endpoint.py --roi '{"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}'
|
| 30 |
+
|
| 31 |
+
# Test with specific expected fields
|
| 32 |
+
python test_api_endpoint.py --expected-fields document_number surname given_names
|
| 33 |
+
|
| 34 |
+
# Verbose output
|
| 35 |
+
python test_api_endpoint.py --verbose
|
| 36 |
+
|
| 37 |
+
# Custom timeout
|
| 38 |
+
python test_api_endpoint.py --timeout 60
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Options:**
|
| 42 |
+
- `--url`: API base URL (default: http://localhost:7860)
|
| 43 |
+
- `--timeout`: Request timeout in seconds (default: 30)
|
| 44 |
+
- `--roi`: ROI coordinates as JSON string
|
| 45 |
+
- `--expected-fields`: List of expected field names to validate
|
| 46 |
+
- `--verbose`: Enable verbose logging
|
| 47 |
+
|
| 48 |
+
### 2. `quick_test.py` - Quick Validation
|
| 49 |
+
|
| 50 |
+
A simple script for quick API validation after deployment.
|
| 51 |
+
|
| 52 |
+
**Usage:**
|
| 53 |
+
```bash
|
| 54 |
+
# Test local API
|
| 55 |
+
python quick_test.py
|
| 56 |
+
|
| 57 |
+
# Test remote API
|
| 58 |
+
python quick_test.py https://your-api.example.com
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Test Configuration
|
| 62 |
+
|
| 63 |
+
### `test_config.json`
|
| 64 |
+
|
| 65 |
+
Configuration file for test parameters and thresholds.
|
| 66 |
+
|
| 67 |
+
**Configuration sections:**
|
| 68 |
+
- `api_endpoints`: Different API URLs for various environments
|
| 69 |
+
- `test_images`: List of test image files
|
| 70 |
+
- `expected_fields`: Fields that should be extracted
|
| 71 |
+
- `roi_test_cases`: Different ROI configurations to test
|
| 72 |
+
- `performance_thresholds`: Performance validation criteria
|
| 73 |
+
- `test_timeout`: Default timeout for requests
|
| 74 |
+
|
| 75 |
+
## Test Images
|
| 76 |
+
|
| 77 |
+
The following test images are used for validation:
|
| 78 |
+
|
| 79 |
+
- `tom_id_card_front.jpg` - Front of Dutch ID card
|
| 80 |
+
- `tom_id_card_back.jpg` - Back of Dutch ID card
|
| 81 |
+
|
| 82 |
+
## Testing Scenarios
|
| 83 |
+
|
| 84 |
+
### 1. Basic Functionality Test
|
| 85 |
+
```bash
|
| 86 |
+
python test_api_endpoint.py
|
| 87 |
+
```
|
| 88 |
+
Tests basic API functionality with default settings.
|
| 89 |
+
|
| 90 |
+
### 2. ROI Testing
|
| 91 |
+
```bash
|
| 92 |
+
python test_api_endpoint.py --roi '{"x1": 0.25, "y1": 0.25, "x2": 0.75, "y2": 0.75}'
|
| 93 |
+
```
|
| 94 |
+
Tests Region of Interest cropping functionality.
|
| 95 |
+
|
| 96 |
+
### 3. Field Validation Test
|
| 97 |
+
```bash
|
| 98 |
+
python test_api_endpoint.py --expected-fields document_number surname given_names nationality
|
| 99 |
+
```
|
| 100 |
+
Tests that specific fields are extracted correctly.
|
| 101 |
+
|
| 102 |
+
### 4. Performance Test
|
| 103 |
+
```bash
|
| 104 |
+
python test_api_endpoint.py --timeout 60 --verbose
|
| 105 |
+
```
|
| 106 |
+
Tests API performance with extended timeout and detailed logging.
|
| 107 |
+
|
| 108 |
+
## Expected Results
|
| 109 |
+
|
| 110 |
+
### Successful Test Output
|
| 111 |
+
```
|
| 112 |
+
🔍 Checking API health...
|
| 113 |
+
✅ API is healthy: {'status': 'healthy', 'version': '1.0.0', 'model_loaded': True}
|
| 114 |
+
🚀 Starting API tests with 2 images...
|
| 115 |
+
✅ tom_id_card_front.jpg: 2.45s
|
| 116 |
+
✅ tom_id_card_back.jpg: 1.23s
|
| 117 |
+
📊 Test Results:
|
| 118 |
+
Total images: 2
|
| 119 |
+
Successful: 2
|
| 120 |
+
Failed: 0
|
| 121 |
+
Success rate: 100.0%
|
| 122 |
+
Average processing time: 1.84s
|
| 123 |
+
🎉 All tests completed successfully!
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Field Extraction Example
|
| 127 |
+
```
|
| 128 |
+
Page 1: 11 fields extracted
|
| 129 |
+
document_number: NLD123456789 (confidence: 0.90)
|
| 130 |
+
surname: MULDER (confidence: 0.90)
|
| 131 |
+
given_names: THOMAS JAN (confidence: 0.90)
|
| 132 |
+
nationality: NLD (confidence: 0.95)
|
| 133 |
+
date_of_birth: 15-03-1990 (confidence: 0.90)
|
| 134 |
+
gender: M (confidence: 0.95)
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## Troubleshooting
|
| 138 |
+
|
| 139 |
+
### Common Issues
|
| 140 |
+
|
| 141 |
+
1. **Connection Refused**
|
| 142 |
+
- Check if the API is running
|
| 143 |
+
- Verify the correct URL and port
|
| 144 |
+
- Check firewall settings
|
| 145 |
+
|
| 146 |
+
2. **Timeout Errors**
|
| 147 |
+
- Increase timeout with `--timeout` parameter
|
| 148 |
+
- Check API performance and resource usage
|
| 149 |
+
|
| 150 |
+
3. **Missing Fields**
|
| 151 |
+
- Verify test images contain the expected text
|
| 152 |
+
- Check field extraction patterns in the code
|
| 153 |
+
- Review API logs for processing errors
|
| 154 |
+
|
| 155 |
+
4. **Validation Errors**
|
| 156 |
+
- Check API response format
|
| 157 |
+
- Verify model is loaded correctly
|
| 158 |
+
- Review error logs for details
|
| 159 |
+
|
| 160 |
+
### Debug Mode
|
| 161 |
+
|
| 162 |
+
Enable verbose logging for detailed debugging:
|
| 163 |
+
```bash
|
| 164 |
+
python test_api_endpoint.py --verbose
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
## Integration with CI/CD
|
| 168 |
+
|
| 169 |
+
The test scripts can be integrated into CI/CD pipelines:
|
| 170 |
+
|
| 171 |
+
```yaml
|
| 172 |
+
# Example GitHub Actions step
|
| 173 |
+
- name: Test API Endpoint
|
| 174 |
+
run: |
|
| 175 |
+
python scripts/test_api_endpoint.py --url ${{ env.API_URL }} --timeout 60
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
## Performance Monitoring
|
| 179 |
+
|
| 180 |
+
The scripts provide performance metrics that can be used for monitoring:
|
| 181 |
+
|
| 182 |
+
- Processing time per image
|
| 183 |
+
- Success rate
|
| 184 |
+
- Field extraction accuracy
|
| 185 |
+
- Response validation results
|
| 186 |
+
|
| 187 |
+
These metrics can be integrated with monitoring systems like Prometheus or DataDog.
|
| 188 |
+
|
| 189 |
+
## 🚀 Production API Testing
|
| 190 |
+
|
| 191 |
+
### Current Production Endpoint
|
| 192 |
+
- **URL**: https://algoryn-dots-ocr-idcard.hf.space
|
| 193 |
+
- **Health Check**: https://algoryn-dots-ocr-idcard.hf.space/health
|
| 194 |
+
- **API Docs**: https://algoryn-dots-ocr-idcard.hf.space/docs
|
| 195 |
+
|
| 196 |
+
### Quick Production Test
|
| 197 |
+
```bash
|
| 198 |
+
# Test production API
|
| 199 |
+
./run_tests.sh -e production
|
| 200 |
+
|
| 201 |
+
# Quick test with curl (no Python dependencies)
|
| 202 |
+
./test_production_curl.sh
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### Staging Environment
|
| 206 |
+
- **Staging URL**: https://algoryn-dots-ocr-idcard-staging.hf.space (to be created)
|
| 207 |
+
- **Purpose**: Safe testing before production deployment
|
| 208 |
+
|
| 209 |
+
### Environment-Specific Testing
|
| 210 |
+
```bash
|
| 211 |
+
# Test different environments
|
| 212 |
+
./run_tests.sh -e local # Local development
|
| 213 |
+
./run_tests.sh -e staging # Staging environment
|
| 214 |
+
./run_tests.sh -e production # Production environment
|
| 215 |
+
```
|
scripts/quick_test.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Quick API Test Script
|
| 3 |
+
|
| 4 |
+
A simple script to quickly test the deployed Dots.OCR API endpoint.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import json
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
def test_api(base_url="http://localhost:7860"):
|
| 13 |
+
"""Quick test of the API endpoint."""
|
| 14 |
+
|
| 15 |
+
print(f"🔍 Testing API at {base_url}")
|
| 16 |
+
|
| 17 |
+
# Health check
|
| 18 |
+
try:
|
| 19 |
+
health_response = requests.get(f"{base_url}/health", timeout=10)
|
| 20 |
+
health_response.raise_for_status()
|
| 21 |
+
health_data = health_response.json()
|
| 22 |
+
print(f"✅ Health check passed: {health_data}")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"❌ Health check failed: {e}")
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
# Test with front image
|
| 28 |
+
front_image = Path(__file__).parent / "tom_id_card_front.jpg"
|
| 29 |
+
if not front_image.exists():
|
| 30 |
+
print(f"❌ Test image not found: {front_image}")
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
print(f"📸 Testing with {front_image.name}")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
with open(front_image, 'rb') as f:
|
| 37 |
+
files = {'file': f}
|
| 38 |
+
response = requests.post(
|
| 39 |
+
f"{base_url}/v1/id/ocr",
|
| 40 |
+
files=files,
|
| 41 |
+
timeout=30
|
| 42 |
+
)
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
result = response.json()
|
| 45 |
+
|
| 46 |
+
print(f"✅ OCR test passed")
|
| 47 |
+
print(f" Request ID: {result.get('request_id')}")
|
| 48 |
+
print(f" Media type: {result.get('media_type')}")
|
| 49 |
+
print(f" Processing time: {result.get('processing_time'):.2f}s")
|
| 50 |
+
print(f" Detections: {len(result.get('detections', []))}")
|
| 51 |
+
|
| 52 |
+
# Show extracted fields
|
| 53 |
+
for i, detection in enumerate(result.get('detections', [])):
|
| 54 |
+
fields = detection.get('extracted_fields', {})
|
| 55 |
+
field_count = len([f for f in fields.values() if f is not None])
|
| 56 |
+
print(f" Page {i+1}: {field_count} fields extracted")
|
| 57 |
+
|
| 58 |
+
# Show some key fields
|
| 59 |
+
key_fields = ['document_number', 'surname', 'given_names', 'nationality']
|
| 60 |
+
for field in key_fields:
|
| 61 |
+
if field in fields and fields[field] is not None:
|
| 62 |
+
value = fields[field].get('value', 'N/A') if isinstance(fields[field], dict) else str(fields[field])
|
| 63 |
+
confidence = fields[field].get('confidence', 'N/A') if isinstance(fields[field], dict) else 'N/A'
|
| 64 |
+
print(f" {field}: {value} (confidence: {confidence})")
|
| 65 |
+
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"❌ OCR test failed: {e}")
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
|
| 74 |
+
success = test_api(base_url)
|
| 75 |
+
sys.exit(0 if success else 1)
|
scripts/run_tests.sh
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Dots.OCR API Test Runner
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
# Colors for output
|
| 7 |
+
RED='\033[0;31m'
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
BLUE='\033[0;34m'
|
| 11 |
+
NC='\033[0m' # No Color
|
| 12 |
+
|
| 13 |
+
# Default values
|
| 14 |
+
API_URL="http://localhost:7860"
|
| 15 |
+
TIMEOUT=30
|
| 16 |
+
VERBOSE=false
|
| 17 |
+
ROI=""
|
| 18 |
+
ENVIRONMENT="local"
|
| 19 |
+
|
| 20 |
+
# Function to print colored output
|
| 21 |
+
print_status() {
|
| 22 |
+
echo -e "${BLUE}[INFO]${NC} $1"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
print_success() {
|
| 26 |
+
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
print_warning() {
|
| 30 |
+
echo -e "${YELLOW}[WARNING]${NC} $1"
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
print_error() {
|
| 34 |
+
echo -e "${RED}[ERROR]${NC} $1"
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Function to show usage
|
| 38 |
+
show_usage() {
|
| 39 |
+
echo "Usage: $0 [OPTIONS]"
|
| 40 |
+
echo ""
|
| 41 |
+
echo "Options:"
|
| 42 |
+
echo " -u, --url URL API base URL (default: http://localhost:7860)"
|
| 43 |
+
echo " -e, --env ENV Environment: local, staging, production (default: local)"
|
| 44 |
+
echo " -t, --timeout SECONDS Request timeout (default: 30)"
|
| 45 |
+
echo " -r, --roi JSON ROI coordinates as JSON string"
|
| 46 |
+
echo " -v, --verbose Enable verbose output"
|
| 47 |
+
echo " -q, --quick Run quick test only"
|
| 48 |
+
echo " -h, --help Show this help message"
|
| 49 |
+
echo ""
|
| 50 |
+
echo "Examples:"
|
| 51 |
+
echo " $0 # Basic test (local)"
|
| 52 |
+
echo " $0 -e production # Test production API"
|
| 53 |
+
echo " $0 -e staging # Test staging API"
|
| 54 |
+
echo " $0 -u https://api.example.com # Test custom API URL"
|
| 55 |
+
echo " $0 -r '{\"x1\":0.1,\"y1\":0.1,\"x2\":0.9,\"y2\":0.9}' # Test with ROI"
|
| 56 |
+
echo " $0 -v -t 60 # Verbose with 60s timeout"
|
| 57 |
+
echo " $0 -q # Quick test only"
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Parse command line arguments
|
| 61 |
+
while [[ $# -gt 0 ]]; do
|
| 62 |
+
case $1 in
|
| 63 |
+
-u|--url)
|
| 64 |
+
API_URL="$2"
|
| 65 |
+
shift 2
|
| 66 |
+
;;
|
| 67 |
+
-e|--env)
|
| 68 |
+
ENVIRONMENT="$2"
|
| 69 |
+
shift 2
|
| 70 |
+
;;
|
| 71 |
+
-t|--timeout)
|
| 72 |
+
TIMEOUT="$2"
|
| 73 |
+
shift 2
|
| 74 |
+
;;
|
| 75 |
+
-r|--roi)
|
| 76 |
+
ROI="$2"
|
| 77 |
+
shift 2
|
| 78 |
+
;;
|
| 79 |
+
-v|--verbose)
|
| 80 |
+
VERBOSE=true
|
| 81 |
+
shift
|
| 82 |
+
;;
|
| 83 |
+
-q|--quick)
|
| 84 |
+
QUICK=true
|
| 85 |
+
shift
|
| 86 |
+
;;
|
| 87 |
+
-h|--help)
|
| 88 |
+
show_usage
|
| 89 |
+
exit 0
|
| 90 |
+
;;
|
| 91 |
+
*)
|
| 92 |
+
print_error "Unknown option: $1"
|
| 93 |
+
show_usage
|
| 94 |
+
exit 1
|
| 95 |
+
;;
|
| 96 |
+
esac
|
| 97 |
+
done
|
| 98 |
+
|
| 99 |
+
# Set API URL based on environment if not explicitly provided
|
| 100 |
+
if [ "$API_URL" = "http://localhost:7860" ] && [ "$ENVIRONMENT" != "local" ]; then
|
| 101 |
+
case $ENVIRONMENT in
|
| 102 |
+
"staging")
|
| 103 |
+
API_URL="https://algoryn-dots-ocr-idcard-staging.hf.space"
|
| 104 |
+
;;
|
| 105 |
+
"production")
|
| 106 |
+
API_URL="https://algoryn-dots-ocr-idcard.hf.space"
|
| 107 |
+
;;
|
| 108 |
+
*)
|
| 109 |
+
print_error "Unknown environment: $ENVIRONMENT. Use: local, staging, production"
|
| 110 |
+
exit 1
|
| 111 |
+
;;
|
| 112 |
+
esac
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
# Check if Python is available
|
| 116 |
+
if ! command -v python3 &> /dev/null; then
|
| 117 |
+
print_error "Python 3 is required but not installed"
|
| 118 |
+
exit 1
|
| 119 |
+
fi
|
| 120 |
+
|
| 121 |
+
# Check if test images exist
|
| 122 |
+
if [ ! -f "tom_id_card_front.jpg" ] || [ ! -f "tom_id_card_back.jpg" ]; then
|
| 123 |
+
print_error "Test images not found. Please ensure tom_id_card_front.jpg and tom_id_card_back.jpg are in the scripts directory"
|
| 124 |
+
exit 1
|
| 125 |
+
fi
|
| 126 |
+
|
| 127 |
+
print_status "Starting Dots.OCR API Tests"
|
| 128 |
+
print_status "Environment: $ENVIRONMENT"
|
| 129 |
+
print_status "API URL: $API_URL"
|
| 130 |
+
print_status "Timeout: $TIMEOUT seconds"
|
| 131 |
+
|
| 132 |
+
# Run quick test if requested
|
| 133 |
+
if [ "$QUICK" = true ]; then
|
| 134 |
+
print_status "Running quick test..."
|
| 135 |
+
if python3 quick_test.py "$API_URL"; then
|
| 136 |
+
print_success "Quick test passed"
|
| 137 |
+
exit 0
|
| 138 |
+
else
|
| 139 |
+
print_error "Quick test failed"
|
| 140 |
+
exit 1
|
| 141 |
+
fi
|
| 142 |
+
fi
|
| 143 |
+
|
| 144 |
+
# Build test command
|
| 145 |
+
TEST_CMD="python3 test_api_endpoint.py --url $API_URL --timeout $TIMEOUT"
|
| 146 |
+
|
| 147 |
+
if [ "$VERBOSE" = true ]; then
|
| 148 |
+
TEST_CMD="$TEST_CMD --verbose"
|
| 149 |
+
fi
|
| 150 |
+
|
| 151 |
+
if [ -n "$ROI" ]; then
|
| 152 |
+
TEST_CMD="$TEST_CMD --roi '$ROI'"
|
| 153 |
+
fi
|
| 154 |
+
|
| 155 |
+
# Run comprehensive test
|
| 156 |
+
print_status "Running comprehensive API test..."
|
| 157 |
+
if eval $TEST_CMD; then
|
| 158 |
+
print_success "All tests passed successfully!"
|
| 159 |
+
exit 0
|
| 160 |
+
else
|
| 161 |
+
print_error "Tests failed"
|
| 162 |
+
exit 1
|
| 163 |
+
fi
|
scripts/test_api_endpoint.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""API Endpoint Test Script for Dots.OCR
|
| 3 |
+
|
| 4 |
+
This script tests the deployed Dots.OCR API endpoint using real ID card images.
|
| 5 |
+
It can be used to validate the complete pipeline in a production environment.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
import requests
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, Any, Optional, List
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DotsOCRAPITester:
|
| 27 |
+
"""Test client for the Dots.OCR API endpoint."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, base_url: str, timeout: int = 30):
|
| 30 |
+
"""Initialize the API tester.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
base_url: Base URL of the deployed API (e.g., "http://localhost:7860")
|
| 34 |
+
timeout: Request timeout in seconds
|
| 35 |
+
"""
|
| 36 |
+
self.base_url = base_url.rstrip('/')
|
| 37 |
+
self.timeout = timeout
|
| 38 |
+
self.session = requests.Session()
|
| 39 |
+
|
| 40 |
+
# Set common headers
|
| 41 |
+
self.session.headers.update({
|
| 42 |
+
'User-Agent': 'DotsOCR-APITester/1.0'
|
| 43 |
+
})
|
| 44 |
+
|
| 45 |
+
def health_check(self) -> Dict[str, Any]:
|
| 46 |
+
"""Check API health status.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Health check response
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
response = self.session.get(
|
| 53 |
+
f"{self.base_url}/health",
|
| 54 |
+
timeout=self.timeout
|
| 55 |
+
)
|
| 56 |
+
response.raise_for_status()
|
| 57 |
+
return response.json()
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Health check failed: {e}")
|
| 60 |
+
return {"error": str(e)}
|
| 61 |
+
|
| 62 |
+
def test_ocr_endpoint(
|
| 63 |
+
self,
|
| 64 |
+
image_path: str,
|
| 65 |
+
roi: Optional[Dict[str, float]] = None,
|
| 66 |
+
expected_fields: Optional[List[str]] = None
|
| 67 |
+
) -> Dict[str, Any]:
|
| 68 |
+
"""Test the OCR endpoint with an image file.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
image_path: Path to the image file
|
| 72 |
+
roi: Optional ROI coordinates as {x1, y1, x2, y2}
|
| 73 |
+
expected_fields: List of expected field names to validate
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Test results dictionary
|
| 77 |
+
"""
|
| 78 |
+
logger.info(f"Testing OCR endpoint with {image_path}")
|
| 79 |
+
|
| 80 |
+
# Prepare files and data
|
| 81 |
+
files = {'file': open(image_path, 'rb')}
|
| 82 |
+
data = {}
|
| 83 |
+
|
| 84 |
+
if roi:
|
| 85 |
+
data['roi'] = json.dumps(roi)
|
| 86 |
+
logger.info(f"Using ROI: {roi}")
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# Make request
|
| 90 |
+
start_time = time.time()
|
| 91 |
+
response = self.session.post(
|
| 92 |
+
f"{self.base_url}/v1/id/ocr",
|
| 93 |
+
files=files,
|
| 94 |
+
data=data,
|
| 95 |
+
timeout=self.timeout
|
| 96 |
+
)
|
| 97 |
+
request_time = time.time() - start_time
|
| 98 |
+
|
| 99 |
+
# Close file
|
| 100 |
+
files['file'].close()
|
| 101 |
+
|
| 102 |
+
# Check response
|
| 103 |
+
response.raise_for_status()
|
| 104 |
+
result = response.json()
|
| 105 |
+
|
| 106 |
+
# Validate response structure
|
| 107 |
+
validation_result = self._validate_response(result)
|
| 108 |
+
|
| 109 |
+
# Check expected fields
|
| 110 |
+
field_validation = self._validate_expected_fields(result, expected_fields)
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"success": True,
|
| 114 |
+
"request_time": request_time,
|
| 115 |
+
"response": result,
|
| 116 |
+
"validation": validation_result,
|
| 117 |
+
"field_validation": field_validation,
|
| 118 |
+
"status_code": response.status_code
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
except requests.exceptions.RequestException as e:
|
| 122 |
+
logger.error(f"Request failed: {e}")
|
| 123 |
+
return {
|
| 124 |
+
"success": False,
|
| 125 |
+
"error": str(e),
|
| 126 |
+
"status_code": getattr(e.response, 'status_code', None)
|
| 127 |
+
}
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.error(f"Unexpected error: {e}")
|
| 130 |
+
return {
|
| 131 |
+
"success": False,
|
| 132 |
+
"error": str(e)
|
| 133 |
+
}
|
| 134 |
+
finally:
|
| 135 |
+
# Ensure file is closed
|
| 136 |
+
if 'file' in locals():
|
| 137 |
+
files['file'].close()
|
| 138 |
+
|
| 139 |
+
def _validate_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
| 140 |
+
"""Validate the API response structure.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
response: API response dictionary
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Validation results
|
| 147 |
+
"""
|
| 148 |
+
validation = {
|
| 149 |
+
"valid": True,
|
| 150 |
+
"errors": [],
|
| 151 |
+
"warnings": []
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Required fields
|
| 155 |
+
required_fields = ['request_id', 'media_type', 'processing_time', 'detections']
|
| 156 |
+
for field in required_fields:
|
| 157 |
+
if field not in response:
|
| 158 |
+
validation["errors"].append(f"Missing required field: {field}")
|
| 159 |
+
validation["valid"] = False
|
| 160 |
+
|
| 161 |
+
# Validate detections
|
| 162 |
+
if 'detections' in response:
|
| 163 |
+
if not isinstance(response['detections'], list):
|
| 164 |
+
validation["errors"].append("detections must be a list")
|
| 165 |
+
validation["valid"] = False
|
| 166 |
+
else:
|
| 167 |
+
for i, detection in enumerate(response['detections']):
|
| 168 |
+
if not isinstance(detection, dict):
|
| 169 |
+
validation["errors"].append(f"detection {i} must be a dictionary")
|
| 170 |
+
validation["valid"] = False
|
| 171 |
+
else:
|
| 172 |
+
# Check for extracted_fields
|
| 173 |
+
if 'extracted_fields' not in detection:
|
| 174 |
+
validation["warnings"].append(f"detection {i} missing extracted_fields")
|
| 175 |
+
if 'mrz_data' not in detection:
|
| 176 |
+
validation["warnings"].append(f"detection {i} missing mrz_data")
|
| 177 |
+
|
| 178 |
+
# Validate processing time
|
| 179 |
+
if 'processing_time' in response:
|
| 180 |
+
if not isinstance(response['processing_time'], (int, float)):
|
| 181 |
+
validation["errors"].append("processing_time must be a number")
|
| 182 |
+
validation["valid"] = False
|
| 183 |
+
elif response['processing_time'] < 0:
|
| 184 |
+
validation["warnings"].append("processing_time is negative")
|
| 185 |
+
|
| 186 |
+
return validation
|
| 187 |
+
|
| 188 |
+
def _validate_expected_fields(
|
| 189 |
+
self,
|
| 190 |
+
response: Dict[str, Any],
|
| 191 |
+
expected_fields: Optional[List[str]]
|
| 192 |
+
) -> Dict[str, Any]:
|
| 193 |
+
"""Validate that expected fields are present in the response.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
response: API response dictionary
|
| 197 |
+
expected_fields: List of expected field names
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Field validation results
|
| 201 |
+
"""
|
| 202 |
+
if not expected_fields:
|
| 203 |
+
return {"valid": True, "found_fields": [], "missing_fields": []}
|
| 204 |
+
|
| 205 |
+
found_fields = []
|
| 206 |
+
missing_fields = []
|
| 207 |
+
|
| 208 |
+
# Check all detections for fields
|
| 209 |
+
for i, detection in enumerate(response.get('detections', [])):
|
| 210 |
+
extracted_fields = detection.get('extracted_fields', {})
|
| 211 |
+
|
| 212 |
+
for field_name in expected_fields:
|
| 213 |
+
if field_name in extracted_fields and extracted_fields[field_name] is not None:
|
| 214 |
+
found_fields.append(f"{field_name} (detection {i})")
|
| 215 |
+
else:
|
| 216 |
+
missing_fields.append(f"{field_name} (detection {i})")
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
"valid": len(missing_fields) == 0,
|
| 220 |
+
"found_fields": found_fields,
|
| 221 |
+
"missing_fields": missing_fields
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
def test_multiple_images(
|
| 225 |
+
self,
|
| 226 |
+
image_paths: List[str],
|
| 227 |
+
roi: Optional[Dict[str, float]] = None
|
| 228 |
+
) -> Dict[str, Any]:
|
| 229 |
+
"""Test multiple images and return aggregated results.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
image_paths: List of image file paths
|
| 233 |
+
roi: Optional ROI coordinates
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
Aggregated test results
|
| 237 |
+
"""
|
| 238 |
+
logger.info(f"Testing {len(image_paths)} images")
|
| 239 |
+
|
| 240 |
+
results = []
|
| 241 |
+
successful_tests = 0
|
| 242 |
+
total_processing_time = 0
|
| 243 |
+
|
| 244 |
+
for image_path in image_paths:
|
| 245 |
+
if not os.path.exists(image_path):
|
| 246 |
+
logger.warning(f"Image not found: {image_path}")
|
| 247 |
+
results.append({
|
| 248 |
+
"image": image_path,
|
| 249 |
+
"success": False,
|
| 250 |
+
"error": "File not found"
|
| 251 |
+
})
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
result = self.test_ocr_endpoint(image_path, roi)
|
| 255 |
+
results.append({
|
| 256 |
+
"image": image_path,
|
| 257 |
+
**result
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
if result.get("success", False):
|
| 261 |
+
successful_tests += 1
|
| 262 |
+
total_processing_time += result.get("request_time", 0)
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"total_images": len(image_paths),
|
| 266 |
+
"successful_tests": successful_tests,
|
| 267 |
+
"failed_tests": len(image_paths) - successful_tests,
|
| 268 |
+
"success_rate": successful_tests / len(image_paths) if image_paths else 0,
|
| 269 |
+
"average_processing_time": total_processing_time / successful_tests if successful_tests > 0 else 0,
|
| 270 |
+
"results": results
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def main():
|
| 275 |
+
"""Main test function."""
|
| 276 |
+
parser = argparse.ArgumentParser(description="Test Dots.OCR API endpoint")
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--url",
|
| 279 |
+
default="http://localhost:7860",
|
| 280 |
+
help="API base URL (default: http://localhost:7860)"
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--timeout",
|
| 284 |
+
type=int,
|
| 285 |
+
default=30,
|
| 286 |
+
help="Request timeout in seconds (default: 30)"
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--roi",
|
| 290 |
+
type=str,
|
| 291 |
+
help="ROI coordinates as JSON string (e.g., '{\"x1\": 0.1, \"y1\": 0.1, \"x2\": 0.9, \"y2\": 0.9}')"
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--expected-fields",
|
| 295 |
+
nargs="+",
|
| 296 |
+
help="Expected field names to validate (e.g., document_number surname given_names)"
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--verbose",
|
| 300 |
+
action="store_true",
|
| 301 |
+
help="Enable verbose logging"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
args = parser.parse_args()
|
| 305 |
+
|
| 306 |
+
if args.verbose:
|
| 307 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 308 |
+
|
| 309 |
+
# Parse ROI if provided
|
| 310 |
+
roi = None
|
| 311 |
+
if args.roi:
|
| 312 |
+
try:
|
| 313 |
+
roi = json.loads(args.roi)
|
| 314 |
+
except json.JSONDecodeError as e:
|
| 315 |
+
logger.error(f"Invalid ROI JSON: {e}")
|
| 316 |
+
sys.exit(1)
|
| 317 |
+
|
| 318 |
+
# Initialize tester
|
| 319 |
+
tester = DotsOCRAPITester(args.url, args.timeout)
|
| 320 |
+
|
| 321 |
+
# Health check
|
| 322 |
+
logger.info("🔍 Checking API health...")
|
| 323 |
+
health = tester.health_check()
|
| 324 |
+
if "error" in health:
|
| 325 |
+
logger.error(f"❌ API health check failed: {health['error']}")
|
| 326 |
+
sys.exit(1)
|
| 327 |
+
|
| 328 |
+
logger.info(f"✅ API is healthy: {health}")
|
| 329 |
+
|
| 330 |
+
# Test images
|
| 331 |
+
test_images = [
|
| 332 |
+
"tom_id_card_front.jpg",
|
| 333 |
+
"tom_id_card_back.jpg"
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
# Check if test images exist
|
| 337 |
+
existing_images = []
|
| 338 |
+
for image in test_images:
|
| 339 |
+
image_path = Path(__file__).parent / image
|
| 340 |
+
if image_path.exists():
|
| 341 |
+
existing_images.append(str(image_path))
|
| 342 |
+
else:
|
| 343 |
+
logger.warning(f"Test image not found: {image_path}")
|
| 344 |
+
|
| 345 |
+
if not existing_images:
|
| 346 |
+
logger.error("❌ No test images found")
|
| 347 |
+
sys.exit(1)
|
| 348 |
+
|
| 349 |
+
# Expected fields for validation
|
| 350 |
+
expected_fields = args.expected_fields or [
|
| 351 |
+
"document_number",
|
| 352 |
+
"surname",
|
| 353 |
+
"given_names",
|
| 354 |
+
"nationality",
|
| 355 |
+
"date_of_birth",
|
| 356 |
+
"gender"
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
# Run tests
|
| 360 |
+
logger.info(f"🚀 Starting API tests with {len(existing_images)} images...")
|
| 361 |
+
|
| 362 |
+
if len(existing_images) == 1:
|
| 363 |
+
# Single image test
|
| 364 |
+
result = tester.test_ocr_endpoint(existing_images[0], roi, expected_fields)
|
| 365 |
+
|
| 366 |
+
if result["success"]:
|
| 367 |
+
logger.info("✅ Single image test passed")
|
| 368 |
+
logger.info(f"⏱️ Processing time: {result['request_time']:.2f}s")
|
| 369 |
+
logger.info(f"📄 Detections: {len(result['response']['detections'])}")
|
| 370 |
+
|
| 371 |
+
# Print field validation results
|
| 372 |
+
field_validation = result.get("field_validation", {})
|
| 373 |
+
if field_validation.get("found_fields"):
|
| 374 |
+
logger.info(f"✅ Found fields: {', '.join(field_validation['found_fields'])}")
|
| 375 |
+
if field_validation.get("missing_fields"):
|
| 376 |
+
logger.warning(f"⚠️ Missing fields: {', '.join(field_validation['missing_fields'])}")
|
| 377 |
+
else:
|
| 378 |
+
logger.error(f"❌ Single image test failed: {result.get('error', 'Unknown error')}")
|
| 379 |
+
sys.exit(1)
|
| 380 |
+
|
| 381 |
+
else:
|
| 382 |
+
# Multiple images test
|
| 383 |
+
results = tester.test_multiple_images(existing_images, roi)
|
| 384 |
+
|
| 385 |
+
logger.info(f"📊 Test Results:")
|
| 386 |
+
logger.info(f" Total images: {results['total_images']}")
|
| 387 |
+
logger.info(f" Successful: {results['successful_tests']}")
|
| 388 |
+
logger.info(f" Failed: {results['failed_tests']}")
|
| 389 |
+
logger.info(f" Success rate: {results['success_rate']:.1%}")
|
| 390 |
+
logger.info(f" Average processing time: {results['average_processing_time']:.2f}s")
|
| 391 |
+
|
| 392 |
+
# Print detailed results
|
| 393 |
+
for result in results["results"]:
|
| 394 |
+
image_name = Path(result["image"]).name
|
| 395 |
+
if result["success"]:
|
| 396 |
+
logger.info(f" ✅ {image_name}: {result['request_time']:.2f}s")
|
| 397 |
+
else:
|
| 398 |
+
logger.error(f" ❌ {image_name}: {result.get('error', 'Unknown error')}")
|
| 399 |
+
|
| 400 |
+
if results["failed_tests"] > 0:
|
| 401 |
+
sys.exit(1)
|
| 402 |
+
|
| 403 |
+
logger.info("🎉 All tests completed successfully!")
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
main()
|
scripts/test_config.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"api_endpoints": {
|
| 3 |
+
"local": "http://localhost:7860",
|
| 4 |
+
"staging": "https://algoryn-dots-ocr-idcard-staging.hf.space",
|
| 5 |
+
"production": "https://algoryn-dots-ocr-idcard.hf.space"
|
| 6 |
+
},
|
| 7 |
+
"test_images": [
|
| 8 |
+
"tom_id_card_front.jpg",
|
| 9 |
+
"tom_id_card_back.jpg"
|
| 10 |
+
],
|
| 11 |
+
"expected_fields": [
|
| 12 |
+
"document_number",
|
| 13 |
+
"surname",
|
| 14 |
+
"given_names",
|
| 15 |
+
"nationality",
|
| 16 |
+
"date_of_birth",
|
| 17 |
+
"gender",
|
| 18 |
+
"date_of_issue",
|
| 19 |
+
"date_of_expiry"
|
| 20 |
+
],
|
| 21 |
+
"roi_test_cases": [
|
| 22 |
+
{
|
| 23 |
+
"name": "full_image",
|
| 24 |
+
"roi": null,
|
| 25 |
+
"description": "Process entire image"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"name": "center_crop",
|
| 29 |
+
"roi": {
|
| 30 |
+
"x1": 0.25,
|
| 31 |
+
"y1": 0.25,
|
| 32 |
+
"x2": 0.75,
|
| 33 |
+
"y2": 0.75
|
| 34 |
+
},
|
| 35 |
+
"description": "Process center 50% of image"
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"name": "top_half",
|
| 39 |
+
"roi": {
|
| 40 |
+
"x1": 0.0,
|
| 41 |
+
"y1": 0.0,
|
| 42 |
+
"x2": 1.0,
|
| 43 |
+
"y2": 0.5
|
| 44 |
+
},
|
| 45 |
+
"description": "Process top half of image"
|
| 46 |
+
}
|
| 47 |
+
],
|
| 48 |
+
"performance_thresholds": {
|
| 49 |
+
"max_processing_time": 10.0,
|
| 50 |
+
"min_confidence": 0.7,
|
| 51 |
+
"min_fields_extracted": 3
|
| 52 |
+
},
|
| 53 |
+
"test_timeout": 30
|
| 54 |
+
}
|
scripts/test_production.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Production API Test Script
|
| 3 |
+
|
| 4 |
+
Quick test script specifically for the production Dots.OCR API.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import json
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
def test_production_api():
|
| 13 |
+
"""Test the production API endpoint."""
|
| 14 |
+
|
| 15 |
+
api_url = "https://algoryn-dots-ocr-idcard.hf.space"
|
| 16 |
+
print(f"🔍 Testing Production API at {api_url}")
|
| 17 |
+
|
| 18 |
+
# Health check
|
| 19 |
+
try:
|
| 20 |
+
print("📡 Checking API health...")
|
| 21 |
+
health_response = requests.get(f"{api_url}/health", timeout=10)
|
| 22 |
+
health_response.raise_for_status()
|
| 23 |
+
health_data = health_response.json()
|
| 24 |
+
print(f"✅ Health check passed: {health_data}")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"❌ Health check failed: {e}")
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
# Test with front image
|
| 30 |
+
front_image = Path(__file__).parent / "tom_id_card_front.jpg"
|
| 31 |
+
if not front_image.exists():
|
| 32 |
+
print(f"❌ Test image not found: {front_image}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
print(f"📸 Testing OCR with {front_image.name}")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
with open(front_image, 'rb') as f:
|
| 39 |
+
files = {'file': f}
|
| 40 |
+
response = requests.post(
|
| 41 |
+
f"{api_url}/v1/id/ocr",
|
| 42 |
+
files=files,
|
| 43 |
+
timeout=60 # Longer timeout for production
|
| 44 |
+
)
|
| 45 |
+
response.raise_for_status()
|
| 46 |
+
result = response.json()
|
| 47 |
+
|
| 48 |
+
print(f"✅ OCR test passed")
|
| 49 |
+
print(f" Request ID: {result.get('request_id')}")
|
| 50 |
+
print(f" Media type: {result.get('media_type')}")
|
| 51 |
+
print(f" Processing time: {result.get('processing_time'):.2f}s")
|
| 52 |
+
print(f" Detections: {len(result.get('detections', []))}")
|
| 53 |
+
|
| 54 |
+
# Show extracted fields
|
| 55 |
+
for i, detection in enumerate(result.get('detections', [])):
|
| 56 |
+
fields = detection.get('extracted_fields', {})
|
| 57 |
+
field_count = len([f for f in fields.values() if f is not None])
|
| 58 |
+
print(f" Page {i+1}: {field_count} fields extracted")
|
| 59 |
+
|
| 60 |
+
# Show some key fields
|
| 61 |
+
key_fields = ['document_number', 'surname', 'given_names', 'nationality']
|
| 62 |
+
for field in key_fields:
|
| 63 |
+
if field in fields and fields[field] is not None:
|
| 64 |
+
value = fields[field].get('value', 'N/A') if isinstance(fields[field], dict) else str(fields[field])
|
| 65 |
+
confidence = fields[field].get('confidence', 'N/A') if isinstance(fields[field], dict) else 'N/A'
|
| 66 |
+
print(f" {field}: {value} (confidence: {confidence})")
|
| 67 |
+
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"❌ OCR test failed: {e}")
|
| 72 |
+
if hasattr(e, 'response') and e.response is not None:
|
| 73 |
+
print(f" Status code: {e.response.status_code}")
|
| 74 |
+
print(f" Response: {e.response.text}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
success = test_production_api()
|
| 79 |
+
sys.exit(0 if success else 1)
|
scripts/test_production_curl.sh
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Production API Test using curl
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
# Colors for output
|
| 7 |
+
RED='\033[0;31m'
|
| 8 |
+
GREEN='\033[0;32m'
|
| 9 |
+
YELLOW='\033[1;33m'
|
| 10 |
+
BLUE='\033[0;34m'
|
| 11 |
+
NC='\033[0m' # No Color
|
| 12 |
+
|
| 13 |
+
# Production API URL
|
| 14 |
+
API_URL="https://algoryn-dots-ocr-idcard.hf.space"
|
| 15 |
+
|
| 16 |
+
# Function to print colored output
|
| 17 |
+
print_status() {
|
| 18 |
+
echo -e "${BLUE}[INFO]${NC} $1"
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
print_success() {
|
| 22 |
+
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
print_error() {
|
| 26 |
+
echo -e "${RED}[ERROR]${NC} $1"
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
print_warning() {
|
| 30 |
+
echo -e "${YELLOW}[WARNING]${NC} $1"
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Check if test image exists
|
| 34 |
+
if [ ! -f "tom_id_card_front.jpg" ]; then
|
| 35 |
+
print_error "Test image not found: tom_id_card_front.jpg"
|
| 36 |
+
exit 1
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
print_status "Testing Production API at $API_URL"
|
| 40 |
+
|
| 41 |
+
# Health check
|
| 42 |
+
print_status "Checking API health..."
|
| 43 |
+
if curl -s -f "$API_URL/health" > /dev/null; then
|
| 44 |
+
print_success "Health check passed"
|
| 45 |
+
else
|
| 46 |
+
print_error "Health check failed"
|
| 47 |
+
exit 1
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
# Test OCR endpoint
|
| 51 |
+
print_status "Testing OCR endpoint with tom_id_card_front.jpg"
|
| 52 |
+
|
| 53 |
+
# Make the API request
|
| 54 |
+
response=$(curl -s -w "\n%{http_code}" -X POST \
|
| 55 |
+
-F "file=@tom_id_card_front.jpg" \
|
| 56 |
+
"$API_URL/v1/id/ocr")
|
| 57 |
+
|
| 58 |
+
# Split response and status code
|
| 59 |
+
http_code=$(echo "$response" | tail -n1)
|
| 60 |
+
response_body=$(echo "$response" | head -n -1)
|
| 61 |
+
|
| 62 |
+
if [ "$http_code" -eq 200 ]; then
|
| 63 |
+
print_success "OCR request successful"
|
| 64 |
+
|
| 65 |
+
# Parse and display results
|
| 66 |
+
echo "$response_body" | jq -r '.request_id' | while read request_id; do
|
| 67 |
+
echo "Request ID: $request_id"
|
| 68 |
+
done
|
| 69 |
+
|
| 70 |
+
echo "$response_body" | jq -r '.processing_time' | while read processing_time; do
|
| 71 |
+
echo "Processing time: ${processing_time}s"
|
| 72 |
+
done
|
| 73 |
+
|
| 74 |
+
echo "$response_body" | jq -r '.detections | length' | while read detection_count; do
|
| 75 |
+
echo "Detections: $detection_count"
|
| 76 |
+
done
|
| 77 |
+
|
| 78 |
+
# Show extracted fields
|
| 79 |
+
echo "$response_body" | jq -r '.detections[0].extracted_fields | to_entries[] | select(.value != null) | "\(.key): \(.value.value) (confidence: \(.value.confidence))"' | while read field_info; do
|
| 80 |
+
echo " $field_info"
|
| 81 |
+
done
|
| 82 |
+
|
| 83 |
+
print_success "Production API test completed successfully!"
|
| 84 |
+
|
| 85 |
+
else
|
| 86 |
+
print_error "OCR request failed with status code: $http_code"
|
| 87 |
+
echo "Response: $response_body"
|
| 88 |
+
exit 1
|
| 89 |
+
fi
|
setup_dev.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Development setup script for KYB Tech Dots.OCR."""
|
| 3 |
+
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
def run_command(cmd, description):
|
| 9 |
+
"""Run a command and handle errors."""
|
| 10 |
+
print(f"🔄 {description}...")
|
| 11 |
+
try:
|
| 12 |
+
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
|
| 13 |
+
print(f"✅ {description} completed")
|
| 14 |
+
return True
|
| 15 |
+
except subprocess.CalledProcessError as e:
|
| 16 |
+
print(f"❌ {description} failed: {e}")
|
| 17 |
+
print(f"Error output: {e.stderr}")
|
| 18 |
+
return False
|
| 19 |
+
|
| 20 |
+
def main():
|
| 21 |
+
"""Set up development environment."""
|
| 22 |
+
print("🚀 Setting up KYB Tech Dots.OCR development environment...")
|
| 23 |
+
|
| 24 |
+
# Check if uv is installed
|
| 25 |
+
if not run_command("uv --version", "Checking uv installation"):
|
| 26 |
+
print("📦 Installing uv...")
|
| 27 |
+
if not run_command("curl -LsSf https://astral.sh/uv/install.sh | sh", "Installing uv"):
|
| 28 |
+
print("❌ Failed to install uv. Please install it manually from https://github.com/astral-sh/uv")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
|
| 31 |
+
# Create virtual environment
|
| 32 |
+
if not run_command("uv venv", "Creating virtual environment"):
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
# Install dependencies
|
| 36 |
+
if not run_command("uv pip install -e .", "Installing package in development mode"):
|
| 37 |
+
sys.exit(1)
|
| 38 |
+
|
| 39 |
+
# Install development dependencies
|
| 40 |
+
if not run_command("uv pip install -e .[dev]", "Installing development dependencies"):
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
|
| 43 |
+
print("\n🎉 Development environment setup complete!")
|
| 44 |
+
print("\n📋 Next steps:")
|
| 45 |
+
print("1. Activate the virtual environment:")
|
| 46 |
+
print(" source .venv/bin/activate # On Unix/macOS")
|
| 47 |
+
print(" .venv\\Scripts\\activate # On Windows")
|
| 48 |
+
print("\n2. Run the application:")
|
| 49 |
+
print(" python main.py")
|
| 50 |
+
print("\n3. Run tests:")
|
| 51 |
+
print(" pytest")
|
| 52 |
+
print("\n4. Run linting:")
|
| 53 |
+
print(" ruff check .")
|
| 54 |
+
print(" black .")
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
main()
|
src/kybtech_dots_ocr/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""KYB Tech Dots.OCR Package
|
| 2 |
+
|
| 3 |
+
A FastAPI application for identity document text extraction using Dots.OCR model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__author__ = "Algoryn"
|
| 8 |
+
__email__ = "info@algoryn.com"
|
| 9 |
+
|
| 10 |
+
from .app import app
|
| 11 |
+
from .api_models import OCRResponse, OCRDetection, ExtractedFields, MRZData, ExtractedField
|
| 12 |
+
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 13 |
+
from .preprocessing import process_document, validate_file_size, get_document_info
|
| 14 |
+
from .response_builder import build_ocr_response, build_error_response
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"app",
|
| 18 |
+
"OCRResponse",
|
| 19 |
+
"OCRDetection",
|
| 20 |
+
"ExtractedFields",
|
| 21 |
+
"MRZData",
|
| 22 |
+
"ExtractedField",
|
| 23 |
+
"load_model",
|
| 24 |
+
"extract_text",
|
| 25 |
+
"is_model_loaded",
|
| 26 |
+
"get_model_info",
|
| 27 |
+
"process_document",
|
| 28 |
+
"validate_file_size",
|
| 29 |
+
"get_document_info",
|
| 30 |
+
"build_ocr_response",
|
| 31 |
+
"build_error_response",
|
| 32 |
+
]
|
app.py → src/kybtech_dots_ocr/api_models.py
RENAMED
|
@@ -1,44 +1,11 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
This
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import logging
|
| 8 |
-
import time
|
| 9 |
-
import uuid
|
| 10 |
-
import json
|
| 11 |
-
import re
|
| 12 |
from typing import List, Optional, Dict, Any
|
| 13 |
-
from contextlib import asynccontextmanager
|
| 14 |
-
|
| 15 |
-
import cv2
|
| 16 |
-
import numpy as np
|
| 17 |
-
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 18 |
-
from fastapi.responses import JSONResponse
|
| 19 |
from pydantic import BaseModel, Field
|
| 20 |
-
import torch
|
| 21 |
-
from PIL import Image
|
| 22 |
-
import io
|
| 23 |
-
import base64
|
| 24 |
-
|
| 25 |
-
# Dots.OCR imports
|
| 26 |
-
try:
|
| 27 |
-
from dots_ocr import DotsOCR
|
| 28 |
-
DOTS_OCR_AVAILABLE = True
|
| 29 |
-
except ImportError:
|
| 30 |
-
DOTS_OCR_AVAILABLE = False
|
| 31 |
-
logging.warning("Dots.OCR not available - using mock implementation")
|
| 32 |
-
|
| 33 |
-
# Import local field extraction utilities
|
| 34 |
-
from field_extraction import FieldExtractor
|
| 35 |
-
|
| 36 |
-
# Configure logging
|
| 37 |
-
logging.basicConfig(level=logging.INFO)
|
| 38 |
-
logger = logging.getLogger(__name__)
|
| 39 |
-
|
| 40 |
-
# Global model instance
|
| 41 |
-
dots_ocr_model = None
|
| 42 |
|
| 43 |
|
| 44 |
class BoundingBox(BaseModel):
|
|
@@ -57,6 +24,31 @@ class ExtractedField(BaseModel):
|
|
| 57 |
source: str = Field(..., description="Extraction source (e.g., 'ocr')")
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
class ExtractedFields(BaseModel):
|
| 61 |
"""All extracted fields from identity document."""
|
| 62 |
document_number: Optional[ExtractedField] = None
|
|
@@ -78,6 +70,7 @@ class ExtractedFields(BaseModel):
|
|
| 78 |
|
| 79 |
class MRZData(BaseModel):
|
| 80 |
"""Machine Readable Zone data."""
|
|
|
|
| 81 |
document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)")
|
| 82 |
issuing_country: Optional[str] = Field(None, description="Issuing country code")
|
| 83 |
surname: Optional[str] = Field(None, description="Surname from MRZ")
|
|
@@ -91,6 +84,11 @@ class MRZData(BaseModel):
|
|
| 91 |
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
|
| 92 |
confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence")
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
class OCRDetection(BaseModel):
|
| 96 |
"""Single OCR detection result."""
|
|
@@ -104,131 +102,3 @@ class OCRResponse(BaseModel):
|
|
| 104 |
media_type: str = Field(..., description="Media type processed")
|
| 105 |
processing_time: float = Field(..., description="Processing time in seconds")
|
| 106 |
detections: List[OCRDetection] = Field(..., description="List of OCR detections")
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
# FieldExtractor is now imported from the shared module
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def crop_image_by_roi(image: np.ndarray, roi: BoundingBox) -> np.ndarray:
|
| 113 |
-
"""Crop image using ROI coordinates."""
|
| 114 |
-
h, w = image.shape[:2]
|
| 115 |
-
x1 = int(roi.x1 * w)
|
| 116 |
-
y1 = int(roi.y1 * h)
|
| 117 |
-
x2 = int(roi.x2 * w)
|
| 118 |
-
y2 = int(roi.y2 * h)
|
| 119 |
-
|
| 120 |
-
# Ensure coordinates are within image bounds
|
| 121 |
-
x1 = max(0, min(x1, w))
|
| 122 |
-
y1 = max(0, min(y1, h))
|
| 123 |
-
x2 = max(x1, min(x2, w))
|
| 124 |
-
y2 = max(y1, min(y2, h))
|
| 125 |
-
|
| 126 |
-
return image[y1:y2, x1:x2]
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@asynccontextmanager
|
| 130 |
-
async def lifespan(app: FastAPI):
|
| 131 |
-
"""Application lifespan manager for model loading."""
|
| 132 |
-
global dots_ocr_model
|
| 133 |
-
|
| 134 |
-
logger.info("Loading Dots.OCR model...")
|
| 135 |
-
try:
|
| 136 |
-
if DOTS_OCR_AVAILABLE:
|
| 137 |
-
# Load Dots.OCR model
|
| 138 |
-
dots_ocr_model = DotsOCR()
|
| 139 |
-
logger.info("Dots.OCR model loaded successfully")
|
| 140 |
-
else:
|
| 141 |
-
logger.warning("Dots.OCR not available - using mock implementation")
|
| 142 |
-
dots_ocr_model = "mock"
|
| 143 |
-
except Exception as e:
|
| 144 |
-
logger.error(f"Failed to load Dots.OCR model: {e}")
|
| 145 |
-
# Don't raise - allow mock mode for development
|
| 146 |
-
dots_ocr_model = "mock"
|
| 147 |
-
|
| 148 |
-
yield
|
| 149 |
-
|
| 150 |
-
logger.info("Shutting down Dots.OCR endpoint...")
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
app = FastAPI(
|
| 154 |
-
title="KYB Dots.OCR Text Extraction",
|
| 155 |
-
description="Dots.OCR for identity document text extraction with ROI support",
|
| 156 |
-
version="1.0.0",
|
| 157 |
-
lifespan=lifespan
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
@app.get("/health")
|
| 162 |
-
async def health_check():
|
| 163 |
-
"""Health check endpoint."""
|
| 164 |
-
return {"status": "healthy", "version": "1.0.0"}
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
@app.post("/v1/id/ocr", response_model=OCRResponse)
|
| 168 |
-
async def extract_text(
|
| 169 |
-
file: UploadFile = File(..., description="Image file to process"),
|
| 170 |
-
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
|
| 171 |
-
):
|
| 172 |
-
"""Extract text from identity document image."""
|
| 173 |
-
if dots_ocr_model is None:
|
| 174 |
-
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 175 |
-
|
| 176 |
-
start_time = time.time()
|
| 177 |
-
request_id = str(uuid.uuid4())
|
| 178 |
-
|
| 179 |
-
try:
|
| 180 |
-
# Read and validate image
|
| 181 |
-
image_data = await file.read()
|
| 182 |
-
image = Image.open(io.BytesIO(image_data))
|
| 183 |
-
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 184 |
-
|
| 185 |
-
# Parse ROI if provided
|
| 186 |
-
roi_bbox = None
|
| 187 |
-
if roi:
|
| 188 |
-
try:
|
| 189 |
-
roi_data = json.loads(roi)
|
| 190 |
-
roi_bbox = BoundingBox(**roi_data)
|
| 191 |
-
# Crop image to ROI
|
| 192 |
-
image_cv = crop_image_by_roi(image_cv, roi_bbox)
|
| 193 |
-
except Exception as e:
|
| 194 |
-
logger.warning(f"Invalid ROI provided: {e}")
|
| 195 |
-
|
| 196 |
-
# Run OCR
|
| 197 |
-
if DOTS_OCR_AVAILABLE and dots_ocr_model != "mock":
|
| 198 |
-
# Use real Dots.OCR model
|
| 199 |
-
ocr_results = dots_ocr_model(image_cv)
|
| 200 |
-
ocr_text = " ".join([result.text for result in ocr_results])
|
| 201 |
-
else:
|
| 202 |
-
# Mock implementation for development
|
| 203 |
-
ocr_text = "MOCK OCR TEXT - Document Number: NLD123456789 Surname: MULDER Given Names: THOMAS"
|
| 204 |
-
logger.info("Using mock OCR implementation")
|
| 205 |
-
|
| 206 |
-
# Extract structured fields
|
| 207 |
-
extracted_fields = FieldExtractor.extract_fields(ocr_text)
|
| 208 |
-
|
| 209 |
-
# Extract MRZ data
|
| 210 |
-
mrz_data = FieldExtractor.extract_mrz(ocr_text)
|
| 211 |
-
|
| 212 |
-
# Create detection
|
| 213 |
-
detection = OCRDetection(
|
| 214 |
-
mrz_data=mrz_data,
|
| 215 |
-
extracted_fields=extracted_fields
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
processing_time = time.time() - start_time
|
| 219 |
-
|
| 220 |
-
return OCRResponse(
|
| 221 |
-
request_id=request_id,
|
| 222 |
-
media_type="image",
|
| 223 |
-
processing_time=processing_time,
|
| 224 |
-
detections=[detection]
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
except Exception as e:
|
| 228 |
-
logger.error(f"OCR extraction failed: {e}")
|
| 229 |
-
raise HTTPException(status_code=500, detail=f"OCR extraction failed: {str(e)}")
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
if __name__ == "__main__":
|
| 233 |
-
import uvicorn
|
| 234 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
+
"""API models for Dots.OCR text extraction service.
|
| 2 |
|
| 3 |
+
This module defines the data structures used for API requests,
|
| 4 |
+
responses, and internal data processing.
|
| 5 |
"""
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from typing import List, Optional, Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class BoundingBox(BaseModel):
|
|
|
|
| 24 |
source: str = Field(..., description="Extraction source (e.g., 'ocr')")
|
| 25 |
|
| 26 |
|
| 27 |
+
class IdCardFields(BaseModel):
|
| 28 |
+
"""Structured fields extracted from identity documents."""
|
| 29 |
+
document_number: Optional[ExtractedField] = Field(None, description="Document number/ID")
|
| 30 |
+
document_type: Optional[ExtractedField] = Field(None, description="Type of document")
|
| 31 |
+
issuing_country: Optional[ExtractedField] = Field(None, description="Issuing country code")
|
| 32 |
+
issuing_authority: Optional[ExtractedField] = Field(None, description="Issuing authority")
|
| 33 |
+
|
| 34 |
+
# Personal Information
|
| 35 |
+
surname: Optional[ExtractedField] = Field(None, description="Family name/surname")
|
| 36 |
+
given_names: Optional[ExtractedField] = Field(None, description="Given names")
|
| 37 |
+
nationality: Optional[ExtractedField] = Field(None, description="Nationality code")
|
| 38 |
+
date_of_birth: Optional[ExtractedField] = Field(None, description="Date of birth")
|
| 39 |
+
gender: Optional[ExtractedField] = Field(None, description="Gender")
|
| 40 |
+
place_of_birth: Optional[ExtractedField] = Field(None, description="Place of birth")
|
| 41 |
+
|
| 42 |
+
# Validity Information
|
| 43 |
+
date_of_issue: Optional[ExtractedField] = Field(None, description="Date of issue")
|
| 44 |
+
date_of_expiry: Optional[ExtractedField] = Field(None, description="Date of expiry")
|
| 45 |
+
personal_number: Optional[ExtractedField] = Field(None, description="Personal number")
|
| 46 |
+
|
| 47 |
+
# Additional fields for specific document types
|
| 48 |
+
optional_data_1: Optional[ExtractedField] = Field(None, description="Optional data field 1")
|
| 49 |
+
optional_data_2: Optional[ExtractedField] = Field(None, description="Optional data field 2")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
class ExtractedFields(BaseModel):
|
| 53 |
"""All extracted fields from identity document."""
|
| 54 |
document_number: Optional[ExtractedField] = None
|
|
|
|
| 70 |
|
| 71 |
class MRZData(BaseModel):
|
| 72 |
"""Machine Readable Zone data."""
|
| 73 |
+
# Primary canonical fields
|
| 74 |
document_type: Optional[str] = Field(None, description="MRZ document type (TD1|TD2|TD3)")
|
| 75 |
issuing_country: Optional[str] = Field(None, description="Issuing country code")
|
| 76 |
surname: Optional[str] = Field(None, description="Surname from MRZ")
|
|
|
|
| 84 |
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
|
| 85 |
confidence: float = Field(0.0, ge=0.0, le=1.0, description="MRZ extraction confidence")
|
| 86 |
|
| 87 |
+
# Backwards compatibility fields (some older code/tests expect these names)
|
| 88 |
+
# These duplicate information from the canonical fields above.
|
| 89 |
+
format_type: Optional[str] = Field(None, description="Alias of document_type for backward compatibility")
|
| 90 |
+
raw_text: Optional[str] = Field(None, description="Alias of raw_mrz for backward compatibility")
|
| 91 |
+
|
| 92 |
|
| 93 |
class OCRDetection(BaseModel):
|
| 94 |
"""Single OCR detection result."""
|
|
|
|
| 102 |
media_type: str = Field(..., description="Media type processed")
|
| 103 |
processing_time: float = Field(..., description="Processing time in seconds")
|
| 104 |
detections: List[OCRDetection] = Field(..., description="List of OCR detections")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/kybtech_dots_ocr/app.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Dots.OCR Text Extraction Endpoint
|
| 2 |
+
|
| 3 |
+
This FastAPI application provides a Hugging Face Space endpoint for Dots.OCR
|
| 4 |
+
text extraction with ROI support and standardized field extraction schema.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
import json
|
| 12 |
+
import re
|
| 13 |
+
from typing import List, Optional, Dict, Any
|
| 14 |
+
from contextlib import asynccontextmanager
|
| 15 |
+
|
| 16 |
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 17 |
+
from fastapi.responses import JSONResponse
|
| 18 |
+
|
| 19 |
+
# Import local modules
|
| 20 |
+
from .api_models import BoundingBox, ExtractedField, ExtractedFields, MRZData, OCRDetection, OCRResponse
|
| 21 |
+
from .enhanced_field_extraction import EnhancedFieldExtractor
|
| 22 |
+
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 23 |
+
from .preprocessing import process_document, validate_file_size, get_document_info
|
| 24 |
+
from .response_builder import build_ocr_response, build_error_response
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# Global model state
|
| 31 |
+
model_loaded = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# FieldExtractor is now imported from the shared module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@asynccontextmanager
|
| 40 |
+
async def lifespan(app: FastAPI):
|
| 41 |
+
"""Application lifespan manager for model loading."""
|
| 42 |
+
global model_loaded
|
| 43 |
+
|
| 44 |
+
# Allow tests and lightweight environments to skip model loading
|
| 45 |
+
# Set DOTS_OCR_SKIP_MODEL_LOAD=1 to bypass heavy downloads during tests/CI
|
| 46 |
+
skip_model_load = os.getenv("DOTS_OCR_SKIP_MODEL_LOAD", "0") == "1"
|
| 47 |
+
|
| 48 |
+
logger.info("Loading Dots.OCR model...")
|
| 49 |
+
try:
|
| 50 |
+
if skip_model_load:
|
| 51 |
+
# Explicitly skip model loading for fast startup in tests/CI
|
| 52 |
+
model_loaded = False
|
| 53 |
+
logger.warning("DOTS_OCR_SKIP_MODEL_LOAD=1 set - skipping model load (mock mode)")
|
| 54 |
+
else:
|
| 55 |
+
# Load the model using the new model loader
|
| 56 |
+
load_model()
|
| 57 |
+
model_loaded = True
|
| 58 |
+
logger.info("Dots.OCR model loaded successfully")
|
| 59 |
+
|
| 60 |
+
# Log model information
|
| 61 |
+
model_info = get_model_info()
|
| 62 |
+
logger.info(f"Model info: {model_info}")
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Failed to load Dots.OCR model: {e}")
|
| 66 |
+
# Don't raise - allow mock mode for development
|
| 67 |
+
model_loaded = False
|
| 68 |
+
logger.warning("Model loading failed - using mock implementation")
|
| 69 |
+
|
| 70 |
+
yield
|
| 71 |
+
|
| 72 |
+
logger.info("Shutting down Dots.OCR endpoint...")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
app = FastAPI(
|
| 76 |
+
title="KYB Dots.OCR Text Extraction",
|
| 77 |
+
description="Dots.OCR for identity document text extraction with ROI support",
|
| 78 |
+
version="1.0.0",
|
| 79 |
+
lifespan=lifespan
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@app.get("/health")
|
| 84 |
+
async def health_check():
|
| 85 |
+
"""Health check endpoint."""
|
| 86 |
+
global model_loaded
|
| 87 |
+
|
| 88 |
+
status = "healthy" if model_loaded else "degraded"
|
| 89 |
+
model_info = get_model_info() if model_loaded else None
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"status": status,
|
| 93 |
+
"version": "1.0.0",
|
| 94 |
+
"model_loaded": model_loaded,
|
| 95 |
+
"model_info": model_info
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@app.post("/v1/id/ocr", response_model=OCRResponse)
|
| 100 |
+
async def extract_text_endpoint(
|
| 101 |
+
file: UploadFile = File(..., description="Image or PDF file to process"),
|
| 102 |
+
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
|
| 103 |
+
):
|
| 104 |
+
"""Extract text from identity document image or PDF."""
|
| 105 |
+
global model_loaded
|
| 106 |
+
|
| 107 |
+
# Allow mock mode when model isn't loaded to support tests/CI and dev flows
|
| 108 |
+
allow_mock = os.getenv("DOTS_OCR_ALLOW_MOCK", "1") == "1"
|
| 109 |
+
is_mock_mode = (not model_loaded) and allow_mock
|
| 110 |
+
if not model_loaded and not allow_mock:
|
| 111 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 112 |
+
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
request_id = str(uuid.uuid4())
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Read file data
|
| 118 |
+
file_data = await file.read()
|
| 119 |
+
|
| 120 |
+
# Validate file size
|
| 121 |
+
if not validate_file_size(file_data):
|
| 122 |
+
raise HTTPException(status_code=413, detail="File size exceeds limit")
|
| 123 |
+
|
| 124 |
+
# Get document information
|
| 125 |
+
doc_info = get_document_info(file_data)
|
| 126 |
+
logger.info(f"Processing document: {doc_info}")
|
| 127 |
+
|
| 128 |
+
# Parse ROI if provided
|
| 129 |
+
roi_coords = None
|
| 130 |
+
if roi:
|
| 131 |
+
try:
|
| 132 |
+
roi_data = json.loads(roi)
|
| 133 |
+
roi_bbox = BoundingBox(**roi_data)
|
| 134 |
+
roi_coords = (roi_bbox.x1, roi_bbox.y1, roi_bbox.x2, roi_bbox.y2)
|
| 135 |
+
logger.info(f"Using ROI: {roi_coords}")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.warning(f"Invalid ROI provided: {e}")
|
| 138 |
+
raise HTTPException(status_code=400, detail=f"Invalid ROI format: {e}")
|
| 139 |
+
|
| 140 |
+
# Process document (PDF to images or single image)
|
| 141 |
+
try:
|
| 142 |
+
processed_images = process_document(file_data, roi_coords)
|
| 143 |
+
logger.info(f"Processed {len(processed_images)} images from document")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Document processing failed: {e}")
|
| 146 |
+
raise HTTPException(status_code=400, detail=f"Document processing failed: {e}")
|
| 147 |
+
|
| 148 |
+
# Process each image and extract text
|
| 149 |
+
ocr_texts = []
|
| 150 |
+
page_metadata = []
|
| 151 |
+
|
| 152 |
+
for i, image in enumerate(processed_images):
|
| 153 |
+
try:
|
| 154 |
+
# Extract text using the loaded model, or produce mock output in mock mode
|
| 155 |
+
if is_mock_mode:
|
| 156 |
+
# In mock mode, we skip model inference and return empty text
|
| 157 |
+
ocr_text = ""
|
| 158 |
+
else:
|
| 159 |
+
ocr_text = extract_text(image)
|
| 160 |
+
logger.info(f"Page {i + 1} - Extracted text length: {len(ocr_text)} characters")
|
| 161 |
+
|
| 162 |
+
ocr_texts.append(ocr_text)
|
| 163 |
+
|
| 164 |
+
# Collect page metadata
|
| 165 |
+
page_meta = {
|
| 166 |
+
"page_index": i,
|
| 167 |
+
"image_size": image.size,
|
| 168 |
+
"text_length": len(ocr_text),
|
| 169 |
+
"processing_successful": True
|
| 170 |
+
}
|
| 171 |
+
page_metadata.append(page_meta)
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Text extraction failed for page {i + 1}: {e}")
|
| 175 |
+
# Add empty text for failed page
|
| 176 |
+
ocr_texts.append("")
|
| 177 |
+
|
| 178 |
+
page_meta = {
|
| 179 |
+
"page_index": i,
|
| 180 |
+
"image_size": image.size if hasattr(image, 'size') else (0, 0),
|
| 181 |
+
"text_length": 0,
|
| 182 |
+
"processing_successful": False,
|
| 183 |
+
"error": str(e)
|
| 184 |
+
}
|
| 185 |
+
page_metadata.append(page_meta)
|
| 186 |
+
|
| 187 |
+
# Determine media type for response
|
| 188 |
+
media_type = "pdf" if doc_info["is_pdf"] else "image"
|
| 189 |
+
|
| 190 |
+
processing_time = time.time() - start_time
|
| 191 |
+
|
| 192 |
+
# Build response using the response builder
|
| 193 |
+
return build_ocr_response(
|
| 194 |
+
request_id=request_id,
|
| 195 |
+
media_type=media_type,
|
| 196 |
+
processing_time=processing_time,
|
| 197 |
+
ocr_texts=ocr_texts,
|
| 198 |
+
page_metadata=page_metadata
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
except HTTPException:
|
| 202 |
+
# Re-raise HTTP exceptions as-is
|
| 203 |
+
raise
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"OCR extraction failed: {e}")
|
| 206 |
+
processing_time = time.time() - start_time
|
| 207 |
+
error_response = build_error_response(
|
| 208 |
+
request_id=request_id,
|
| 209 |
+
error_message=f"OCR extraction failed: {str(e)}",
|
| 210 |
+
processing_time=processing_time
|
| 211 |
+
)
|
| 212 |
+
raise HTTPException(status_code=500, detail=error_response.dict())
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
import uvicorn
|
| 217 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
src/kybtech_dots_ocr/enhanced_field_extraction.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Enhanced field extraction utilities for Dots.OCR text processing.
|
| 2 |
+
|
| 3 |
+
This module provides improved field extraction and mapping from OCR results
|
| 4 |
+
to structured KYB field formats with better confidence scoring and validation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Optional, Dict, List, Tuple, Any
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from .api_models import ExtractedField, IdCardFields, MRZData
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EnhancedFieldExtractor:
|
| 18 |
+
"""Enhanced field extraction with improved confidence scoring and validation."""
|
| 19 |
+
|
| 20 |
+
# Enhanced field mapping patterns with confidence scoring
|
| 21 |
+
FIELD_PATTERNS = {
|
| 22 |
+
"document_number": [
|
| 23 |
+
(r"documentnummer[:\s]*([A-Z0-9]{6,15})", 0.9), # Dutch format
|
| 24 |
+
(r"document\s*number[:\s]*([A-Z0-9]{6,15})", 0.85), # English format
|
| 25 |
+
(r"nr[:\s]*([A-Z0-9]{6,15})", 0.7), # Abbreviated format
|
| 26 |
+
(r"ID[:\s]*([A-Z0-9]{6,15})", 0.8), # ID format
|
| 27 |
+
(r"([A-Z]{3}\d{9})", 0.75), # Passport format (3 letters + 9 digits)
|
| 28 |
+
],
|
| 29 |
+
"surname": [
|
| 30 |
+
# Anchor to line and capture value up to newline to avoid spilling into next label
|
| 31 |
+
(r"^\s*achternaam[:\s]*([^\r\n]+)", 0.95), # Dutch format (line-anchored)
|
| 32 |
+
(r"^\s*surname[:\s]*([^\r\n]+)", 0.9), # English format (line-anchored)
|
| 33 |
+
(r"^\s*family\s*name[:\s]*([^\r\n]+)", 0.85), # Full English
|
| 34 |
+
(r"^\s*last\s*name[:\s]*([^\r\n]+)", 0.85), # Alternative English
|
| 35 |
+
],
|
| 36 |
+
"given_names": [
|
| 37 |
+
(r"^\s*voornamen[:\s]*([^\r\n]+)", 0.95), # Dutch format (line-anchored)
|
| 38 |
+
(r"^\s*given\s*names[:\s]*([^\r\n]+)", 0.9), # English format (line-anchored)
|
| 39 |
+
(r"^\s*first\s*name[:\s]*([^\r\n]+)", 0.85), # First name only
|
| 40 |
+
(r"^\s*voorletters[:\s]*([^\r\n]+)", 0.75), # Dutch initials
|
| 41 |
+
],
|
| 42 |
+
"nationality": [
|
| 43 |
+
(r"nationaliteit[:\s]*([A-Z]{3})", 0.9), # Dutch format (3-letter code)
|
| 44 |
+
(r"nationality[:\s]*([A-Z]{3})", 0.85), # English format
|
| 45 |
+
(r"nationality[:\s]*([A-Za-z\s]{3,20})", 0.7), # Full country name
|
| 46 |
+
],
|
| 47 |
+
"date_of_birth": [
|
| 48 |
+
(r"geboortedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 49 |
+
(r"date\s*of\s*birth[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.85), # English format
|
| 50 |
+
(r"born[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 51 |
+
(r"(\d{2}[./-]\d{2}[./-]\d{4})", 0.6), # Generic date pattern
|
| 52 |
+
],
|
| 53 |
+
"gender": [
|
| 54 |
+
(r"geslacht[:\s]*([MF])", 0.9), # Dutch format
|
| 55 |
+
(r"gender[:\s]*([MF])", 0.85), # English format
|
| 56 |
+
(r"sex[:\s]*([MF])", 0.8), # Alternative English
|
| 57 |
+
(r"geslacht[:\s]*(man|vrouw)", 0.7), # Dutch full words
|
| 58 |
+
(r"gender[:\s]*(male|female)", 0.7), # English full words
|
| 59 |
+
],
|
| 60 |
+
"place_of_birth": [
|
| 61 |
+
(r"geboorteplaats[:\s]*([A-Za-z\s]{2,30})", 0.9), # Dutch format
|
| 62 |
+
(r"place\s*of\s*birth[:\s]*([A-Za-z\s]{2,30})", 0.85), # English format
|
| 63 |
+
(r"born\s*in[:\s]*([A-Za-z\s]{2,30})", 0.8), # Short English
|
| 64 |
+
],
|
| 65 |
+
"date_of_issue": [
|
| 66 |
+
(r"uitgiftedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 67 |
+
(r"date\s*of\s*issue[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.85), # English format
|
| 68 |
+
(r"issued[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 69 |
+
],
|
| 70 |
+
"date_of_expiry": [
|
| 71 |
+
(r"vervaldatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 72 |
+
(r"date\s*of\s*expiry[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.85), # English format
|
| 73 |
+
(r"expires[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 74 |
+
(r"valid\s*until[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Alternative English
|
| 75 |
+
],
|
| 76 |
+
"personal_number": [
|
| 77 |
+
(r"persoonsnummer[:\s]*(\d{9})", 0.9), # Dutch format
|
| 78 |
+
(r"personal\s*number[:\s]*(\d{9})", 0.85), # English format
|
| 79 |
+
(r"bsn[:\s]*(\d{9})", 0.9), # Dutch BSN
|
| 80 |
+
(r"social\s*security[:\s]*(\d{9})", 0.8), # SSN format
|
| 81 |
+
],
|
| 82 |
+
"document_type": [
|
| 83 |
+
(r"document\s*type[:\s]*([A-Za-z\s]{3,20})", 0.8), # English format
|
| 84 |
+
(r"soort\s*document[:\s]*([A-Za-z\s]{3,20})", 0.9), # Dutch format
|
| 85 |
+
(r"(passport|paspoort)", 0.9), # Passport
|
| 86 |
+
(r"(identity\s*card|identiteitskaart)", 0.9), # ID card
|
| 87 |
+
(r"(driving\s*license|rijbewijs)", 0.9), # Driving license
|
| 88 |
+
],
|
| 89 |
+
"issuing_country": [
|
| 90 |
+
(r"issuing\s*country[:\s]*([A-Z]{3})", 0.85), # English format
|
| 91 |
+
(r"uitgevende\s*land[:\s]*([A-Z]{3})", 0.9), # Dutch format
|
| 92 |
+
(r"country[:\s]*([A-Z]{3})", 0.7), # Short format
|
| 93 |
+
],
|
| 94 |
+
"issuing_authority": [
|
| 95 |
+
(r"issuing\s*authority[:\s]*([A-Za-z\s]{3,30})", 0.8), # English format
|
| 96 |
+
(r"uitgevende\s*autoriteit[:\s]*([A-Za-z\s]{3,30})", 0.9), # Dutch format
|
| 97 |
+
(r"authority[:\s]*([A-Za-z\s]{3,30})", 0.7), # Short format
|
| 98 |
+
]
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# MRZ patterns with confidence scoring
|
| 102 |
+
MRZ_PATTERNS = [
|
| 103 |
+
# Strict formats first, allowing leading/trailing whitespace per line
|
| 104 |
+
(r"^\s*((?:[A-Z0-9<]{44})\s*\n\s*(?:[A-Z0-9<]{44}))\s*$", 0.95), # TD3: Passport (2 x 44)
|
| 105 |
+
(r"^\s*((?:[A-Z0-9<]{36})\s*\n\s*(?:[A-Z0-9<]{36}))\s*$", 0.9), # TD2: ID card (2 x 36)
|
| 106 |
+
(r"^\s*((?:[A-Z0-9<]{30})\s*\n\s*(?:[A-Z0-9<]{30})\s*\n\s*(?:[A-Z0-9<]{30}))\s*$", 0.85), # TD1: (3 x 30)
|
| 107 |
+
# Fallback generic: a line starting with P< followed by another MRZ-like line
|
| 108 |
+
(r"(P<[^\r\n]+\n[^\r\n]+)", 0.85),
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def extract_fields(cls, ocr_text: str) -> IdCardFields:
|
| 113 |
+
"""Extract structured fields from OCR text with enhanced confidence scoring.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
ocr_text: Raw OCR text from document processing
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
IdCardFields object with extracted field data
|
| 120 |
+
"""
|
| 121 |
+
logger.info(f"Extracting fields from text of length: {len(ocr_text)}")
|
| 122 |
+
|
| 123 |
+
fields = {}
|
| 124 |
+
extraction_stats = {"total_patterns": 0, "matches_found": 0}
|
| 125 |
+
|
| 126 |
+
for field_name, patterns in cls.FIELD_PATTERNS.items():
|
| 127 |
+
value = None
|
| 128 |
+
confidence = 0.0
|
| 129 |
+
best_pattern = None
|
| 130 |
+
|
| 131 |
+
for pattern, base_confidence in patterns:
|
| 132 |
+
extraction_stats["total_patterns"] += 1
|
| 133 |
+
match = re.search(pattern, ocr_text, re.IGNORECASE | re.MULTILINE)
|
| 134 |
+
if match:
|
| 135 |
+
candidate_value = match.group(1).strip()
|
| 136 |
+
# Validate the extracted value
|
| 137 |
+
if cls._validate_field_value(field_name, candidate_value):
|
| 138 |
+
value = candidate_value
|
| 139 |
+
confidence = base_confidence
|
| 140 |
+
best_pattern = pattern
|
| 141 |
+
extraction_stats["matches_found"] += 1
|
| 142 |
+
logger.debug(f"Found {field_name}: '{value}' (confidence: {confidence:.2f})")
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
if value:
|
| 146 |
+
# Apply additional confidence adjustments
|
| 147 |
+
confidence = cls._adjust_confidence(field_name, value, confidence, ocr_text)
|
| 148 |
+
|
| 149 |
+
fields[field_name] = ExtractedField(
|
| 150 |
+
field_name=field_name,
|
| 151 |
+
value=value,
|
| 152 |
+
confidence=confidence,
|
| 153 |
+
source="ocr"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
logger.info(f"Field extraction complete: {extraction_stats['matches_found']}/{extraction_stats['total_patterns']} patterns matched")
|
| 157 |
+
return IdCardFields(**fields)
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def _validate_field_value(cls, field_name: str, value: str) -> bool:
|
| 161 |
+
"""Validate extracted field value based on field type.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
field_name: Name of the field
|
| 165 |
+
value: Extracted value to validate
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
True if value is valid
|
| 169 |
+
"""
|
| 170 |
+
if not value or len(value.strip()) == 0:
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
# Field-specific validation
|
| 174 |
+
if field_name == "document_number":
|
| 175 |
+
return len(value) >= 6 and len(value) <= 15
|
| 176 |
+
elif field_name in ["surname", "given_names", "place_of_birth"]:
|
| 177 |
+
return len(value) >= 2 and len(value) <= 50
|
| 178 |
+
elif field_name == "nationality":
|
| 179 |
+
return len(value) == 3 and value.isalpha()
|
| 180 |
+
elif field_name in ["date_of_birth", "date_of_issue", "date_of_expiry"]:
|
| 181 |
+
return cls._validate_date_format(value)
|
| 182 |
+
elif field_name == "gender":
|
| 183 |
+
return value.upper() in ["M", "F", "MALE", "FEMALE", "MAN", "VROUW"]
|
| 184 |
+
elif field_name == "personal_number":
|
| 185 |
+
return len(value) == 9 and value.isdigit()
|
| 186 |
+
elif field_name == "issuing_country":
|
| 187 |
+
return len(value) == 3 and value.isalpha()
|
| 188 |
+
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def _validate_date_format(cls, date_str: str) -> bool:
|
| 193 |
+
"""Validate date format and basic date logic.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
date_str: Date string to validate
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
True if date format is valid
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
# Try different date separators
|
| 203 |
+
for sep in [".", "/", "-"]:
|
| 204 |
+
if sep in date_str:
|
| 205 |
+
parts = date_str.split(sep)
|
| 206 |
+
if len(parts) == 3:
|
| 207 |
+
day, month, year = parts
|
| 208 |
+
# Basic validation
|
| 209 |
+
if (1 <= int(day) <= 31 and
|
| 210 |
+
1 <= int(month) <= 12 and
|
| 211 |
+
1900 <= int(year) <= 2100):
|
| 212 |
+
return True
|
| 213 |
+
except (ValueError, IndexError):
|
| 214 |
+
pass
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
@classmethod
|
| 218 |
+
def _adjust_confidence(cls, field_name: str, value: str, base_confidence: float, full_text: str) -> float:
|
| 219 |
+
"""Adjust confidence based on additional factors.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
field_name: Name of the field
|
| 223 |
+
value: Extracted value
|
| 224 |
+
base_confidence: Base confidence from pattern matching
|
| 225 |
+
full_text: Full OCR text for context
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Adjusted confidence score
|
| 229 |
+
"""
|
| 230 |
+
confidence = base_confidence
|
| 231 |
+
|
| 232 |
+
# Length-based adjustments
|
| 233 |
+
if field_name in ["surname", "given_names"] and len(value) < 3:
|
| 234 |
+
confidence *= 0.8 # Shorter names are less reliable
|
| 235 |
+
|
| 236 |
+
# Context-based adjustments
|
| 237 |
+
if field_name == "document_number" and "passport" in full_text.lower():
|
| 238 |
+
confidence *= 1.1 # Higher confidence in passport context
|
| 239 |
+
|
| 240 |
+
# Multiple occurrence bonus
|
| 241 |
+
if value in full_text and full_text.count(value) > 1:
|
| 242 |
+
confidence *= 1.05 # Slight bonus for repeated values
|
| 243 |
+
|
| 244 |
+
# Ensure confidence stays within bounds
|
| 245 |
+
return min(max(confidence, 0.0), 1.0)
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def extract_mrz(cls, ocr_text: str) -> Optional[MRZData]:
|
| 249 |
+
"""Extract MRZ data from OCR text with enhanced validation.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
ocr_text: Raw OCR text from document processing
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
MRZData object if MRZ detected, None otherwise
|
| 256 |
+
"""
|
| 257 |
+
logger.info("Extracting MRZ data from OCR text")
|
| 258 |
+
|
| 259 |
+
best_match = None
|
| 260 |
+
best_confidence = 0.0
|
| 261 |
+
|
| 262 |
+
for pattern, base_confidence in cls.MRZ_PATTERNS:
|
| 263 |
+
match = re.search(pattern, ocr_text, re.MULTILINE)
|
| 264 |
+
if match:
|
| 265 |
+
raw_mrz = match.group(1)
|
| 266 |
+
# Validate MRZ format
|
| 267 |
+
if cls._validate_mrz_format(raw_mrz):
|
| 268 |
+
confidence = base_confidence
|
| 269 |
+
# Adjust confidence based on MRZ quality
|
| 270 |
+
confidence = cls._adjust_mrz_confidence(raw_mrz, confidence)
|
| 271 |
+
|
| 272 |
+
if confidence > best_confidence:
|
| 273 |
+
best_match = raw_mrz
|
| 274 |
+
best_confidence = confidence
|
| 275 |
+
logger.debug(f"Found MRZ with confidence {confidence:.2f}")
|
| 276 |
+
|
| 277 |
+
if best_match:
|
| 278 |
+
# Parse MRZ to determine format type
|
| 279 |
+
format_type = cls._determine_mrz_format(best_match)
|
| 280 |
+
|
| 281 |
+
# Basic checksum validation
|
| 282 |
+
is_valid, errors = cls._validate_mrz_checksums(best_match, format_type)
|
| 283 |
+
|
| 284 |
+
logger.info(f"MRZ extracted: {format_type} format, valid: {is_valid}")
|
| 285 |
+
|
| 286 |
+
# Convert to the format expected by the API
|
| 287 |
+
from .api_models import MRZData as APIMRZData
|
| 288 |
+
# Populate both canonical and legacy alias fields for compatibility
|
| 289 |
+
return APIMRZData(
|
| 290 |
+
document_type=format_type,
|
| 291 |
+
format_type=format_type, # legacy alias
|
| 292 |
+
issuing_country=None, # would be parsed in full impl
|
| 293 |
+
surname=None,
|
| 294 |
+
given_names=None,
|
| 295 |
+
document_number=None,
|
| 296 |
+
nationality=None,
|
| 297 |
+
date_of_birth=None,
|
| 298 |
+
gender=None,
|
| 299 |
+
date_of_expiry=None,
|
| 300 |
+
personal_number=None,
|
| 301 |
+
raw_mrz=best_match,
|
| 302 |
+
raw_text=best_match, # legacy alias
|
| 303 |
+
confidence=best_confidence,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
logger.info("No MRZ data found in OCR text")
|
| 307 |
+
return None
|
| 308 |
+
|
| 309 |
+
@classmethod
|
| 310 |
+
def _validate_mrz_format(cls, mrz_text: str) -> bool:
|
| 311 |
+
"""Validate basic MRZ format.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
mrz_text: Raw MRZ text
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
True if format is valid
|
| 318 |
+
"""
|
| 319 |
+
lines = mrz_text.strip().split('\n')
|
| 320 |
+
if len(lines) < 2:
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
# Normalize whitespace and validate character set only.
|
| 324 |
+
normalized_lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 325 |
+
for line in normalized_lines:
|
| 326 |
+
if not re.match(r'^[A-Z0-9<]+$', line):
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
return True
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def _determine_mrz_format(cls, mrz_text: str) -> str:
|
| 333 |
+
"""Determine MRZ format type.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
mrz_text: Raw MRZ text
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
Format type (TD1, TD2, TD3, etc.)
|
| 340 |
+
"""
|
| 341 |
+
lines = mrz_text.strip().split('\n')
|
| 342 |
+
lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 343 |
+
line_count = len(lines)
|
| 344 |
+
line_length = len(lines[0]) if lines else 0
|
| 345 |
+
|
| 346 |
+
# Heuristic mapping: prioritize semantics over exact lengths for robustness
|
| 347 |
+
if line_count == 2 and lines[0].startswith("P<"):
|
| 348 |
+
return "TD3" # Passport format commonly starts with P<
|
| 349 |
+
if line_count == 2 and line_length == 36:
|
| 350 |
+
return "TD2" # ID card format
|
| 351 |
+
if line_count == 3:
|
| 352 |
+
return "TD1"
|
| 353 |
+
return "UNKNOWN"
|
| 354 |
+
|
| 355 |
+
@classmethod
|
| 356 |
+
def _adjust_mrz_confidence(cls, mrz_text: str, base_confidence: float) -> float:
|
| 357 |
+
"""Adjust MRZ confidence based on quality indicators.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
mrz_text: Raw MRZ text
|
| 361 |
+
base_confidence: Base confidence from pattern matching
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Adjusted confidence
|
| 365 |
+
"""
|
| 366 |
+
confidence = base_confidence
|
| 367 |
+
|
| 368 |
+
# Check line consistency
|
| 369 |
+
lines = mrz_text.strip().split('\n')
|
| 370 |
+
if len(set(len(line) for line in lines)) == 1:
|
| 371 |
+
confidence *= 1.05 # Bonus for consistent line lengths
|
| 372 |
+
|
| 373 |
+
return min(max(confidence, 0.0), 1.0)
|
| 374 |
+
|
| 375 |
+
@classmethod
|
| 376 |
+
def _validate_mrz_checksums(cls, mrz_text: str, format_type: str) -> Tuple[bool, List[str]]:
|
| 377 |
+
"""Validate MRZ checksums (simplified implementation).
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
mrz_text: Raw MRZ text
|
| 381 |
+
format_type: MRZ format type
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Tuple of (is_valid, list_of_errors)
|
| 385 |
+
"""
|
| 386 |
+
# This is a simplified implementation
|
| 387 |
+
# In production, you would implement full MRZ checksum validation
|
| 388 |
+
errors = []
|
| 389 |
+
|
| 390 |
+
# Basic validation - check for reasonable character distribution
|
| 391 |
+
if mrz_text.count('<') > len(mrz_text) * 0.3:
|
| 392 |
+
errors.append("Too many fill characters")
|
| 393 |
+
|
| 394 |
+
# For now, assume valid if basic format is correct
|
| 395 |
+
is_valid = len(errors) == 0
|
| 396 |
+
|
| 397 |
+
return is_valid, errors
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# Backward compatibility - use enhanced extractor as default
|
| 401 |
+
class FieldExtractor(EnhancedFieldExtractor):
|
| 402 |
+
"""Backward compatible field extractor using enhanced implementation."""
|
| 403 |
+
pass
|
field_extraction.py → src/kybtech_dots_ocr/field_extraction.py
RENAMED
|
@@ -6,7 +6,7 @@ to structured KYB field formats.
|
|
| 6 |
|
| 7 |
import re
|
| 8 |
from typing import Optional
|
| 9 |
-
from
|
| 10 |
|
| 11 |
|
| 12 |
class FieldExtractor:
|
|
|
|
| 6 |
|
| 7 |
import re
|
| 8 |
from typing import Optional
|
| 9 |
+
from .api_models import ExtractedField, IdCardFields, MRZData
|
| 10 |
|
| 11 |
|
| 12 |
class FieldExtractor:
|
src/kybtech_dots_ocr/model_loader.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dots.OCR Model Loader
|
| 2 |
+
|
| 3 |
+
This module handles downloading and loading the Dots.OCR model using Hugging Face's
|
| 4 |
+
snapshot_download functionality. It provides device selection, dtype configuration,
|
| 5 |
+
and model initialization with proper error handling.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
import torch
|
| 11 |
+
from typing import Optional, Tuple, Dict, Any
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import snapshot_download
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Environment variable configuration
|
| 22 |
+
REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
|
| 23 |
+
LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
|
| 24 |
+
DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
|
| 25 |
+
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
|
| 26 |
+
USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "1") == "1"
|
| 27 |
+
MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
|
| 28 |
+
MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
|
| 29 |
+
CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
|
| 30 |
+
|
| 31 |
+
# Default transcription prompt for faithful text extraction
|
| 32 |
+
DEFAULT_PROMPT = (
|
| 33 |
+
"Transcribe all visible text in the image in the original language. "
|
| 34 |
+
"Do not translate. Preserve natural reading order. Output plain text only."
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class DotsOCRModelLoader:
|
| 39 |
+
"""Handles Dots.OCR model downloading, loading, and inference."""
|
| 40 |
+
|
| 41 |
+
def __init__(self):
|
| 42 |
+
"""Initialize the model loader."""
|
| 43 |
+
self.model = None
|
| 44 |
+
self.processor = None
|
| 45 |
+
self.device = None
|
| 46 |
+
self.dtype = None
|
| 47 |
+
self.local_dir = None
|
| 48 |
+
self.prompt = CUSTOM_PROMPT or DEFAULT_PROMPT
|
| 49 |
+
|
| 50 |
+
def _determine_device_and_dtype(self) -> Tuple[str, torch.dtype]:
|
| 51 |
+
"""Determine the best device and dtype based on availability and configuration."""
|
| 52 |
+
if DEVICE_CONFIG == "cpu":
|
| 53 |
+
device = "cpu"
|
| 54 |
+
dtype = torch.float32
|
| 55 |
+
elif DEVICE_CONFIG == "cuda" and torch.cuda.is_available():
|
| 56 |
+
device = "cuda"
|
| 57 |
+
dtype = torch.bfloat16
|
| 58 |
+
elif DEVICE_CONFIG == "auto":
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
device = "cuda"
|
| 61 |
+
dtype = torch.bfloat16
|
| 62 |
+
else:
|
| 63 |
+
device = "cpu"
|
| 64 |
+
dtype = torch.float32
|
| 65 |
+
else:
|
| 66 |
+
# Fallback to CPU if CUDA requested but not available
|
| 67 |
+
logger.warning(f"CUDA requested but not available, falling back to CPU")
|
| 68 |
+
device = "cpu"
|
| 69 |
+
dtype = torch.float32
|
| 70 |
+
|
| 71 |
+
logger.info(f"Selected device: {device}, dtype: {dtype}")
|
| 72 |
+
return device, dtype
|
| 73 |
+
|
| 74 |
+
def _download_model(self) -> str:
|
| 75 |
+
"""Download the model using snapshot_download."""
|
| 76 |
+
logger.info(f"Downloading model from {REPO_ID} to {LOCAL_DIR}")
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# Ensure local directory exists
|
| 80 |
+
Path(LOCAL_DIR).mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
# Download model snapshot
|
| 83 |
+
local_path = snapshot_download(
|
| 84 |
+
repo_id=REPO_ID,
|
| 85 |
+
local_dir=LOCAL_DIR,
|
| 86 |
+
local_dir_use_symlinks=False, # Avoid symlink issues in containers
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
logger.info(f"Model downloaded successfully to {local_path}")
|
| 90 |
+
return local_path
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logger.error(f"Failed to download model: {e}")
|
| 94 |
+
raise RuntimeError(f"Model download failed: {e}")
|
| 95 |
+
|
| 96 |
+
def load_model(self) -> None:
|
| 97 |
+
"""Load the Dots.OCR model and processor."""
|
| 98 |
+
try:
|
| 99 |
+
# Determine device and dtype
|
| 100 |
+
self.device, self.dtype = self._determine_device_and_dtype()
|
| 101 |
+
|
| 102 |
+
# Download model if not already present
|
| 103 |
+
self.local_dir = self._download_model()
|
| 104 |
+
|
| 105 |
+
# Load processor
|
| 106 |
+
logger.info("Loading processor...")
|
| 107 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 108 |
+
self.local_dir,
|
| 109 |
+
trust_remote_code=True
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Load model with appropriate configuration
|
| 113 |
+
model_kwargs = {
|
| 114 |
+
"torch_dtype": self.dtype,
|
| 115 |
+
"trust_remote_code": True,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Add device-specific configurations
|
| 119 |
+
if self.device == "cuda":
|
| 120 |
+
# Use flash attention if available and requested
|
| 121 |
+
if USE_FLASH_ATTENTION:
|
| 122 |
+
try:
|
| 123 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 124 |
+
logger.info("Using flash attention 2")
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.warning(f"Flash attention not available: {e}")
|
| 127 |
+
logger.info("Falling back to standard attention")
|
| 128 |
+
|
| 129 |
+
# Use device_map for automatic GPU memory management
|
| 130 |
+
model_kwargs["device_map"] = "auto"
|
| 131 |
+
else:
|
| 132 |
+
# For CPU, don't use device_map
|
| 133 |
+
model_kwargs["device_map"] = None
|
| 134 |
+
|
| 135 |
+
logger.info("Loading model...")
|
| 136 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 137 |
+
self.local_dir,
|
| 138 |
+
**model_kwargs
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Move model to device if not using device_map
|
| 142 |
+
if self.device == "cpu" or model_kwargs.get("device_map") is None:
|
| 143 |
+
self.model = self.model.to(self.device)
|
| 144 |
+
|
| 145 |
+
logger.info(f"Model loaded successfully on {self.device}")
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Failed to load model: {e}")
|
| 149 |
+
raise RuntimeError(f"Model loading failed: {e}")
|
| 150 |
+
|
| 151 |
+
def _preprocess_image(self, image: Image.Image) -> Image.Image:
|
| 152 |
+
"""Preprocess image to meet model requirements."""
|
| 153 |
+
# Convert to RGB if necessary
|
| 154 |
+
if image.mode != "RGB":
|
| 155 |
+
image = image.convert("RGB")
|
| 156 |
+
|
| 157 |
+
# Calculate current pixel count
|
| 158 |
+
width, height = image.size
|
| 159 |
+
current_pixels = width * height
|
| 160 |
+
|
| 161 |
+
# Resize if necessary to meet pixel requirements
|
| 162 |
+
if current_pixels < MIN_PIXELS:
|
| 163 |
+
# Scale up to meet minimum pixel requirement
|
| 164 |
+
scale_factor = (MIN_PIXELS / current_pixels) ** 0.5
|
| 165 |
+
new_width = int(width * scale_factor)
|
| 166 |
+
new_height = int(height * scale_factor)
|
| 167 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 168 |
+
logger.info(f"Scaled up image from {width}x{height} to {new_width}x{new_height}")
|
| 169 |
+
|
| 170 |
+
elif current_pixels > MAX_PIXELS:
|
| 171 |
+
# Scale down to meet maximum pixel requirement
|
| 172 |
+
scale_factor = (MAX_PIXELS / current_pixels) ** 0.5
|
| 173 |
+
new_width = int(width * scale_factor)
|
| 174 |
+
new_height = int(height * scale_factor)
|
| 175 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 176 |
+
logger.info(f"Scaled down image from {width}x{height} to {new_width}x{new_height}")
|
| 177 |
+
|
| 178 |
+
# Ensure dimensions are divisible by 28 (common requirement for vision models)
|
| 179 |
+
width, height = image.size
|
| 180 |
+
new_width = ((width + 27) // 28) * 28
|
| 181 |
+
new_height = ((height + 27) // 28) * 28
|
| 182 |
+
|
| 183 |
+
if new_width != width or new_height != height:
|
| 184 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 185 |
+
logger.info(f"Adjusted image dimensions to be divisible by 28: {new_width}x{new_height}")
|
| 186 |
+
|
| 187 |
+
return image
|
| 188 |
+
|
| 189 |
+
@torch.inference_mode()
|
| 190 |
+
def extract_text(self, image: Image.Image, prompt: Optional[str] = None) -> str:
|
| 191 |
+
"""Extract text from an image using the loaded model."""
|
| 192 |
+
if self.model is None or self.processor is None:
|
| 193 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
# Preprocess image
|
| 197 |
+
processed_image = self._preprocess_image(image)
|
| 198 |
+
|
| 199 |
+
# Use provided prompt or default
|
| 200 |
+
text_prompt = prompt or self.prompt
|
| 201 |
+
|
| 202 |
+
# Prepare messages for the model
|
| 203 |
+
messages = [{
|
| 204 |
+
"role": "user",
|
| 205 |
+
"content": [
|
| 206 |
+
{"type": "image", "image": processed_image},
|
| 207 |
+
{"type": "text", "text": text_prompt},
|
| 208 |
+
],
|
| 209 |
+
}]
|
| 210 |
+
|
| 211 |
+
# Apply chat template
|
| 212 |
+
text = self.processor.apply_chat_template(
|
| 213 |
+
messages,
|
| 214 |
+
tokenize=False,
|
| 215 |
+
add_generation_prompt=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Process vision information (required for some models)
|
| 219 |
+
try:
|
| 220 |
+
from qwen_vl_utils import process_vision_info
|
| 221 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 222 |
+
except ImportError:
|
| 223 |
+
# Fallback if qwen_vl_utils not available
|
| 224 |
+
logger.warning("qwen_vl_utils not available, using basic processing")
|
| 225 |
+
image_inputs = [processed_image]
|
| 226 |
+
video_inputs = []
|
| 227 |
+
|
| 228 |
+
# Prepare inputs
|
| 229 |
+
inputs = self.processor(
|
| 230 |
+
text=[text],
|
| 231 |
+
images=image_inputs,
|
| 232 |
+
videos=video_inputs,
|
| 233 |
+
padding=True,
|
| 234 |
+
return_tensors="pt"
|
| 235 |
+
).to(self.device)
|
| 236 |
+
|
| 237 |
+
# Generate text
|
| 238 |
+
output_ids = self.model.generate(
|
| 239 |
+
**inputs,
|
| 240 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 241 |
+
do_sample=False,
|
| 242 |
+
temperature=0.0,
|
| 243 |
+
pad_token_id=self.processor.tokenizer.eos_token_id
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Decode output
|
| 247 |
+
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, output_ids)]
|
| 248 |
+
decoded = self.processor.batch_decode(
|
| 249 |
+
trimmed,
|
| 250 |
+
skip_special_tokens=True,
|
| 251 |
+
clean_up_tokenization_spaces=False
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
return decoded[0] if decoded else ""
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(f"Text extraction failed: {e}")
|
| 258 |
+
raise RuntimeError(f"Text extraction failed: {e}")
|
| 259 |
+
|
| 260 |
+
def is_loaded(self) -> bool:
|
| 261 |
+
"""Check if the model is loaded and ready for inference."""
|
| 262 |
+
return self.model is not None and self.processor is not None
|
| 263 |
+
|
| 264 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 265 |
+
"""Get information about the loaded model."""
|
| 266 |
+
return {
|
| 267 |
+
"device": self.device,
|
| 268 |
+
"dtype": str(self.dtype),
|
| 269 |
+
"local_dir": self.local_dir,
|
| 270 |
+
"repo_id": REPO_ID,
|
| 271 |
+
"max_new_tokens": MAX_NEW_TOKENS,
|
| 272 |
+
"use_flash_attention": USE_FLASH_ATTENTION,
|
| 273 |
+
"prompt": self.prompt,
|
| 274 |
+
"is_loaded": self.is_loaded()
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Global model instance
|
| 279 |
+
_model_loader: Optional[DotsOCRModelLoader] = None
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_model_loader() -> DotsOCRModelLoader:
|
| 283 |
+
"""Get the global model loader instance."""
|
| 284 |
+
global _model_loader
|
| 285 |
+
if _model_loader is None:
|
| 286 |
+
_model_loader = DotsOCRModelLoader()
|
| 287 |
+
return _model_loader
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def load_model() -> None:
|
| 291 |
+
"""Load the Dots.OCR model."""
|
| 292 |
+
loader = get_model_loader()
|
| 293 |
+
loader.load_model()
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def extract_text(image: Image.Image, prompt: Optional[str] = None) -> str:
|
| 297 |
+
"""Extract text from an image using the loaded model."""
|
| 298 |
+
loader = get_model_loader()
|
| 299 |
+
if not loader.is_loaded():
|
| 300 |
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
| 301 |
+
return loader.extract_text(image, prompt)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def is_model_loaded() -> bool:
|
| 305 |
+
"""Check if the model is loaded and ready."""
|
| 306 |
+
loader = get_model_loader()
|
| 307 |
+
return loader.is_loaded()
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def get_model_info() -> Dict[str, Any]:
|
| 311 |
+
"""Get information about the loaded model."""
|
| 312 |
+
loader = get_model_loader()
|
| 313 |
+
return loader.get_model_info()
|
models.py → src/kybtech_dots_ocr/models.py
RENAMED
|
File without changes
|
src/kybtech_dots_ocr/preprocessing.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image and PDF preprocessing utilities for Dots.OCR.
|
| 2 |
+
|
| 3 |
+
This module handles PDF to image conversion, image preprocessing,
|
| 4 |
+
and multi-page document processing for the Dots.OCR model.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Tuple, Optional, Union
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import io
|
| 12 |
+
|
| 13 |
+
import fitz # PyMuPDF
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image, ImageOps
|
| 16 |
+
import cv2
|
| 17 |
+
|
| 18 |
+
# Configure logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Environment variable configuration
|
| 22 |
+
PDF_DPI = int(os.getenv("DOTS_OCR_PDF_DPI", "300"))
|
| 23 |
+
PDF_MAX_PAGES = int(os.getenv("DOTS_OCR_PDF_MAX_PAGES", "10"))
|
| 24 |
+
IMAGE_MAX_SIZE = int(os.getenv("DOTS_OCR_IMAGE_MAX_SIZE", "10")) * 1024 * 1024 # 10MB default
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ImagePreprocessor:
|
| 28 |
+
"""Handles image preprocessing for Dots.OCR model."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, min_pixels: int = 3136, max_pixels: int = 11289600, divisor: int = 28):
|
| 31 |
+
"""Initialize the image preprocessor.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
min_pixels: Minimum pixel count for images
|
| 35 |
+
max_pixels: Maximum pixel count for images
|
| 36 |
+
divisor: Required divisor for image dimensions
|
| 37 |
+
"""
|
| 38 |
+
self.min_pixels = min_pixels
|
| 39 |
+
self.max_pixels = max_pixels
|
| 40 |
+
self.divisor = divisor
|
| 41 |
+
|
| 42 |
+
def preprocess_image(self, image: Image.Image) -> Image.Image:
|
| 43 |
+
"""Preprocess an image to meet model requirements.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
image: Input PIL Image
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Preprocessed PIL Image
|
| 50 |
+
"""
|
| 51 |
+
# Convert to RGB if necessary
|
| 52 |
+
if image.mode != "RGB":
|
| 53 |
+
image = image.convert("RGB")
|
| 54 |
+
|
| 55 |
+
# Auto-orient image based on EXIF data
|
| 56 |
+
image = ImageOps.exif_transpose(image)
|
| 57 |
+
|
| 58 |
+
# Calculate current pixel count
|
| 59 |
+
width, height = image.size
|
| 60 |
+
current_pixels = width * height
|
| 61 |
+
|
| 62 |
+
logger.info(f"Original image size: {width}x{height} ({current_pixels} pixels)")
|
| 63 |
+
|
| 64 |
+
# Resize if necessary to meet pixel requirements
|
| 65 |
+
if current_pixels < self.min_pixels:
|
| 66 |
+
# Scale up to meet minimum pixel requirement
|
| 67 |
+
scale_factor = (self.min_pixels / current_pixels) ** 0.5
|
| 68 |
+
new_width = int(width * scale_factor)
|
| 69 |
+
new_height = int(height * scale_factor)
|
| 70 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 71 |
+
logger.info(f"Scaled up image to {new_width}x{new_height}")
|
| 72 |
+
|
| 73 |
+
elif current_pixels > self.max_pixels:
|
| 74 |
+
# Scale down to meet maximum pixel requirement
|
| 75 |
+
scale_factor = (self.max_pixels / current_pixels) ** 0.5
|
| 76 |
+
new_width = int(width * scale_factor)
|
| 77 |
+
new_height = int(height * scale_factor)
|
| 78 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 79 |
+
logger.info(f"Scaled down image to {new_width}x{new_height}")
|
| 80 |
+
|
| 81 |
+
# Ensure dimensions are divisible by the required divisor
|
| 82 |
+
width, height = image.size
|
| 83 |
+
new_width = ((width + self.divisor - 1) // self.divisor) * self.divisor
|
| 84 |
+
new_height = ((height + self.divisor - 1) // self.divisor) * self.divisor
|
| 85 |
+
|
| 86 |
+
if new_width != width or new_height != height:
|
| 87 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 88 |
+
logger.info(f"Adjusted dimensions to be divisible by {self.divisor}: {new_width}x{new_height}")
|
| 89 |
+
|
| 90 |
+
return image
|
| 91 |
+
|
| 92 |
+
def crop_by_roi(self, image: Image.Image, roi: Tuple[float, float, float, float]) -> Image.Image:
|
| 93 |
+
"""Crop image using ROI coordinates.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
image: Input PIL Image
|
| 97 |
+
roi: ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Cropped PIL Image
|
| 101 |
+
"""
|
| 102 |
+
x1, y1, x2, y2 = roi
|
| 103 |
+
width, height = image.size
|
| 104 |
+
|
| 105 |
+
# Convert normalized coordinates to pixel coordinates
|
| 106 |
+
x1_px = int(x1 * width)
|
| 107 |
+
y1_px = int(y1 * height)
|
| 108 |
+
x2_px = int(x2 * width)
|
| 109 |
+
y2_px = int(y2 * height)
|
| 110 |
+
|
| 111 |
+
# Ensure coordinates are within image bounds
|
| 112 |
+
x1_px = max(0, min(x1_px, width))
|
| 113 |
+
y1_px = max(0, min(y1_px, height))
|
| 114 |
+
x2_px = max(x1_px, min(x2_px, width))
|
| 115 |
+
y2_px = max(y1_px, min(y2_px, height))
|
| 116 |
+
|
| 117 |
+
# Crop the image
|
| 118 |
+
cropped = image.crop((x1_px, y1_px, x2_px, y2_px))
|
| 119 |
+
logger.info(f"Cropped image to {x2_px - x1_px}x{y2_px - y1_px} pixels")
|
| 120 |
+
|
| 121 |
+
return cropped
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class PDFProcessor:
|
| 125 |
+
"""Handles PDF to image conversion and multi-page processing."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, dpi: int = PDF_DPI, max_pages: int = PDF_MAX_PAGES):
|
| 128 |
+
"""Initialize the PDF processor.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
dpi: DPI for PDF to image conversion
|
| 132 |
+
max_pages: Maximum number of pages to process
|
| 133 |
+
"""
|
| 134 |
+
self.dpi = dpi
|
| 135 |
+
self.max_pages = max_pages
|
| 136 |
+
|
| 137 |
+
def pdf_to_images(self, pdf_data: bytes) -> List[Image.Image]:
|
| 138 |
+
"""Convert PDF to list of images.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
pdf_data: PDF file data as bytes
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of PIL Images, one per page
|
| 145 |
+
"""
|
| 146 |
+
try:
|
| 147 |
+
# Open PDF from bytes
|
| 148 |
+
pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
|
| 149 |
+
images = []
|
| 150 |
+
|
| 151 |
+
# Limit number of pages to process
|
| 152 |
+
num_pages = min(len(pdf_document), self.max_pages)
|
| 153 |
+
logger.info(f"Processing {num_pages} pages from PDF")
|
| 154 |
+
|
| 155 |
+
for page_num in range(num_pages):
|
| 156 |
+
page = pdf_document[page_num]
|
| 157 |
+
|
| 158 |
+
# Convert page to image
|
| 159 |
+
mat = fitz.Matrix(self.dpi / 72, self.dpi / 72) # 72 is default DPI
|
| 160 |
+
pix = page.get_pixmap(matrix=mat)
|
| 161 |
+
|
| 162 |
+
# Convert to PIL Image
|
| 163 |
+
img_data = pix.tobytes("png")
|
| 164 |
+
image = Image.open(io.BytesIO(img_data))
|
| 165 |
+
images.append(image)
|
| 166 |
+
|
| 167 |
+
logger.info(f"Converted page {page_num + 1} to image: {image.size}")
|
| 168 |
+
|
| 169 |
+
pdf_document.close()
|
| 170 |
+
return images
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"Failed to convert PDF to images: {e}")
|
| 174 |
+
raise RuntimeError(f"PDF conversion failed: {e}")
|
| 175 |
+
|
| 176 |
+
def is_pdf(self, file_data: bytes) -> bool:
|
| 177 |
+
"""Check if file data is a PDF.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
file_data: File data as bytes
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
True if file is a PDF
|
| 184 |
+
"""
|
| 185 |
+
return file_data.startswith(b'%PDF-')
|
| 186 |
+
|
| 187 |
+
def get_pdf_page_count(self, pdf_data: bytes) -> int:
|
| 188 |
+
"""Get the number of pages in a PDF.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
pdf_data: PDF file data as bytes
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Number of pages in the PDF
|
| 195 |
+
"""
|
| 196 |
+
try:
|
| 197 |
+
pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
|
| 198 |
+
page_count = len(pdf_document)
|
| 199 |
+
pdf_document.close()
|
| 200 |
+
return page_count
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Failed to get PDF page count: {e}")
|
| 203 |
+
return 0
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class DocumentProcessor:
|
| 207 |
+
"""Main document processing class that handles both images and PDFs."""
|
| 208 |
+
|
| 209 |
+
def __init__(self):
|
| 210 |
+
"""Initialize the document processor."""
|
| 211 |
+
self.image_preprocessor = ImagePreprocessor()
|
| 212 |
+
self.pdf_processor = PDFProcessor()
|
| 213 |
+
|
| 214 |
+
def process_document(
|
| 215 |
+
self,
|
| 216 |
+
file_data: bytes,
|
| 217 |
+
roi: Optional[Tuple[float, float, float, float]] = None
|
| 218 |
+
) -> List[Image.Image]:
|
| 219 |
+
"""Process a document (image or PDF) and return preprocessed images.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
file_data: Document file data as bytes
|
| 223 |
+
roi: Optional ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
List of preprocessed PIL Images
|
| 227 |
+
"""
|
| 228 |
+
# Check if it's a PDF
|
| 229 |
+
if self.pdf_processor.is_pdf(file_data):
|
| 230 |
+
logger.info("Processing PDF document")
|
| 231 |
+
images = self.pdf_processor.pdf_to_images(file_data)
|
| 232 |
+
else:
|
| 233 |
+
# Process as image
|
| 234 |
+
logger.info("Processing image document")
|
| 235 |
+
try:
|
| 236 |
+
image = Image.open(io.BytesIO(file_data))
|
| 237 |
+
images = [image]
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.error(f"Failed to open image: {e}")
|
| 240 |
+
raise RuntimeError(f"Image processing failed: {e}")
|
| 241 |
+
|
| 242 |
+
# Preprocess each image
|
| 243 |
+
processed_images = []
|
| 244 |
+
for i, image in enumerate(images):
|
| 245 |
+
try:
|
| 246 |
+
# Apply ROI cropping if provided
|
| 247 |
+
if roi is not None:
|
| 248 |
+
image = self.image_preprocessor.crop_by_roi(image, roi)
|
| 249 |
+
|
| 250 |
+
# Preprocess image for model requirements
|
| 251 |
+
processed_image = self.image_preprocessor.preprocess_image(image)
|
| 252 |
+
processed_images.append(processed_image)
|
| 253 |
+
|
| 254 |
+
logger.info(f"Processed image {i + 1}: {processed_image.size}")
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(f"Failed to preprocess image {i + 1}: {e}")
|
| 258 |
+
# Continue with other images even if one fails
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
if not processed_images:
|
| 262 |
+
raise RuntimeError("No images could be processed from the document")
|
| 263 |
+
|
| 264 |
+
logger.info(f"Successfully processed {len(processed_images)} images")
|
| 265 |
+
return processed_images
|
| 266 |
+
|
| 267 |
+
def validate_file_size(self, file_data: bytes) -> bool:
|
| 268 |
+
"""Validate that file size is within limits.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
file_data: File data as bytes
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
True if file size is acceptable
|
| 275 |
+
"""
|
| 276 |
+
file_size = len(file_data)
|
| 277 |
+
if file_size > IMAGE_MAX_SIZE:
|
| 278 |
+
logger.warning(f"File size {file_size} exceeds limit {IMAGE_MAX_SIZE}")
|
| 279 |
+
return False
|
| 280 |
+
return True
|
| 281 |
+
|
| 282 |
+
def get_document_info(self, file_data: bytes) -> dict:
|
| 283 |
+
"""Get information about the document.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
file_data: Document file data as bytes
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Dictionary with document information
|
| 290 |
+
"""
|
| 291 |
+
info = {
|
| 292 |
+
"file_size": len(file_data),
|
| 293 |
+
"is_pdf": self.pdf_processor.is_pdf(file_data),
|
| 294 |
+
"page_count": 1
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if info["is_pdf"]:
|
| 298 |
+
info["page_count"] = self.pdf_processor.get_pdf_page_count(file_data)
|
| 299 |
+
|
| 300 |
+
return info
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# Global document processor instance
|
| 304 |
+
_document_processor: Optional[DocumentProcessor] = None
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def get_document_processor() -> DocumentProcessor:
|
| 308 |
+
"""Get the global document processor instance."""
|
| 309 |
+
global _document_processor
|
| 310 |
+
if _document_processor is None:
|
| 311 |
+
_document_processor = DocumentProcessor()
|
| 312 |
+
return _document_processor
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def process_document(
|
| 316 |
+
file_data: bytes,
|
| 317 |
+
roi: Optional[Tuple[float, float, float, float]] = None
|
| 318 |
+
) -> List[Image.Image]:
|
| 319 |
+
"""Process a document and return preprocessed images."""
|
| 320 |
+
processor = get_document_processor()
|
| 321 |
+
return processor.process_document(file_data, roi)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def validate_file_size(file_data: bytes) -> bool:
|
| 325 |
+
"""Validate that file size is within limits."""
|
| 326 |
+
processor = get_document_processor()
|
| 327 |
+
return processor.validate_file_size(file_data)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_document_info(file_data: bytes) -> dict:
|
| 331 |
+
"""Get information about the document."""
|
| 332 |
+
processor = get_document_processor()
|
| 333 |
+
return processor.get_document_info(file_data)
|
src/kybtech_dots_ocr/response_builder.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Response builder for Dots.OCR API responses.
|
| 2 |
+
|
| 3 |
+
This module handles the construction and validation of OCR API responses
|
| 4 |
+
according to the specified schema with proper error handling and metadata.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Optional, Dict, Any
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
from .api_models import OCRResponse, OCRDetection, ExtractedFields, MRZData, ExtractedField
|
| 13 |
+
from .enhanced_field_extraction import EnhancedFieldExtractor
|
| 14 |
+
|
| 15 |
+
# Configure logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class OCRResponseBuilder:
|
| 20 |
+
"""Builds OCR API responses with proper validation and metadata."""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""Initialize the response builder."""
|
| 24 |
+
self.field_extractor = EnhancedFieldExtractor()
|
| 25 |
+
|
| 26 |
+
def build_response(
|
| 27 |
+
self,
|
| 28 |
+
request_id: str,
|
| 29 |
+
media_type: str,
|
| 30 |
+
processing_time: float,
|
| 31 |
+
ocr_texts: List[str],
|
| 32 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None
|
| 33 |
+
) -> OCRResponse:
|
| 34 |
+
"""Build a complete OCR response from extracted texts.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
request_id: Unique request identifier
|
| 38 |
+
media_type: Type of media processed ("image" or "pdf")
|
| 39 |
+
processing_time: Total processing time in seconds
|
| 40 |
+
ocr_texts: List of OCR text results (one per page)
|
| 41 |
+
page_metadata: Optional metadata for each page
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Complete OCRResponse object
|
| 45 |
+
"""
|
| 46 |
+
logger.info(f"Building response for {len(ocr_texts)} pages")
|
| 47 |
+
|
| 48 |
+
detections = []
|
| 49 |
+
|
| 50 |
+
for i, ocr_text in enumerate(ocr_texts):
|
| 51 |
+
try:
|
| 52 |
+
# Extract fields and MRZ data
|
| 53 |
+
extracted_fields = self.field_extractor.extract_fields(ocr_text)
|
| 54 |
+
mrz_data = self.field_extractor.extract_mrz(ocr_text)
|
| 55 |
+
|
| 56 |
+
# Create detection for this page
|
| 57 |
+
detection = self._create_detection(extracted_fields, mrz_data, i, page_metadata)
|
| 58 |
+
detections.append(detection)
|
| 59 |
+
|
| 60 |
+
logger.info(f"Page {i + 1}: {len(extracted_fields.__dict__)} fields, MRZ: {mrz_data is not None}")
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(f"Failed to process page {i + 1}: {e}")
|
| 64 |
+
# Create empty detection for failed page
|
| 65 |
+
detection = self._create_empty_detection(i)
|
| 66 |
+
detections.append(detection)
|
| 67 |
+
|
| 68 |
+
# Build final response
|
| 69 |
+
response = OCRResponse(
|
| 70 |
+
request_id=request_id,
|
| 71 |
+
media_type=media_type,
|
| 72 |
+
processing_time=processing_time,
|
| 73 |
+
detections=detections
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Validate response
|
| 77 |
+
self._validate_response(response)
|
| 78 |
+
|
| 79 |
+
logger.info(f"Response built successfully: {len(detections)} detections")
|
| 80 |
+
return response
|
| 81 |
+
|
| 82 |
+
def _create_detection(
|
| 83 |
+
self,
|
| 84 |
+
extracted_fields: ExtractedFields,
|
| 85 |
+
mrz_data: Optional[MRZData],
|
| 86 |
+
page_index: int,
|
| 87 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None
|
| 88 |
+
) -> OCRDetection:
|
| 89 |
+
"""Create an OCR detection from extracted data.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
extracted_fields: Extracted field data
|
| 93 |
+
mrz_data: MRZ data if available
|
| 94 |
+
page_index: Index of the page
|
| 95 |
+
page_metadata: Optional metadata for the page
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
OCRDetection object
|
| 99 |
+
"""
|
| 100 |
+
# Convert IdCardFields to ExtractedFields format expected by OCRDetection
|
| 101 |
+
converted_fields = self._convert_fields_format(extracted_fields)
|
| 102 |
+
|
| 103 |
+
# Enhance MRZ data if available
|
| 104 |
+
enhanced_mrz = self._enhance_mrz_data(mrz_data, page_index, page_metadata)
|
| 105 |
+
|
| 106 |
+
return OCRDetection(
|
| 107 |
+
mrz_data=enhanced_mrz,
|
| 108 |
+
extracted_fields=converted_fields
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def _convert_fields_format(self, id_card_fields) -> ExtractedFields:
|
| 112 |
+
"""Convert IdCardFields to the format expected by OCRDetection.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
id_card_fields: IdCardFields object
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
ExtractedFields object
|
| 119 |
+
"""
|
| 120 |
+
# Convert IdCardFields to ExtractedFields by mapping the fields
|
| 121 |
+
field_dict = {}
|
| 122 |
+
|
| 123 |
+
for field_name, field_value in id_card_fields.__dict__.items():
|
| 124 |
+
if field_value is not None:
|
| 125 |
+
# Convert ExtractedField to dict for Pydantic validation
|
| 126 |
+
field_dict[field_name] = field_value.dict() if hasattr(field_value, 'dict') else field_value
|
| 127 |
+
|
| 128 |
+
return ExtractedFields(**field_dict)
|
| 129 |
+
|
| 130 |
+
def _enhance_mrz_data(
|
| 131 |
+
self,
|
| 132 |
+
mrz_data: Optional[MRZData],
|
| 133 |
+
page_index: int,
|
| 134 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None
|
| 135 |
+
) -> Optional[MRZData]:
|
| 136 |
+
"""Enhance MRZ data with additional context if available.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
mrz_data: Original MRZ data
|
| 140 |
+
page_index: Index of the page
|
| 141 |
+
page_metadata: Optional metadata for the page
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Enhanced MRZ data or None
|
| 145 |
+
"""
|
| 146 |
+
if mrz_data is None:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
# Add page context if available
|
| 150 |
+
if page_metadata and page_index < len(page_metadata):
|
| 151 |
+
metadata = page_metadata[page_index]
|
| 152 |
+
# Could add page-specific confidence adjustments here
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
return mrz_data
|
| 156 |
+
|
| 157 |
+
def _create_empty_detection(self, page_index: int) -> OCRDetection:
|
| 158 |
+
"""Create an empty detection for failed pages.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
page_index: Index of the failed page
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Empty OCRDetection object
|
| 165 |
+
"""
|
| 166 |
+
logger.warning(f"Creating empty detection for failed page {page_index + 1}")
|
| 167 |
+
|
| 168 |
+
return OCRDetection(
|
| 169 |
+
mrz_data=None,
|
| 170 |
+
extracted_fields=ExtractedFields()
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _validate_response(self, response: OCRResponse) -> None:
|
| 174 |
+
"""Validate the response structure and data.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
response: OCRResponse to validate
|
| 178 |
+
|
| 179 |
+
Raises:
|
| 180 |
+
ValueError: If response validation fails
|
| 181 |
+
"""
|
| 182 |
+
# Validate request_id
|
| 183 |
+
if not response.request_id or len(response.request_id) == 0:
|
| 184 |
+
raise ValueError("Request ID cannot be empty")
|
| 185 |
+
|
| 186 |
+
# Validate media_type
|
| 187 |
+
if response.media_type not in ["image", "pdf"]:
|
| 188 |
+
raise ValueError(f"Invalid media_type: {response.media_type}")
|
| 189 |
+
|
| 190 |
+
# Validate processing_time
|
| 191 |
+
if response.processing_time < 0:
|
| 192 |
+
raise ValueError("Processing time cannot be negative")
|
| 193 |
+
|
| 194 |
+
# Validate detections
|
| 195 |
+
if not response.detections:
|
| 196 |
+
logger.warning("Response has no detections")
|
| 197 |
+
|
| 198 |
+
# Validate each detection
|
| 199 |
+
for i, detection in enumerate(response.detections):
|
| 200 |
+
self._validate_detection(detection, i)
|
| 201 |
+
|
| 202 |
+
logger.debug("Response validation passed")
|
| 203 |
+
|
| 204 |
+
def _validate_detection(self, detection: OCRDetection, index: int) -> None:
|
| 205 |
+
"""Validate a single detection.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
detection: OCRDetection to validate
|
| 209 |
+
index: Index of the detection
|
| 210 |
+
|
| 211 |
+
Raises:
|
| 212 |
+
ValueError: If detection validation fails
|
| 213 |
+
"""
|
| 214 |
+
# Validate MRZ data if present
|
| 215 |
+
if detection.mrz_data:
|
| 216 |
+
self._validate_mrz_data(detection.mrz_data, index)
|
| 217 |
+
|
| 218 |
+
# Validate extracted fields
|
| 219 |
+
if detection.extracted_fields:
|
| 220 |
+
self._validate_extracted_fields(detection.extracted_fields, index)
|
| 221 |
+
|
| 222 |
+
def _validate_mrz_data(self, mrz_data: MRZData, index: int) -> None:
|
| 223 |
+
"""Validate MRZ data.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
mrz_data: MRZ data to validate
|
| 227 |
+
index: Index of the detection
|
| 228 |
+
|
| 229 |
+
Raises:
|
| 230 |
+
ValueError: If MRZ data validation fails
|
| 231 |
+
"""
|
| 232 |
+
# Support both canonical and legacy attribute names
|
| 233 |
+
raw_text_value = getattr(mrz_data, "raw_text", None) or getattr(mrz_data, "raw_mrz", None)
|
| 234 |
+
if not raw_text_value:
|
| 235 |
+
raise ValueError(f"MRZ raw text cannot be empty for detection {index}")
|
| 236 |
+
|
| 237 |
+
format_type_value = getattr(mrz_data, "format_type", None) or getattr(mrz_data, "document_type", None)
|
| 238 |
+
if not format_type_value:
|
| 239 |
+
raise ValueError(f"MRZ format type cannot be empty for detection {index}")
|
| 240 |
+
|
| 241 |
+
if not (0.0 <= mrz_data.confidence <= 1.0):
|
| 242 |
+
raise ValueError(f"MRZ confidence must be between 0.0 and 1.0 for detection {index}")
|
| 243 |
+
|
| 244 |
+
def _validate_extracted_fields(self, fields: ExtractedFields, index: int) -> None:
|
| 245 |
+
"""Validate extracted fields.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
fields: Extracted fields to validate
|
| 249 |
+
index: Index of the detection
|
| 250 |
+
|
| 251 |
+
Raises:
|
| 252 |
+
ValueError: If fields validation fails
|
| 253 |
+
"""
|
| 254 |
+
# Validate each field if present
|
| 255 |
+
for field_name, field_value in fields.__dict__.items():
|
| 256 |
+
if field_value is not None:
|
| 257 |
+
if not isinstance(field_value, ExtractedField):
|
| 258 |
+
raise ValueError(f"Field {field_name} must be ExtractedField instance for detection {index}")
|
| 259 |
+
|
| 260 |
+
# Validate field content
|
| 261 |
+
if not (0.0 <= field_value.confidence <= 1.0):
|
| 262 |
+
raise ValueError(f"Field {field_name} confidence must be between 0.0 and 1.0 for detection {index}")
|
| 263 |
+
|
| 264 |
+
def build_error_response(
|
| 265 |
+
self,
|
| 266 |
+
request_id: str,
|
| 267 |
+
error_message: str,
|
| 268 |
+
processing_time: float = 0.0
|
| 269 |
+
) -> OCRResponse:
|
| 270 |
+
"""Build an error response.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
request_id: Unique request identifier
|
| 274 |
+
error_message: Error message
|
| 275 |
+
processing_time: Processing time before error
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Error OCRResponse object
|
| 279 |
+
"""
|
| 280 |
+
logger.error(f"Building error response: {error_message}")
|
| 281 |
+
|
| 282 |
+
return OCRResponse(
|
| 283 |
+
request_id=request_id,
|
| 284 |
+
media_type="image", # Default media type
|
| 285 |
+
processing_time=processing_time,
|
| 286 |
+
detections=[] # Empty detections for error
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Global response builder instance
|
| 291 |
+
_response_builder: Optional[OCRResponseBuilder] = None
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_response_builder() -> OCRResponseBuilder:
|
| 295 |
+
"""Get the global response builder instance."""
|
| 296 |
+
global _response_builder
|
| 297 |
+
if _response_builder is None:
|
| 298 |
+
_response_builder = OCRResponseBuilder()
|
| 299 |
+
return _response_builder
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def build_ocr_response(
|
| 303 |
+
request_id: str,
|
| 304 |
+
media_type: str,
|
| 305 |
+
processing_time: float,
|
| 306 |
+
ocr_texts: List[str],
|
| 307 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None
|
| 308 |
+
) -> OCRResponse:
|
| 309 |
+
"""Build a complete OCR response from extracted texts."""
|
| 310 |
+
builder = get_response_builder()
|
| 311 |
+
return builder.build_response(request_id, media_type, processing_time, ocr_texts, page_metadata)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def build_error_response(
|
| 315 |
+
request_id: str,
|
| 316 |
+
error_message: str,
|
| 317 |
+
processing_time: float = 0.0
|
| 318 |
+
) -> OCRResponse:
|
| 319 |
+
"""Build an error response."""
|
| 320 |
+
builder = get_response_builder()
|
| 321 |
+
return builder.build_error_response(request_id, error_message, processing_time)
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Test package for KYB Tech Dots.OCR."""
|
tests/test_app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the main FastAPI application."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from fastapi.testclient import TestClient
|
| 5 |
+
from src.kybtech_dots_ocr.app import app
|
| 6 |
+
|
| 7 |
+
client = TestClient(app)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_health_check():
|
| 11 |
+
"""Test the health check endpoint."""
|
| 12 |
+
response = client.get("/health")
|
| 13 |
+
assert response.status_code == 200
|
| 14 |
+
data = response.json()
|
| 15 |
+
assert "status" in data
|
| 16 |
+
assert "version" in data
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_ocr_endpoint_missing_file():
|
| 20 |
+
"""Test OCR endpoint with missing file."""
|
| 21 |
+
response = client.post("/v1/id/ocr")
|
| 22 |
+
assert response.status_code == 422 # Validation error
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_ocr_endpoint_invalid_file():
|
| 26 |
+
"""Test OCR endpoint with invalid file."""
|
| 27 |
+
files = {"file": ("test.txt", b"not an image", "text/plain")}
|
| 28 |
+
response = client.post("/v1/id/ocr", files=files)
|
| 29 |
+
# Should handle gracefully
|
| 30 |
+
assert response.status_code in [400, 422, 500]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.mark.skip(reason="Requires model to be loaded")
|
| 34 |
+
def test_ocr_endpoint_with_image():
|
| 35 |
+
"""Test OCR endpoint with actual image (requires model)."""
|
| 36 |
+
# This test would require the model to be loaded
|
| 37 |
+
# and actual image data
|
| 38 |
+
pass
|
tests/test_field_extraction.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for field extraction functionality."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from src.kybtech_dots_ocr.enhanced_field_extraction import EnhancedFieldExtractor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestEnhancedFieldExtractor:
|
| 8 |
+
"""Test cases for EnhancedFieldExtractor."""
|
| 9 |
+
|
| 10 |
+
def test_extract_fields_dutch_id(self):
|
| 11 |
+
"""Test field extraction with Dutch ID card text."""
|
| 12 |
+
extractor = EnhancedFieldExtractor()
|
| 13 |
+
text = """
|
| 14 |
+
IDENTITEITSKAART
|
| 15 |
+
Documentnummer: NLD123456789
|
| 16 |
+
Achternaam: MULDER
|
| 17 |
+
Voornamen: THOMAS JAN
|
| 18 |
+
Nationaliteit: NLD
|
| 19 |
+
Geboortedatum: 15-03-1990
|
| 20 |
+
Geslacht: M
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
fields = extractor.extract_fields(text)
|
| 24 |
+
|
| 25 |
+
assert fields.document_number is not None
|
| 26 |
+
assert fields.document_number.value == "NLD123456789"
|
| 27 |
+
assert fields.surname is not None
|
| 28 |
+
assert fields.surname.value == "MULDER"
|
| 29 |
+
assert fields.given_names is not None
|
| 30 |
+
assert fields.given_names.value == "THOMAS JAN"
|
| 31 |
+
|
| 32 |
+
def test_extract_fields_english_id(self):
|
| 33 |
+
"""Test field extraction with English ID card text."""
|
| 34 |
+
extractor = EnhancedFieldExtractor()
|
| 35 |
+
text = """
|
| 36 |
+
IDENTITY CARD
|
| 37 |
+
Document Number: NLD123456789
|
| 38 |
+
Surname: MULDER
|
| 39 |
+
Given Names: THOMAS JAN
|
| 40 |
+
Nationality: NLD
|
| 41 |
+
Date of Birth: 15-03-1990
|
| 42 |
+
Gender: M
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
fields = extractor.extract_fields(text)
|
| 46 |
+
|
| 47 |
+
assert fields.document_number is not None
|
| 48 |
+
assert fields.document_number.value == "NLD123456789"
|
| 49 |
+
assert fields.surname is not None
|
| 50 |
+
assert fields.surname.value == "MULDER"
|
| 51 |
+
|
| 52 |
+
def test_extract_mrz_data(self):
|
| 53 |
+
"""Test MRZ data extraction."""
|
| 54 |
+
extractor = EnhancedFieldExtractor()
|
| 55 |
+
text = """
|
| 56 |
+
P<NLDMULDER<<THOMAS<<<<<<<<<<<<<<<<<<<<<<<<<
|
| 57 |
+
NLD123456789NLD9003151M300101123456789<<<<<<<<
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
mrz_data = extractor.extract_mrz(text)
|
| 61 |
+
|
| 62 |
+
assert mrz_data is not None
|
| 63 |
+
assert mrz_data.format_type == "TD3"
|
| 64 |
+
assert mrz_data.confidence > 0.8
|
| 65 |
+
|
| 66 |
+
def test_extract_fields_empty_text(self):
|
| 67 |
+
"""Test field extraction with empty text."""
|
| 68 |
+
extractor = EnhancedFieldExtractor()
|
| 69 |
+
fields = extractor.extract_fields("")
|
| 70 |
+
|
| 71 |
+
# Should return empty fields
|
| 72 |
+
assert fields.document_number is None
|
| 73 |
+
assert fields.surname is None
|
| 74 |
+
|
| 75 |
+
def test_confidence_scoring(self):
|
| 76 |
+
"""Test confidence scoring functionality."""
|
| 77 |
+
extractor = EnhancedFieldExtractor()
|
| 78 |
+
|
| 79 |
+
# High quality text
|
| 80 |
+
high_quality = "Documentnummer: NLD123456789 Achternaam: MULDER"
|
| 81 |
+
fields_high = extractor.extract_fields(high_quality)
|
| 82 |
+
|
| 83 |
+
# Lower quality text
|
| 84 |
+
low_quality = "doc nr: NLD123"
|
| 85 |
+
fields_low = extractor.extract_fields(low_quality)
|
| 86 |
+
|
| 87 |
+
if fields_high.document_number and fields_low.document_number:
|
| 88 |
+
assert fields_high.document_number.confidence >= fields_low.document_number.confidence
|