Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import transformers | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| from functools import partial | |
| from peft import LoraConfig, get_peft_model | |
| from peft import prepare_model_for_kbit_training | |
| from datasets import load_dataset | |
| from accelerate import FullyShardedDataParallelPlugin, Accelerator | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| def formatting_func_QA(example): | |
| text = f"### Question: Given an image prompt {example['input']}\n give me random Edit Action and the output prompt \n ### Answer: Here is the edit action {example['edit']}, and here is the output {example['output']}" | |
| return text | |
| def formatting_func_Edit(example, is_train=True): | |
| text = f"### Categorizes image editing actions, outputting classifications in the format 'Edit Class: A,B,C'. In this format, 'A' represents whether the edit is 'Global' or 'Local', and 'B' denotes the specific type of manipulation, such as 'Filter', 'Stylization', 'SceneChange', etc. 'C' denotes a specified 'B' such as 'FujiFilter', 'Part' etc. This structured approach provides clear and concise information, facilitating easy understanding of the edit class. The GPT remains committed to a formal, user-friendly communication style, ensuring the classifications are accessible and precise, without delving into technical complexities.\ | |
| Question: Given the Edit Action {example['edit']}, what is its edit type?\n" | |
| if is_train: | |
| text = text + f"### Answer: Edit Class: {example['class']}" | |
| return text | |
| def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset): | |
| lengths = [len(x['input_ids']) for x in tokenized_train_dataset] | |
| lengths += [len(x['input_ids']) for x in tokenized_val_dataset] | |
| print(len(lengths)) | |
| # Plotting the histogram | |
| plt.figure(figsize=(10, 6)) | |
| plt.hist(lengths, bins=10, alpha=0.7, color='blue') | |
| plt.xlabel('Length of input_ids') | |
| plt.ylabel('Frequency') | |
| plt.title('Distribution of Lengths of input_ids') | |
| # Saving the figure to a file | |
| plt.savefig('./experiments/figure.png') # Spe | |
| def generate_and_tokenize_prompt(prompt, formatting=None): | |
| return tokenizer(formatting(prompt)) | |
| def generate_and_tokenize_prompt2(prompt, max_length=512, formatting=None): | |
| result = tokenizer( | |
| formatting(prompt), | |
| truncation=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| ) | |
| result["labels"] = result["input_ids"].copy() | |
| return result | |
| def print_trainable_parameters(model): | |
| """ | |
| Prints the number of trainable parameters in the model. | |
| """ | |
| trainable_params = 0 | |
| all_param = 0 | |
| for _, param in model.named_parameters(): | |
| all_param += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| print( | |
| f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
| ) | |
| def train(): | |
| generate_and_tokenize = partial(generate_and_tokenize_prompt2, | |
| max_length=128, | |
| formatting=formatting_func_Edit) | |
| # configs here latter change | |
| model_root = "/mnt/bn/wp-maliva-bytenas/mlx/users/peng.wang/playground/model/checkpoint_bk/" | |
| output_root = "/mlx/users/peng.wang/playground/data/chat_edit/models/llm" | |
| output_root = "/opt/tiger/llm" | |
| os.makedirs(output_root, exist_ok=True) | |
| ######### Tune model with Mixtral MoE ######### | |
| base_model_id = f"{model_root}/Mixtral-8x7B-v0.1" | |
| base_model_id = f"{model_root}/Mixtral-8x7B-Instruct-v0.1" | |
| base_model_name = "mixtral-8x7b" | |
| # ######### Tune model with Mixtral Instruct 7B ######### | |
| # base_model_id = f"{model_root}/Mistral-7B-Instruct-v0.2" | |
| # base_model_name = "mixtral-7b" | |
| ######### Instructions ######### | |
| train_json = "./data/chat_edit/assets/test200/edit_instructions_v0.jsonl" | |
| val_json = train_json | |
| project = "edit-finetune" | |
| run_name = base_model_name + "-" + project | |
| output_dir = f"{output_root}/{run_name}" | |
| fsdp_plugin = FullyShardedDataParallelPlugin( | |
| state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False), | |
| optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False), | |
| ) | |
| accelerator = Accelerator(fsdp_plugin=fsdp_plugin) | |
| train_dataset = load_dataset('json', data_files=train_json, split='train') | |
| eval_dataset = load_dataset('json', data_files=val_json, split='train') | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_id, | |
| padding_side="left", | |
| add_eos_token=True, | |
| add_bos_token=True, | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenized_train_dataset = train_dataset.map(generate_and_tokenize) | |
| tokenized_val_dataset = eval_dataset.map(generate_and_tokenize) | |
| print(tokenized_train_dataset[1]['input_ids']) | |
| plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset) | |
| # load model and do finetune | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, quantization_config=bnb_config, device_map="auto") | |
| model.gradient_checkpointing_enable() | |
| model = prepare_model_for_kbit_training(model) | |
| print(model) | |
| config = LoraConfig( | |
| r=32, | |
| lora_alpha=64, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "w1", | |
| "w2", | |
| "w3", | |
| "lm_head", | |
| ], | |
| bias="none", | |
| lora_dropout=0.01, # Conventional | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, config) | |
| print_trainable_parameters(model) | |
| print(model) | |
| ## RUN training ## | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_id, | |
| padding_side="left", | |
| add_eos_token=True, | |
| add_bos_token=True, | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if torch.cuda.device_count() > 1: # If more than 1 GPU | |
| model.is_parallelizable = True | |
| model.model_parallel = True | |
| trainer = transformers.Trainer( | |
| model=model, | |
| train_dataset=tokenized_train_dataset, | |
| eval_dataset=tokenized_val_dataset, | |
| args=transformers.TrainingArguments( | |
| output_dir=output_dir, | |
| warmup_steps=1, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=1, | |
| gradient_checkpointing=True, | |
| max_steps=100, | |
| learning_rate=2.5e-5, # Want a small lr for finetuning | |
| fp16=True, | |
| optim="paged_adamw_8bit", | |
| logging_steps=25, # When to start reporting loss | |
| logging_dir="./experiments/logs", # Directory for storing logs | |
| save_strategy="steps", # Save the model checkpoint every logging step | |
| save_steps=100, # Save checkpoints every 50 steps | |
| evaluation_strategy="steps", # Evaluate the model every logging step | |
| eval_steps=25, # Evaluate and save checkpoints every 50 steps | |
| do_eval=True, # Perform evaluation at the end of training | |
| report_to="wandb", # Comment this out if you don't want to use weights & baises | |
| run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" # Name of the W&B run (optional) | |
| ), | |
| data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), | |
| ) | |
| model.config.use_cache = False # silence the warnings. Please re-enable for inference! | |
| trainer.train() | |
| if __name__ == '__main__': | |
| train() |