""" train.py — Finetune a Hugging Face vision model (e.g., ViT) on breast ultrasound images """ from datasets import load_dataset from transformers import ( AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, ) import evaluate import numpy as np import torch # ---- 1. Load dataset ---- dataset = load_dataset("gymprathap/Breast-Cancer-Ultrasound-Images-Dataset") # Dataset info labels = dataset["train"].features["label"].names num_labels = len(labels) print(f"Classes: {labels}") # ---- 2. Preprocessing ---- checkpoint = "google/vit-base-patch16-224-in21k" # choose your model image_processor = AutoImageProcessor.from_pretrained(checkpoint) def transform_examples(examples): images = [img.convert("RGB") for img in examples["image"]] # ensure 3-channel inputs = image_processor(images, return_tensors="pt") inputs["labels"] = examples["label"] return inputs prepared_ds = dataset.with_transform(transform_examples) # Split dataset splits = prepared_ds["train"].train_test_split(test_size=0.2, seed=42) train_ds, val_ds = splits["train"], splits["test"] # ---- 3. Load model ---- model = AutoModelForImageClassification.from_pretrained( checkpoint, num_labels=num_labels, ignore_mismatched_sizes=True, # handles final layer shape mismatch ) # ---- 4. Metrics ---- accuracy = evaluate.load("accuracy") f1 = evaluate.load("f1") def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) acc = accuracy.compute(predictions=preds, references=labels)["accuracy"] f1_score = f1.compute(predictions=preds, references=labels, average="macro")["f1"] return {"accuracy": acc, "f1": f1_score} # ---- 5. Training setup ---- training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=8, per_device_eval_batch_size=8, eval_strategy="epoch", save_strategy="epoch", num_train_epochs=3, learning_rate=5e-5, logging_dir="./logs", load_best_model_at_end=True, remove_unused_columns=False, push_to_hub=True, hub_model_id="hugging-science/sample-breast-cancer-classification", report_to="none", ) # ---- 6. Trainer ---- trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, tokenizer=image_processor, compute_metrics=compute_metrics, ) # ---- 7. Train ---- trainer.train() # ---- 8. Save locally ---- model.save_pretrained("./finetuned-ultrasound-model") image_processor.save_pretrained("./finetuned-ultrasound-model") print("✅ Training complete. Model saved to ./finetuned-ultrasound-model")