File size: 4,950 Bytes
6a6918c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer
import os
import numpy as np
from typing import Union

# --- Configuration ---
MODEL_PATH = "./ocr_model_output/checkpoint-441"

class OCRInference:
    """A class to perform OCR inference using a trained model."""

    def __init__(self, model_path: str, encoder_id: str = "google/vit-base-patch16-224-in21k", decoder_id: str = "prajjwal1/bert-tiny"):
        """
        Initializes the OCRInference class by loading the model, image processor, and tokenizer.

        Args:
            model_path (str): The path to the trained model checkpoint.
            encoder_id (str): The encoder ID to load the image processor from.
            decoder_id (str): The decoder ID to load the tokenizer from.
        """
        print(f"Loading model from: {model_path}")
        self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
        
        # Load image processor and save it if not present
        try:
            self.image_processor = AutoImageProcessor.from_pretrained(model_path)
        except OSError:
            print("Image processor not found locally. Loading from encoder ID and saving.")
            self.image_processor = AutoImageProcessor.from_pretrained(encoder_id)
            self.image_processor.save_pretrained(model_path)
            
        # Load tokenizer and save it if not present
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        except (KeyError, OSError):
            print("Tokenizer not found locally. Loading from decoder ID and saving.")
            self.tokenizer = AutoTokenizer.from_pretrained(decoder_id)
            self.tokenizer.save_pretrained(model_path)

        # --- Set special tokens and generation parameters ---
        self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.config.vocab_size = self.tokenizer.vocab_size

        self.model.config.eos_token_id = self.tokenizer.sep_token_id
        self.model.config.max_length = 64
        self.model.config.early_stopping = True
        self.model.config.no_repeat_ngram_size = 3
        self.model.config.length_penalty = 2.0
        self.model.config.num_beams = 4

        print("Model, image processor, and tokenizer loaded.")

    def perform_inference(self, image_input: Union[str, np.ndarray]) -> str:
        """
        Performs inference on a single image, which can be a file path or a NumPy array.

        Args:
            image_input (Union[str, np.ndarray]): Path to the input image or a NumPy array representing the image.

        Returns:
            str: The predicted text.
        """
        if isinstance(image_input, str):
            if not os.path.exists(image_input):
                raise FileNotFoundError(f"Image file not found at: {image_input}")
            image = Image.open(image_input).convert("RGB")
        elif isinstance(image_input, np.ndarray):
            image = Image.fromarray(image_input).convert("RGB")
        else:
            raise TypeError("image_input must be a file path (str) or a NumPy array.")

        # Process the image
        pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values

        # Generate text
        with torch.no_grad():
            output_ids = self.model.generate(pixel_values, max_length=64, num_beams=4, early_stopping=True)

        # Decode the generated ids to text
        preds = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return preds

if __name__ == '__main__':
    # Provide a path to an image for inference
    # Using an example image from the dataset
    image_path = "../ai_augment_output/20250901_115123_336458_ccd9d646-fc99-4d27-8076-0c17d0dba784.png"

    # --- Initialize the Inference Class ---
    ocr_engine = OCRInference(model_path=MODEL_PATH)

    # --- Perform Inference from a file path ---
    try:
        predicted_text = ocr_engine.perform_inference(image_path)
        print(f"\n--- Inference from file path ---")
        print(f"Image: {image_path}")
        print(f"Predicted Text: {predicted_text}")
    except FileNotFoundError as e:
        print(e)
        print("Please update the 'image_path' variable in the script with a valid image path.")

    # --- Perform Inference from a NumPy array (example) ---
    try:
        # Create a dummy numpy array for demonstration
        if os.path.exists(image_path):
            dummy_image_array = np.array(Image.open(image_path))
            predicted_text_from_array = ocr_engine.perform_inference(dummy_image_array)
            print(f"\n--- Inference from NumPy array ---")
            print(f"Predicted Text: {predicted_text_from_array}")
    except Exception as e:
        print(f"An error occurred during inference from NumPy array: {e}")