|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from peft.import_utils import is_bnb_available |
|
|
from peft.optimizers import create_loraplus_optimizer |
|
|
|
|
|
from .testing_utils import require_bitsandbytes, torch_device |
|
|
|
|
|
|
|
|
if is_bnb_available(): |
|
|
import bitsandbytes as bnb |
|
|
|
|
|
|
|
|
class SimpleNet(nn.Module): |
|
|
def __init__(self, bias=True): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(100, 20) |
|
|
self.layer_norm = nn.LayerNorm(20) |
|
|
self.lin0 = nn.Linear(20, 20, bias=bias) |
|
|
self.relu = nn.ReLU() |
|
|
self.lin1 = nn.Linear(20, 16, bias=bias) |
|
|
|
|
|
def forward(self, X): |
|
|
X = self.lin0(self.layer_norm(self.embedding(X))) |
|
|
X = self.relu(X) |
|
|
X = self.lin1(X) |
|
|
return X |
|
|
|
|
|
|
|
|
@require_bitsandbytes |
|
|
def test_lora_plus_helper_sucess(): |
|
|
model = SimpleNet() |
|
|
optimizer_cls = bnb.optim.Adam8bit |
|
|
lr = 5e-5 |
|
|
optim_config = { |
|
|
"eps": 1e-6, |
|
|
"betas": (0.9, 0.999), |
|
|
"loraplus_weight_decay": 0.0, |
|
|
} |
|
|
loraplus_lr_ratio = 1.2 |
|
|
loraplus_lr_embedding = 1e-6 |
|
|
optim = create_loraplus_optimizer( |
|
|
model=model, |
|
|
optimizer_cls=optimizer_cls, |
|
|
lr=lr, |
|
|
loraplus_lr_ratio=loraplus_lr_ratio, |
|
|
loraplus_lr_embedding=loraplus_lr_embedding, |
|
|
**optim_config, |
|
|
) |
|
|
assert optim is not None |
|
|
assert len(optim.param_groups) == 4 |
|
|
assert optim.param_groups[0]["lr"] == lr |
|
|
assert optim.param_groups[1]["lr"] == loraplus_lr_embedding |
|
|
assert optim.param_groups[2]["lr"] == optim.param_groups[3]["lr"] == (lr * loraplus_lr_ratio) |
|
|
|
|
|
|
|
|
@require_bitsandbytes |
|
|
def test_lora_plus_optimizer_sucess(): |
|
|
""" |
|
|
Test if the optimizer is correctly created and step function runs without any exception |
|
|
""" |
|
|
optimizer_cls = bnb.optim.Adam8bit |
|
|
optim_config = { |
|
|
"eps": 1e-6, |
|
|
"betas": (0.9, 0.999), |
|
|
"loraplus_weight_decay": 0.0, |
|
|
} |
|
|
model: SimpleNet = SimpleNet().to(torch_device) |
|
|
optim = create_loraplus_optimizer( |
|
|
model=model, |
|
|
optimizer_cls=optimizer_cls, |
|
|
lr=5e-5, |
|
|
loraplus_lr_ratio=1.2, |
|
|
loraplus_lr_embedding=1e-6, |
|
|
**optim_config, |
|
|
) |
|
|
loss = torch.nn.CrossEntropyLoss() |
|
|
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters()) |
|
|
x = torch.randint(100, (2, 4, 10)).to(torch_device) |
|
|
output = model(x).permute(0, 3, 1, 2) |
|
|
label = torch.randint(16, (2, 4, 10)).to(torch_device) |
|
|
loss_value = loss(output, label) |
|
|
loss_value.backward() |
|
|
optim.step() |
|
|
|