Spaces:
Runtime error
Runtime error
| """ | |
| Borrowed from modded-nanogpt. By Keller, @vagrawal, et al. | |
| Not a general optimizer! But works for our specific use. | |
| """ | |
| import torch | |
| import torch.distributed as dist | |
| from torch import Tensor | |
| class DistAdamW(torch.optim.Optimizer): | |
| """ | |
| Distributed AdamW optimizer. | |
| In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction | |
| """ | |
| def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): | |
| defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) | |
| super().__init__(param_groups, defaults) | |
| def step(self): | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| reduce_scatter_futures: list[torch.Future] = [] | |
| all_reduce_futures: list[torch.Future] = [] | |
| grad_slices = [] | |
| for group in self.param_groups: | |
| params: list[Tensor] = group["params"] | |
| grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly | |
| for base_i in range(len(params)): | |
| grad = params[base_i].grad | |
| rank_size = grad.shape[0] // world_size | |
| grad_slice = torch.empty_like(grad[:rank_size]) | |
| reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) | |
| grad_slices.append(grad_slice) | |
| idx = 0 | |
| for group in self.param_groups: | |
| beta1, beta2 = group['betas'] | |
| eps = group['eps'] | |
| wd = group['weight_decay'] | |
| params = group['params'] | |
| for base in range(len(params)): | |
| reduce_scatter_futures[idx].wait() | |
| p = params[base] | |
| rank_size = p.shape[0] // world_size | |
| p_slice = p[rank * rank_size:(rank + 1) * rank_size] | |
| lr = group['lr'] * getattr(p, "lr_mul", 1.0) | |
| state = self.state[p] | |
| g_slice = grad_slices[idx] | |
| # State init | |
| if not state: | |
| state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) | |
| state['exp_avg'] = torch.zeros_like(p_slice) | |
| state['exp_avg_sq'] = torch.zeros_like(p_slice) | |
| exp_avg = state['exp_avg'] | |
| exp_avg_sq = state['exp_avg_sq'] | |
| state['step'] += 1 | |
| t = state['step'] | |
| # weight decay | |
| if wd != 0: | |
| eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) | |
| p_slice.mul_(1 - eff_weight_decay) | |
| # update running averages | |
| exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) | |
| exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) | |
| # bias corrections | |
| bias1 = 1 - beta1 ** t | |
| bias2 = 1 - beta2 ** t | |
| # compute step | |
| denom = exp_avg_sq.sqrt().add_(eps) | |
| step_size = lr * (torch.sqrt(bias2) / bias1) | |
| update = exp_avg.div(denom).mul_(step_size) | |
| p_slice.add_(other=update, alpha=-1.0) | |
| idx += 1 | |
| all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) | |
| torch.futures.collect_all(all_reduce_futures).wait() | |