File size: 6,576 Bytes
6a6918c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import math # For ceil in input_lengths calculation
import shutil # For cleaning up dummy data
import logging
import numpy as np

from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, TrainingArguments, Trainer
from jiwer import cer # For CER calculation

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- OCRDataset (Adapted for Hugging Face) ---
class OCRDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, "images")

        # Load the JSON mapping file
        mapping_file_path = os.path.join(root_dir, "annotations.json") # Assuming the JSON is named annotations.json
        logging.info(f"Loading mapping file from: {mapping_file_path}")
        with open(mapping_file_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        logging.info(f"Loaded {len(self.data)} entries from mapping file.")

        # Store image filenames and their corresponding texts
        self.image_filenames = list(self.data.keys())

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        text = self.data[image_filename]

        image_path = os.path.join(self.image_dir, image_filename)
        image = Image.open(image_path).convert("RGB") # Ensure image is in RGB format

        # Return raw PIL Image and text string
        return image, text

# --- Custom Collate Function for Hugging Face Processors ---
# This function will be passed to the DataLoader
def collate_fn_hf(batch, image_processor, tokenizer):
    images, texts = zip(*batch)

    # Process images using AutoImageProcessor
    # This handles resizing, normalization, and converting to tensor
    pixel_values = image_processor(images=list(images), return_tensors="pt").pixel_values

    # Tokenize texts using AutoTokenizer
    # This handles tokenization, padding, and converting to tensor
    labels = tokenizer(text=list(texts), padding="longest", return_tensors="pt").input_ids

    # Return a dictionary expected by the Hugging Face Trainer
    return {"pixel_values": pixel_values, "labels": labels}

# --- Define compute_metrics for Trainer ---
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_logits = pred.predictions[0]
    pred_ids = np.argmax(pred_logits, axis=-1)

    # Replace -100 in labels as we can't decode them (they are padding tokens)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    # Calculate CER
    cer_score = cer(label_str, pred_str)
    logging.info(f"Validation CER: {cer_score}")
    return {"cer": cer_score}

# --- Main Training Script ---
if __name__ == '__main__':
    logging.info("Starting OCR training script.")
    data_root_dir = "text_dataset"
    logging.info(f"Using dataset at: {os.path.abspath(data_root_dir)}")

    # --- Hugging Face Model and Processor Loading ---
    #encoder_id = "google/mobilenet_v3_small_100_224"
    encoder_id = "google/vit-base-patch16-224-in21k"
    decoder_id = "prajjwal1/bert-tiny"

    logging.info(f"Loading encoder: {encoder_id}")
    logging.info(f"Loading decoder: {decoder_id}")

    model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        encoder_pretrained_model_name_or_path=encoder_id,
        decoder_pretrained_model_name_or_path=decoder_id,
    )

    image_processor = AutoImageProcessor.from_pretrained(encoder_id)
    tokenizer = AutoTokenizer.from_pretrained(decoder_id)
    logging.info("Model, image processor, and tokenizer loaded.")

    # --- Set special tokens and generation parameters ---
    model.config.decoder_start_token_id = tokenizer.cls_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.vocab_size = tokenizer.vocab_size # Ensure model knows decoder vocab size

    model.config.eos_token_id = tokenizer.sep_token_id
    model.config.max_length = 64
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4
    logging.info("Model configuration set.")

    # --- Dataset and DataLoader Setup ---
    logging.info("Setting up datasets.")
    train_dataset = OCRDataset(root_dir=data_root_dir)
    # For a real project, you'd split your data into train/val/test
    # For this example, we'll use the same dummy data for simplicity
    val_dataset = OCRDataset(root_dir=data_root_dir)
    logging.info(f"Training dataset size: {len(train_dataset)}")
    logging.info(f"Validation dataset size: {len(val_dataset)}")

    # --- Training Arguments ---
    training_args = TrainingArguments(
        output_dir="./ocr_model_output", # Output directory for checkpoints and logs
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_train_epochs=3, # Small number for quick demo
        logging_dir="./logs",
        logging_steps=10,
        # save_steps=500, # Save checkpoint every 500 steps
        eval_strategy ="epoch", # Evaluate at the end of each epoch
        save_strategy ="epoch", # Evaluate at the end of each epoch
        save_total_limit=2, # Only keep the last 2 checkpoints
        report_to="none", # Disable reporting to W&B, MLflow etc. for simplicity
        # predict_with_generate=True, # Crucial for generation tasks (uses model.generate() for eval)
        load_best_model_at_end=True, # Load the best model based on eval_loss at the end of training
        metric_for_best_model="cer", # Metric to monitor for best model
        greater_is_better=False, # Lower CER is better
    )
    logging.info("Training arguments set.")

    # --- Trainer Initialization ---
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # Pass image_processor and tokenizer to collate_fn using a lambda
        data_collator=lambda batch: collate_fn_hf(batch, image_processor, tokenizer),
        compute_metrics=compute_metrics,
    )
    logging.info("Trainer initialized.")

    # --- Start Training ---
    logging.info("--- Starting Training ---")
    trainer.train()
    logging.info("--- Training finished! ---")