Spaces:
Runtime error
Runtime error
| # Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: | |
| # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| import copy | |
| import os | |
| from dataclasses import dataclass, field | |
| import random | |
| import json | |
| import logging | |
| import pathlib | |
| from typing import Dict, Optional, Sequence, List | |
| import torch | |
| import torch.distributed as dist | |
| from deepspeed import zero | |
| from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType | |
| import transformers | |
| from torch.utils.data import Dataset | |
| from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed | |
| from fastchat.train.train_flant5 import ( | |
| smart_tokenizer_and_embedding_resize, | |
| make_supervised_data_module, | |
| ) | |
| from fastchat.train.train_lora import get_peft_state_maybe_zero_3 | |
| from fastchat.model.model_adapter import get_conversation_template | |
| default_conversation = get_conversation_template("t5") | |
| # TODO: import and use code from ../data/dataset.py | |
| IGNORE_INDEX = -100 | |
| DEFAULT_PAD_TOKEN = "[PAD]" | |
| DEFAULT_EOS_TOKEN = "</s>" | |
| DEFAULT_BOS_TOKEN = "</s>" | |
| DEFAULT_UNK_TOKEN = "</s>" | |
| class LoraArguments: | |
| lora_r: int = 8 | |
| lora_alpha: int = 16 | |
| lora_dropout: float = 0.05 | |
| lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"]) | |
| lora_weight_path: str = "" | |
| lora_bias: str = "none" | |
| q_lora: bool = False | |
| class ModelArguments: | |
| model_name_or_path: Optional[str] = field(default="facebook/opt-125m") | |
| class DataArguments: | |
| data_path: str = field( | |
| default=None, metadata={"help": "Path to the training data."} | |
| ) | |
| lazy_preprocess: bool = False | |
| num_data: int = -1 | |
| preprocessed_path: str = field( | |
| default=None, metadata={"help": "Path to the preprocessed training data."} | |
| ) | |
| class TrainingArguments(transformers.TrainingArguments): | |
| cache_dir: Optional[str] = field(default=None) | |
| optim: str = field(default="adamw_torch") | |
| model_max_length: int = field( | |
| default=2048, | |
| metadata={ | |
| "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | |
| }, | |
| ) | |
| def safe_save_model_for_hf_trainer( | |
| trainer: transformers.Trainer, output_dir: str, state_dict: dict | |
| ): | |
| """Collects the state dict and dump to disk.""" | |
| if trainer.args.should_save: | |
| cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} | |
| del state_dict | |
| trainer._save(output_dir, state_dict=cpu_state_dict) # noqa | |
| def train(): | |
| parser = transformers.HfArgumentParser( | |
| (ModelArguments, DataArguments, TrainingArguments, LoraArguments) | |
| ) | |
| ( | |
| model_args, | |
| data_args, | |
| training_args, | |
| lora_args, | |
| ) = parser.parse_args_into_dataclasses() | |
| device_map = None | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| ddp = world_size != 1 | |
| if lora_args.q_lora: | |
| device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None | |
| if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): | |
| logging.warning( | |
| "FSDP and ZeRO3 are both currently incompatible with QLoRA." | |
| ) | |
| compute_dtype = ( | |
| torch.float16 | |
| if training_args.fp16 | |
| else (torch.bfloat16 if training_args.bf16 else torch.float32) | |
| ) | |
| model = transformers.AutoModelForSeq2SeqLM.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=training_args.cache_dir, | |
| device_map=device_map, | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| ) | |
| if lora_args.q_lora | |
| else None, | |
| ) | |
| lora_config = LoraConfig( | |
| r=lora_args.lora_r, | |
| lora_alpha=lora_args.lora_alpha, | |
| target_modules=lora_args.lora_target_modules, | |
| lora_dropout=lora_args.lora_dropout, | |
| bias=lora_args.lora_bias, | |
| task_type=TaskType.SEQ_2_SEQ_LM, | |
| ) | |
| if lora_args.q_lora: | |
| model = prepare_model_for_kbit_training( | |
| model, use_gradient_checkpointing=training_args.gradient_checkpointing | |
| ) | |
| if not ddp and torch.cuda.device_count() > 1: | |
| # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available | |
| model.is_parallelizable = True | |
| model.model_parallel = True | |
| model = get_peft_model(model, lora_config) | |
| if training_args.deepspeed is not None and training_args.local_rank == 0: | |
| model.print_trainable_parameters() | |
| if training_args.gradient_checkpointing: | |
| model.enable_input_require_grads() | |
| # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend | |
| # a space before special tokens. | |
| tokenizer = transformers.T5Tokenizer.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=training_args.cache_dir, | |
| model_max_length=training_args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| smart_tokenizer_and_embedding_resize( | |
| special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), | |
| other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], | |
| tokenizer=tokenizer, | |
| model=model, | |
| ) | |
| data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) | |
| trainer = Trainer( | |
| model=model, tokenizer=tokenizer, args=training_args, **data_module | |
| ) | |
| if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | |
| trainer.train(resume_from_checkpoint=True) | |
| else: | |
| trainer.train() | |
| trainer.save_state() | |
| # check if zero3 mode enabled | |
| if deepspeed.is_deepspeed_zero3_enabled(): | |
| # use deepspeed engine internal function to gather state dict | |
| # state_dict_zero3 contains whole parameters of base and lora adapters | |
| # we will not extract lora parameters since peft save_pretrained will do that | |
| # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 | |
| # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 | |
| state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() | |
| if training_args.local_rank == 0: | |
| state_dict = state_dict_zero3 | |
| else: | |
| # in other mode we use original code from fastchat team, to make sure our change is minimum | |
| state_dict = get_peft_state_maybe_zero_3( | |
| model.named_parameters(), lora_args.lora_bias | |
| ) | |
| if training_args.local_rank == 0: | |
| safe_save_model_for_hf_trainer( | |
| trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict | |
| ) | |
| if __name__ == "__main__": | |
| train() | |