|
|
""" |
|
|
train.py β Finetune a Hugging Face vision model (e.g., ViT) on breast ultrasound images |
|
|
""" |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
AutoModelForImageClassification, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
) |
|
|
import evaluate |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
dataset = load_dataset("gymprathap/Breast-Cancer-Ultrasound-Images-Dataset") |
|
|
|
|
|
|
|
|
labels = dataset["train"].features["label"].names |
|
|
num_labels = len(labels) |
|
|
print(f"Classes: {labels}") |
|
|
|
|
|
|
|
|
checkpoint = "google/vit-base-patch16-224-in21k" |
|
|
image_processor = AutoImageProcessor.from_pretrained(checkpoint) |
|
|
|
|
|
def transform_examples(examples): |
|
|
images = [img.convert("RGB") for img in examples["image"]] |
|
|
inputs = image_processor(images, return_tensors="pt") |
|
|
inputs["labels"] = examples["label"] |
|
|
return inputs |
|
|
|
|
|
prepared_ds = dataset.with_transform(transform_examples) |
|
|
|
|
|
|
|
|
splits = prepared_ds["train"].train_test_split(test_size=0.2, seed=42) |
|
|
train_ds, val_ds = splits["train"], splits["test"] |
|
|
|
|
|
|
|
|
model = AutoModelForImageClassification.from_pretrained( |
|
|
checkpoint, |
|
|
num_labels=num_labels, |
|
|
ignore_mismatched_sizes=True, |
|
|
) |
|
|
|
|
|
|
|
|
accuracy = evaluate.load("accuracy") |
|
|
f1 = evaluate.load("f1") |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
preds = np.argmax(logits, axis=-1) |
|
|
acc = accuracy.compute(predictions=preds, references=labels)["accuracy"] |
|
|
f1_score = f1.compute(predictions=preds, references=labels, average="macro")["f1"] |
|
|
return {"accuracy": acc, "f1": f1_score} |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=8, |
|
|
eval_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
num_train_epochs=3, |
|
|
learning_rate=5e-5, |
|
|
logging_dir="./logs", |
|
|
load_best_model_at_end=True, |
|
|
remove_unused_columns=False, |
|
|
push_to_hub=True, |
|
|
hub_model_id="hugging-science/sample-breast-cancer-classification", |
|
|
report_to="none", |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=val_ds, |
|
|
tokenizer=image_processor, |
|
|
compute_metrics=compute_metrics, |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model.save_pretrained("./finetuned-ultrasound-model") |
|
|
image_processor.save_pretrained("./finetuned-ultrasound-model") |
|
|
|
|
|
print("β
Training complete. Model saved to ./finetuned-ultrasound-model") |
|
|
|