cgeorgiaw's picture
cgeorgiaw HF Staff
first push
49e2b29
"""
train.py β€” Finetune a Hugging Face vision model (e.g., ViT) on breast ultrasound images
"""
from datasets import load_dataset
from transformers import (
AutoImageProcessor,
AutoModelForImageClassification,
TrainingArguments,
Trainer,
)
import evaluate
import numpy as np
import torch
# ---- 1. Load dataset ----
dataset = load_dataset("gymprathap/Breast-Cancer-Ultrasound-Images-Dataset")
# Dataset info
labels = dataset["train"].features["label"].names
num_labels = len(labels)
print(f"Classes: {labels}")
# ---- 2. Preprocessing ----
checkpoint = "google/vit-base-patch16-224-in21k" # choose your model
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
def transform_examples(examples):
images = [img.convert("RGB") for img in examples["image"]] # ensure 3-channel
inputs = image_processor(images, return_tensors="pt")
inputs["labels"] = examples["label"]
return inputs
prepared_ds = dataset.with_transform(transform_examples)
# Split dataset
splits = prepared_ds["train"].train_test_split(test_size=0.2, seed=42)
train_ds, val_ds = splits["train"], splits["test"]
# ---- 3. Load model ----
model = AutoModelForImageClassification.from_pretrained(
checkpoint,
num_labels=num_labels,
ignore_mismatched_sizes=True, # handles final layer shape mismatch
)
# ---- 4. Metrics ----
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
acc = accuracy.compute(predictions=preds, references=labels)["accuracy"]
f1_score = f1.compute(predictions=preds, references=labels, average="macro")["f1"]
return {"accuracy": acc, "f1": f1_score}
# ---- 5. Training setup ----
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
eval_strategy="epoch",
save_strategy="epoch",
num_train_epochs=3,
learning_rate=5e-5,
logging_dir="./logs",
load_best_model_at_end=True,
remove_unused_columns=False,
push_to_hub=True,
hub_model_id="hugging-science/sample-breast-cancer-classification",
report_to="none",
)
# ---- 6. Trainer ----
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=image_processor,
compute_metrics=compute_metrics,
)
# ---- 7. Train ----
trainer.train()
# ---- 8. Save locally ----
model.save_pretrained("./finetuned-ultrasound-model")
image_processor.save_pretrained("./finetuned-ultrasound-model")
print("βœ… Training complete. Model saved to ./finetuned-ultrasound-model")