PaperShow
/
Paper2Video
/src
/evaluation
/PresentQuiz
/docling
/models
/document_picture_classifier.py
| from pathlib import Path | |
| from typing import Iterable, List, Literal, Optional, Tuple, Union | |
| import numpy as np | |
| from docling_core.types.doc import ( | |
| DoclingDocument, | |
| NodeItem, | |
| PictureClassificationClass, | |
| PictureClassificationData, | |
| PictureItem, | |
| ) | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| from docling.datamodel.pipeline_options import AcceleratorOptions | |
| from docling.models.base_model import BaseEnrichmentModel | |
| from docling.utils.accelerator_utils import decide_device | |
| class DocumentPictureClassifierOptions(BaseModel): | |
| """ | |
| Options for configuring the DocumentPictureClassifier. | |
| Attributes | |
| ---------- | |
| kind : Literal["document_picture_classifier"] | |
| Identifier for the type of classifier. | |
| """ | |
| kind: Literal["document_picture_classifier"] = "document_picture_classifier" | |
| class DocumentPictureClassifier(BaseEnrichmentModel): | |
| """ | |
| A model for classifying pictures in documents. | |
| This class enriches document pictures with predicted classifications | |
| based on a predefined set of classes. | |
| Attributes | |
| ---------- | |
| enabled : bool | |
| Whether the classifier is enabled for use. | |
| options : DocumentPictureClassifierOptions | |
| Configuration options for the classifier. | |
| document_picture_classifier : DocumentPictureClassifierPredictor | |
| The underlying prediction model, loaded if the classifier is enabled. | |
| Methods | |
| ------- | |
| __init__(enabled, artifacts_path, options, accelerator_options) | |
| Initializes the classifier with specified configurations. | |
| is_processable(doc, element) | |
| Checks if the given element can be processed by the classifier. | |
| __call__(doc, element_batch) | |
| Processes a batch of elements and adds classification annotations. | |
| """ | |
| _model_repo_folder = "ds4sd--DocumentFigureClassifier" | |
| images_scale = 2 | |
| def __init__( | |
| self, | |
| enabled: bool, | |
| artifacts_path: Optional[Path], | |
| options: DocumentPictureClassifierOptions, | |
| accelerator_options: AcceleratorOptions, | |
| ): | |
| """ | |
| Initializes the DocumentPictureClassifier. | |
| Parameters | |
| ---------- | |
| enabled : bool | |
| Indicates whether the classifier is enabled. | |
| artifacts_path : Optional[Union[Path, str]], | |
| Path to the directory containing model artifacts. | |
| options : DocumentPictureClassifierOptions | |
| Configuration options for the classifier. | |
| accelerator_options : AcceleratorOptions | |
| Options for configuring the device and parallelism. | |
| """ | |
| self.enabled = enabled | |
| self.options = options | |
| if self.enabled: | |
| device = decide_device(accelerator_options.device) | |
| from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import ( | |
| DocumentFigureClassifierPredictor, | |
| ) | |
| if artifacts_path is None: | |
| artifacts_path = self.download_models() | |
| else: | |
| artifacts_path = artifacts_path / self._model_repo_folder | |
| self.document_picture_classifier = DocumentFigureClassifierPredictor( | |
| artifacts_path=str(artifacts_path), | |
| device=device, | |
| num_threads=accelerator_options.num_threads, | |
| ) | |
| def download_models( | |
| local_dir: Optional[Path] = None, force: bool = False, progress: bool = False | |
| ) -> Path: | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub.utils import disable_progress_bars | |
| if not progress: | |
| disable_progress_bars() | |
| download_path = snapshot_download( | |
| repo_id="ds4sd/DocumentFigureClassifier", | |
| force_download=force, | |
| local_dir=local_dir, | |
| revision="v1.0.0", | |
| ) | |
| return Path(download_path) | |
| def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: | |
| """ | |
| Determines if the given element can be processed by the classifier. | |
| Parameters | |
| ---------- | |
| doc : DoclingDocument | |
| The document containing the element. | |
| element : NodeItem | |
| The element to be checked. | |
| Returns | |
| ------- | |
| bool | |
| True if the element is a PictureItem and processing is enabled; False otherwise. | |
| """ | |
| return self.enabled and isinstance(element, PictureItem) | |
| def __call__( | |
| self, | |
| doc: DoclingDocument, | |
| element_batch: Iterable[NodeItem], | |
| ) -> Iterable[NodeItem]: | |
| """ | |
| Processes a batch of elements and enriches them with classification predictions. | |
| Parameters | |
| ---------- | |
| doc : DoclingDocument | |
| The document containing the elements to be processed. | |
| element_batch : Iterable[NodeItem] | |
| A batch of pictures to classify. | |
| Returns | |
| ------- | |
| Iterable[NodeItem] | |
| An iterable of NodeItem objects after processing. The field | |
| 'data.classification' is added containing the classification for each picture. | |
| """ | |
| if not self.enabled: | |
| for element in element_batch: | |
| yield element | |
| return | |
| images: List[Union[Image.Image, np.ndarray]] = [] | |
| elements: List[PictureItem] = [] | |
| for el in element_batch: | |
| assert isinstance(el, PictureItem) | |
| elements.append(el) | |
| img = el.get_image(doc) | |
| assert img is not None | |
| images.append(img) | |
| outputs = self.document_picture_classifier.predict(images) | |
| for element, output in zip(elements, outputs): | |
| element.annotations.append( | |
| PictureClassificationData( | |
| provenance="DocumentPictureClassifier", | |
| predicted_classes=[ | |
| PictureClassificationClass( | |
| class_name=pred[0], | |
| confidence=pred[1], | |
| ) | |
| for pred in output | |
| ], | |
| ) | |
| ) | |
| yield element | |