Spaces:
Running
Running
| """ | |
| Finetune a base model to be a chat model. | |
| Run on one GPU e.g. for debugging: | |
| python -m scripts.chat_sft | |
| Or torchrun for training: | |
| torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft | |
| """ | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import wandb | |
| import torch | |
| import torch.distributed as dist | |
| from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb | |
| from nanochat.checkpoint_manager import load_model | |
| from nanochat.checkpoint_manager import save_checkpoint | |
| from nanochat.engine import Engine | |
| from scripts.chat_eval import run_chat_eval | |
| from tasks.common import TaskMixture | |
| from tasks.arc import ARC | |
| from tasks.gsm8k import GSM8K | |
| from tasks.smoltalk import SmolTalk | |
| # ----------------------------------------------------------------------------- | |
| # SFT Hyperparameters | |
| run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) | |
| # input model options | |
| source = "mid" # base|mid , which checkpoint to load the model from (base model or midtrained model) | |
| model_tag = None # model tag to load the model from (base model or midtrained model) | |
| step = None # step to load the model from (base model or midtrained model) | |
| # compute/precision | |
| dtype = "bfloat16" | |
| device_batch_size = 4 # max to avoid OOM | |
| # optimization | |
| num_epochs = 1 | |
| max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) | |
| target_examples_per_step = 32 | |
| unembedding_lr = 0.004 | |
| embedding_lr = 0.2 | |
| matrix_lr = 0.02 | |
| weight_decay = 0.0 | |
| init_lr_frac = 0.02 | |
| # evaluation and logging there of | |
| eval_every = 100 | |
| eval_steps = 100 | |
| eval_metrics_every = 200 | |
| # now allow CLI to override the settings via the configurator lol | |
| config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] | |
| exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file | |
| user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging | |
| # ----------------------------------------------------------------------------- | |
| # Compute init | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() | |
| master_process = ddp_rank == 0 | |
| dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 | |
| autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) | |
| # wandb logging init | |
| use_dummy_wandb = run == "dummy" or not master_process | |
| wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) | |
| # Load the model and tokenizer | |
| model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step) | |
| orig_model = model # original, uncompiled model | |
| # model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs | |
| engine = Engine(model, tokenizer) # will be used for inline model evaluation only | |
| # ----------------------------------------------------------------------------- | |
| # Task data mixture we'll train on | |
| train_ds = TaskMixture([ | |
| ARC(subset="ARC-Easy", split="train"), # 2.3K rows | |
| ARC(subset="ARC-Challenge", split="train"), # 1.1K rows | |
| GSM8K(subset="main", split="train"), # 8K rows | |
| SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk | |
| ]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows | |
| val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) | |
| # ----------------------------------------------------------------------------- | |
| # DataLoader | |
| def sft_data_generator(dataset, batch_size): | |
| pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss | |
| # prepares a list of tokenized conversations into a batch and yields | |
| def collate_and_yield(batch): | |
| nrows = len(batch) | |
| ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 | |
| inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) | |
| targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index | |
| for i, (ids, mask) in enumerate(batch): | |
| n = len(ids) | |
| ids_tensor = torch.tensor(ids, dtype=torch.long) | |
| inputs[i, :n-1] = ids_tensor[:-1] | |
| # recall -1 is the ignore index, so mask out targets where mask is 0 | |
| row_targets = ids_tensor[1:] | |
| # mask[1:] omits the mask for the BOS token, which is never a target atm so it's ok | |
| mask_tensor = torch.tensor(mask[1:], dtype=torch.long) | |
| row_targets[mask_tensor == 0] = -1 # mask out targets where mask is 0 | |
| targets[i, :n-1] = row_targets | |
| inputs = inputs.to(device) # move to device | |
| targets = targets.to(device) | |
| return inputs, targets | |
| # iterates over the dataset in epochs, tokenizes | |
| batch = [] | |
| while True: | |
| for i in range(ddp_rank, len(dataset), ddp_world_size): | |
| doc = dataset[i] | |
| ids, mask = tokenizer.render_conversation(doc) | |
| batch.append((ids, mask)) | |
| if len(batch) == batch_size: | |
| yield collate_and_yield(batch) | |
| batch = [] | |
| examples_per_step = device_batch_size * ddp_world_size | |
| print0(f"Target examples per step: {target_examples_per_step}") | |
| print0(f"Device batch size: {device_batch_size}") | |
| print0(f"Examples per step is device_batch_size * ddp_world_size: {examples_per_step}") | |
| assert target_examples_per_step % examples_per_step == 0, "Target examples per step must be divisible by examples per step" | |
| grad_accum_steps = target_examples_per_step // examples_per_step | |
| print0(f"=> Setting grad accum steps: {grad_accum_steps}") | |
| num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs | |
| if max_iterations >= 0 and num_iterations > max_iterations: | |
| print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}") | |
| num_iterations = max_iterations | |
| train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) | |
| build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) | |
| # ----------------------------------------------------------------------------- | |
| # Initialize the Optimizer | |
| optimizers = model.setup_optimizers( | |
| unembedding_lr=unembedding_lr, | |
| embedding_lr=embedding_lr, | |
| matrix_lr=matrix_lr, | |
| weight_decay=weight_decay, | |
| ) | |
| # Set the initial learning rate as a fraction of the base learning rate | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["lr"] = group["lr"] * init_lr_frac | |
| group["initial_lr"] = group["lr"] # save the initial learning so we can decay easily later | |
| # ----------------------------------------------------------------------------- | |
| # Training loop | |
| # Learning rate scheduler | |
| def get_lr_multiplier(it): | |
| lrm = 1.0 - it / num_iterations | |
| return lrm | |
| # Go! | |
| step = 0 | |
| train_iter = iter(train_loader) | |
| for step in range(num_iterations): | |
| last_step = step == num_iterations - 1 | |
| # evaluate the validation loss | |
| if last_step or step % eval_every == 0: | |
| model.eval() | |
| val_iter = iter(build_val_loader()) | |
| losses = [] | |
| for _ in range(eval_steps): | |
| val_inputs, val_targets = next(val_iter) | |
| with torch.no_grad(), autocast_ctx: | |
| loss = model(val_inputs, val_targets) | |
| losses.append(loss) | |
| val_loss = torch.stack(losses).mean() # average over eval_steps | |
| if ddp: | |
| dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # average over ranks | |
| val_loss = val_loss.item() | |
| print0(f"Step {step:05d} | Validation loss: {val_loss:.6f}") | |
| wandb_run.log({ | |
| "step": step, | |
| "val_loss": val_loss, | |
| }) | |
| model.train() | |
| # evlauate accuracy of the multiple choice tasks (which are quick to run) | |
| if last_step or (step > 0 and step % eval_metrics_every == 0): | |
| model.eval() | |
| metrics = {} | |
| with torch.no_grad(), autocast_ctx: | |
| # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size | |
| metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) | |
| metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) | |
| metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) | |
| print0(f"Step {step:05d} | {metrics_str}") | |
| wandb_run.log({ | |
| "step": step, | |
| **metrics, | |
| }) | |
| model.train() | |
| if last_step: | |
| break | |
| # evaluate the gradient | |
| num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen | |
| for micro_step in range(grad_accum_steps): | |
| train_inputs, train_targets = next(train_iter) | |
| with autocast_ctx: | |
| loss = model(train_inputs, train_targets) | |
| train_loss = loss.detach() # for logging | |
| loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here | |
| loss.backward() # accumulate the gradient | |
| num_tokens += (train_targets >= 0).sum() | |
| if ddp: | |
| dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks | |
| # learning rate scheduler | |
| lrm = get_lr_multiplier(step) | |
| for opt in optimizers: | |
| for group in opt.param_groups: | |
| group["lr"] = group["initial_lr"] * lrm | |
| # step the optimizers | |
| for opt in optimizers: | |
| opt.step() | |
| model.zero_grad(set_to_none=True) | |
| # logging | |
| train_loss_item = train_loss.item() | |
| num_tokens_item = num_tokens.item() | |
| print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}") | |
| wandb_run.log({ | |
| "step": step, | |
| "lrm": lrm, | |
| "train_loss": train_loss_item, | |
| "num_tokens": num_tokens_item, | |
| }) | |
| step += 1 | |
| # Save the model at the end of the run | |
| if master_process: | |
| base_dir = get_base_dir() | |
| depth = model.config.n_layer | |
| model_tag = f"d{depth}" # base the model tag on the depth of the base model | |
| checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) | |
| model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer | |
| save_checkpoint( | |
| checkpoint_dir, | |
| step, | |
| model.state_dict(), | |
| None, # note: we don't bother to save the optimizer state | |
| { | |
| "step": step, | |
| "val_loss": val_loss, | |
| **metrics, | |
| "model_config": model_config_kwargs, | |
| } | |
| ) | |
| print(f"✅ Saved model checkpoint to {checkpoint_dir}") | |
| # Log to report | |
| from nanochat.report import get_report | |
| get_report().log(section="Chat SFT", data=[ | |
| user_config, # CLI args | |
| { | |
| "Training rows": len(train_ds), | |
| "Number of iterations": num_iterations, | |
| "Training loss": train_loss_item, | |
| "Validation loss": val_loss, | |
| }, | |
| ]) | |
| # Cleanup | |
| wandb_run.finish() | |
| compute_cleanup() | |