|
|
import logging |
|
|
import os |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union |
|
|
|
|
|
from pydantic import AnyUrl, BaseModel, ConfigDict, Field, model_validator |
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict |
|
|
|
|
|
_log = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class AcceleratorDevice(str, Enum): |
|
|
"""Devices to run model inference""" |
|
|
|
|
|
AUTO = "auto" |
|
|
CPU = "cpu" |
|
|
CUDA = "cuda" |
|
|
MPS = "mps" |
|
|
|
|
|
|
|
|
class AcceleratorOptions(BaseSettings): |
|
|
model_config = SettingsConfigDict( |
|
|
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True |
|
|
) |
|
|
|
|
|
num_threads: int = 4 |
|
|
device: AcceleratorDevice = AcceleratorDevice.AUTO |
|
|
|
|
|
@model_validator(mode="before") |
|
|
@classmethod |
|
|
def check_alternative_envvars(cls, data: Any) -> Any: |
|
|
r""" |
|
|
Set num_threads from the "alternative" envvar OMP_NUM_THREADS. |
|
|
The alternative envvar is used only if it is valid and the regular envvar is not set. |
|
|
|
|
|
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide |
|
|
the same functionality. In case the alias envvar is set and the user tries to override the |
|
|
parameter in settings initialization, Pydantic treats the parameter provided in __init__() |
|
|
as an extra input instead of simply overwriting the evvar value for that parameter. |
|
|
""" |
|
|
if isinstance(data, dict): |
|
|
input_num_threads = data.get("num_threads") |
|
|
|
|
|
|
|
|
if input_num_threads is None: |
|
|
docling_num_threads = os.getenv("DOCLING_NUM_THREADS") |
|
|
omp_num_threads = os.getenv("OMP_NUM_THREADS") |
|
|
if docling_num_threads is None and omp_num_threads is not None: |
|
|
try: |
|
|
data["num_threads"] = int(omp_num_threads) |
|
|
except ValueError: |
|
|
_log.error( |
|
|
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'", |
|
|
omp_num_threads, |
|
|
) |
|
|
return data |
|
|
|
|
|
|
|
|
class TableFormerMode(str, Enum): |
|
|
"""Modes for the TableFormer model.""" |
|
|
|
|
|
FAST = "fast" |
|
|
ACCURATE = "accurate" |
|
|
|
|
|
|
|
|
class TableStructureOptions(BaseModel): |
|
|
"""Options for the table structure.""" |
|
|
|
|
|
do_cell_matching: bool = ( |
|
|
True |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
mode: TableFormerMode = TableFormerMode.FAST |
|
|
|
|
|
|
|
|
class OcrOptions(BaseModel): |
|
|
"""OCR options.""" |
|
|
|
|
|
kind: str |
|
|
lang: List[str] |
|
|
force_full_page_ocr: bool = False |
|
|
bitmap_area_threshold: float = ( |
|
|
0.05 |
|
|
) |
|
|
|
|
|
|
|
|
class RapidOcrOptions(OcrOptions): |
|
|
"""Options for the RapidOCR engine.""" |
|
|
|
|
|
kind: Literal["rapidocr"] = "rapidocr" |
|
|
|
|
|
|
|
|
lang: List[str] = [ |
|
|
"english", |
|
|
"chinese", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
text_score: float = 0.5 |
|
|
|
|
|
use_det: Optional[bool] = None |
|
|
use_cls: Optional[bool] = None |
|
|
use_rec: Optional[bool] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print_verbose: bool = False |
|
|
|
|
|
det_model_path: Optional[str] = None |
|
|
cls_model_path: Optional[str] = None |
|
|
rec_model_path: Optional[str] = None |
|
|
rec_keys_path: Optional[str] = None |
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
) |
|
|
|
|
|
|
|
|
class EasyOcrOptions(OcrOptions): |
|
|
"""Options for the EasyOCR engine.""" |
|
|
|
|
|
kind: Literal["easyocr"] = "easyocr" |
|
|
lang: List[str] = ["fr", "de", "es", "en"] |
|
|
|
|
|
use_gpu: Optional[bool] = None |
|
|
|
|
|
confidence_threshold: float = 0.5 |
|
|
|
|
|
model_storage_directory: Optional[str] = None |
|
|
recog_network: Optional[str] = "standard" |
|
|
download_enabled: bool = True |
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
protected_namespaces=(), |
|
|
) |
|
|
|
|
|
|
|
|
class TesseractCliOcrOptions(OcrOptions): |
|
|
"""Options for the TesseractCli engine.""" |
|
|
|
|
|
kind: Literal["tesseract"] = "tesseract" |
|
|
lang: List[str] = ["fra", "deu", "spa", "eng"] |
|
|
tesseract_cmd: str = "tesseract" |
|
|
path: Optional[str] = None |
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
) |
|
|
|
|
|
|
|
|
class TesseractOcrOptions(OcrOptions): |
|
|
"""Options for the Tesseract engine.""" |
|
|
|
|
|
kind: Literal["tesserocr"] = "tesserocr" |
|
|
lang: List[str] = ["fra", "deu", "spa", "eng"] |
|
|
path: Optional[str] = None |
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
) |
|
|
|
|
|
|
|
|
class OcrMacOptions(OcrOptions): |
|
|
"""Options for the Mac OCR engine.""" |
|
|
|
|
|
kind: Literal["ocrmac"] = "ocrmac" |
|
|
lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"] |
|
|
recognition: str = "accurate" |
|
|
framework: str = "vision" |
|
|
|
|
|
model_config = ConfigDict( |
|
|
extra="forbid", |
|
|
) |
|
|
|
|
|
|
|
|
class PictureDescriptionBaseOptions(BaseModel): |
|
|
kind: str |
|
|
batch_size: int = 8 |
|
|
scale: float = 2 |
|
|
|
|
|
bitmap_area_threshold: float = ( |
|
|
0.2 |
|
|
) |
|
|
|
|
|
|
|
|
class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): |
|
|
kind: Literal["api"] = "api" |
|
|
|
|
|
url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions") |
|
|
headers: Dict[str, str] = {} |
|
|
params: Dict[str, Any] = {} |
|
|
timeout: float = 20 |
|
|
|
|
|
prompt: str = "Describe this image in a few sentences." |
|
|
provenance: str = "" |
|
|
|
|
|
|
|
|
class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions): |
|
|
kind: Literal["vlm"] = "vlm" |
|
|
|
|
|
repo_id: str |
|
|
prompt: str = "Describe this image in a few sentences." |
|
|
|
|
|
generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False) |
|
|
|
|
|
@property |
|
|
def repo_cache_folder(self) -> str: |
|
|
return self.repo_id.replace("/", "--") |
|
|
|
|
|
|
|
|
smolvlm_picture_description = PictureDescriptionVlmOptions( |
|
|
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct" |
|
|
) |
|
|
|
|
|
granite_picture_description = PictureDescriptionVlmOptions( |
|
|
repo_id="ibm-granite/granite-vision-3.1-2b-preview", |
|
|
prompt="What is shown in this image?", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class PdfBackend(str, Enum): |
|
|
"""Enum of valid PDF backends.""" |
|
|
|
|
|
PYPDFIUM2 = "pypdfium2" |
|
|
DLPARSE_V1 = "dlparse_v1" |
|
|
DLPARSE_V2 = "dlparse_v2" |
|
|
|
|
|
|
|
|
|
|
|
class OcrEngine(str, Enum): |
|
|
"""Enum of valid OCR engines.""" |
|
|
|
|
|
EASYOCR = "easyocr" |
|
|
TESSERACT_CLI = "tesseract_cli" |
|
|
TESSERACT = "tesseract" |
|
|
OCRMAC = "ocrmac" |
|
|
RAPIDOCR = "rapidocr" |
|
|
|
|
|
|
|
|
class PipelineOptions(BaseModel): |
|
|
"""Base pipeline options.""" |
|
|
|
|
|
create_legacy_output: bool = ( |
|
|
True |
|
|
) |
|
|
document_timeout: Optional[float] = None |
|
|
accelerator_options: AcceleratorOptions = AcceleratorOptions() |
|
|
|
|
|
|
|
|
class PdfPipelineOptions(PipelineOptions): |
|
|
"""Options for the PDF pipeline.""" |
|
|
|
|
|
artifacts_path: Optional[Union[Path, str]] = None |
|
|
do_table_structure: bool = True |
|
|
do_ocr: bool = True |
|
|
do_code_enrichment: bool = False |
|
|
do_formula_enrichment: bool = False |
|
|
do_picture_classification: bool = False |
|
|
do_picture_description: bool = False |
|
|
|
|
|
table_structure_options: TableStructureOptions = TableStructureOptions() |
|
|
ocr_options: Union[ |
|
|
EasyOcrOptions, |
|
|
TesseractCliOcrOptions, |
|
|
TesseractOcrOptions, |
|
|
OcrMacOptions, |
|
|
RapidOcrOptions, |
|
|
] = Field(EasyOcrOptions(), discriminator="kind") |
|
|
picture_description_options: Annotated[ |
|
|
Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions], |
|
|
Field(discriminator="kind"), |
|
|
] = smolvlm_picture_description |
|
|
|
|
|
images_scale: float = 1.0 |
|
|
generate_page_images: bool = False |
|
|
generate_picture_images: bool = False |
|
|
generate_table_images: bool = Field( |
|
|
default=False, |
|
|
deprecated=( |
|
|
"Field `generate_table_images` is deprecated. " |
|
|
"To obtain table images, set `PdfPipelineOptions.generate_page_images = True` " |
|
|
"before conversion and then use the `TableItem.get_image` function." |
|
|
), |
|
|
) |
|
|
|