Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| import math # For ceil in input_lengths calculation | |
| import shutil # For cleaning up dummy data | |
| import logging | |
| import numpy as np | |
| from transformers import VisionEncoderDecoderModel, AutoImageProcessor, AutoTokenizer, TrainingArguments, Trainer | |
| from jiwer import cer # For CER calculation | |
| # --- Setup Logging --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # --- OCRDataset (Adapted for Hugging Face) --- | |
| class OCRDataset(Dataset): | |
| def __init__(self, root_dir): | |
| self.root_dir = root_dir | |
| self.image_dir = os.path.join(root_dir, "images") | |
| # Load the JSON mapping file | |
| mapping_file_path = os.path.join(root_dir, "annotations.json") # Assuming the JSON is named annotations.json | |
| logging.info(f"Loading mapping file from: {mapping_file_path}") | |
| with open(mapping_file_path, 'r', encoding='utf-8') as f: | |
| self.data = json.load(f) | |
| logging.info(f"Loaded {len(self.data)} entries from mapping file.") | |
| # Store image filenames and their corresponding texts | |
| self.image_filenames = list(self.data.keys()) | |
| def __len__(self): | |
| return len(self.image_filenames) | |
| def __getitem__(self, idx): | |
| image_filename = self.image_filenames[idx] | |
| text = self.data[image_filename] | |
| image_path = os.path.join(self.image_dir, image_filename) | |
| image = Image.open(image_path).convert("RGB") # Ensure image is in RGB format | |
| # Return raw PIL Image and text string | |
| return image, text | |
| # --- Custom Collate Function for Hugging Face Processors --- | |
| # This function will be passed to the DataLoader | |
| def collate_fn_hf(batch, image_processor, tokenizer): | |
| images, texts = zip(*batch) | |
| # Process images using AutoImageProcessor | |
| # This handles resizing, normalization, and converting to tensor | |
| pixel_values = image_processor(images=list(images), return_tensors="pt").pixel_values | |
| # Tokenize texts using AutoTokenizer | |
| # This handles tokenization, padding, and converting to tensor | |
| labels = tokenizer(text=list(texts), padding="longest", return_tensors="pt").input_ids | |
| # Return a dictionary expected by the Hugging Face Trainer | |
| return {"pixel_values": pixel_values, "labels": labels} | |
| # --- Define compute_metrics for Trainer --- | |
| def compute_metrics(pred): | |
| labels_ids = pred.label_ids | |
| pred_logits = pred.predictions[0] | |
| pred_ids = np.argmax(pred_logits, axis=-1) | |
| # Replace -100 in labels as we can't decode them (they are padding tokens) | |
| labels_ids[labels_ids == -100] = tokenizer.pad_token_id | |
| pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
| label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) | |
| # Calculate CER | |
| cer_score = cer(label_str, pred_str) | |
| logging.info(f"Validation CER: {cer_score}") | |
| return {"cer": cer_score} | |
| # --- Main Training Script --- | |
| if __name__ == '__main__': | |
| logging.info("Starting OCR training script.") | |
| data_root_dir = "text_dataset" | |
| logging.info(f"Using dataset at: {os.path.abspath(data_root_dir)}") | |
| # --- Hugging Face Model and Processor Loading --- | |
| #encoder_id = "google/mobilenet_v3_small_100_224" | |
| encoder_id = "google/vit-base-patch16-224-in21k" | |
| decoder_id = "prajjwal1/bert-tiny" | |
| logging.info(f"Loading encoder: {encoder_id}") | |
| logging.info(f"Loading decoder: {decoder_id}") | |
| model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( | |
| encoder_pretrained_model_name_or_path=encoder_id, | |
| decoder_pretrained_model_name_or_path=decoder_id, | |
| ) | |
| image_processor = AutoImageProcessor.from_pretrained(encoder_id) | |
| tokenizer = AutoTokenizer.from_pretrained(decoder_id) | |
| logging.info("Model, image processor, and tokenizer loaded.") | |
| # --- Set special tokens and generation parameters --- | |
| model.config.decoder_start_token_id = tokenizer.cls_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.config.vocab_size = tokenizer.vocab_size # Ensure model knows decoder vocab size | |
| model.config.eos_token_id = tokenizer.sep_token_id | |
| model.config.max_length = 64 | |
| model.config.early_stopping = True | |
| model.config.no_repeat_ngram_size = 3 | |
| model.config.length_penalty = 2.0 | |
| model.config.num_beams = 4 | |
| logging.info("Model configuration set.") | |
| # --- Dataset and DataLoader Setup --- | |
| logging.info("Setting up datasets.") | |
| train_dataset = OCRDataset(root_dir=data_root_dir) | |
| # For a real project, you'd split your data into train/val/test | |
| # For this example, we'll use the same dummy data for simplicity | |
| val_dataset = OCRDataset(root_dir=data_root_dir) | |
| logging.info(f"Training dataset size: {len(train_dataset)}") | |
| logging.info(f"Validation dataset size: {len(val_dataset)}") | |
| # --- Training Arguments --- | |
| training_args = TrainingArguments( | |
| output_dir="./ocr_model_output", # Output directory for checkpoints and logs | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| num_train_epochs=3, # Small number for quick demo | |
| logging_dir="./logs", | |
| logging_steps=10, | |
| # save_steps=500, # Save checkpoint every 500 steps | |
| eval_strategy ="epoch", # Evaluate at the end of each epoch | |
| save_strategy ="epoch", # Evaluate at the end of each epoch | |
| save_total_limit=2, # Only keep the last 2 checkpoints | |
| report_to="none", # Disable reporting to W&B, MLflow etc. for simplicity | |
| # predict_with_generate=True, # Crucial for generation tasks (uses model.generate() for eval) | |
| load_best_model_at_end=True, # Load the best model based on eval_loss at the end of training | |
| metric_for_best_model="cer", # Metric to monitor for best model | |
| greater_is_better=False, # Lower CER is better | |
| ) | |
| logging.info("Training arguments set.") | |
| # --- Trainer Initialization --- | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| # Pass image_processor and tokenizer to collate_fn using a lambda | |
| data_collator=lambda batch: collate_fn_hf(batch, image_processor, tokenizer), | |
| compute_metrics=compute_metrics, | |
| ) | |
| logging.info("Trainer initialized.") | |
| # --- Start Training --- | |
| logging.info("--- Starting Training ---") | |
| trainer.train() | |
| logging.info("--- Training finished! ---") | |