Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import Dict, Optional, Sequence | |
| import logging | |
| import os, sys | |
| import copy | |
| import torch | |
| import transformers | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| from torch.utils.data import Dataset | |
| from transformers import Trainer | |
| sys.path.append(os.path.dirname(__file__)) | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
| from utils.special_tok_llama2 import ( | |
| B_CODE, | |
| E_CODE, | |
| B_RESULT, | |
| E_RESULT, | |
| B_INST, | |
| E_INST, | |
| B_SYS, | |
| E_SYS, | |
| DEFAULT_PAD_TOKEN, | |
| DEFAULT_BOS_TOKEN, | |
| DEFAULT_EOS_TOKEN, | |
| DEFAULT_UNK_TOKEN, | |
| IGNORE_INDEX, | |
| ) | |
| from conversation_template import json_to_code_result_tok_temp | |
| class ModelArguments: | |
| model_name_or_path: Optional[str] = field(default="./ckpt/llama-2-13b-chat") | |
| peft: bool = field(default=False) | |
| class DataArguments: | |
| data_path: str = field( | |
| default=None, metadata={"help": "Path to the training data."} | |
| ) | |
| class TrainingArguments(transformers.TrainingArguments): | |
| cache_dir: Optional[str] = field(default=None) | |
| optim: str = field(default="adamw_torch") | |
| model_max_length: int = field( | |
| default=4096, | |
| metadata={ | |
| "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | |
| }, | |
| ) | |
| def create_peft_config(model): | |
| from peft import ( | |
| get_peft_model, | |
| LoraConfig, | |
| TaskType, | |
| prepare_model_for_int8_training, | |
| ) | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "v_proj"], | |
| ) | |
| # prepare int-8 model for training | |
| model = prepare_model_for_int8_training(model) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| print(f"Using Peft") | |
| return model, peft_config | |
| def _tokenize_fn( | |
| strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer | |
| ) -> Dict: | |
| """Tokenize a list of strings.""" | |
| tokenized_list = [ | |
| tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ) | |
| for text in strings | |
| ] | |
| input_ids = [tokenized.input_ids[0] for tokenized in tokenized_list] | |
| input_ids_lens = [ | |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() | |
| for tokenized in tokenized_list | |
| ] | |
| return dict( | |
| input_ids=input_ids, | |
| input_ids_lens=input_ids_lens, | |
| ) | |
| def find_all_sublist_end(main_list, sublist): | |
| """Find all the ending indices of a sublist in a main list.""" | |
| sublist_len = len(sublist) | |
| main_list = main_list.tolist() | |
| indices = [] | |
| for index in (i for i, e in enumerate(main_list) if e == sublist[0]): | |
| if main_list[index : index + sublist_len] == sublist: | |
| indices.append(index + sublist_len) | |
| return indices | |
| def find_all_sublist_start(main_list, sublist): | |
| """Find all the starting indices of a sublist in a main list.""" | |
| sublist_len = len(sublist) | |
| main_list = main_list.tolist() | |
| indices = [] | |
| for index in (i for i, e in enumerate(main_list) if e == sublist[0]): | |
| if main_list[index : index + sublist_len] == sublist: | |
| indices.append(index) | |
| return indices | |
| def preprocess( | |
| trajs: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| ) -> Dict: | |
| INST_START_INDEX = tokenizer.encode(f"{B_INST}")[-1] | |
| INST_END_INDEX = tokenizer.encode(f"{E_INST}")[-1] | |
| RESULT_START_INDEX = tokenizer.encode(f"{B_RESULT}")[-1] | |
| RESULT_END_INDEX = tokenizer.encode(f"{E_RESULT}")[-1] | |
| """Preprocess the data by tokenizing.""" | |
| examples_tokenized = _tokenize_fn(trajs, tokenizer) | |
| input_ids_lens = examples_tokenized["input_ids_lens"] | |
| input_ids = examples_tokenized["input_ids"] # [torch.tensor , torch.tensor , ...] | |
| labels = copy.deepcopy(input_ids) | |
| # IGNORE INDEX SET | |
| for i, label in enumerate(labels): | |
| user_start_inds = find_all_sublist_start(label, [INST_START_INDEX]) | |
| assistant_start_inds = find_all_sublist_end(label, [INST_END_INDEX]) | |
| result_start_inds = find_all_sublist_start(label, [RESULT_START_INDEX]) | |
| result_end_inds = find_all_sublist_end(label, [RESULT_END_INDEX]) | |
| # for debug | |
| # for len_i, ind in enumerate(label): | |
| # print(f'{len_i}|{ind} -> "{tokenizer.decode(ind)}"') | |
| assert len(user_start_inds) == len( | |
| assistant_start_inds | |
| ), f"User and Assistant pair should be equal :: \n\tUser [{user_start_inds}]/\n\tAssistant [{assistant_start_inds}]\n\n Text : \n{trajs[i]}" | |
| assert len(result_start_inds) == len( | |
| result_end_inds | |
| ), f"Start and End indices pairs do not match.: : \nText : \n{trajs[i]}" | |
| for user_start_ind, assistant_start_ind in zip( | |
| user_start_inds, assistant_start_inds | |
| ): | |
| label[user_start_ind + 1 : assistant_start_ind - 1] = IGNORE_INDEX | |
| for start, end in zip(result_start_inds, result_end_inds): | |
| label[start + 1 : end - 1] = IGNORE_INDEX | |
| # cut max length | |
| input_ids = [i[:1500] for i in input_ids] | |
| labels = [i[:1500] for i in labels] | |
| return dict(input_ids=input_ids, labels=labels) | |
| class SupervisedDataset(Dataset): | |
| """Dataset for supervised fine-tuning.""" | |
| def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): | |
| super(SupervisedDataset, self).__init__() | |
| logging.warning(f"Loading data from data path : {data_path}") | |
| all_json = os.listdir(data_path) | |
| trajs = list() | |
| for json_file_name in all_json: | |
| traj = json_to_code_result_tok_temp(json_file_name=json_file_name) | |
| trajs.append(traj) | |
| logging.warning("Tokenizing inputs... This may take some time...") | |
| data_dict = preprocess(trajs, tokenizer) | |
| self.input_ids = data_dict["input_ids"] | |
| self.labels = data_dict["labels"] | |
| def __len__(self): | |
| return len(self.input_ids) | |
| def __getitem__(self, i) -> Dict[str, torch.Tensor]: | |
| return dict(input_ids=self.input_ids[i], labels=self.labels[i]) | |
| class DataCollatorForSupervisedDataset(object): | |
| """Collate examples for supervised fine-tuning.""" | |
| tokenizer: transformers.PreTrainedTokenizer | |
| def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: | |
| input_ids, labels = tuple( | |
| [instance[key] for instance in instances] for key in ("input_ids", "labels") | |
| ) | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id | |
| ) | |
| labels = torch.nn.utils.rnn.pad_sequence( | |
| labels, batch_first=True, padding_value=IGNORE_INDEX | |
| ) | |
| return dict( | |
| input_ids=input_ids, | |
| labels=labels, | |
| attention_mask=input_ids.ne(self.tokenizer.pad_token_id), | |
| ) | |
| def make_supervised_data_module( | |
| tokenizer: transformers.PreTrainedTokenizer, data_args | |
| ) -> Dict: | |
| """Make dataset and collator for supervised fine-tuning.""" | |
| train_dataset = SupervisedDataset( | |
| tokenizer=tokenizer, data_path=data_args.data_path | |
| ) | |
| data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) | |
| return dict( | |
| train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator | |
| ) | |
| def build_model_from_hf_path( | |
| hf_model_path: str = "./ckpt/llama-2-13b-chat", peft: bool = False | |
| ): | |
| # build tokenizer | |
| tokenizer = LlamaTokenizer.from_pretrained( | |
| hf_model_path, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| special_tokens_dict = dict() | |
| if tokenizer.pad_token is None: | |
| special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN # 32000 | |
| if tokenizer.eos_token is None: | |
| special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN # 2 | |
| if tokenizer.bos_token is None: | |
| special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN # 1 | |
| if tokenizer.unk_token is None: | |
| special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN | |
| tokenizer.add_special_tokens(special_tokens_dict) | |
| tokenizer.add_tokens( | |
| [ | |
| B_CODE, # 32001 | |
| E_CODE, # 32002 | |
| B_RESULT, # 32003 | |
| E_RESULT, # 32004 | |
| B_INST, | |
| E_INST, | |
| B_SYS, | |
| E_SYS, # 32008 | |
| ], | |
| special_tokens=True, | |
| ) | |
| # build model | |
| if peft: | |
| model = LlamaForCausalLM.from_pretrained( | |
| hf_model_path, | |
| load_in_8bit=True, | |
| device_map="auto", | |
| ignore_mismatched_sizes=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| else: | |
| # for llama | |
| # model = LlamaForCausalLM.from_pretrained( | |
| # hf_model_path, ignore_mismatched_sizes=True | |
| # ) | |
| # for codellama | |
| from codellama_wrapper import CodeLlamaForCausalLM | |
| model = CodeLlamaForCausalLM.from_pretrained(hf_model_path) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| return {"tokenizer": tokenizer, "model": model} | |
| def train(): | |
| parser = transformers.HfArgumentParser( | |
| (ModelArguments, DataArguments, TrainingArguments) | |
| ) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| model_dict = build_model_from_hf_path( | |
| hf_model_path=model_args.model_name_or_path, peft=model_args.peft | |
| ) | |
| model, tokenizer = model_dict["model"], model_dict["tokenizer"] | |
| # peft setting | |
| model.train() | |
| if model_args.peft: | |
| model, lora_config = create_peft_config(model) | |
| # make dataset | |
| data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) | |
| trainer = Trainer( | |
| model=model, tokenizer=tokenizer, args=training_args, **data_module | |
| ) | |
| # train | |
| trainer.train() | |
| trainer.save_state() | |
| trainer.save_model(output_dir=training_args.output_dir) | |
| if __name__ == "__main__": | |
| train() | |