Spaces:
Running
Running
File size: 6,576 Bytes
6a6918c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import math # For ceil in input_lengths calculation
import shutil # For cleaning up dummy data
import logging
import numpy as np
from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, TrainingArguments, Trainer
from jiwer import cer # For CER calculation
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- OCRDataset (Adapted for Hugging Face) ---
class OCRDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.image_dir = os.path.join(root_dir, "images")
# Load the JSON mapping file
mapping_file_path = os.path.join(root_dir, "annotations.json") # Assuming the JSON is named annotations.json
logging.info(f"Loading mapping file from: {mapping_file_path}")
with open(mapping_file_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
logging.info(f"Loaded {len(self.data)} entries from mapping file.")
# Store image filenames and their corresponding texts
self.image_filenames = list(self.data.keys())
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
image_filename = self.image_filenames[idx]
text = self.data[image_filename]
image_path = os.path.join(self.image_dir, image_filename)
image = Image.open(image_path).convert("RGB") # Ensure image is in RGB format
# Return raw PIL Image and text string
return image, text
# --- Custom Collate Function for Hugging Face Processors ---
# This function will be passed to the DataLoader
def collate_fn_hf(batch, image_processor, tokenizer):
images, texts = zip(*batch)
# Process images using AutoImageProcessor
# This handles resizing, normalization, and converting to tensor
pixel_values = image_processor(images=list(images), return_tensors="pt").pixel_values
# Tokenize texts using AutoTokenizer
# This handles tokenization, padding, and converting to tensor
labels = tokenizer(text=list(texts), padding="longest", return_tensors="pt").input_ids
# Return a dictionary expected by the Hugging Face Trainer
return {"pixel_values": pixel_values, "labels": labels}
# --- Define compute_metrics for Trainer ---
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_logits = pred.predictions[0]
pred_ids = np.argmax(pred_logits, axis=-1)
# Replace -100 in labels as we can't decode them (they are padding tokens)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
# Calculate CER
cer_score = cer(label_str, pred_str)
logging.info(f"Validation CER: {cer_score}")
return {"cer": cer_score}
# --- Main Training Script ---
if __name__ == '__main__':
logging.info("Starting OCR training script.")
data_root_dir = "text_dataset"
logging.info(f"Using dataset at: {os.path.abspath(data_root_dir)}")
# --- Hugging Face Model and Processor Loading ---
#encoder_id = "google/mobilenet_v3_small_100_224"
encoder_id = "google/vit-base-patch16-224-in21k"
decoder_id = "prajjwal1/bert-tiny"
logging.info(f"Loading encoder: {encoder_id}")
logging.info(f"Loading decoder: {decoder_id}")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=encoder_id,
decoder_pretrained_model_name_or_path=decoder_id,
)
image_processor = AutoImageProcessor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
logging.info("Model, image processor, and tokenizer loaded.")
# --- Set special tokens and generation parameters ---
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = tokenizer.vocab_size # Ensure model knows decoder vocab size
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
logging.info("Model configuration set.")
# --- Dataset and DataLoader Setup ---
logging.info("Setting up datasets.")
train_dataset = OCRDataset(root_dir=data_root_dir)
# For a real project, you'd split your data into train/val/test
# For this example, we'll use the same dummy data for simplicity
val_dataset = OCRDataset(root_dir=data_root_dir)
logging.info(f"Training dataset size: {len(train_dataset)}")
logging.info(f"Validation dataset size: {len(val_dataset)}")
# --- Training Arguments ---
training_args = TrainingArguments(
output_dir="./ocr_model_output", # Output directory for checkpoints and logs
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3, # Small number for quick demo
logging_dir="./logs",
logging_steps=10,
# save_steps=500, # Save checkpoint every 500 steps
eval_strategy ="epoch", # Evaluate at the end of each epoch
save_strategy ="epoch", # Evaluate at the end of each epoch
save_total_limit=2, # Only keep the last 2 checkpoints
report_to="none", # Disable reporting to W&B, MLflow etc. for simplicity
# predict_with_generate=True, # Crucial for generation tasks (uses model.generate() for eval)
load_best_model_at_end=True, # Load the best model based on eval_loss at the end of training
metric_for_best_model="cer", # Metric to monitor for best model
greater_is_better=False, # Lower CER is better
)
logging.info("Training arguments set.")
# --- Trainer Initialization ---
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
# Pass image_processor and tokenizer to collate_fn using a lambda
data_collator=lambda batch: collate_fn_hf(batch, image_processor, tokenizer),
compute_metrics=compute_metrics,
)
logging.info("Trainer initialized.")
# --- Start Training ---
logging.info("--- Starting Training ---")
trainer.train()
logging.info("--- Training finished! ---")
|