File size: 3,666 Bytes
c774022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging

from datasets import load_dataset

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
import logging

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "google/siglip-base-patch16-512",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="Google SigLIP (512x512 resolution) model trained on COCO Captions",
    ),
    tokenizer_kwargs={"do_convert_rgb": True}
)

# 3. Load a dataset to finetune on
dataset = load_dataset("jxie/coco_captions")
train_dataset = dataset["train"].select(range(10_000))
eval_dataset = dataset["validation"].select(range(1_000))
test_dataset = dataset["test"].select(range(1_000))

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
run_name = "google-siglip-base-coco"
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=0.1,
    save_strategy="steps",
    save_steps=0.1,
    save_total_limit=2,
    logging_steps=0.01,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
eval_queries = {qid: sample["caption"] for qid, sample in enumerate(eval_dataset)}
eval_corpus = {sample["cocoid"]: sample["image"] for sample in eval_dataset}
eval_relevant_docs = {qid: [sample["cocoid"]] for qid, sample in enumerate(eval_dataset)}
eval_evaluator = InformationRetrievalEvaluator(
    queries=eval_queries,
    corpus=eval_corpus,
    relevant_docs=eval_relevant_docs,
    name="coco-eval",
)
eval_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset.select_columns(["image", "caption"]),
    eval_dataset=eval_dataset.select_columns(["image", "caption"]),
    loss=loss,
    evaluator=eval_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set
test_queries = {qid: sample["caption"] for qid, sample in enumerate(test_dataset)}
test_corpus = {sample["cocoid"]: sample["image"] for sample in test_dataset}
test_relevant_docs = {qid: [sample["cocoid"]] for qid, sample in enumerate(test_dataset)}
test_evaluator = InformationRetrievalEvaluator(
    queries=test_queries,
    corpus=test_corpus,
    relevant_docs=test_relevant_docs,
    name="coco-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name, private=True)