cho-sr commited on
Commit
cc9aa53
·
verified ·
1 Parent(s): 9317120

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +181 -0
train.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ dataset_name = "ayoubkirouane/llava-instruct-small"
4
+
5
+ # Load Dataset
6
+ dataset = load_dataset(dataset_name)
7
+
8
+ # import os
9
+ # import zipfile
10
+ # import io
11
+ # # from datasets import DatasetDict
12
+ # from huggingface_hub import hf_hub_download, list_repo_files
13
+ # from PIL import Image
14
+
15
+ # dataset_train_split = "test"
16
+
17
+ # def format_data(samples: dict[str, any]) -> dict[str, list]:
18
+ # formatted_samples = {"messages": []}
19
+ # for cont in range(len(samples["question"])):
20
+ # images = []
21
+ # for img_path in samples["input_image_path"][cont]:
22
+ # try:
23
+ # with open(img_path, "rb") as f:
24
+ # img_bytes = f.read()
25
+ # image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
26
+ # images.append({"type": "image", "image": image})
27
+ # except Exception as e:
28
+ # print(f"Error processing image {img_path}: {e}")
29
+ # continue
30
+
31
+ # formatted_samples["messages"].append(
32
+ # [
33
+ # {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
34
+ # {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
35
+ # {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
36
+ # ]
37
+ # )
38
+ # return formatted_samples
39
+
40
+ # For multi-image example
41
+ # def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
42
+ # all_files = list_repo_files(dataset_name, repo_type="dataset")
43
+ # zip_files = [f for f in all_files if f.endswith(".zip")]
44
+
45
+ # for zip_filename in zip_files:
46
+ # zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
47
+ # extract_folder = zip_filename.replace(".zip", "")
48
+ # os.makedirs(extract_folder, exist_ok=True)
49
+
50
+ # with zipfile.ZipFile(zip_path, "r") as zip_ref:
51
+ # zip_ref.extractall(extract_folder)
52
+
53
+ # dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
54
+ # return dataset
55
+
56
+ # dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
57
+ import torch
58
+ from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
59
+
60
+ model_id = "HuggingFaceTB/SmolVLM-256M-Instruct"
61
+
62
+ # BitsAndBytesConfig int-4 config
63
+ bnb_config = BitsAndBytesConfig(
64
+ load_in_4bit=True,
65
+ bnb_4bit_use_double_quant=True,
66
+ bnb_4bit_quant_type="nf4",
67
+ bnb_4bit_compute_dtype=torch.bfloat16,
68
+ bnb_4bit_quant_storage=torch.bfloat16,
69
+ )
70
+
71
+ # Load model and tokenizer
72
+ model = AutoModelForImageTextToText.from_pretrained(
73
+ model_id,
74
+ device_map="auto",
75
+ torch_dtype=torch.bfloat16,
76
+ attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
77
+ quantization_config=bnb_config
78
+ )
79
+ processor = AutoProcessor.from_pretrained(model_id,use_fast=True)
80
+ processor.tokenizer.padding_side = "right"
81
+
82
+ from peft import LoraConfig, get_peft_model
83
+
84
+ # Configure QLoRA
85
+ peft_config = LoraConfig(
86
+ lora_alpha=16,
87
+ lora_dropout=0.05,
88
+ r=16,
89
+ bias="none",
90
+ target_modules="all-linear",
91
+ task_type="CAUSAL_LM",
92
+ modules_to_save=[
93
+ "lm_head",
94
+ "embed_tokens",
95
+ ],
96
+ )
97
+
98
+ from trl import SFTConfig
99
+
100
+ training_args = SFTConfig(
101
+ output_dir="smolvlm-trl-sft-test", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
102
+ num_train_epochs=1, # Set the number of epochs to train the model.
103
+ per_device_train_batch_size=2, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
104
+ gradient_accumulation_steps=32, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
105
+ gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
106
+ optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
107
+ save_strategy="epoch", # Save checkpoints at the end of each epoch.
108
+ learning_rate=2e-05, # Learning rate for training.
109
+ bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
110
+ push_to_hub=False, # Automatically push the fine-tuned model to Hugging Face Hub after training.
111
+ report_to="tensorboard", # Automatically report metrics to tensorboard.
112
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
113
+ dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
114
+ remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
115
+ )
116
+ from PIL import Image
117
+
118
+ # For multi-image cases
119
+ def process_vision_info(messages: list[dict]) -> list[Image.Image]:
120
+ image_inputs = []
121
+ for msg in messages:
122
+ content = msg.get("content", [])
123
+ if not isinstance(content, list):
124
+ content = [content]
125
+
126
+ for element in content:
127
+ if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
128
+ if "image" in element:
129
+ image = element["image"]
130
+ else:
131
+ image = element
132
+ if image is not None:
133
+ image = Image.open(io.BytesIO(image["bytes"]))
134
+ image_inputs.append(image.convert("RGB"))
135
+ return image_inputs
136
+
137
+ def collate_fn(examples):
138
+ texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
139
+ if "images" in examples[0]: # single-image
140
+ images = [
141
+ [img.convert("RGB") for img in example["images"]]
142
+ for example in examples
143
+ ]
144
+ else: # multi-image
145
+ images = [process_vision_info(example["messages"]) for example in examples]
146
+
147
+ # Tokenize the texts and process the images
148
+ batch = processor(
149
+ images=images, text=texts, return_tensors="pt", padding=True
150
+ ) # Encode texts and images into tensors
151
+
152
+ # The labels are the input_ids, and we mask the padding tokens in the loss computation
153
+ labels = batch["input_ids"].clone() # Clone input IDs for labels
154
+ # Mask image tokens
155
+ image_token_id = getattr(model.config, "image_token_id", None)
156
+ if image_token_id is None:
157
+ image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
158
+ # Mask tokens for not being used in the loss computation
159
+ labels[labels == processor.tokenizer.pad_token_id] = -100
160
+ labels[labels == image_token_id] = -100
161
+ # labels[labels == 262144] = -100
162
+
163
+ batch["labels"] = labels
164
+ return batch # Return the prepared batch
165
+
166
+ # Training
167
+ from trl import SFTTrainer
168
+
169
+ trainer = SFTTrainer(
170
+ model=model,
171
+ args=training_args,
172
+ data_collator=collate_fn,
173
+ train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
174
+ processing_class=processor,
175
+ peft_config=peft_config,
176
+ )
177
+
178
+ trainer.train()
179
+
180
+ # Save the final model
181
+ trainer.save_model()