SHIROI-07 commited on
Commit
32e454f
·
verified ·
1 Parent(s): f1c845c

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +42 -0
train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
2
+
3
+ model_id = "mistralai/Mistral-7B-Instruct" # if you have resources
4
+
5
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
6
+ model = AutoModelForCausalLM.from_pretrained(model_id)
7
+
8
+ # Tokenize your dataset
9
+ def load_dataset(file_path, tokenizer, block_size=512):
10
+ return TextDataset(
11
+ tokenizer=tokenizer,
12
+ file_path=file_path,
13
+ block_size=block_size
14
+ )
15
+
16
+ train_dataset = load_dataset("coaching_data.txt", tokenizer)
17
+
18
+ data_collator = DataCollatorForLanguageModeling(
19
+ tokenizer=tokenizer, mlm=False,
20
+ )
21
+
22
+ training_args = TrainingArguments(
23
+ output_dir="./skilllink-coach",
24
+ overwrite_output_dir=True,
25
+ num_train_epochs=3,
26
+ per_device_train_batch_size=2,
27
+ save_steps=100,
28
+ save_total_limit=1,
29
+ logging_dir="./logs",
30
+ fp16=True, # If using GPU
31
+ )
32
+
33
+ trainer = Trainer(
34
+ model=model,
35
+ args=training_args,
36
+ data_collator=data_collator,
37
+ train_dataset=train_dataset,
38
+ )
39
+
40
+ trainer.train()
41
+ trainer.save_model("./skilllink-coach")
42
+ tokenizer.save_pretrained("./skilllink-coach")