import os import json from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import math # For ceil in input_lengths calculation import shutil # For cleaning up dummy data import logging import numpy as np from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, TrainingArguments, Trainer from jiwer import cer # For CER calculation # --- Setup Logging --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- OCRDataset (Adapted for Hugging Face) --- class OCRDataset(Dataset): def __init__(self, root_dir): self.root_dir = root_dir self.image_dir = os.path.join(root_dir, "images") # Load the JSON mapping file mapping_file_path = os.path.join(root_dir, "annotations.json") # Assuming the JSON is named annotations.json logging.info(f"Loading mapping file from: {mapping_file_path}") with open(mapping_file_path, 'r', encoding='utf-8') as f: self.data = json.load(f) logging.info(f"Loaded {len(self.data)} entries from mapping file.") # Store image filenames and their corresponding texts self.image_filenames = list(self.data.keys()) def __len__(self): return len(self.image_filenames) def __getitem__(self, idx): image_filename = self.image_filenames[idx] text = self.data[image_filename] image_path = os.path.join(self.image_dir, image_filename) image = Image.open(image_path).convert("RGB") # Ensure image is in RGB format # Return raw PIL Image and text string return image, text # --- Custom Collate Function for Hugging Face Processors --- # This function will be passed to the DataLoader def collate_fn_hf(batch, image_processor, tokenizer): images, texts = zip(*batch) # Process images using AutoImageProcessor # This handles resizing, normalization, and converting to tensor pixel_values = image_processor(images=list(images), return_tensors="pt").pixel_values # Tokenize texts using AutoTokenizer # This handles tokenization, padding, and converting to tensor labels = tokenizer(text=list(texts), padding="longest", return_tensors="pt").input_ids # Return a dictionary expected by the Hugging Face Trainer return {"pixel_values": pixel_values, "labels": labels} # --- Define compute_metrics for Trainer --- def compute_metrics(pred): labels_ids = pred.label_ids pred_logits = pred.predictions[0] pred_ids = np.argmax(pred_logits, axis=-1) # Replace -100 in labels as we can't decode them (they are padding tokens) labels_ids[labels_ids == -100] = tokenizer.pad_token_id pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) # Calculate CER cer_score = cer(label_str, pred_str) logging.info(f"Validation CER: {cer_score}") return {"cer": cer_score} # --- Main Training Script --- if __name__ == '__main__': logging.info("Starting OCR training script.") data_root_dir = "text_dataset" logging.info(f"Using dataset at: {os.path.abspath(data_root_dir)}") # --- Hugging Face Model and Processor Loading --- #encoder_id = "google/mobilenet_v3_small_100_224" encoder_id = "google/vit-base-patch16-224-in21k" decoder_id = "prajjwal1/bert-tiny" logging.info(f"Loading encoder: {encoder_id}") logging.info(f"Loading decoder: {decoder_id}") model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( encoder_pretrained_model_name_or_path=encoder_id, decoder_pretrained_model_name_or_path=decoder_id, ) image_processor = AutoImageProcessor.from_pretrained(encoder_id) tokenizer = AutoTokenizer.from_pretrained(decoder_id) logging.info("Model, image processor, and tokenizer loaded.") # --- Set special tokens and generation parameters --- model.config.decoder_start_token_id = tokenizer.cls_token_id model.config.pad_token_id = tokenizer.pad_token_id model.config.vocab_size = tokenizer.vocab_size # Ensure model knows decoder vocab size model.config.eos_token_id = tokenizer.sep_token_id model.config.max_length = 64 model.config.early_stopping = True model.config.no_repeat_ngram_size = 3 model.config.length_penalty = 2.0 model.config.num_beams = 4 logging.info("Model configuration set.") # --- Dataset and DataLoader Setup --- logging.info("Setting up datasets.") train_dataset = OCRDataset(root_dir=data_root_dir) # For a real project, you'd split your data into train/val/test # For this example, we'll use the same dummy data for simplicity val_dataset = OCRDataset(root_dir=data_root_dir) logging.info(f"Training dataset size: {len(train_dataset)}") logging.info(f"Validation dataset size: {len(val_dataset)}") # --- Training Arguments --- training_args = TrainingArguments( output_dir="./ocr_model_output", # Output directory for checkpoints and logs per_device_train_batch_size=2, per_device_eval_batch_size=2, num_train_epochs=3, # Small number for quick demo logging_dir="./logs", logging_steps=10, # save_steps=500, # Save checkpoint every 500 steps eval_strategy ="epoch", # Evaluate at the end of each epoch save_strategy ="epoch", # Evaluate at the end of each epoch save_total_limit=2, # Only keep the last 2 checkpoints report_to="none", # Disable reporting to W&B, MLflow etc. for simplicity # predict_with_generate=True, # Crucial for generation tasks (uses model.generate() for eval) load_best_model_at_end=True, # Load the best model based on eval_loss at the end of training metric_for_best_model="cer", # Metric to monitor for best model greater_is_better=False, # Lower CER is better ) logging.info("Training arguments set.") # --- Trainer Initialization --- trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, # Pass image_processor and tokenizer to collate_fn using a lambda data_collator=lambda batch: collate_fn_hf(batch, image_processor, tokenizer), compute_metrics=compute_metrics, ) logging.info("Trainer initialized.") # --- Start Training --- logging.info("--- Starting Training ---") trainer.train() logging.info("--- Training finished! ---")