PauloFN's picture
first
6a6918c
raw
history blame
6.58 kB
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import math # For ceil in input_lengths calculation
import shutil # For cleaning up dummy data
import logging
import numpy as np
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, TrainingArguments, Trainer
from jiwer import cer # For CER calculation
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- OCRDataset (Adapted for Hugging Face) ---
class OCRDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.image_dir = os.path.join(root_dir, "images")
# Load the JSON mapping file
mapping_file_path = os.path.join(root_dir, "annotations.json") # Assuming the JSON is named annotations.json
logging.info(f"Loading mapping file from: {mapping_file_path}")
with open(mapping_file_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
logging.info(f"Loaded {len(self.data)} entries from mapping file.")
# Store image filenames and their corresponding texts
self.image_filenames = list(self.data.keys())
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
image_filename = self.image_filenames[idx]
text = self.data[image_filename]
image_path = os.path.join(self.image_dir, image_filename)
image = Image.open(image_path).convert("RGB") # Ensure image is in RGB format
# Return raw PIL Image and text string
return image, text
# --- Custom Collate Function for Hugging Face Processors ---
# This function will be passed to the DataLoader
def collate_fn_hf(batch, image_processor, tokenizer):
images, texts = zip(*batch)
# Process images using AutoImageProcessor
# This handles resizing, normalization, and converting to tensor
pixel_values = image_processor(images=list(images), return_tensors="pt").pixel_values
# Tokenize texts using AutoTokenizer
# This handles tokenization, padding, and converting to tensor
labels = tokenizer(text=list(texts), padding="longest", return_tensors="pt").input_ids
# Return a dictionary expected by the Hugging Face Trainer
return {"pixel_values": pixel_values, "labels": labels}
# --- Define compute_metrics for Trainer ---
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_logits = pred.predictions[0]
pred_ids = np.argmax(pred_logits, axis=-1)
# Replace -100 in labels as we can't decode them (they are padding tokens)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
# Calculate CER
cer_score = cer(label_str, pred_str)
logging.info(f"Validation CER: {cer_score}")
return {"cer": cer_score}
# --- Main Training Script ---
if __name__ == '__main__':
logging.info("Starting OCR training script.")
data_root_dir = "text_dataset"
logging.info(f"Using dataset at: {os.path.abspath(data_root_dir)}")
# --- Hugging Face Model and Processor Loading ---
#encoder_id = "google/mobilenet_v3_small_100_224"
encoder_id = "google/vit-base-patch16-224-in21k"
decoder_id = "prajjwal1/bert-tiny"
logging.info(f"Loading encoder: {encoder_id}")
logging.info(f"Loading decoder: {decoder_id}")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=encoder_id,
decoder_pretrained_model_name_or_path=decoder_id,
)
image_processor = AutoImageProcessor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
logging.info("Model, image processor, and tokenizer loaded.")
# --- Set special tokens and generation parameters ---
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = tokenizer.vocab_size # Ensure model knows decoder vocab size
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
logging.info("Model configuration set.")
# --- Dataset and DataLoader Setup ---
logging.info("Setting up datasets.")
train_dataset = OCRDataset(root_dir=data_root_dir)
# For a real project, you'd split your data into train/val/test
# For this example, we'll use the same dummy data for simplicity
val_dataset = OCRDataset(root_dir=data_root_dir)
logging.info(f"Training dataset size: {len(train_dataset)}")
logging.info(f"Validation dataset size: {len(val_dataset)}")
# --- Training Arguments ---
training_args = TrainingArguments(
output_dir="./ocr_model_output", # Output directory for checkpoints and logs
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3, # Small number for quick demo
logging_dir="./logs",
logging_steps=10,
# save_steps=500, # Save checkpoint every 500 steps
eval_strategy ="epoch", # Evaluate at the end of each epoch
save_strategy ="epoch", # Evaluate at the end of each epoch
save_total_limit=2, # Only keep the last 2 checkpoints
report_to="none", # Disable reporting to W&B, MLflow etc. for simplicity
# predict_with_generate=True, # Crucial for generation tasks (uses model.generate() for eval)
load_best_model_at_end=True, # Load the best model based on eval_loss at the end of training
metric_for_best_model="cer", # Metric to monitor for best model
greater_is_better=False, # Lower CER is better
)
logging.info("Training arguments set.")
# --- Trainer Initialization ---
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
# Pass image_processor and tokenizer to collate_fn using a lambda
data_collator=lambda batch: collate_fn_hf(batch, image_processor, tokenizer),
compute_metrics=compute_metrics,
)
logging.info("Trainer initialized.")
# --- Start Training ---
logging.info("--- Starting Training ---")
trainer.train()
logging.info("--- Training finished! ---")