Spaces:
Paused
Paused
style: format Python files with Black
Browse files- src/kybtech_dots_ocr/__init__.py +8 -2
- src/kybtech_dots_ocr/api_models.py +40 -11
- src/kybtech_dots_ocr/app.py +78 -40
- src/kybtech_dots_ocr/enhanced_field_extraction.py +118 -80
- src/kybtech_dots_ocr/field_extraction.py +28 -32
- src/kybtech_dots_ocr/models.py +33 -10
- src/kybtech_dots_ocr/preprocessing.py +72 -67
- src/kybtech_dots_ocr/response_builder.py +53 -3
src/kybtech_dots_ocr/__init__.py
CHANGED
|
@@ -8,7 +8,13 @@ __author__ = "Algoryn"
|
|
| 8 |
__email__ = "info@algoryn.com"
|
| 9 |
|
| 10 |
from .app import app
|
| 11 |
-
from .api_models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 13 |
from .preprocessing import process_document, validate_file_size, get_document_info
|
| 14 |
from .response_builder import build_ocr_response, build_error_response
|
|
@@ -16,7 +22,7 @@ from .response_builder import build_ocr_response, build_error_response
|
|
| 16 |
__all__ = [
|
| 17 |
"app",
|
| 18 |
"OCRResponse",
|
| 19 |
-
"OCRDetection",
|
| 20 |
"ExtractedFields",
|
| 21 |
"MRZData",
|
| 22 |
"ExtractedField",
|
|
|
|
| 8 |
__email__ = "info@algoryn.com"
|
| 9 |
|
| 10 |
from .app import app
|
| 11 |
+
from .api_models import (
|
| 12 |
+
OCRResponse,
|
| 13 |
+
OCRDetection,
|
| 14 |
+
ExtractedFields,
|
| 15 |
+
MRZData,
|
| 16 |
+
ExtractedField,
|
| 17 |
+
)
|
| 18 |
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 19 |
from .preprocessing import process_document, validate_file_size, get_document_info
|
| 20 |
from .response_builder import build_ocr_response, build_error_response
|
|
|
|
| 22 |
__all__ = [
|
| 23 |
"app",
|
| 24 |
"OCRResponse",
|
| 25 |
+
"OCRDetection",
|
| 26 |
"ExtractedFields",
|
| 27 |
"MRZData",
|
| 28 |
"ExtractedField",
|
src/kybtech_dots_ocr/api_models.py
CHANGED
|
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
|
|
| 10 |
|
| 11 |
class BoundingBox(BaseModel):
|
| 12 |
"""Normalized bounding box coordinates."""
|
|
|
|
| 13 |
x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate")
|
| 14 |
y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate")
|
| 15 |
x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate")
|
|
@@ -18,6 +19,7 @@ class BoundingBox(BaseModel):
|
|
| 18 |
|
| 19 |
class ExtractedField(BaseModel):
|
| 20 |
"""Individual extracted field with confidence and source."""
|
|
|
|
| 21 |
field_name: str = Field(..., description="Standardized field name")
|
| 22 |
value: Optional[str] = Field(None, description="Extracted field value")
|
| 23 |
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
|
|
@@ -26,10 +28,19 @@ class ExtractedField(BaseModel):
|
|
| 26 |
|
| 27 |
class IdCardFields(BaseModel):
|
| 28 |
"""Structured fields extracted from identity documents."""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Personal Information
|
| 35 |
surname: Optional[ExtractedField] = Field(None, description="Family name/surname")
|
|
@@ -42,15 +53,22 @@ class IdCardFields(BaseModel):
|
|
| 42 |
# Validity Information
|
| 43 |
date_of_issue: Optional[ExtractedField] = Field(None, description="Date of issue")
|
| 44 |
date_of_expiry: Optional[ExtractedField] = Field(None, description="Date of expiry")
|
| 45 |
-
personal_number: Optional[ExtractedField] = Field(
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# Additional fields for specific document types
|
| 48 |
-
optional_data_1: Optional[ExtractedField] = Field(
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class ExtractedFields(BaseModel):
|
| 53 |
"""All extracted fields from identity document."""
|
|
|
|
| 54 |
document_number: Optional[ExtractedField] = None
|
| 55 |
document_type: Optional[ExtractedField] = None
|
| 56 |
issuing_country: Optional[ExtractedField] = None
|
|
@@ -70,8 +88,11 @@ class ExtractedFields(BaseModel):
|
|
| 70 |
|
| 71 |
class MRZData(BaseModel):
|
| 72 |
"""Machine Readable Zone data."""
|
|
|
|
| 73 |
# Primary canonical fields
|
| 74 |
-
document_type: Optional[str] = Field(
|
|
|
|
|
|
|
| 75 |
issuing_country: Optional[str] = Field(None, description="Issuing country code")
|
| 76 |
surname: Optional[str] = Field(None, description="Surname from MRZ")
|
| 77 |
given_names: Optional[str] = Field(None, description="Given names from MRZ")
|
|
@@ -82,22 +103,30 @@ class MRZData(BaseModel):
|
|
| 82 |
date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ")
|
| 83 |
personal_number: Optional[str] = Field(None, description="Personal number from MRZ")
|
| 84 |
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
|
| 85 |
-
confidence: float = Field(
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Backwards compatibility fields (some older code/tests expect these names)
|
| 88 |
# These duplicate information from the canonical fields above.
|
| 89 |
-
format_type: Optional[str] = Field(
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
class OCRDetection(BaseModel):
|
| 94 |
"""Single OCR detection result."""
|
|
|
|
| 95 |
mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected")
|
| 96 |
extracted_fields: ExtractedFields = Field(..., description="Extracted field data")
|
| 97 |
|
| 98 |
|
| 99 |
class OCRResponse(BaseModel):
|
| 100 |
"""OCR API response."""
|
|
|
|
| 101 |
request_id: str = Field(..., description="Unique request identifier")
|
| 102 |
media_type: str = Field(..., description="Media type processed")
|
| 103 |
processing_time: float = Field(..., description="Processing time in seconds")
|
|
|
|
| 10 |
|
| 11 |
class BoundingBox(BaseModel):
|
| 12 |
"""Normalized bounding box coordinates."""
|
| 13 |
+
|
| 14 |
x1: float = Field(..., ge=0.0, le=1.0, description="Top-left x coordinate")
|
| 15 |
y1: float = Field(..., ge=0.0, le=1.0, description="Top-left y coordinate")
|
| 16 |
x2: float = Field(..., ge=0.0, le=1.0, description="Bottom-right x coordinate")
|
|
|
|
| 19 |
|
| 20 |
class ExtractedField(BaseModel):
|
| 21 |
"""Individual extracted field with confidence and source."""
|
| 22 |
+
|
| 23 |
field_name: str = Field(..., description="Standardized field name")
|
| 24 |
value: Optional[str] = Field(None, description="Extracted field value")
|
| 25 |
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
|
|
|
|
| 28 |
|
| 29 |
class IdCardFields(BaseModel):
|
| 30 |
"""Structured fields extracted from identity documents."""
|
| 31 |
+
|
| 32 |
+
document_number: Optional[ExtractedField] = Field(
|
| 33 |
+
None, description="Document number/ID"
|
| 34 |
+
)
|
| 35 |
+
document_type: Optional[ExtractedField] = Field(
|
| 36 |
+
None, description="Type of document"
|
| 37 |
+
)
|
| 38 |
+
issuing_country: Optional[ExtractedField] = Field(
|
| 39 |
+
None, description="Issuing country code"
|
| 40 |
+
)
|
| 41 |
+
issuing_authority: Optional[ExtractedField] = Field(
|
| 42 |
+
None, description="Issuing authority"
|
| 43 |
+
)
|
| 44 |
|
| 45 |
# Personal Information
|
| 46 |
surname: Optional[ExtractedField] = Field(None, description="Family name/surname")
|
|
|
|
| 53 |
# Validity Information
|
| 54 |
date_of_issue: Optional[ExtractedField] = Field(None, description="Date of issue")
|
| 55 |
date_of_expiry: Optional[ExtractedField] = Field(None, description="Date of expiry")
|
| 56 |
+
personal_number: Optional[ExtractedField] = Field(
|
| 57 |
+
None, description="Personal number"
|
| 58 |
+
)
|
| 59 |
|
| 60 |
# Additional fields for specific document types
|
| 61 |
+
optional_data_1: Optional[ExtractedField] = Field(
|
| 62 |
+
None, description="Optional data field 1"
|
| 63 |
+
)
|
| 64 |
+
optional_data_2: Optional[ExtractedField] = Field(
|
| 65 |
+
None, description="Optional data field 2"
|
| 66 |
+
)
|
| 67 |
|
| 68 |
|
| 69 |
class ExtractedFields(BaseModel):
|
| 70 |
"""All extracted fields from identity document."""
|
| 71 |
+
|
| 72 |
document_number: Optional[ExtractedField] = None
|
| 73 |
document_type: Optional[ExtractedField] = None
|
| 74 |
issuing_country: Optional[ExtractedField] = None
|
|
|
|
| 88 |
|
| 89 |
class MRZData(BaseModel):
|
| 90 |
"""Machine Readable Zone data."""
|
| 91 |
+
|
| 92 |
# Primary canonical fields
|
| 93 |
+
document_type: Optional[str] = Field(
|
| 94 |
+
None, description="MRZ document type (TD1|TD2|TD3)"
|
| 95 |
+
)
|
| 96 |
issuing_country: Optional[str] = Field(None, description="Issuing country code")
|
| 97 |
surname: Optional[str] = Field(None, description="Surname from MRZ")
|
| 98 |
given_names: Optional[str] = Field(None, description="Given names from MRZ")
|
|
|
|
| 103 |
date_of_expiry: Optional[str] = Field(None, description="Date of expiry from MRZ")
|
| 104 |
personal_number: Optional[str] = Field(None, description="Personal number from MRZ")
|
| 105 |
raw_mrz: Optional[str] = Field(None, description="Raw MRZ text")
|
| 106 |
+
confidence: float = Field(
|
| 107 |
+
0.0, ge=0.0, le=1.0, description="MRZ extraction confidence"
|
| 108 |
+
)
|
| 109 |
|
| 110 |
# Backwards compatibility fields (some older code/tests expect these names)
|
| 111 |
# These duplicate information from the canonical fields above.
|
| 112 |
+
format_type: Optional[str] = Field(
|
| 113 |
+
None, description="Alias of document_type for backward compatibility"
|
| 114 |
+
)
|
| 115 |
+
raw_text: Optional[str] = Field(
|
| 116 |
+
None, description="Alias of raw_mrz for backward compatibility"
|
| 117 |
+
)
|
| 118 |
|
| 119 |
|
| 120 |
class OCRDetection(BaseModel):
|
| 121 |
"""Single OCR detection result."""
|
| 122 |
+
|
| 123 |
mrz_data: Optional[MRZData] = Field(None, description="MRZ data if detected")
|
| 124 |
extracted_fields: ExtractedFields = Field(..., description="Extracted field data")
|
| 125 |
|
| 126 |
|
| 127 |
class OCRResponse(BaseModel):
|
| 128 |
"""OCR API response."""
|
| 129 |
+
|
| 130 |
request_id: str = Field(..., description="Unique request identifier")
|
| 131 |
media_type: str = Field(..., description="Media type processed")
|
| 132 |
processing_time: float = Field(..., description="Processing time in seconds")
|
src/kybtech_dots_ocr/app.py
CHANGED
|
@@ -17,7 +17,14 @@ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
| 17 |
from fastapi.responses import JSONResponse
|
| 18 |
|
| 19 |
# Import local modules
|
| 20 |
-
from .api_models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from .enhanced_field_extraction import EnhancedFieldExtractor
|
| 22 |
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 23 |
from .preprocessing import process_document, validate_file_size, get_document_info
|
|
@@ -27,6 +34,13 @@ from .response_builder import build_ocr_response, build_error_response
|
|
| 27 |
logging.basicConfig(level=logging.INFO)
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Global model state
|
| 31 |
model_loaded = False
|
| 32 |
|
|
@@ -34,13 +48,11 @@ model_loaded = False
|
|
| 34 |
# FieldExtractor is now imported from the shared module
|
| 35 |
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
@asynccontextmanager
|
| 40 |
async def lifespan(app: FastAPI):
|
| 41 |
"""Application lifespan manager for model loading."""
|
| 42 |
global model_loaded
|
| 43 |
-
|
| 44 |
# Allow tests and lightweight environments to skip model loading
|
| 45 |
# Set DOTS_OCR_SKIP_MODEL_LOAD=1 to bypass heavy downloads during tests/CI
|
| 46 |
skip_model_load = os.getenv("DOTS_OCR_SKIP_MODEL_LOAD", "0") == "1"
|
|
@@ -50,25 +62,27 @@ async def lifespan(app: FastAPI):
|
|
| 50 |
if skip_model_load:
|
| 51 |
# Explicitly skip model loading for fast startup in tests/CI
|
| 52 |
model_loaded = False
|
| 53 |
-
logger.warning(
|
|
|
|
|
|
|
| 54 |
else:
|
| 55 |
# Load the model using the new model loader
|
| 56 |
load_model()
|
| 57 |
model_loaded = True
|
| 58 |
logger.info("Dots.OCR model loaded successfully")
|
| 59 |
-
|
| 60 |
# Log model information
|
| 61 |
model_info = get_model_info()
|
| 62 |
logger.info(f"Model info: {model_info}")
|
| 63 |
-
|
| 64 |
except Exception as e:
|
| 65 |
logger.error(f"Failed to load Dots.OCR model: {e}")
|
| 66 |
# Don't raise - allow mock mode for development
|
| 67 |
model_loaded = False
|
| 68 |
logger.warning("Model loading failed - using mock implementation")
|
| 69 |
-
|
| 70 |
yield
|
| 71 |
-
|
| 72 |
logger.info("Shutting down Dots.OCR endpoint...")
|
| 73 |
|
| 74 |
|
|
@@ -76,61 +90,79 @@ app = FastAPI(
|
|
| 76 |
title="KYB Dots.OCR Text Extraction",
|
| 77 |
description="Dots.OCR for identity document text extraction with ROI support",
|
| 78 |
version="1.0.0",
|
| 79 |
-
lifespan=lifespan
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
@app.get("/")
|
| 84 |
async def root():
|
| 85 |
"""Root route for uptime checks."""
|
| 86 |
-
return {"status": "ok"
|
| 87 |
|
| 88 |
|
| 89 |
@app.get("/health")
|
| 90 |
async def health_check():
|
| 91 |
"""Health check endpoint."""
|
| 92 |
global model_loaded
|
| 93 |
-
|
| 94 |
status = "healthy" if model_loaded else "degraded"
|
| 95 |
model_info = get_model_info() if model_loaded else None
|
| 96 |
-
|
| 97 |
return {
|
| 98 |
-
"status": status,
|
| 99 |
"version": "1.0.0",
|
| 100 |
"model_loaded": model_loaded,
|
| 101 |
-
"model_info": model_info
|
| 102 |
}
|
| 103 |
|
| 104 |
|
| 105 |
@app.post("/v1/id/ocr", response_model=OCRResponse)
|
| 106 |
async def extract_text_endpoint(
|
| 107 |
file: UploadFile = File(..., description="Image or PDF file to process"),
|
| 108 |
-
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
):
|
| 110 |
"""Extract text from identity document image or PDF."""
|
| 111 |
global model_loaded
|
| 112 |
-
|
| 113 |
# Allow mock mode when model isn't loaded to support tests/CI and dev flows
|
| 114 |
allow_mock = os.getenv("DOTS_OCR_ALLOW_MOCK", "1") == "1"
|
| 115 |
is_mock_mode = (not model_loaded) and allow_mock
|
| 116 |
if not model_loaded and not allow_mock:
|
| 117 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
start_time = time.time()
|
| 120 |
request_id = str(uuid.uuid4())
|
| 121 |
-
|
| 122 |
try:
|
| 123 |
# Read file data
|
| 124 |
file_data = await file.read()
|
| 125 |
-
|
| 126 |
# Validate file size
|
| 127 |
if not validate_file_size(file_data):
|
| 128 |
raise HTTPException(status_code=413, detail="File size exceeds limit")
|
| 129 |
-
|
| 130 |
# Get document information
|
| 131 |
doc_info = get_document_info(file_data)
|
| 132 |
logger.info(f"Processing document: {doc_info}")
|
| 133 |
-
|
| 134 |
# Parse ROI if provided
|
| 135 |
roi_coords = None
|
| 136 |
if roi:
|
|
@@ -142,19 +174,21 @@ async def extract_text_endpoint(
|
|
| 142 |
except Exception as e:
|
| 143 |
logger.warning(f"Invalid ROI provided: {e}")
|
| 144 |
raise HTTPException(status_code=400, detail=f"Invalid ROI format: {e}")
|
| 145 |
-
|
| 146 |
# Process document (PDF to images or single image)
|
| 147 |
try:
|
| 148 |
processed_images = process_document(file_data, roi_coords)
|
| 149 |
logger.info(f"Processed {len(processed_images)} images from document")
|
| 150 |
except Exception as e:
|
| 151 |
logger.error(f"Document processing failed: {e}")
|
| 152 |
-
raise HTTPException(
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
# Process each image and extract text
|
| 155 |
ocr_texts = []
|
| 156 |
page_metadata = []
|
| 157 |
-
|
| 158 |
for i, image in enumerate(processed_images):
|
| 159 |
try:
|
| 160 |
# Extract text using the loaded model, or produce mock output in mock mode
|
|
@@ -163,47 +197,50 @@ async def extract_text_endpoint(
|
|
| 163 |
ocr_text = ""
|
| 164 |
else:
|
| 165 |
ocr_text = extract_text(image)
|
| 166 |
-
logger.info(
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
ocr_texts.append(ocr_text)
|
| 169 |
-
|
| 170 |
# Collect page metadata
|
| 171 |
page_meta = {
|
| 172 |
"page_index": i,
|
| 173 |
"image_size": image.size,
|
| 174 |
"text_length": len(ocr_text),
|
| 175 |
-
"processing_successful": True
|
| 176 |
}
|
| 177 |
page_metadata.append(page_meta)
|
| 178 |
-
|
| 179 |
except Exception as e:
|
| 180 |
logger.error(f"Text extraction failed for page {i + 1}: {e}")
|
| 181 |
# Add empty text for failed page
|
| 182 |
ocr_texts.append("")
|
| 183 |
-
|
| 184 |
page_meta = {
|
| 185 |
"page_index": i,
|
| 186 |
-
"image_size": image.size if hasattr(image,
|
| 187 |
"text_length": 0,
|
| 188 |
"processing_successful": False,
|
| 189 |
-
"error": str(e)
|
| 190 |
}
|
| 191 |
page_metadata.append(page_meta)
|
| 192 |
-
|
| 193 |
# Determine media type for response
|
| 194 |
media_type = "pdf" if doc_info["is_pdf"] else "image"
|
| 195 |
-
|
| 196 |
processing_time = time.time() - start_time
|
| 197 |
-
|
| 198 |
# Build response using the response builder
|
| 199 |
return build_ocr_response(
|
| 200 |
request_id=request_id,
|
| 201 |
media_type=media_type,
|
| 202 |
processing_time=processing_time,
|
| 203 |
ocr_texts=ocr_texts,
|
| 204 |
-
page_metadata=page_metadata
|
|
|
|
| 205 |
)
|
| 206 |
-
|
| 207 |
except HTTPException:
|
| 208 |
# Re-raise HTTP exceptions as-is
|
| 209 |
raise
|
|
@@ -213,11 +250,12 @@ async def extract_text_endpoint(
|
|
| 213 |
error_response = build_error_response(
|
| 214 |
request_id=request_id,
|
| 215 |
error_message=f"OCR extraction failed: {str(e)}",
|
| 216 |
-
processing_time=processing_time
|
| 217 |
)
|
| 218 |
raise HTTPException(status_code=500, detail=error_response.dict())
|
| 219 |
|
| 220 |
|
| 221 |
if __name__ == "__main__":
|
| 222 |
import uvicorn
|
|
|
|
| 223 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 17 |
from fastapi.responses import JSONResponse
|
| 18 |
|
| 19 |
# Import local modules
|
| 20 |
+
from .api_models import (
|
| 21 |
+
BoundingBox,
|
| 22 |
+
ExtractedField,
|
| 23 |
+
ExtractedFields,
|
| 24 |
+
MRZData,
|
| 25 |
+
OCRDetection,
|
| 26 |
+
OCRResponse,
|
| 27 |
+
)
|
| 28 |
from .enhanced_field_extraction import EnhancedFieldExtractor
|
| 29 |
from .model_loader import load_model, extract_text, is_model_loaded, get_model_info
|
| 30 |
from .preprocessing import process_document, validate_file_size, get_document_info
|
|
|
|
| 34 |
logging.basicConfig(level=logging.INFO)
|
| 35 |
logger = logging.getLogger(__name__)
|
| 36 |
|
| 37 |
+
# Enable verbose logging globally if DOTS_OCR_DEBUG env var is set.
|
| 38 |
+
_env_debug = os.getenv("DOTS_OCR_DEBUG", "0").lower() in {"1", "true", "yes"}
|
| 39 |
+
if _env_debug:
|
| 40 |
+
# Elevate root logger to DEBUG to include lower-level events from submodules
|
| 41 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 42 |
+
logger.info("DOTS_OCR_DEBUG enabled via environment — verbose logging active")
|
| 43 |
+
|
| 44 |
# Global model state
|
| 45 |
model_loaded = False
|
| 46 |
|
|
|
|
| 48 |
# FieldExtractor is now imported from the shared module
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
| 51 |
@asynccontextmanager
|
| 52 |
async def lifespan(app: FastAPI):
|
| 53 |
"""Application lifespan manager for model loading."""
|
| 54 |
global model_loaded
|
| 55 |
+
|
| 56 |
# Allow tests and lightweight environments to skip model loading
|
| 57 |
# Set DOTS_OCR_SKIP_MODEL_LOAD=1 to bypass heavy downloads during tests/CI
|
| 58 |
skip_model_load = os.getenv("DOTS_OCR_SKIP_MODEL_LOAD", "0") == "1"
|
|
|
|
| 62 |
if skip_model_load:
|
| 63 |
# Explicitly skip model loading for fast startup in tests/CI
|
| 64 |
model_loaded = False
|
| 65 |
+
logger.warning(
|
| 66 |
+
"DOTS_OCR_SKIP_MODEL_LOAD=1 set - skipping model load (mock mode)"
|
| 67 |
+
)
|
| 68 |
else:
|
| 69 |
# Load the model using the new model loader
|
| 70 |
load_model()
|
| 71 |
model_loaded = True
|
| 72 |
logger.info("Dots.OCR model loaded successfully")
|
| 73 |
+
|
| 74 |
# Log model information
|
| 75 |
model_info = get_model_info()
|
| 76 |
logger.info(f"Model info: {model_info}")
|
| 77 |
+
|
| 78 |
except Exception as e:
|
| 79 |
logger.error(f"Failed to load Dots.OCR model: {e}")
|
| 80 |
# Don't raise - allow mock mode for development
|
| 81 |
model_loaded = False
|
| 82 |
logger.warning("Model loading failed - using mock implementation")
|
| 83 |
+
|
| 84 |
yield
|
| 85 |
+
|
| 86 |
logger.info("Shutting down Dots.OCR endpoint...")
|
| 87 |
|
| 88 |
|
|
|
|
| 90 |
title="KYB Dots.OCR Text Extraction",
|
| 91 |
description="Dots.OCR for identity document text extraction with ROI support",
|
| 92 |
version="1.0.0",
|
| 93 |
+
lifespan=lifespan,
|
| 94 |
)
|
| 95 |
|
| 96 |
|
| 97 |
@app.get("/")
|
| 98 |
async def root():
|
| 99 |
"""Root route for uptime checks."""
|
| 100 |
+
return {"status": "ok"}
|
| 101 |
|
| 102 |
|
| 103 |
@app.get("/health")
|
| 104 |
async def health_check():
|
| 105 |
"""Health check endpoint."""
|
| 106 |
global model_loaded
|
| 107 |
+
|
| 108 |
status = "healthy" if model_loaded else "degraded"
|
| 109 |
model_info = get_model_info() if model_loaded else None
|
| 110 |
+
|
| 111 |
return {
|
| 112 |
+
"status": status,
|
| 113 |
"version": "1.0.0",
|
| 114 |
"model_loaded": model_loaded,
|
| 115 |
+
"model_info": model_info,
|
| 116 |
}
|
| 117 |
|
| 118 |
|
| 119 |
@app.post("/v1/id/ocr", response_model=OCRResponse)
|
| 120 |
async def extract_text_endpoint(
|
| 121 |
file: UploadFile = File(..., description="Image or PDF file to process"),
|
| 122 |
+
roi: Optional[str] = Form(None, description="ROI coordinates as JSON string"),
|
| 123 |
+
debug: Optional[bool] = Form(
|
| 124 |
+
None,
|
| 125 |
+
description=(
|
| 126 |
+
"Enable verbose debug logging for this request. Overrides env when True."
|
| 127 |
+
),
|
| 128 |
+
),
|
| 129 |
):
|
| 130 |
"""Extract text from identity document image or PDF."""
|
| 131 |
global model_loaded
|
| 132 |
+
|
| 133 |
# Allow mock mode when model isn't loaded to support tests/CI and dev flows
|
| 134 |
allow_mock = os.getenv("DOTS_OCR_ALLOW_MOCK", "1") == "1"
|
| 135 |
is_mock_mode = (not model_loaded) and allow_mock
|
| 136 |
if not model_loaded and not allow_mock:
|
| 137 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 138 |
+
|
| 139 |
+
# Determine effective debug mode for this request
|
| 140 |
+
env_debug = os.getenv("DOTS_OCR_DEBUG", "0").lower() in {"1", "true", "yes"}
|
| 141 |
+
debug_enabled = bool(debug) if debug is not None else env_debug
|
| 142 |
+
if debug_enabled:
|
| 143 |
+
logger.info(
|
| 144 |
+
f"[debug] Request {request_id}: debug logging enabled (env={env_debug}, form={debug})"
|
| 145 |
+
)
|
| 146 |
+
if is_mock_mode:
|
| 147 |
+
logger.warning(
|
| 148 |
+
"Using mock mode — OCR text will be empty. To enable real inference, ensure the model loads successfully (unset DOTS_OCR_SKIP_MODEL_LOAD and provide resources)."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
start_time = time.time()
|
| 152 |
request_id = str(uuid.uuid4())
|
| 153 |
+
|
| 154 |
try:
|
| 155 |
# Read file data
|
| 156 |
file_data = await file.read()
|
| 157 |
+
|
| 158 |
# Validate file size
|
| 159 |
if not validate_file_size(file_data):
|
| 160 |
raise HTTPException(status_code=413, detail="File size exceeds limit")
|
| 161 |
+
|
| 162 |
# Get document information
|
| 163 |
doc_info = get_document_info(file_data)
|
| 164 |
logger.info(f"Processing document: {doc_info}")
|
| 165 |
+
|
| 166 |
# Parse ROI if provided
|
| 167 |
roi_coords = None
|
| 168 |
if roi:
|
|
|
|
| 174 |
except Exception as e:
|
| 175 |
logger.warning(f"Invalid ROI provided: {e}")
|
| 176 |
raise HTTPException(status_code=400, detail=f"Invalid ROI format: {e}")
|
| 177 |
+
|
| 178 |
# Process document (PDF to images or single image)
|
| 179 |
try:
|
| 180 |
processed_images = process_document(file_data, roi_coords)
|
| 181 |
logger.info(f"Processed {len(processed_images)} images from document")
|
| 182 |
except Exception as e:
|
| 183 |
logger.error(f"Document processing failed: {e}")
|
| 184 |
+
raise HTTPException(
|
| 185 |
+
status_code=400, detail=f"Document processing failed: {e}"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
# Process each image and extract text
|
| 189 |
ocr_texts = []
|
| 190 |
page_metadata = []
|
| 191 |
+
|
| 192 |
for i, image in enumerate(processed_images):
|
| 193 |
try:
|
| 194 |
# Extract text using the loaded model, or produce mock output in mock mode
|
|
|
|
| 197 |
ocr_text = ""
|
| 198 |
else:
|
| 199 |
ocr_text = extract_text(image)
|
| 200 |
+
logger.info(
|
| 201 |
+
f"Page {i + 1} - Extracted text length: {len(ocr_text)} characters"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
ocr_texts.append(ocr_text)
|
| 205 |
+
|
| 206 |
# Collect page metadata
|
| 207 |
page_meta = {
|
| 208 |
"page_index": i,
|
| 209 |
"image_size": image.size,
|
| 210 |
"text_length": len(ocr_text),
|
| 211 |
+
"processing_successful": True,
|
| 212 |
}
|
| 213 |
page_metadata.append(page_meta)
|
| 214 |
+
|
| 215 |
except Exception as e:
|
| 216 |
logger.error(f"Text extraction failed for page {i + 1}: {e}")
|
| 217 |
# Add empty text for failed page
|
| 218 |
ocr_texts.append("")
|
| 219 |
+
|
| 220 |
page_meta = {
|
| 221 |
"page_index": i,
|
| 222 |
+
"image_size": image.size if hasattr(image, "size") else (0, 0),
|
| 223 |
"text_length": 0,
|
| 224 |
"processing_successful": False,
|
| 225 |
+
"error": str(e),
|
| 226 |
}
|
| 227 |
page_metadata.append(page_meta)
|
| 228 |
+
|
| 229 |
# Determine media type for response
|
| 230 |
media_type = "pdf" if doc_info["is_pdf"] else "image"
|
| 231 |
+
|
| 232 |
processing_time = time.time() - start_time
|
| 233 |
+
|
| 234 |
# Build response using the response builder
|
| 235 |
return build_ocr_response(
|
| 236 |
request_id=request_id,
|
| 237 |
media_type=media_type,
|
| 238 |
processing_time=processing_time,
|
| 239 |
ocr_texts=ocr_texts,
|
| 240 |
+
page_metadata=page_metadata,
|
| 241 |
+
debug=debug_enabled,
|
| 242 |
)
|
| 243 |
+
|
| 244 |
except HTTPException:
|
| 245 |
# Re-raise HTTP exceptions as-is
|
| 246 |
raise
|
|
|
|
| 250 |
error_response = build_error_response(
|
| 251 |
request_id=request_id,
|
| 252 |
error_message=f"OCR extraction failed: {str(e)}",
|
| 253 |
+
processing_time=processing_time,
|
| 254 |
)
|
| 255 |
raise HTTPException(status_code=500, detail=error_response.dict())
|
| 256 |
|
| 257 |
|
| 258 |
if __name__ == "__main__":
|
| 259 |
import uvicorn
|
| 260 |
+
|
| 261 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
src/kybtech_dots_ocr/enhanced_field_extraction.py
CHANGED
|
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|
| 16 |
|
| 17 |
class EnhancedFieldExtractor:
|
| 18 |
"""Enhanced field extraction with improved confidence scoring and validation."""
|
| 19 |
-
|
| 20 |
# Enhanced field mapping patterns with confidence scoring
|
| 21 |
FIELD_PATTERNS = {
|
| 22 |
"document_number": [
|
|
@@ -35,7 +35,10 @@ class EnhancedFieldExtractor:
|
|
| 35 |
],
|
| 36 |
"given_names": [
|
| 37 |
(r"^\s*voornamen[:\s]*([^\r\n]+)", 0.95), # Dutch format (line-anchored)
|
| 38 |
-
(
|
|
|
|
|
|
|
|
|
|
| 39 |
(r"^\s*first\s*name[:\s]*([^\r\n]+)", 0.85), # First name only
|
| 40 |
(r"^\s*voorletters[:\s]*([^\r\n]+)", 0.75), # Dutch initials
|
| 41 |
],
|
|
@@ -46,7 +49,10 @@ class EnhancedFieldExtractor:
|
|
| 46 |
],
|
| 47 |
"date_of_birth": [
|
| 48 |
(r"geboortedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 49 |
-
(
|
|
|
|
|
|
|
|
|
|
| 50 |
(r"born[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 51 |
(r"(\d{2}[./-]\d{2}[./-]\d{4})", 0.6), # Generic date pattern
|
| 52 |
],
|
|
@@ -64,14 +70,23 @@ class EnhancedFieldExtractor:
|
|
| 64 |
],
|
| 65 |
"date_of_issue": [
|
| 66 |
(r"uitgiftedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 67 |
-
(
|
|
|
|
|
|
|
|
|
|
| 68 |
(r"issued[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 69 |
],
|
| 70 |
"date_of_expiry": [
|
| 71 |
(r"vervaldatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 72 |
-
(
|
|
|
|
|
|
|
|
|
|
| 73 |
(r"expires[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 74 |
-
(
|
|
|
|
|
|
|
|
|
|
| 75 |
],
|
| 76 |
"personal_number": [
|
| 77 |
(r"persoonsnummer[:\s]*(\d{9})", 0.9), # Dutch format
|
|
@@ -95,39 +110,48 @@ class EnhancedFieldExtractor:
|
|
| 95 |
(r"issuing\s*authority[:\s]*([A-Za-z\s]{3,30})", 0.8), # English format
|
| 96 |
(r"uitgevende\s*autoriteit[:\s]*([A-Za-z\s]{3,30})", 0.9), # Dutch format
|
| 97 |
(r"authority[:\s]*([A-Za-z\s]{3,30})", 0.7), # Short format
|
| 98 |
-
]
|
| 99 |
}
|
| 100 |
-
|
| 101 |
# MRZ patterns with confidence scoring
|
| 102 |
MRZ_PATTERNS = [
|
| 103 |
# Strict formats first, allowing leading/trailing whitespace per line
|
| 104 |
-
(
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Fallback generic: a line starting with P< followed by another MRZ-like line
|
| 108 |
(r"(P<[^\r\n]+\n[^\r\n]+)", 0.85),
|
| 109 |
]
|
| 110 |
-
|
| 111 |
@classmethod
|
| 112 |
def extract_fields(cls, ocr_text: str) -> IdCardFields:
|
| 113 |
"""Extract structured fields from OCR text with enhanced confidence scoring.
|
| 114 |
-
|
| 115 |
Args:
|
| 116 |
ocr_text: Raw OCR text from document processing
|
| 117 |
-
|
| 118 |
Returns:
|
| 119 |
IdCardFields object with extracted field data
|
| 120 |
"""
|
| 121 |
logger.info(f"Extracting fields from text of length: {len(ocr_text)}")
|
| 122 |
-
|
| 123 |
fields = {}
|
| 124 |
extraction_stats = {"total_patterns": 0, "matches_found": 0}
|
| 125 |
-
|
| 126 |
for field_name, patterns in cls.FIELD_PATTERNS.items():
|
| 127 |
value = None
|
| 128 |
confidence = 0.0
|
| 129 |
best_pattern = None
|
| 130 |
-
|
| 131 |
for pattern, base_confidence in patterns:
|
| 132 |
extraction_stats["total_patterns"] += 1
|
| 133 |
match = re.search(pattern, ocr_text, re.IGNORECASE | re.MULTILINE)
|
|
@@ -139,37 +163,43 @@ class EnhancedFieldExtractor:
|
|
| 139 |
confidence = base_confidence
|
| 140 |
best_pattern = pattern
|
| 141 |
extraction_stats["matches_found"] += 1
|
| 142 |
-
logger.debug(
|
|
|
|
|
|
|
| 143 |
break
|
| 144 |
-
|
| 145 |
if value:
|
| 146 |
# Apply additional confidence adjustments
|
| 147 |
-
confidence = cls._adjust_confidence(
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
fields[field_name] = ExtractedField(
|
| 150 |
field_name=field_name,
|
| 151 |
value=value,
|
| 152 |
confidence=confidence,
|
| 153 |
-
source="ocr"
|
| 154 |
)
|
| 155 |
-
|
| 156 |
-
logger.info(
|
|
|
|
|
|
|
| 157 |
return IdCardFields(**fields)
|
| 158 |
-
|
| 159 |
@classmethod
|
| 160 |
def _validate_field_value(cls, field_name: str, value: str) -> bool:
|
| 161 |
"""Validate extracted field value based on field type.
|
| 162 |
-
|
| 163 |
Args:
|
| 164 |
field_name: Name of the field
|
| 165 |
value: Extracted value to validate
|
| 166 |
-
|
| 167 |
Returns:
|
| 168 |
True if value is valid
|
| 169 |
"""
|
| 170 |
if not value or len(value.strip()) == 0:
|
| 171 |
return False
|
| 172 |
-
|
| 173 |
# Field-specific validation
|
| 174 |
if field_name == "document_number":
|
| 175 |
return len(value) >= 6 and len(value) <= 15
|
|
@@ -185,16 +215,16 @@ class EnhancedFieldExtractor:
|
|
| 185 |
return len(value) == 9 and value.isdigit()
|
| 186 |
elif field_name == "issuing_country":
|
| 187 |
return len(value) == 3 and value.isalpha()
|
| 188 |
-
|
| 189 |
return True
|
| 190 |
-
|
| 191 |
@classmethod
|
| 192 |
def _validate_date_format(cls, date_str: str) -> bool:
|
| 193 |
"""Validate date format and basic date logic.
|
| 194 |
-
|
| 195 |
Args:
|
| 196 |
date_str: Date string to validate
|
| 197 |
-
|
| 198 |
Returns:
|
| 199 |
True if date format is valid
|
| 200 |
"""
|
|
@@ -206,59 +236,63 @@ class EnhancedFieldExtractor:
|
|
| 206 |
if len(parts) == 3:
|
| 207 |
day, month, year = parts
|
| 208 |
# Basic validation
|
| 209 |
-
if (
|
| 210 |
-
1 <= int(
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
return True
|
| 213 |
except (ValueError, IndexError):
|
| 214 |
pass
|
| 215 |
return False
|
| 216 |
-
|
| 217 |
@classmethod
|
| 218 |
-
def _adjust_confidence(
|
|
|
|
|
|
|
| 219 |
"""Adjust confidence based on additional factors.
|
| 220 |
-
|
| 221 |
Args:
|
| 222 |
field_name: Name of the field
|
| 223 |
value: Extracted value
|
| 224 |
base_confidence: Base confidence from pattern matching
|
| 225 |
full_text: Full OCR text for context
|
| 226 |
-
|
| 227 |
Returns:
|
| 228 |
Adjusted confidence score
|
| 229 |
"""
|
| 230 |
confidence = base_confidence
|
| 231 |
-
|
| 232 |
# Length-based adjustments
|
| 233 |
if field_name in ["surname", "given_names"] and len(value) < 3:
|
| 234 |
confidence *= 0.8 # Shorter names are less reliable
|
| 235 |
-
|
| 236 |
# Context-based adjustments
|
| 237 |
if field_name == "document_number" and "passport" in full_text.lower():
|
| 238 |
confidence *= 1.1 # Higher confidence in passport context
|
| 239 |
-
|
| 240 |
# Multiple occurrence bonus
|
| 241 |
if value in full_text and full_text.count(value) > 1:
|
| 242 |
confidence *= 1.05 # Slight bonus for repeated values
|
| 243 |
-
|
| 244 |
# Ensure confidence stays within bounds
|
| 245 |
return min(max(confidence, 0.0), 1.0)
|
| 246 |
-
|
| 247 |
@classmethod
|
| 248 |
def extract_mrz(cls, ocr_text: str) -> Optional[MRZData]:
|
| 249 |
"""Extract MRZ data from OCR text with enhanced validation.
|
| 250 |
-
|
| 251 |
Args:
|
| 252 |
ocr_text: Raw OCR text from document processing
|
| 253 |
-
|
| 254 |
Returns:
|
| 255 |
MRZData object if MRZ detected, None otherwise
|
| 256 |
"""
|
| 257 |
logger.info("Extracting MRZ data from OCR text")
|
| 258 |
-
|
| 259 |
best_match = None
|
| 260 |
best_confidence = 0.0
|
| 261 |
-
|
| 262 |
for pattern, base_confidence in cls.MRZ_PATTERNS:
|
| 263 |
match = re.search(pattern, ocr_text, re.MULTILINE)
|
| 264 |
if match:
|
|
@@ -268,23 +302,24 @@ class EnhancedFieldExtractor:
|
|
| 268 |
confidence = base_confidence
|
| 269 |
# Adjust confidence based on MRZ quality
|
| 270 |
confidence = cls._adjust_mrz_confidence(raw_mrz, confidence)
|
| 271 |
-
|
| 272 |
if confidence > best_confidence:
|
| 273 |
best_match = raw_mrz
|
| 274 |
best_confidence = confidence
|
| 275 |
logger.debug(f"Found MRZ with confidence {confidence:.2f}")
|
| 276 |
-
|
| 277 |
if best_match:
|
| 278 |
# Parse MRZ to determine format type
|
| 279 |
format_type = cls._determine_mrz_format(best_match)
|
| 280 |
-
|
| 281 |
# Basic checksum validation
|
| 282 |
is_valid, errors = cls._validate_mrz_checksums(best_match, format_type)
|
| 283 |
-
|
| 284 |
logger.info(f"MRZ extracted: {format_type} format, valid: {is_valid}")
|
| 285 |
-
|
| 286 |
# Convert to the format expected by the API
|
| 287 |
from .api_models import MRZData as APIMRZData
|
|
|
|
| 288 |
# Populate both canonical and legacy alias fields for compatibility
|
| 289 |
return APIMRZData(
|
| 290 |
document_type=format_type,
|
|
@@ -302,47 +337,47 @@ class EnhancedFieldExtractor:
|
|
| 302 |
raw_text=best_match, # legacy alias
|
| 303 |
confidence=best_confidence,
|
| 304 |
)
|
| 305 |
-
|
| 306 |
logger.info("No MRZ data found in OCR text")
|
| 307 |
return None
|
| 308 |
-
|
| 309 |
@classmethod
|
| 310 |
def _validate_mrz_format(cls, mrz_text: str) -> bool:
|
| 311 |
"""Validate basic MRZ format.
|
| 312 |
-
|
| 313 |
Args:
|
| 314 |
mrz_text: Raw MRZ text
|
| 315 |
-
|
| 316 |
Returns:
|
| 317 |
True if format is valid
|
| 318 |
"""
|
| 319 |
-
lines = mrz_text.strip().split(
|
| 320 |
if len(lines) < 2:
|
| 321 |
return False
|
| 322 |
-
|
| 323 |
# Normalize whitespace and validate character set only.
|
| 324 |
normalized_lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 325 |
for line in normalized_lines:
|
| 326 |
-
if not re.match(r
|
| 327 |
return False
|
| 328 |
-
|
| 329 |
return True
|
| 330 |
-
|
| 331 |
@classmethod
|
| 332 |
def _determine_mrz_format(cls, mrz_text: str) -> str:
|
| 333 |
"""Determine MRZ format type.
|
| 334 |
-
|
| 335 |
Args:
|
| 336 |
mrz_text: Raw MRZ text
|
| 337 |
-
|
| 338 |
Returns:
|
| 339 |
Format type (TD1, TD2, TD3, etc.)
|
| 340 |
"""
|
| 341 |
-
lines = mrz_text.strip().split(
|
| 342 |
lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 343 |
line_count = len(lines)
|
| 344 |
line_length = len(lines[0]) if lines else 0
|
| 345 |
-
|
| 346 |
# Heuristic mapping: prioritize semantics over exact lengths for robustness
|
| 347 |
if line_count == 2 and lines[0].startswith("P<"):
|
| 348 |
return "TD3" # Passport format commonly starts with P<
|
|
@@ -351,53 +386,56 @@ class EnhancedFieldExtractor:
|
|
| 351 |
if line_count == 3:
|
| 352 |
return "TD1"
|
| 353 |
return "UNKNOWN"
|
| 354 |
-
|
| 355 |
@classmethod
|
| 356 |
def _adjust_mrz_confidence(cls, mrz_text: str, base_confidence: float) -> float:
|
| 357 |
"""Adjust MRZ confidence based on quality indicators.
|
| 358 |
-
|
| 359 |
Args:
|
| 360 |
mrz_text: Raw MRZ text
|
| 361 |
base_confidence: Base confidence from pattern matching
|
| 362 |
-
|
| 363 |
Returns:
|
| 364 |
Adjusted confidence
|
| 365 |
"""
|
| 366 |
confidence = base_confidence
|
| 367 |
-
|
| 368 |
# Check line consistency
|
| 369 |
-
lines = mrz_text.strip().split(
|
| 370 |
if len(set(len(line) for line in lines)) == 1:
|
| 371 |
confidence *= 1.05 # Bonus for consistent line lengths
|
| 372 |
-
|
| 373 |
return min(max(confidence, 0.0), 1.0)
|
| 374 |
-
|
| 375 |
@classmethod
|
| 376 |
-
def _validate_mrz_checksums(
|
|
|
|
|
|
|
| 377 |
"""Validate MRZ checksums (simplified implementation).
|
| 378 |
-
|
| 379 |
Args:
|
| 380 |
mrz_text: Raw MRZ text
|
| 381 |
format_type: MRZ format type
|
| 382 |
-
|
| 383 |
Returns:
|
| 384 |
Tuple of (is_valid, list_of_errors)
|
| 385 |
"""
|
| 386 |
# This is a simplified implementation
|
| 387 |
# In production, you would implement full MRZ checksum validation
|
| 388 |
errors = []
|
| 389 |
-
|
| 390 |
# Basic validation - check for reasonable character distribution
|
| 391 |
-
if mrz_text.count(
|
| 392 |
errors.append("Too many fill characters")
|
| 393 |
-
|
| 394 |
# For now, assume valid if basic format is correct
|
| 395 |
is_valid = len(errors) == 0
|
| 396 |
-
|
| 397 |
return is_valid, errors
|
| 398 |
|
| 399 |
|
| 400 |
# Backward compatibility - use enhanced extractor as default
|
| 401 |
class FieldExtractor(EnhancedFieldExtractor):
|
| 402 |
"""Backward compatible field extractor using enhanced implementation."""
|
|
|
|
| 403 |
pass
|
|
|
|
| 16 |
|
| 17 |
class EnhancedFieldExtractor:
|
| 18 |
"""Enhanced field extraction with improved confidence scoring and validation."""
|
| 19 |
+
|
| 20 |
# Enhanced field mapping patterns with confidence scoring
|
| 21 |
FIELD_PATTERNS = {
|
| 22 |
"document_number": [
|
|
|
|
| 35 |
],
|
| 36 |
"given_names": [
|
| 37 |
(r"^\s*voornamen[:\s]*([^\r\n]+)", 0.95), # Dutch format (line-anchored)
|
| 38 |
+
(
|
| 39 |
+
r"^\s*given\s*names[:\s]*([^\r\n]+)",
|
| 40 |
+
0.9,
|
| 41 |
+
), # English format (line-anchored)
|
| 42 |
(r"^\s*first\s*name[:\s]*([^\r\n]+)", 0.85), # First name only
|
| 43 |
(r"^\s*voorletters[:\s]*([^\r\n]+)", 0.75), # Dutch initials
|
| 44 |
],
|
|
|
|
| 49 |
],
|
| 50 |
"date_of_birth": [
|
| 51 |
(r"geboortedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 52 |
+
(
|
| 53 |
+
r"date\s*of\s*birth[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 54 |
+
0.85,
|
| 55 |
+
), # English format
|
| 56 |
(r"born[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 57 |
(r"(\d{2}[./-]\d{2}[./-]\d{4})", 0.6), # Generic date pattern
|
| 58 |
],
|
|
|
|
| 70 |
],
|
| 71 |
"date_of_issue": [
|
| 72 |
(r"uitgiftedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 73 |
+
(
|
| 74 |
+
r"date\s*of\s*issue[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 75 |
+
0.85,
|
| 76 |
+
), # English format
|
| 77 |
(r"issued[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 78 |
],
|
| 79 |
"date_of_expiry": [
|
| 80 |
(r"vervaldatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.9), # Dutch format
|
| 81 |
+
(
|
| 82 |
+
r"date\s*of\s*expiry[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 83 |
+
0.85,
|
| 84 |
+
), # English format
|
| 85 |
(r"expires[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})", 0.8), # Short English
|
| 86 |
+
(
|
| 87 |
+
r"valid\s*until[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 88 |
+
0.8,
|
| 89 |
+
), # Alternative English
|
| 90 |
],
|
| 91 |
"personal_number": [
|
| 92 |
(r"persoonsnummer[:\s]*(\d{9})", 0.9), # Dutch format
|
|
|
|
| 110 |
(r"issuing\s*authority[:\s]*([A-Za-z\s]{3,30})", 0.8), # English format
|
| 111 |
(r"uitgevende\s*autoriteit[:\s]*([A-Za-z\s]{3,30})", 0.9), # Dutch format
|
| 112 |
(r"authority[:\s]*([A-Za-z\s]{3,30})", 0.7), # Short format
|
| 113 |
+
],
|
| 114 |
}
|
| 115 |
+
|
| 116 |
# MRZ patterns with confidence scoring
|
| 117 |
MRZ_PATTERNS = [
|
| 118 |
# Strict formats first, allowing leading/trailing whitespace per line
|
| 119 |
+
(
|
| 120 |
+
r"^\s*((?:[A-Z0-9<]{44})\s*\n\s*(?:[A-Z0-9<]{44}))\s*$",
|
| 121 |
+
0.95,
|
| 122 |
+
), # TD3: Passport (2 x 44)
|
| 123 |
+
(
|
| 124 |
+
r"^\s*((?:[A-Z0-9<]{36})\s*\n\s*(?:[A-Z0-9<]{36}))\s*$",
|
| 125 |
+
0.9,
|
| 126 |
+
), # TD2: ID card (2 x 36)
|
| 127 |
+
(
|
| 128 |
+
r"^\s*((?:[A-Z0-9<]{30})\s*\n\s*(?:[A-Z0-9<]{30})\s*\n\s*(?:[A-Z0-9<]{30}))\s*$",
|
| 129 |
+
0.85,
|
| 130 |
+
), # TD1: (3 x 30)
|
| 131 |
# Fallback generic: a line starting with P< followed by another MRZ-like line
|
| 132 |
(r"(P<[^\r\n]+\n[^\r\n]+)", 0.85),
|
| 133 |
]
|
| 134 |
+
|
| 135 |
@classmethod
|
| 136 |
def extract_fields(cls, ocr_text: str) -> IdCardFields:
|
| 137 |
"""Extract structured fields from OCR text with enhanced confidence scoring.
|
| 138 |
+
|
| 139 |
Args:
|
| 140 |
ocr_text: Raw OCR text from document processing
|
| 141 |
+
|
| 142 |
Returns:
|
| 143 |
IdCardFields object with extracted field data
|
| 144 |
"""
|
| 145 |
logger.info(f"Extracting fields from text of length: {len(ocr_text)}")
|
| 146 |
+
|
| 147 |
fields = {}
|
| 148 |
extraction_stats = {"total_patterns": 0, "matches_found": 0}
|
| 149 |
+
|
| 150 |
for field_name, patterns in cls.FIELD_PATTERNS.items():
|
| 151 |
value = None
|
| 152 |
confidence = 0.0
|
| 153 |
best_pattern = None
|
| 154 |
+
|
| 155 |
for pattern, base_confidence in patterns:
|
| 156 |
extraction_stats["total_patterns"] += 1
|
| 157 |
match = re.search(pattern, ocr_text, re.IGNORECASE | re.MULTILINE)
|
|
|
|
| 163 |
confidence = base_confidence
|
| 164 |
best_pattern = pattern
|
| 165 |
extraction_stats["matches_found"] += 1
|
| 166 |
+
logger.debug(
|
| 167 |
+
f"Found {field_name}: '{value}' (confidence: {confidence:.2f})"
|
| 168 |
+
)
|
| 169 |
break
|
| 170 |
+
|
| 171 |
if value:
|
| 172 |
# Apply additional confidence adjustments
|
| 173 |
+
confidence = cls._adjust_confidence(
|
| 174 |
+
field_name, value, confidence, ocr_text
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
fields[field_name] = ExtractedField(
|
| 178 |
field_name=field_name,
|
| 179 |
value=value,
|
| 180 |
confidence=confidence,
|
| 181 |
+
source="ocr",
|
| 182 |
)
|
| 183 |
+
|
| 184 |
+
logger.info(
|
| 185 |
+
f"Field extraction complete: {extraction_stats['matches_found']}/{extraction_stats['total_patterns']} patterns matched"
|
| 186 |
+
)
|
| 187 |
return IdCardFields(**fields)
|
| 188 |
+
|
| 189 |
@classmethod
|
| 190 |
def _validate_field_value(cls, field_name: str, value: str) -> bool:
|
| 191 |
"""Validate extracted field value based on field type.
|
| 192 |
+
|
| 193 |
Args:
|
| 194 |
field_name: Name of the field
|
| 195 |
value: Extracted value to validate
|
| 196 |
+
|
| 197 |
Returns:
|
| 198 |
True if value is valid
|
| 199 |
"""
|
| 200 |
if not value or len(value.strip()) == 0:
|
| 201 |
return False
|
| 202 |
+
|
| 203 |
# Field-specific validation
|
| 204 |
if field_name == "document_number":
|
| 205 |
return len(value) >= 6 and len(value) <= 15
|
|
|
|
| 215 |
return len(value) == 9 and value.isdigit()
|
| 216 |
elif field_name == "issuing_country":
|
| 217 |
return len(value) == 3 and value.isalpha()
|
| 218 |
+
|
| 219 |
return True
|
| 220 |
+
|
| 221 |
@classmethod
|
| 222 |
def _validate_date_format(cls, date_str: str) -> bool:
|
| 223 |
"""Validate date format and basic date logic.
|
| 224 |
+
|
| 225 |
Args:
|
| 226 |
date_str: Date string to validate
|
| 227 |
+
|
| 228 |
Returns:
|
| 229 |
True if date format is valid
|
| 230 |
"""
|
|
|
|
| 236 |
if len(parts) == 3:
|
| 237 |
day, month, year = parts
|
| 238 |
# Basic validation
|
| 239 |
+
if (
|
| 240 |
+
1 <= int(day) <= 31
|
| 241 |
+
and 1 <= int(month) <= 12
|
| 242 |
+
and 1900 <= int(year) <= 2100
|
| 243 |
+
):
|
| 244 |
return True
|
| 245 |
except (ValueError, IndexError):
|
| 246 |
pass
|
| 247 |
return False
|
| 248 |
+
|
| 249 |
@classmethod
|
| 250 |
+
def _adjust_confidence(
|
| 251 |
+
cls, field_name: str, value: str, base_confidence: float, full_text: str
|
| 252 |
+
) -> float:
|
| 253 |
"""Adjust confidence based on additional factors.
|
| 254 |
+
|
| 255 |
Args:
|
| 256 |
field_name: Name of the field
|
| 257 |
value: Extracted value
|
| 258 |
base_confidence: Base confidence from pattern matching
|
| 259 |
full_text: Full OCR text for context
|
| 260 |
+
|
| 261 |
Returns:
|
| 262 |
Adjusted confidence score
|
| 263 |
"""
|
| 264 |
confidence = base_confidence
|
| 265 |
+
|
| 266 |
# Length-based adjustments
|
| 267 |
if field_name in ["surname", "given_names"] and len(value) < 3:
|
| 268 |
confidence *= 0.8 # Shorter names are less reliable
|
| 269 |
+
|
| 270 |
# Context-based adjustments
|
| 271 |
if field_name == "document_number" and "passport" in full_text.lower():
|
| 272 |
confidence *= 1.1 # Higher confidence in passport context
|
| 273 |
+
|
| 274 |
# Multiple occurrence bonus
|
| 275 |
if value in full_text and full_text.count(value) > 1:
|
| 276 |
confidence *= 1.05 # Slight bonus for repeated values
|
| 277 |
+
|
| 278 |
# Ensure confidence stays within bounds
|
| 279 |
return min(max(confidence, 0.0), 1.0)
|
| 280 |
+
|
| 281 |
@classmethod
|
| 282 |
def extract_mrz(cls, ocr_text: str) -> Optional[MRZData]:
|
| 283 |
"""Extract MRZ data from OCR text with enhanced validation.
|
| 284 |
+
|
| 285 |
Args:
|
| 286 |
ocr_text: Raw OCR text from document processing
|
| 287 |
+
|
| 288 |
Returns:
|
| 289 |
MRZData object if MRZ detected, None otherwise
|
| 290 |
"""
|
| 291 |
logger.info("Extracting MRZ data from OCR text")
|
| 292 |
+
|
| 293 |
best_match = None
|
| 294 |
best_confidence = 0.0
|
| 295 |
+
|
| 296 |
for pattern, base_confidence in cls.MRZ_PATTERNS:
|
| 297 |
match = re.search(pattern, ocr_text, re.MULTILINE)
|
| 298 |
if match:
|
|
|
|
| 302 |
confidence = base_confidence
|
| 303 |
# Adjust confidence based on MRZ quality
|
| 304 |
confidence = cls._adjust_mrz_confidence(raw_mrz, confidence)
|
| 305 |
+
|
| 306 |
if confidence > best_confidence:
|
| 307 |
best_match = raw_mrz
|
| 308 |
best_confidence = confidence
|
| 309 |
logger.debug(f"Found MRZ with confidence {confidence:.2f}")
|
| 310 |
+
|
| 311 |
if best_match:
|
| 312 |
# Parse MRZ to determine format type
|
| 313 |
format_type = cls._determine_mrz_format(best_match)
|
| 314 |
+
|
| 315 |
# Basic checksum validation
|
| 316 |
is_valid, errors = cls._validate_mrz_checksums(best_match, format_type)
|
| 317 |
+
|
| 318 |
logger.info(f"MRZ extracted: {format_type} format, valid: {is_valid}")
|
| 319 |
+
|
| 320 |
# Convert to the format expected by the API
|
| 321 |
from .api_models import MRZData as APIMRZData
|
| 322 |
+
|
| 323 |
# Populate both canonical and legacy alias fields for compatibility
|
| 324 |
return APIMRZData(
|
| 325 |
document_type=format_type,
|
|
|
|
| 337 |
raw_text=best_match, # legacy alias
|
| 338 |
confidence=best_confidence,
|
| 339 |
)
|
| 340 |
+
|
| 341 |
logger.info("No MRZ data found in OCR text")
|
| 342 |
return None
|
| 343 |
+
|
| 344 |
@classmethod
|
| 345 |
def _validate_mrz_format(cls, mrz_text: str) -> bool:
|
| 346 |
"""Validate basic MRZ format.
|
| 347 |
+
|
| 348 |
Args:
|
| 349 |
mrz_text: Raw MRZ text
|
| 350 |
+
|
| 351 |
Returns:
|
| 352 |
True if format is valid
|
| 353 |
"""
|
| 354 |
+
lines = mrz_text.strip().split("\n")
|
| 355 |
if len(lines) < 2:
|
| 356 |
return False
|
| 357 |
+
|
| 358 |
# Normalize whitespace and validate character set only.
|
| 359 |
normalized_lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 360 |
for line in normalized_lines:
|
| 361 |
+
if not re.match(r"^[A-Z0-9<]+$", line):
|
| 362 |
return False
|
| 363 |
+
|
| 364 |
return True
|
| 365 |
+
|
| 366 |
@classmethod
|
| 367 |
def _determine_mrz_format(cls, mrz_text: str) -> str:
|
| 368 |
"""Determine MRZ format type.
|
| 369 |
+
|
| 370 |
Args:
|
| 371 |
mrz_text: Raw MRZ text
|
| 372 |
+
|
| 373 |
Returns:
|
| 374 |
Format type (TD1, TD2, TD3, etc.)
|
| 375 |
"""
|
| 376 |
+
lines = mrz_text.strip().split("\n")
|
| 377 |
lines = [re.sub(r"\s+", "", line) for line in lines]
|
| 378 |
line_count = len(lines)
|
| 379 |
line_length = len(lines[0]) if lines else 0
|
| 380 |
+
|
| 381 |
# Heuristic mapping: prioritize semantics over exact lengths for robustness
|
| 382 |
if line_count == 2 and lines[0].startswith("P<"):
|
| 383 |
return "TD3" # Passport format commonly starts with P<
|
|
|
|
| 386 |
if line_count == 3:
|
| 387 |
return "TD1"
|
| 388 |
return "UNKNOWN"
|
| 389 |
+
|
| 390 |
@classmethod
|
| 391 |
def _adjust_mrz_confidence(cls, mrz_text: str, base_confidence: float) -> float:
|
| 392 |
"""Adjust MRZ confidence based on quality indicators.
|
| 393 |
+
|
| 394 |
Args:
|
| 395 |
mrz_text: Raw MRZ text
|
| 396 |
base_confidence: Base confidence from pattern matching
|
| 397 |
+
|
| 398 |
Returns:
|
| 399 |
Adjusted confidence
|
| 400 |
"""
|
| 401 |
confidence = base_confidence
|
| 402 |
+
|
| 403 |
# Check line consistency
|
| 404 |
+
lines = mrz_text.strip().split("\n")
|
| 405 |
if len(set(len(line) for line in lines)) == 1:
|
| 406 |
confidence *= 1.05 # Bonus for consistent line lengths
|
| 407 |
+
|
| 408 |
return min(max(confidence, 0.0), 1.0)
|
| 409 |
+
|
| 410 |
@classmethod
|
| 411 |
+
def _validate_mrz_checksums(
|
| 412 |
+
cls, mrz_text: str, format_type: str
|
| 413 |
+
) -> Tuple[bool, List[str]]:
|
| 414 |
"""Validate MRZ checksums (simplified implementation).
|
| 415 |
+
|
| 416 |
Args:
|
| 417 |
mrz_text: Raw MRZ text
|
| 418 |
format_type: MRZ format type
|
| 419 |
+
|
| 420 |
Returns:
|
| 421 |
Tuple of (is_valid, list_of_errors)
|
| 422 |
"""
|
| 423 |
# This is a simplified implementation
|
| 424 |
# In production, you would implement full MRZ checksum validation
|
| 425 |
errors = []
|
| 426 |
+
|
| 427 |
# Basic validation - check for reasonable character distribution
|
| 428 |
+
if mrz_text.count("<") > len(mrz_text) * 0.3:
|
| 429 |
errors.append("Too many fill characters")
|
| 430 |
+
|
| 431 |
# For now, assume valid if basic format is correct
|
| 432 |
is_valid = len(errors) == 0
|
| 433 |
+
|
| 434 |
return is_valid, errors
|
| 435 |
|
| 436 |
|
| 437 |
# Backward compatibility - use enhanced extractor as default
|
| 438 |
class FieldExtractor(EnhancedFieldExtractor):
|
| 439 |
"""Backward compatible field extractor using enhanced implementation."""
|
| 440 |
+
|
| 441 |
pass
|
src/kybtech_dots_ocr/field_extraction.py
CHANGED
|
@@ -11,100 +11,96 @@ from .api_models import ExtractedField, IdCardFields, MRZData
|
|
| 11 |
|
| 12 |
class FieldExtractor:
|
| 13 |
"""Field extraction and mapping from OCR results."""
|
| 14 |
-
|
| 15 |
# Field mapping patterns for Dutch ID cards
|
| 16 |
FIELD_PATTERNS = {
|
| 17 |
"document_number": [
|
| 18 |
r"documentnummer[:\s]*([A-Z0-9]+)",
|
| 19 |
r"document\s*number[:\s]*([A-Z0-9]+)",
|
| 20 |
-
r"nr[:\s]*([A-Z0-9]+)"
|
| 21 |
],
|
| 22 |
"surname": [
|
| 23 |
r"achternaam[:\s]*([A-Z]+)",
|
| 24 |
r"surname[:\s]*([A-Z]+)",
|
| 25 |
-
r"family\s*name[:\s]*([A-Z]+)"
|
| 26 |
],
|
| 27 |
"given_names": [
|
| 28 |
r"voornamen[:\s]*([A-Z]+)",
|
| 29 |
r"given\s*names[:\s]*([A-Z]+)",
|
| 30 |
-
r"first\s*name[:\s]*([A-Z]+)"
|
| 31 |
],
|
| 32 |
"nationality": [
|
| 33 |
r"nationaliteit[:\s]*([A-Za-z]+)",
|
| 34 |
-
r"nationality[:\s]*([A-Za-z]+)"
|
| 35 |
],
|
| 36 |
"date_of_birth": [
|
| 37 |
r"geboortedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 38 |
r"date\s*of\s*birth[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 39 |
-
r"born[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})"
|
| 40 |
-
],
|
| 41 |
-
"gender": [
|
| 42 |
-
r"geslacht[:\s]*([MF])",
|
| 43 |
-
r"gender[:\s]*([MF])",
|
| 44 |
-
r"sex[:\s]*([MF])"
|
| 45 |
],
|
|
|
|
| 46 |
"place_of_birth": [
|
| 47 |
r"geboorteplaats[:\s]*([A-Za-z\s]+)",
|
| 48 |
r"place\s*of\s*birth[:\s]*([A-Za-z\s]+)",
|
| 49 |
-
r"born\s*in[:\s]*([A-Za-z\s]+)"
|
| 50 |
],
|
| 51 |
"date_of_issue": [
|
| 52 |
r"uitgiftedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 53 |
r"date\s*of\s*issue[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 54 |
-
r"issued[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})"
|
| 55 |
],
|
| 56 |
"date_of_expiry": [
|
| 57 |
r"vervaldatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 58 |
r"date\s*of\s*expiry[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 59 |
-
r"expires[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})"
|
| 60 |
],
|
| 61 |
"personal_number": [
|
| 62 |
r"persoonsnummer[:\s]*(\d{9})",
|
| 63 |
r"personal\s*number[:\s]*(\d{9})",
|
| 64 |
-
r"bsn[:\s]*(\d{9})"
|
| 65 |
-
]
|
| 66 |
}
|
| 67 |
-
|
| 68 |
@classmethod
|
| 69 |
def extract_fields(cls, ocr_text: str) -> IdCardFields:
|
| 70 |
"""Extract structured fields from OCR text.
|
| 71 |
-
|
| 72 |
Args:
|
| 73 |
ocr_text: Raw OCR text from document processing
|
| 74 |
-
|
| 75 |
Returns:
|
| 76 |
IdCardFields object with extracted field data
|
| 77 |
"""
|
| 78 |
fields = {}
|
| 79 |
-
|
| 80 |
for field_name, patterns in cls.FIELD_PATTERNS.items():
|
| 81 |
value = None
|
| 82 |
confidence = 0.0
|
| 83 |
-
|
| 84 |
for pattern in patterns:
|
| 85 |
match = re.search(pattern, ocr_text, re.IGNORECASE)
|
| 86 |
if match:
|
| 87 |
value = match.group(1).strip()
|
| 88 |
confidence = 0.8 # Base confidence for pattern match
|
| 89 |
break
|
| 90 |
-
|
| 91 |
if value:
|
| 92 |
fields[field_name] = ExtractedField(
|
| 93 |
field_name=field_name,
|
| 94 |
value=value,
|
| 95 |
confidence=confidence,
|
| 96 |
-
source="ocr"
|
| 97 |
)
|
| 98 |
-
|
| 99 |
return IdCardFields(**fields)
|
| 100 |
-
|
| 101 |
@classmethod
|
| 102 |
def extract_mrz(cls, ocr_text: str) -> Optional[MRZData]:
|
| 103 |
"""Extract MRZ data from OCR text.
|
| 104 |
-
|
| 105 |
Args:
|
| 106 |
ocr_text: Raw OCR text from document processing
|
| 107 |
-
|
| 108 |
Returns:
|
| 109 |
MRZData object if MRZ detected, None otherwise
|
| 110 |
"""
|
|
@@ -113,9 +109,9 @@ class FieldExtractor:
|
|
| 113 |
r"(P<[A-Z0-9<]+\n[A-Z0-9<]+)", # Generic passport format (try first)
|
| 114 |
r"([A-Z0-9<]{30}\n[A-Z0-9<]{30})", # TD1 format
|
| 115 |
r"([A-Z0-9<]{44}\n[A-Z0-9<]{44})", # TD2 format
|
| 116 |
-
r"([A-Z0-9<]{44}\n[A-Z0-9<]{44}\n[A-Z0-9<]{44})" # TD3 format
|
| 117 |
]
|
| 118 |
-
|
| 119 |
for pattern in mrz_patterns:
|
| 120 |
match = re.search(pattern, ocr_text, re.MULTILINE)
|
| 121 |
if match:
|
|
@@ -123,10 +119,10 @@ class FieldExtractor:
|
|
| 123 |
# Basic MRZ parsing (simplified)
|
| 124 |
return MRZData(
|
| 125 |
raw_text=raw_mrz,
|
| 126 |
-
format_type="TD3" if len(raw_mrz.split(
|
| 127 |
is_valid=True, # Assume valid if present
|
| 128 |
checksum_errors=[], # Not implemented in basic version
|
| 129 |
-
confidence=0.9
|
| 130 |
)
|
| 131 |
-
|
| 132 |
return None
|
|
|
|
| 11 |
|
| 12 |
class FieldExtractor:
|
| 13 |
"""Field extraction and mapping from OCR results."""
|
| 14 |
+
|
| 15 |
# Field mapping patterns for Dutch ID cards
|
| 16 |
FIELD_PATTERNS = {
|
| 17 |
"document_number": [
|
| 18 |
r"documentnummer[:\s]*([A-Z0-9]+)",
|
| 19 |
r"document\s*number[:\s]*([A-Z0-9]+)",
|
| 20 |
+
r"nr[:\s]*([A-Z0-9]+)",
|
| 21 |
],
|
| 22 |
"surname": [
|
| 23 |
r"achternaam[:\s]*([A-Z]+)",
|
| 24 |
r"surname[:\s]*([A-Z]+)",
|
| 25 |
+
r"family\s*name[:\s]*([A-Z]+)",
|
| 26 |
],
|
| 27 |
"given_names": [
|
| 28 |
r"voornamen[:\s]*([A-Z]+)",
|
| 29 |
r"given\s*names[:\s]*([A-Z]+)",
|
| 30 |
+
r"first\s*name[:\s]*([A-Z]+)",
|
| 31 |
],
|
| 32 |
"nationality": [
|
| 33 |
r"nationaliteit[:\s]*([A-Za-z]+)",
|
| 34 |
+
r"nationality[:\s]*([A-Za-z]+)",
|
| 35 |
],
|
| 36 |
"date_of_birth": [
|
| 37 |
r"geboortedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 38 |
r"date\s*of\s*birth[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 39 |
+
r"born[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
],
|
| 41 |
+
"gender": [r"geslacht[:\s]*([MF])", r"gender[:\s]*([MF])", r"sex[:\s]*([MF])"],
|
| 42 |
"place_of_birth": [
|
| 43 |
r"geboorteplaats[:\s]*([A-Za-z\s]+)",
|
| 44 |
r"place\s*of\s*birth[:\s]*([A-Za-z\s]+)",
|
| 45 |
+
r"born\s*in[:\s]*([A-Za-z\s]+)",
|
| 46 |
],
|
| 47 |
"date_of_issue": [
|
| 48 |
r"uitgiftedatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 49 |
r"date\s*of\s*issue[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 50 |
+
r"issued[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 51 |
],
|
| 52 |
"date_of_expiry": [
|
| 53 |
r"vervaldatum[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 54 |
r"date\s*of\s*expiry[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 55 |
+
r"expires[:\s]*(\d{2}[./-]\d{2}[./-]\d{4})",
|
| 56 |
],
|
| 57 |
"personal_number": [
|
| 58 |
r"persoonsnummer[:\s]*(\d{9})",
|
| 59 |
r"personal\s*number[:\s]*(\d{9})",
|
| 60 |
+
r"bsn[:\s]*(\d{9})",
|
| 61 |
+
],
|
| 62 |
}
|
| 63 |
+
|
| 64 |
@classmethod
|
| 65 |
def extract_fields(cls, ocr_text: str) -> IdCardFields:
|
| 66 |
"""Extract structured fields from OCR text.
|
| 67 |
+
|
| 68 |
Args:
|
| 69 |
ocr_text: Raw OCR text from document processing
|
| 70 |
+
|
| 71 |
Returns:
|
| 72 |
IdCardFields object with extracted field data
|
| 73 |
"""
|
| 74 |
fields = {}
|
| 75 |
+
|
| 76 |
for field_name, patterns in cls.FIELD_PATTERNS.items():
|
| 77 |
value = None
|
| 78 |
confidence = 0.0
|
| 79 |
+
|
| 80 |
for pattern in patterns:
|
| 81 |
match = re.search(pattern, ocr_text, re.IGNORECASE)
|
| 82 |
if match:
|
| 83 |
value = match.group(1).strip()
|
| 84 |
confidence = 0.8 # Base confidence for pattern match
|
| 85 |
break
|
| 86 |
+
|
| 87 |
if value:
|
| 88 |
fields[field_name] = ExtractedField(
|
| 89 |
field_name=field_name,
|
| 90 |
value=value,
|
| 91 |
confidence=confidence,
|
| 92 |
+
source="ocr",
|
| 93 |
)
|
| 94 |
+
|
| 95 |
return IdCardFields(**fields)
|
| 96 |
+
|
| 97 |
@classmethod
|
| 98 |
def extract_mrz(cls, ocr_text: str) -> Optional[MRZData]:
|
| 99 |
"""Extract MRZ data from OCR text.
|
| 100 |
+
|
| 101 |
Args:
|
| 102 |
ocr_text: Raw OCR text from document processing
|
| 103 |
+
|
| 104 |
Returns:
|
| 105 |
MRZData object if MRZ detected, None otherwise
|
| 106 |
"""
|
|
|
|
| 109 |
r"(P<[A-Z0-9<]+\n[A-Z0-9<]+)", # Generic passport format (try first)
|
| 110 |
r"([A-Z0-9<]{30}\n[A-Z0-9<]{30})", # TD1 format
|
| 111 |
r"([A-Z0-9<]{44}\n[A-Z0-9<]{44})", # TD2 format
|
| 112 |
+
r"([A-Z0-9<]{44}\n[A-Z0-9<]{44}\n[A-Z0-9<]{44})", # TD3 format
|
| 113 |
]
|
| 114 |
+
|
| 115 |
for pattern in mrz_patterns:
|
| 116 |
match = re.search(pattern, ocr_text, re.MULTILINE)
|
| 117 |
if match:
|
|
|
|
| 119 |
# Basic MRZ parsing (simplified)
|
| 120 |
return MRZData(
|
| 121 |
raw_text=raw_mrz,
|
| 122 |
+
format_type="TD3" if len(raw_mrz.split("\n")) == 3 else "TD2",
|
| 123 |
is_valid=True, # Assume valid if present
|
| 124 |
checksum_errors=[], # Not implemented in basic version
|
| 125 |
+
confidence=0.9,
|
| 126 |
)
|
| 127 |
+
|
| 128 |
return None
|
src/kybtech_dots_ocr/models.py
CHANGED
|
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
|
|
| 10 |
|
| 11 |
class ExtractedField(BaseModel):
|
| 12 |
"""Individual extracted field from identity document."""
|
|
|
|
| 13 |
field_name: str = Field(..., description="Standardized field name")
|
| 14 |
value: Optional[str] = Field(None, description="Extracted field value")
|
| 15 |
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
|
|
@@ -18,10 +19,19 @@ class ExtractedField(BaseModel):
|
|
| 18 |
|
| 19 |
class IdCardFields(BaseModel):
|
| 20 |
"""Structured fields extracted from identity documents."""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Personal Information
|
| 27 |
surname: Optional[ExtractedField] = Field(None, description="Family name/surname")
|
|
@@ -34,17 +44,30 @@ class IdCardFields(BaseModel):
|
|
| 34 |
# Validity Information
|
| 35 |
date_of_issue: Optional[ExtractedField] = Field(None, description="Date of issue")
|
| 36 |
date_of_expiry: Optional[ExtractedField] = Field(None, description="Date of expiry")
|
| 37 |
-
personal_number: Optional[ExtractedField] = Field(
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Additional fields for specific document types
|
| 40 |
-
optional_data_1: Optional[ExtractedField] = Field(
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
class MRZData(BaseModel):
|
| 45 |
"""Machine Readable Zone data extracted from identity documents."""
|
|
|
|
| 46 |
raw_text: str = Field(..., description="Raw MRZ text as extracted")
|
| 47 |
-
format_type: str = Field(
|
|
|
|
|
|
|
| 48 |
is_valid: bool = Field(..., description="Whether MRZ checksums are valid")
|
| 49 |
-
checksum_errors: List[str] = Field(
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class ExtractedField(BaseModel):
|
| 12 |
"""Individual extracted field from identity document."""
|
| 13 |
+
|
| 14 |
field_name: str = Field(..., description="Standardized field name")
|
| 15 |
value: Optional[str] = Field(None, description="Extracted field value")
|
| 16 |
confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
|
|
|
|
| 19 |
|
| 20 |
class IdCardFields(BaseModel):
|
| 21 |
"""Structured fields extracted from identity documents."""
|
| 22 |
+
|
| 23 |
+
document_number: Optional[ExtractedField] = Field(
|
| 24 |
+
None, description="Document number/ID"
|
| 25 |
+
)
|
| 26 |
+
document_type: Optional[ExtractedField] = Field(
|
| 27 |
+
None, description="Type of document"
|
| 28 |
+
)
|
| 29 |
+
issuing_country: Optional[ExtractedField] = Field(
|
| 30 |
+
None, description="Issuing country code"
|
| 31 |
+
)
|
| 32 |
+
issuing_authority: Optional[ExtractedField] = Field(
|
| 33 |
+
None, description="Issuing authority"
|
| 34 |
+
)
|
| 35 |
|
| 36 |
# Personal Information
|
| 37 |
surname: Optional[ExtractedField] = Field(None, description="Family name/surname")
|
|
|
|
| 44 |
# Validity Information
|
| 45 |
date_of_issue: Optional[ExtractedField] = Field(None, description="Date of issue")
|
| 46 |
date_of_expiry: Optional[ExtractedField] = Field(None, description="Date of expiry")
|
| 47 |
+
personal_number: Optional[ExtractedField] = Field(
|
| 48 |
+
None, description="Personal number"
|
| 49 |
+
)
|
| 50 |
|
| 51 |
# Additional fields for specific document types
|
| 52 |
+
optional_data_1: Optional[ExtractedField] = Field(
|
| 53 |
+
None, description="Optional data field 1"
|
| 54 |
+
)
|
| 55 |
+
optional_data_2: Optional[ExtractedField] = Field(
|
| 56 |
+
None, description="Optional data field 2"
|
| 57 |
+
)
|
| 58 |
|
| 59 |
|
| 60 |
class MRZData(BaseModel):
|
| 61 |
"""Machine Readable Zone data extracted from identity documents."""
|
| 62 |
+
|
| 63 |
raw_text: str = Field(..., description="Raw MRZ text as extracted")
|
| 64 |
+
format_type: str = Field(
|
| 65 |
+
..., description="MRZ format type (TD1, TD2, TD3, MRVA, MRVB)"
|
| 66 |
+
)
|
| 67 |
is_valid: bool = Field(..., description="Whether MRZ checksums are valid")
|
| 68 |
+
checksum_errors: List[str] = Field(
|
| 69 |
+
default_factory=list, description="List of checksum validation errors"
|
| 70 |
+
)
|
| 71 |
+
confidence: float = Field(
|
| 72 |
+
..., ge=0.0, le=1.0, description="Extraction confidence score"
|
| 73 |
+
)
|
src/kybtech_dots_ocr/preprocessing.py
CHANGED
|
@@ -21,15 +21,19 @@ logger = logging.getLogger(__name__)
|
|
| 21 |
# Environment variable configuration
|
| 22 |
PDF_DPI = int(os.getenv("DOTS_OCR_PDF_DPI", "300"))
|
| 23 |
PDF_MAX_PAGES = int(os.getenv("DOTS_OCR_PDF_MAX_PAGES", "10"))
|
| 24 |
-
IMAGE_MAX_SIZE =
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class ImagePreprocessor:
|
| 28 |
"""Handles image preprocessing for Dots.OCR model."""
|
| 29 |
-
|
| 30 |
-
def __init__(
|
|
|
|
|
|
|
| 31 |
"""Initialize the image preprocessor.
|
| 32 |
-
|
| 33 |
Args:
|
| 34 |
min_pixels: Minimum pixel count for images
|
| 35 |
max_pixels: Maximum pixel count for images
|
|
@@ -38,29 +42,29 @@ class ImagePreprocessor:
|
|
| 38 |
self.min_pixels = min_pixels
|
| 39 |
self.max_pixels = max_pixels
|
| 40 |
self.divisor = divisor
|
| 41 |
-
|
| 42 |
def preprocess_image(self, image: Image.Image) -> Image.Image:
|
| 43 |
"""Preprocess an image to meet model requirements.
|
| 44 |
-
|
| 45 |
Args:
|
| 46 |
image: Input PIL Image
|
| 47 |
-
|
| 48 |
Returns:
|
| 49 |
Preprocessed PIL Image
|
| 50 |
"""
|
| 51 |
# Convert to RGB if necessary
|
| 52 |
if image.mode != "RGB":
|
| 53 |
image = image.convert("RGB")
|
| 54 |
-
|
| 55 |
# Auto-orient image based on EXIF data
|
| 56 |
image = ImageOps.exif_transpose(image)
|
| 57 |
-
|
| 58 |
# Calculate current pixel count
|
| 59 |
width, height = image.size
|
| 60 |
current_pixels = width * height
|
| 61 |
-
|
| 62 |
logger.info(f"Original image size: {width}x{height} ({current_pixels} pixels)")
|
| 63 |
-
|
| 64 |
# Resize if necessary to meet pixel requirements
|
| 65 |
if current_pixels < self.min_pixels:
|
| 66 |
# Scale up to meet minimum pixel requirement
|
|
@@ -69,7 +73,7 @@ class ImagePreprocessor:
|
|
| 69 |
new_height = int(height * scale_factor)
|
| 70 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 71 |
logger.info(f"Scaled up image to {new_width}x{new_height}")
|
| 72 |
-
|
| 73 |
elif current_pixels > self.max_pixels:
|
| 74 |
# Scale down to meet maximum pixel requirement
|
| 75 |
scale_factor = (self.max_pixels / current_pixels) ** 0.5
|
|
@@ -77,69 +81,73 @@ class ImagePreprocessor:
|
|
| 77 |
new_height = int(height * scale_factor)
|
| 78 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 79 |
logger.info(f"Scaled down image to {new_width}x{new_height}")
|
| 80 |
-
|
| 81 |
# Ensure dimensions are divisible by the required divisor
|
| 82 |
width, height = image.size
|
| 83 |
new_width = ((width + self.divisor - 1) // self.divisor) * self.divisor
|
| 84 |
new_height = ((height + self.divisor - 1) // self.divisor) * self.divisor
|
| 85 |
-
|
| 86 |
if new_width != width or new_height != height:
|
| 87 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 88 |
-
logger.info(
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
return image
|
| 91 |
-
|
| 92 |
-
def crop_by_roi(
|
|
|
|
|
|
|
| 93 |
"""Crop image using ROI coordinates.
|
| 94 |
-
|
| 95 |
Args:
|
| 96 |
image: Input PIL Image
|
| 97 |
roi: ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 98 |
-
|
| 99 |
Returns:
|
| 100 |
Cropped PIL Image
|
| 101 |
"""
|
| 102 |
x1, y1, x2, y2 = roi
|
| 103 |
width, height = image.size
|
| 104 |
-
|
| 105 |
# Convert normalized coordinates to pixel coordinates
|
| 106 |
x1_px = int(x1 * width)
|
| 107 |
y1_px = int(y1 * height)
|
| 108 |
x2_px = int(x2 * width)
|
| 109 |
y2_px = int(y2 * height)
|
| 110 |
-
|
| 111 |
# Ensure coordinates are within image bounds
|
| 112 |
x1_px = max(0, min(x1_px, width))
|
| 113 |
y1_px = max(0, min(y1_px, height))
|
| 114 |
x2_px = max(x1_px, min(x2_px, width))
|
| 115 |
y2_px = max(y1_px, min(y2_px, height))
|
| 116 |
-
|
| 117 |
# Crop the image
|
| 118 |
cropped = image.crop((x1_px, y1_px, x2_px, y2_px))
|
| 119 |
logger.info(f"Cropped image to {x2_px - x1_px}x{y2_px - y1_px} pixels")
|
| 120 |
-
|
| 121 |
return cropped
|
| 122 |
|
| 123 |
|
| 124 |
class PDFProcessor:
|
| 125 |
"""Handles PDF to image conversion and multi-page processing."""
|
| 126 |
-
|
| 127 |
def __init__(self, dpi: int = PDF_DPI, max_pages: int = PDF_MAX_PAGES):
|
| 128 |
"""Initialize the PDF processor.
|
| 129 |
-
|
| 130 |
Args:
|
| 131 |
dpi: DPI for PDF to image conversion
|
| 132 |
max_pages: Maximum number of pages to process
|
| 133 |
"""
|
| 134 |
self.dpi = dpi
|
| 135 |
self.max_pages = max_pages
|
| 136 |
-
|
| 137 |
def pdf_to_images(self, pdf_data: bytes) -> List[Image.Image]:
|
| 138 |
"""Convert PDF to list of images.
|
| 139 |
-
|
| 140 |
Args:
|
| 141 |
pdf_data: PDF file data as bytes
|
| 142 |
-
|
| 143 |
Returns:
|
| 144 |
List of PIL Images, one per page
|
| 145 |
"""
|
|
@@ -147,49 +155,49 @@ class PDFProcessor:
|
|
| 147 |
# Open PDF from bytes
|
| 148 |
pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
|
| 149 |
images = []
|
| 150 |
-
|
| 151 |
# Limit number of pages to process
|
| 152 |
num_pages = min(len(pdf_document), self.max_pages)
|
| 153 |
logger.info(f"Processing {num_pages} pages from PDF")
|
| 154 |
-
|
| 155 |
for page_num in range(num_pages):
|
| 156 |
page = pdf_document[page_num]
|
| 157 |
-
|
| 158 |
# Convert page to image
|
| 159 |
mat = fitz.Matrix(self.dpi / 72, self.dpi / 72) # 72 is default DPI
|
| 160 |
pix = page.get_pixmap(matrix=mat)
|
| 161 |
-
|
| 162 |
# Convert to PIL Image
|
| 163 |
img_data = pix.tobytes("png")
|
| 164 |
image = Image.open(io.BytesIO(img_data))
|
| 165 |
images.append(image)
|
| 166 |
-
|
| 167 |
logger.info(f"Converted page {page_num + 1} to image: {image.size}")
|
| 168 |
-
|
| 169 |
pdf_document.close()
|
| 170 |
return images
|
| 171 |
-
|
| 172 |
except Exception as e:
|
| 173 |
logger.error(f"Failed to convert PDF to images: {e}")
|
| 174 |
raise RuntimeError(f"PDF conversion failed: {e}")
|
| 175 |
-
|
| 176 |
def is_pdf(self, file_data: bytes) -> bool:
|
| 177 |
"""Check if file data is a PDF.
|
| 178 |
-
|
| 179 |
Args:
|
| 180 |
file_data: File data as bytes
|
| 181 |
-
|
| 182 |
Returns:
|
| 183 |
True if file is a PDF
|
| 184 |
"""
|
| 185 |
-
return file_data.startswith(b
|
| 186 |
-
|
| 187 |
def get_pdf_page_count(self, pdf_data: bytes) -> int:
|
| 188 |
"""Get the number of pages in a PDF.
|
| 189 |
-
|
| 190 |
Args:
|
| 191 |
pdf_data: PDF file data as bytes
|
| 192 |
-
|
| 193 |
Returns:
|
| 194 |
Number of pages in the PDF
|
| 195 |
"""
|
|
@@ -205,23 +213,21 @@ class PDFProcessor:
|
|
| 205 |
|
| 206 |
class DocumentProcessor:
|
| 207 |
"""Main document processing class that handles both images and PDFs."""
|
| 208 |
-
|
| 209 |
def __init__(self):
|
| 210 |
"""Initialize the document processor."""
|
| 211 |
self.image_preprocessor = ImagePreprocessor()
|
| 212 |
self.pdf_processor = PDFProcessor()
|
| 213 |
-
|
| 214 |
def process_document(
|
| 215 |
-
self,
|
| 216 |
-
file_data: bytes,
|
| 217 |
-
roi: Optional[Tuple[float, float, float, float]] = None
|
| 218 |
) -> List[Image.Image]:
|
| 219 |
"""Process a document (image or PDF) and return preprocessed images.
|
| 220 |
-
|
| 221 |
Args:
|
| 222 |
file_data: Document file data as bytes
|
| 223 |
roi: Optional ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 224 |
-
|
| 225 |
Returns:
|
| 226 |
List of preprocessed PIL Images
|
| 227 |
"""
|
|
@@ -238,7 +244,7 @@ class DocumentProcessor:
|
|
| 238 |
except Exception as e:
|
| 239 |
logger.error(f"Failed to open image: {e}")
|
| 240 |
raise RuntimeError(f"Image processing failed: {e}")
|
| 241 |
-
|
| 242 |
# Preprocess each image
|
| 243 |
processed_images = []
|
| 244 |
for i, image in enumerate(images):
|
|
@@ -246,30 +252,30 @@ class DocumentProcessor:
|
|
| 246 |
# Apply ROI cropping if provided
|
| 247 |
if roi is not None:
|
| 248 |
image = self.image_preprocessor.crop_by_roi(image, roi)
|
| 249 |
-
|
| 250 |
# Preprocess image for model requirements
|
| 251 |
processed_image = self.image_preprocessor.preprocess_image(image)
|
| 252 |
processed_images.append(processed_image)
|
| 253 |
-
|
| 254 |
logger.info(f"Processed image {i + 1}: {processed_image.size}")
|
| 255 |
-
|
| 256 |
except Exception as e:
|
| 257 |
logger.error(f"Failed to preprocess image {i + 1}: {e}")
|
| 258 |
# Continue with other images even if one fails
|
| 259 |
continue
|
| 260 |
-
|
| 261 |
if not processed_images:
|
| 262 |
raise RuntimeError("No images could be processed from the document")
|
| 263 |
-
|
| 264 |
logger.info(f"Successfully processed {len(processed_images)} images")
|
| 265 |
return processed_images
|
| 266 |
-
|
| 267 |
def validate_file_size(self, file_data: bytes) -> bool:
|
| 268 |
"""Validate that file size is within limits.
|
| 269 |
-
|
| 270 |
Args:
|
| 271 |
file_data: File data as bytes
|
| 272 |
-
|
| 273 |
Returns:
|
| 274 |
True if file size is acceptable
|
| 275 |
"""
|
|
@@ -278,25 +284,25 @@ class DocumentProcessor:
|
|
| 278 |
logger.warning(f"File size {file_size} exceeds limit {IMAGE_MAX_SIZE}")
|
| 279 |
return False
|
| 280 |
return True
|
| 281 |
-
|
| 282 |
def get_document_info(self, file_data: bytes) -> dict:
|
| 283 |
"""Get information about the document.
|
| 284 |
-
|
| 285 |
Args:
|
| 286 |
file_data: Document file data as bytes
|
| 287 |
-
|
| 288 |
Returns:
|
| 289 |
Dictionary with document information
|
| 290 |
"""
|
| 291 |
info = {
|
| 292 |
"file_size": len(file_data),
|
| 293 |
"is_pdf": self.pdf_processor.is_pdf(file_data),
|
| 294 |
-
"page_count": 1
|
| 295 |
}
|
| 296 |
-
|
| 297 |
if info["is_pdf"]:
|
| 298 |
info["page_count"] = self.pdf_processor.get_pdf_page_count(file_data)
|
| 299 |
-
|
| 300 |
return info
|
| 301 |
|
| 302 |
|
|
@@ -313,8 +319,7 @@ def get_document_processor() -> DocumentProcessor:
|
|
| 313 |
|
| 314 |
|
| 315 |
def process_document(
|
| 316 |
-
file_data: bytes,
|
| 317 |
-
roi: Optional[Tuple[float, float, float, float]] = None
|
| 318 |
) -> List[Image.Image]:
|
| 319 |
"""Process a document and return preprocessed images."""
|
| 320 |
processor = get_document_processor()
|
|
|
|
| 21 |
# Environment variable configuration
|
| 22 |
PDF_DPI = int(os.getenv("DOTS_OCR_PDF_DPI", "300"))
|
| 23 |
PDF_MAX_PAGES = int(os.getenv("DOTS_OCR_PDF_MAX_PAGES", "10"))
|
| 24 |
+
IMAGE_MAX_SIZE = (
|
| 25 |
+
int(os.getenv("DOTS_OCR_IMAGE_MAX_SIZE", "10")) * 1024 * 1024
|
| 26 |
+
) # 10MB default
|
| 27 |
|
| 28 |
|
| 29 |
class ImagePreprocessor:
|
| 30 |
"""Handles image preprocessing for Dots.OCR model."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self, min_pixels: int = 3136, max_pixels: int = 11289600, divisor: int = 28
|
| 34 |
+
):
|
| 35 |
"""Initialize the image preprocessor.
|
| 36 |
+
|
| 37 |
Args:
|
| 38 |
min_pixels: Minimum pixel count for images
|
| 39 |
max_pixels: Maximum pixel count for images
|
|
|
|
| 42 |
self.min_pixels = min_pixels
|
| 43 |
self.max_pixels = max_pixels
|
| 44 |
self.divisor = divisor
|
| 45 |
+
|
| 46 |
def preprocess_image(self, image: Image.Image) -> Image.Image:
|
| 47 |
"""Preprocess an image to meet model requirements.
|
| 48 |
+
|
| 49 |
Args:
|
| 50 |
image: Input PIL Image
|
| 51 |
+
|
| 52 |
Returns:
|
| 53 |
Preprocessed PIL Image
|
| 54 |
"""
|
| 55 |
# Convert to RGB if necessary
|
| 56 |
if image.mode != "RGB":
|
| 57 |
image = image.convert("RGB")
|
| 58 |
+
|
| 59 |
# Auto-orient image based on EXIF data
|
| 60 |
image = ImageOps.exif_transpose(image)
|
| 61 |
+
|
| 62 |
# Calculate current pixel count
|
| 63 |
width, height = image.size
|
| 64 |
current_pixels = width * height
|
| 65 |
+
|
| 66 |
logger.info(f"Original image size: {width}x{height} ({current_pixels} pixels)")
|
| 67 |
+
|
| 68 |
# Resize if necessary to meet pixel requirements
|
| 69 |
if current_pixels < self.min_pixels:
|
| 70 |
# Scale up to meet minimum pixel requirement
|
|
|
|
| 73 |
new_height = int(height * scale_factor)
|
| 74 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 75 |
logger.info(f"Scaled up image to {new_width}x{new_height}")
|
| 76 |
+
|
| 77 |
elif current_pixels > self.max_pixels:
|
| 78 |
# Scale down to meet maximum pixel requirement
|
| 79 |
scale_factor = (self.max_pixels / current_pixels) ** 0.5
|
|
|
|
| 81 |
new_height = int(height * scale_factor)
|
| 82 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 83 |
logger.info(f"Scaled down image to {new_width}x{new_height}")
|
| 84 |
+
|
| 85 |
# Ensure dimensions are divisible by the required divisor
|
| 86 |
width, height = image.size
|
| 87 |
new_width = ((width + self.divisor - 1) // self.divisor) * self.divisor
|
| 88 |
new_height = ((height + self.divisor - 1) // self.divisor) * self.divisor
|
| 89 |
+
|
| 90 |
if new_width != width or new_height != height:
|
| 91 |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 92 |
+
logger.info(
|
| 93 |
+
f"Adjusted dimensions to be divisible by {self.divisor}: {new_width}x{new_height}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
return image
|
| 97 |
+
|
| 98 |
+
def crop_by_roi(
|
| 99 |
+
self, image: Image.Image, roi: Tuple[float, float, float, float]
|
| 100 |
+
) -> Image.Image:
|
| 101 |
"""Crop image using ROI coordinates.
|
| 102 |
+
|
| 103 |
Args:
|
| 104 |
image: Input PIL Image
|
| 105 |
roi: ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 106 |
+
|
| 107 |
Returns:
|
| 108 |
Cropped PIL Image
|
| 109 |
"""
|
| 110 |
x1, y1, x2, y2 = roi
|
| 111 |
width, height = image.size
|
| 112 |
+
|
| 113 |
# Convert normalized coordinates to pixel coordinates
|
| 114 |
x1_px = int(x1 * width)
|
| 115 |
y1_px = int(y1 * height)
|
| 116 |
x2_px = int(x2 * width)
|
| 117 |
y2_px = int(y2 * height)
|
| 118 |
+
|
| 119 |
# Ensure coordinates are within image bounds
|
| 120 |
x1_px = max(0, min(x1_px, width))
|
| 121 |
y1_px = max(0, min(y1_px, height))
|
| 122 |
x2_px = max(x1_px, min(x2_px, width))
|
| 123 |
y2_px = max(y1_px, min(y2_px, height))
|
| 124 |
+
|
| 125 |
# Crop the image
|
| 126 |
cropped = image.crop((x1_px, y1_px, x2_px, y2_px))
|
| 127 |
logger.info(f"Cropped image to {x2_px - x1_px}x{y2_px - y1_px} pixels")
|
| 128 |
+
|
| 129 |
return cropped
|
| 130 |
|
| 131 |
|
| 132 |
class PDFProcessor:
|
| 133 |
"""Handles PDF to image conversion and multi-page processing."""
|
| 134 |
+
|
| 135 |
def __init__(self, dpi: int = PDF_DPI, max_pages: int = PDF_MAX_PAGES):
|
| 136 |
"""Initialize the PDF processor.
|
| 137 |
+
|
| 138 |
Args:
|
| 139 |
dpi: DPI for PDF to image conversion
|
| 140 |
max_pages: Maximum number of pages to process
|
| 141 |
"""
|
| 142 |
self.dpi = dpi
|
| 143 |
self.max_pages = max_pages
|
| 144 |
+
|
| 145 |
def pdf_to_images(self, pdf_data: bytes) -> List[Image.Image]:
|
| 146 |
"""Convert PDF to list of images.
|
| 147 |
+
|
| 148 |
Args:
|
| 149 |
pdf_data: PDF file data as bytes
|
| 150 |
+
|
| 151 |
Returns:
|
| 152 |
List of PIL Images, one per page
|
| 153 |
"""
|
|
|
|
| 155 |
# Open PDF from bytes
|
| 156 |
pdf_document = fitz.open(stream=pdf_data, filetype="pdf")
|
| 157 |
images = []
|
| 158 |
+
|
| 159 |
# Limit number of pages to process
|
| 160 |
num_pages = min(len(pdf_document), self.max_pages)
|
| 161 |
logger.info(f"Processing {num_pages} pages from PDF")
|
| 162 |
+
|
| 163 |
for page_num in range(num_pages):
|
| 164 |
page = pdf_document[page_num]
|
| 165 |
+
|
| 166 |
# Convert page to image
|
| 167 |
mat = fitz.Matrix(self.dpi / 72, self.dpi / 72) # 72 is default DPI
|
| 168 |
pix = page.get_pixmap(matrix=mat)
|
| 169 |
+
|
| 170 |
# Convert to PIL Image
|
| 171 |
img_data = pix.tobytes("png")
|
| 172 |
image = Image.open(io.BytesIO(img_data))
|
| 173 |
images.append(image)
|
| 174 |
+
|
| 175 |
logger.info(f"Converted page {page_num + 1} to image: {image.size}")
|
| 176 |
+
|
| 177 |
pdf_document.close()
|
| 178 |
return images
|
| 179 |
+
|
| 180 |
except Exception as e:
|
| 181 |
logger.error(f"Failed to convert PDF to images: {e}")
|
| 182 |
raise RuntimeError(f"PDF conversion failed: {e}")
|
| 183 |
+
|
| 184 |
def is_pdf(self, file_data: bytes) -> bool:
|
| 185 |
"""Check if file data is a PDF.
|
| 186 |
+
|
| 187 |
Args:
|
| 188 |
file_data: File data as bytes
|
| 189 |
+
|
| 190 |
Returns:
|
| 191 |
True if file is a PDF
|
| 192 |
"""
|
| 193 |
+
return file_data.startswith(b"%PDF-")
|
| 194 |
+
|
| 195 |
def get_pdf_page_count(self, pdf_data: bytes) -> int:
|
| 196 |
"""Get the number of pages in a PDF.
|
| 197 |
+
|
| 198 |
Args:
|
| 199 |
pdf_data: PDF file data as bytes
|
| 200 |
+
|
| 201 |
Returns:
|
| 202 |
Number of pages in the PDF
|
| 203 |
"""
|
|
|
|
| 213 |
|
| 214 |
class DocumentProcessor:
|
| 215 |
"""Main document processing class that handles both images and PDFs."""
|
| 216 |
+
|
| 217 |
def __init__(self):
|
| 218 |
"""Initialize the document processor."""
|
| 219 |
self.image_preprocessor = ImagePreprocessor()
|
| 220 |
self.pdf_processor = PDFProcessor()
|
| 221 |
+
|
| 222 |
def process_document(
|
| 223 |
+
self, file_data: bytes, roi: Optional[Tuple[float, float, float, float]] = None
|
|
|
|
|
|
|
| 224 |
) -> List[Image.Image]:
|
| 225 |
"""Process a document (image or PDF) and return preprocessed images.
|
| 226 |
+
|
| 227 |
Args:
|
| 228 |
file_data: Document file data as bytes
|
| 229 |
roi: Optional ROI coordinates as (x1, y1, x2, y2) normalized to [0, 1]
|
| 230 |
+
|
| 231 |
Returns:
|
| 232 |
List of preprocessed PIL Images
|
| 233 |
"""
|
|
|
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Failed to open image: {e}")
|
| 246 |
raise RuntimeError(f"Image processing failed: {e}")
|
| 247 |
+
|
| 248 |
# Preprocess each image
|
| 249 |
processed_images = []
|
| 250 |
for i, image in enumerate(images):
|
|
|
|
| 252 |
# Apply ROI cropping if provided
|
| 253 |
if roi is not None:
|
| 254 |
image = self.image_preprocessor.crop_by_roi(image, roi)
|
| 255 |
+
|
| 256 |
# Preprocess image for model requirements
|
| 257 |
processed_image = self.image_preprocessor.preprocess_image(image)
|
| 258 |
processed_images.append(processed_image)
|
| 259 |
+
|
| 260 |
logger.info(f"Processed image {i + 1}: {processed_image.size}")
|
| 261 |
+
|
| 262 |
except Exception as e:
|
| 263 |
logger.error(f"Failed to preprocess image {i + 1}: {e}")
|
| 264 |
# Continue with other images even if one fails
|
| 265 |
continue
|
| 266 |
+
|
| 267 |
if not processed_images:
|
| 268 |
raise RuntimeError("No images could be processed from the document")
|
| 269 |
+
|
| 270 |
logger.info(f"Successfully processed {len(processed_images)} images")
|
| 271 |
return processed_images
|
| 272 |
+
|
| 273 |
def validate_file_size(self, file_data: bytes) -> bool:
|
| 274 |
"""Validate that file size is within limits.
|
| 275 |
+
|
| 276 |
Args:
|
| 277 |
file_data: File data as bytes
|
| 278 |
+
|
| 279 |
Returns:
|
| 280 |
True if file size is acceptable
|
| 281 |
"""
|
|
|
|
| 284 |
logger.warning(f"File size {file_size} exceeds limit {IMAGE_MAX_SIZE}")
|
| 285 |
return False
|
| 286 |
return True
|
| 287 |
+
|
| 288 |
def get_document_info(self, file_data: bytes) -> dict:
|
| 289 |
"""Get information about the document.
|
| 290 |
+
|
| 291 |
Args:
|
| 292 |
file_data: Document file data as bytes
|
| 293 |
+
|
| 294 |
Returns:
|
| 295 |
Dictionary with document information
|
| 296 |
"""
|
| 297 |
info = {
|
| 298 |
"file_size": len(file_data),
|
| 299 |
"is_pdf": self.pdf_processor.is_pdf(file_data),
|
| 300 |
+
"page_count": 1,
|
| 301 |
}
|
| 302 |
+
|
| 303 |
if info["is_pdf"]:
|
| 304 |
info["page_count"] = self.pdf_processor.get_pdf_page_count(file_data)
|
| 305 |
+
|
| 306 |
return info
|
| 307 |
|
| 308 |
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
def process_document(
|
| 322 |
+
file_data: bytes, roi: Optional[Tuple[float, float, float, float]] = None
|
|
|
|
| 323 |
) -> List[Image.Image]:
|
| 324 |
"""Process a document and return preprocessed images."""
|
| 325 |
processor = get_document_processor()
|
src/kybtech_dots_ocr/response_builder.py
CHANGED
|
@@ -2,9 +2,13 @@
|
|
| 2 |
|
| 3 |
This module handles the construction and validation of OCR API responses
|
| 4 |
according to the specified schema with proper error handling and metadata.
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import logging
|
|
|
|
| 8 |
import time
|
| 9 |
from typing import List, Optional, Dict, Any
|
| 10 |
from datetime import datetime
|
|
@@ -29,7 +33,8 @@ class OCRResponseBuilder:
|
|
| 29 |
media_type: str,
|
| 30 |
processing_time: float,
|
| 31 |
ocr_texts: List[str],
|
| 32 |
-
page_metadata: Optional[List[Dict[str, Any]]] = None
|
|
|
|
| 33 |
) -> OCRResponse:
|
| 34 |
"""Build a complete OCR response from extracted texts.
|
| 35 |
|
|
@@ -39,6 +44,7 @@ class OCRResponseBuilder:
|
|
| 39 |
processing_time: Total processing time in seconds
|
| 40 |
ocr_texts: List of OCR text results (one per page)
|
| 41 |
page_metadata: Optional metadata for each page
|
|
|
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
Complete OCRResponse object
|
|
@@ -46,6 +52,8 @@ class OCRResponseBuilder:
|
|
| 46 |
logger.info(f"Building response for {len(ocr_texts)} pages")
|
| 47 |
|
| 48 |
detections = []
|
|
|
|
|
|
|
| 49 |
|
| 50 |
for i, ocr_text in enumerate(ocr_texts):
|
| 51 |
try:
|
|
@@ -53,6 +61,40 @@ class OCRResponseBuilder:
|
|
| 53 |
extracted_fields = self.field_extractor.extract_fields(ocr_text)
|
| 54 |
mrz_data = self.field_extractor.extract_mrz(ocr_text)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Create detection for this page
|
| 57 |
detection = self._create_detection(extracted_fields, mrz_data, i, page_metadata)
|
| 58 |
detections.append(detection)
|
|
@@ -304,11 +346,19 @@ def build_ocr_response(
|
|
| 304 |
media_type: str,
|
| 305 |
processing_time: float,
|
| 306 |
ocr_texts: List[str],
|
| 307 |
-
page_metadata: Optional[List[Dict[str, Any]]] = None
|
|
|
|
| 308 |
) -> OCRResponse:
|
| 309 |
"""Build a complete OCR response from extracted texts."""
|
| 310 |
builder = get_response_builder()
|
| 311 |
-
return builder.build_response(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
def build_error_response(
|
|
|
|
| 2 |
|
| 3 |
This module handles the construction and validation of OCR API responses
|
| 4 |
according to the specified schema with proper error handling and metadata.
|
| 5 |
+
|
| 6 |
+
Debug-mode logging is supported to surface detailed information about
|
| 7 |
+
extraction results when troubleshooting in environments like Hugging Face.
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
| 11 |
+
import os
|
| 12 |
import time
|
| 13 |
from typing import List, Optional, Dict, Any
|
| 14 |
from datetime import datetime
|
|
|
|
| 33 |
media_type: str,
|
| 34 |
processing_time: float,
|
| 35 |
ocr_texts: List[str],
|
| 36 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None,
|
| 37 |
+
debug: bool = False,
|
| 38 |
) -> OCRResponse:
|
| 39 |
"""Build a complete OCR response from extracted texts.
|
| 40 |
|
|
|
|
| 44 |
processing_time: Total processing time in seconds
|
| 45 |
ocr_texts: List of OCR text results (one per page)
|
| 46 |
page_metadata: Optional metadata for each page
|
| 47 |
+
debug: When True, emit detailed logs about OCR text and mapping
|
| 48 |
|
| 49 |
Returns:
|
| 50 |
Complete OCRResponse object
|
|
|
|
| 52 |
logger.info(f"Building response for {len(ocr_texts)} pages")
|
| 53 |
|
| 54 |
detections = []
|
| 55 |
+
# Allow configuring the OCR text snippet length via env var. Defaults to 1200.
|
| 56 |
+
debug_snippet_len = int(os.getenv("DOTS_OCR_DEBUG_TEXT_SNIPPET_LEN", "1200"))
|
| 57 |
|
| 58 |
for i, ocr_text in enumerate(ocr_texts):
|
| 59 |
try:
|
|
|
|
| 61 |
extracted_fields = self.field_extractor.extract_fields(ocr_text)
|
| 62 |
mrz_data = self.field_extractor.extract_mrz(ocr_text)
|
| 63 |
|
| 64 |
+
# In debug mode, log OCR text snippet and extracted mapping details.
|
| 65 |
+
if debug:
|
| 66 |
+
# Log a bounded snippet of the OCR text to avoid overwhelming logs
|
| 67 |
+
snippet = ocr_text[:debug_snippet_len]
|
| 68 |
+
if len(ocr_text) > debug_snippet_len:
|
| 69 |
+
snippet += "\n...[truncated]"
|
| 70 |
+
logger.info(
|
| 71 |
+
f"[debug] Page {i + 1}: OCR text snippet (len={len(ocr_text)}):\n{snippet}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Prepare a compact dict of non-null extracted fields
|
| 75 |
+
non_null_fields: Dict[str, Any] = {}
|
| 76 |
+
for fname, fval in extracted_fields.__dict__.items():
|
| 77 |
+
if fval is not None:
|
| 78 |
+
non_null_fields[fname] = {
|
| 79 |
+
"value": fval.value,
|
| 80 |
+
"confidence": fval.confidence,
|
| 81 |
+
"source": fval.source,
|
| 82 |
+
}
|
| 83 |
+
logger.info(
|
| 84 |
+
f"[debug] Page {i + 1}: Extracted fields (non-null): {non_null_fields}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if mrz_data is not None:
|
| 88 |
+
# Support both canonical and legacy attribute names
|
| 89 |
+
raw_mrz = getattr(mrz_data, "raw_mrz", None) or getattr(mrz_data, "raw_text", None)
|
| 90 |
+
logger.info(
|
| 91 |
+
f"[debug] Page {i + 1}: MRZ detected — type={getattr(mrz_data, 'document_type', None) or getattr(mrz_data, 'format_type', None)}, confidence={mrz_data.confidence:.2f}"
|
| 92 |
+
)
|
| 93 |
+
if raw_mrz:
|
| 94 |
+
logger.info(f"[debug] Page {i + 1}: MRZ raw text:\n{raw_mrz}")
|
| 95 |
+
else:
|
| 96 |
+
logger.info(f"[debug] Page {i + 1}: No MRZ detected")
|
| 97 |
+
|
| 98 |
# Create detection for this page
|
| 99 |
detection = self._create_detection(extracted_fields, mrz_data, i, page_metadata)
|
| 100 |
detections.append(detection)
|
|
|
|
| 346 |
media_type: str,
|
| 347 |
processing_time: float,
|
| 348 |
ocr_texts: List[str],
|
| 349 |
+
page_metadata: Optional[List[Dict[str, Any]]] = None,
|
| 350 |
+
debug: bool = False,
|
| 351 |
) -> OCRResponse:
|
| 352 |
"""Build a complete OCR response from extracted texts."""
|
| 353 |
builder = get_response_builder()
|
| 354 |
+
return builder.build_response(
|
| 355 |
+
request_id=request_id,
|
| 356 |
+
media_type=media_type,
|
| 357 |
+
processing_time=processing_time,
|
| 358 |
+
ocr_texts=ocr_texts,
|
| 359 |
+
page_metadata=page_metadata,
|
| 360 |
+
debug=debug,
|
| 361 |
+
)
|
| 362 |
|
| 363 |
|
| 364 |
def build_error_response(
|