|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
This module contains the implementation of the LoRA-FA optimizer. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from collections.abc import Iterable |
|
|
from typing import Callable |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from accelerate.utils.imports import is_bf16_available |
|
|
from torch import autocast |
|
|
from torch.optim import Optimizer |
|
|
|
|
|
from ..peft_model import PeftModel |
|
|
from ..utils.other import infer_device |
|
|
|
|
|
|
|
|
class LoraFAOptimizer(Optimizer): |
|
|
""" |
|
|
Implements the LoRA-FA optimizer designed specifically for training Low-Rank Adaptation (LoRA) parameters |
|
|
efficiently. Note that LoraFAOptimizer is based on adamw-hf in transformers, with only LoRA part modified. Without |
|
|
LoRA it will fall back to adamw-hf. |
|
|
|
|
|
Args: |
|
|
params (Iterable[nn.parameter.Parameter]): Parameters to optimize. |
|
|
lr (float, optional): Learning rate (default: 1e-3). |
|
|
betas (Tuple[float, float], optional): |
|
|
Coefficients for computing running averages of gradient and squared gradient (default: (0.9, 0.999)). |
|
|
eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-6). |
|
|
weight_decay (float, optional): Weight decay (L2 penalty) (default: 0.0). |
|
|
correct_bias (bool, optional): Whether to apply bias correction as in original Adam (default: True). |
|
|
|
|
|
Args in sub-function step: |
|
|
closure (Callable, optional): A closure that reevaluates the model and returns the loss. |
|
|
|
|
|
Reference: |
|
|
- LoRA-FA: https://huggingface.co/papers/2308.03303 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
params: Iterable[nn.parameter.Parameter], |
|
|
lr: float = 1e-3, |
|
|
betas: tuple[float, float] = (0.9, 0.999), |
|
|
eps: float = 1e-6, |
|
|
weight_decay: float = 0.0, |
|
|
correct_bias: bool = True, |
|
|
): |
|
|
if lr < 0.0: |
|
|
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") |
|
|
if not 0.0 <= betas[0] < 1.0: |
|
|
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") |
|
|
if not 0.0 <= betas[1] < 1.0: |
|
|
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") |
|
|
if not 0.0 <= eps: |
|
|
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") |
|
|
defaults = { |
|
|
"lr": lr, |
|
|
"betas": betas, |
|
|
"eps": eps, |
|
|
"weight_decay": weight_decay, |
|
|
"correct_bias": correct_bias, |
|
|
} |
|
|
super().__init__(params, defaults) |
|
|
|
|
|
@torch.no_grad() |
|
|
def step(self, closure: Callable = None): |
|
|
""" |
|
|
Performs a single optimization step. |
|
|
|
|
|
Arguments: |
|
|
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. |
|
|
""" |
|
|
loss = None |
|
|
if closure is not None: |
|
|
loss = closure() |
|
|
|
|
|
for group in self.param_groups: |
|
|
scaling_factor = group["scaling_factor"] |
|
|
param_list = [] |
|
|
name_list = [] |
|
|
for p, n in zip(group["params"], group["names"]): |
|
|
|
|
|
if "lora" not in n and p.grad is None: |
|
|
continue |
|
|
grad = p.grad |
|
|
|
|
|
if "lora" in n: |
|
|
param_list.append(p) |
|
|
name_list.append(n) |
|
|
if len(param_list) == 2: |
|
|
name = n[: n.find("lora")] + "lora" |
|
|
elif len(param_list) == 1: |
|
|
continue |
|
|
else: |
|
|
name = n |
|
|
|
|
|
|
|
|
|
|
|
state = self.state[name] |
|
|
|
|
|
if len(state) == 0: |
|
|
if len(param_list) == 2: |
|
|
state["step"] = 0 |
|
|
|
|
|
state["exp_avg_B"] = torch.zeros_like(param_list[1]) |
|
|
|
|
|
state["exp_avg_sq_B"] = torch.zeros_like(param_list[1]) |
|
|
else: |
|
|
state["step"] = 0 |
|
|
|
|
|
state["exp_avg"] = torch.zeros_like(p) |
|
|
|
|
|
state["exp_avg_sq"] = torch.zeros_like(p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(param_list) == 2: |
|
|
A = param_list[0] |
|
|
B = param_list[1] |
|
|
grad_B_orin = B.grad |
|
|
|
|
|
|
|
|
delta = 1e-8 |
|
|
|
|
|
|
|
|
AA_T = A @ A.T |
|
|
AA_T_inv = torch.linalg.pinv(AA_T + delta * torch.eye(A.shape[0]).to(A.device)) |
|
|
|
|
|
device_type = infer_device() |
|
|
|
|
|
if is_bf16_available(): |
|
|
with autocast(device_type=device_type, dtype=torch.bfloat16): |
|
|
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv) |
|
|
else: |
|
|
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv) |
|
|
|
|
|
if grad_B.dtype != B.grad.dtype: |
|
|
grad_B = grad_B.to(B.grad.dtype) |
|
|
|
|
|
exp_avg_B, exp_avg_sq_B = state["exp_avg_B"], state["exp_avg_sq_B"] |
|
|
beta1, beta2 = group["betas"] |
|
|
state["step"] += 1 |
|
|
exp_avg_B.mul_(beta1).add_(grad_B, alpha=(1.0 - beta1)) |
|
|
exp_avg_sq_B.mul_(beta2).addcmul_(grad_B, grad_B, value=1.0 - beta2) |
|
|
|
|
|
denom_B = exp_avg_sq_B.sqrt().add_(group["eps"]) |
|
|
step_size = group["lr"] |
|
|
if group["correct_bias"]: |
|
|
bias_correction1 = 1.0 - beta1 ** state["step"] |
|
|
bias_correction2 = 1.0 - beta2 ** state["step"] |
|
|
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
|
|
B.addcdiv_(exp_avg_B, denom_B, value=-step_size) |
|
|
if group["weight_decay"] > 0.0: |
|
|
B.add_(B, alpha=(-group["lr"] * group["weight_decay"])) |
|
|
param_list = [] |
|
|
name_list = [] |
|
|
|
|
|
|
|
|
else: |
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
|
|
beta1, beta2 = group["betas"] |
|
|
|
|
|
state["step"] += 1 |
|
|
|
|
|
|
|
|
|
|
|
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) |
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) |
|
|
denom = exp_avg_sq.sqrt().add_(group["eps"]) |
|
|
|
|
|
step_size = group["lr"] |
|
|
if group["correct_bias"]: |
|
|
bias_correction1 = 1.0 - beta1 ** state["step"] |
|
|
bias_correction2 = 1.0 - beta2 ** state["step"] |
|
|
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 |
|
|
|
|
|
p.addcdiv_(exp_avg, denom, value=-step_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if group["weight_decay"] > 0.0: |
|
|
p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
def create_lorafa_optimizer( |
|
|
model: PeftModel, r: int, lora_alpha: int, lr: float, weight_decay: float = 0.0, use_rslora: bool = False |
|
|
) -> Optimizer: |
|
|
""" |
|
|
Helper function to instantiate a lorafa optimizer specifically configured for a given model using the LoRA method. |
|
|
|
|
|
This function will: |
|
|
- Disable gradient updates for the "lora_A" parameters (these are typically frozen during LoRA training). |
|
|
- Compute the scaling factor based on provided `lora_alpha` and rank `r` for proper gradient projection. |
|
|
- Create and configure parameter groups for the optimizer including specified learning rate, weight decay, and |
|
|
additional optimizer options. |
|
|
|
|
|
For hyper-params, LoRA-FA uses the same hyper-params as AdamW, except for the LoRA hyper-params (r, lora_alpha, |
|
|
use_rslora). One can always use the same hyper-params such as lr and weight_decay, as AdamW in LoRA tuning. |
|
|
|
|
|
Args: |
|
|
model (PeftModel): The model containing LoRA-adapted parameters. |
|
|
r (int): Rank of the LoRA decomposition. |
|
|
lora_alpha (int): Scaling factor for LoRA parameterization. |
|
|
lr (float): Learning rate for optimizer updates. |
|
|
weight_decay (float): Weight decay for AdamW. |
|
|
use_rslora (bool): |
|
|
whether to use rslora. In rslora, the lora scaling factor becomes to lora_alpha / math.sqrt(r) instead of |
|
|
lora_alpha / r. |
|
|
|
|
|
Returns: |
|
|
Optimizer: Configured lorafa optimizer instance ready for training. |
|
|
""" |
|
|
for name, param in model.named_parameters(): |
|
|
if "lora_A" in name: |
|
|
param.requires_grad_(False) |
|
|
lora_scaling = lora_alpha / math.sqrt(r) if use_rslora else lora_alpha / r |
|
|
param_groups = [ |
|
|
{ |
|
|
"params": model.parameters(), |
|
|
"lr": lr, |
|
|
"names": [name for name, _ in model.named_parameters()], |
|
|
"scaling_factor": lora_scaling, |
|
|
"betas": (0.9, 0.999), |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
] |
|
|
return LoraFAOptimizer(param_groups) |
|
|
|