|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from accelerate.utils.imports import is_bf16_available |
|
|
from safetensors import safe_open |
|
|
from torch import nn |
|
|
|
|
|
from peft import PeftModel, VBLoRAConfig, get_peft_model |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, bias=True): |
|
|
super().__init__() |
|
|
self.relu = nn.ReLU() |
|
|
self.lin0 = nn.Linear(10, 20, bias=bias) |
|
|
self.lin1 = nn.Linear(20, 20, bias=bias) |
|
|
self.lin2 = nn.Linear(20, 20, bias=bias) |
|
|
self.lin3 = nn.Linear(20, 2, bias=bias) |
|
|
self.sm = nn.LogSoftmax(dim=-1) |
|
|
|
|
|
def forward(self, X): |
|
|
X = self.lin0(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin1(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin2(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin3(X) |
|
|
X = self.sm(X) |
|
|
return X |
|
|
|
|
|
|
|
|
class TestVBLoRA: |
|
|
def get_mlp(self): |
|
|
model = MLP() |
|
|
return model |
|
|
|
|
|
def test_vblora_parameters(self): |
|
|
mlp = self.get_mlp() |
|
|
vector_length = 2 |
|
|
num_vectors = 10 |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1", "lin3"], vector_length=vector_length, num_vectors=num_vectors |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
|
|
|
vector_bank = mlp_vblora.vblora_vector_bank["default"] |
|
|
|
|
|
vblora_lin0_logits_B = mlp_vblora.lin0.vblora_logits_B["default"] |
|
|
assert vblora_lin0_logits_B.shape == (mlp.lin0.out_features // vector_length, config.r, num_vectors) |
|
|
|
|
|
vblora_lin1_logits_A = mlp_vblora.lin1.vblora_logits_A["default"] |
|
|
assert vblora_lin1_logits_A.shape == (config.r, mlp.lin1.in_features // vector_length, num_vectors) |
|
|
|
|
|
vblora_lin3_logits_A = mlp_vblora.lin3.vblora_logits_A["default"] |
|
|
assert vblora_lin3_logits_A.shape == (config.r, mlp.lin3.in_features // vector_length, num_vectors) |
|
|
|
|
|
assert vector_bank.shape == (num_vectors, vector_length) |
|
|
|
|
|
|
|
|
assert ( |
|
|
mlp_vblora.lin0.vblora_vector_bank["default"].data_ptr() |
|
|
== mlp_vblora.lin3.vblora_vector_bank["default"].data_ptr() |
|
|
) |
|
|
assert mlp_vblora.lin1.vblora_vector_bank["default"].data_ptr() == vector_bank.data_ptr() |
|
|
|
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
mlp_vblora(input) |
|
|
|
|
|
def test_save_with_topk_weights(self, tmp_path): |
|
|
torch.manual_seed(0) |
|
|
mlp = self.get_mlp() |
|
|
vector_length = 2 |
|
|
num_vectors = 10 |
|
|
topk = 2 |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin3"], |
|
|
topk=topk, |
|
|
vector_length=vector_length, |
|
|
num_vectors=num_vectors, |
|
|
save_only_topk_weights=True, |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
save_path = tmp_path / "vblora" |
|
|
mlp_vblora.save_pretrained(save_path) |
|
|
assert os.path.exists(save_path / "adapter_model.safetensors") |
|
|
|
|
|
adapter_model_dict = {} |
|
|
with safe_open(save_path / "adapter_model.safetensors", framework="pt") as f: |
|
|
for k in f.keys(): |
|
|
adapter_model_dict[k] = f.get_tensor(k) |
|
|
assert "base_model.model.lin0.vblora_logits_A_topk_indices" in adapter_model_dict |
|
|
assert "base_model.model.lin0.vblora_logits_A_topk_weights" in adapter_model_dict |
|
|
assert "base_model.model.lin3.vblora_logits_B_topk_indices" in adapter_model_dict |
|
|
assert "base_model.model.lin3.vblora_logits_B_topk_weights" in adapter_model_dict |
|
|
assert "base_model.model.lin0.vblora_logits_A" not in adapter_model_dict |
|
|
assert "base_model.model.lin3.vblora_logits_B" not in adapter_model_dict |
|
|
|
|
|
assert adapter_model_dict["base_model.model.lin0.vblora_logits_B_topk_indices"].shape == ( |
|
|
mlp.lin0.out_features // vector_length, |
|
|
config.r, |
|
|
topk, |
|
|
) |
|
|
assert adapter_model_dict["base_model.model.lin0.vblora_logits_B_topk_weights"].shape == ( |
|
|
mlp.lin0.out_features // vector_length, |
|
|
config.r, |
|
|
topk - 1, |
|
|
) |
|
|
assert adapter_model_dict["base_model.model.lin3.vblora_logits_A_topk_indices"].shape == ( |
|
|
config.r, |
|
|
mlp.lin3.in_features // vector_length, |
|
|
topk, |
|
|
) |
|
|
assert adapter_model_dict["base_model.model.lin3.vblora_logits_A_topk_weights"].shape == ( |
|
|
config.r, |
|
|
mlp.lin3.in_features // vector_length, |
|
|
topk - 1, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize("save_only_topk_weights", [True, False]) |
|
|
def test_save_load(self, save_only_topk_weights, tmp_path): |
|
|
torch.manual_seed(0) |
|
|
mlp = self.get_mlp() |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1", "lin3"], |
|
|
topk=2, |
|
|
vector_length=2, |
|
|
num_vectors=10, |
|
|
save_only_topk_weights=save_only_topk_weights, |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
save_path = tmp_path / "vblora" |
|
|
mlp_vblora.save_pretrained(save_path) |
|
|
assert os.path.exists(save_path / "adapter_config.json") |
|
|
|
|
|
del mlp |
|
|
torch.manual_seed(0) |
|
|
mlp = self.get_mlp() |
|
|
mlp_vblora_loaded = PeftModel.from_pretrained(mlp, save_path) |
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
output = mlp_vblora(input) |
|
|
output_loaded = mlp_vblora_loaded(input) |
|
|
assert torch.allclose(output, output_loaded, atol=1e-8, rtol=1e-5) |
|
|
|
|
|
def test_resume_training_model_with_topk_weights(self, tmp_path): |
|
|
torch.manual_seed(1) |
|
|
mlp = self.get_mlp() |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1", "lin3"], |
|
|
topk=2, |
|
|
vector_length=2, |
|
|
num_vectors=10, |
|
|
save_only_topk_weights=True, |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
save_path = tmp_path / "vblora" |
|
|
mlp_vblora.save_pretrained(save_path) |
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
mlp_vblora.train() |
|
|
|
|
|
mlp_vblora(input) |
|
|
|
|
|
del mlp |
|
|
torch.manual_seed(1) |
|
|
mlp = self.get_mlp() |
|
|
mlp_vblora_loaded = PeftModel.from_pretrained(mlp, save_path) |
|
|
mlp_vblora_loaded.train() |
|
|
msg = "Found infinity values in VB-LoRA logits. Ensure training was not resumed from a `save_only_topk_weights` model." |
|
|
with pytest.raises(RuntimeError, match=msg): |
|
|
mlp_vblora_loaded(input) |
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_vblora_dtypes(self, dtype): |
|
|
mlp = self.get_mlp() |
|
|
if dtype == torch.bfloat16: |
|
|
if not is_bf16_available(): |
|
|
pytest.skip("bfloat16 not supported on this system, skipping the test") |
|
|
|
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1", "lin3"], vector_length=2, num_vectors=10, save_only_topk_weights=False |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp.to(dtype), config) |
|
|
inputs = torch.randn(5, 10).to(dtype) |
|
|
output = mlp_vblora(inputs) |
|
|
assert output.dtype == dtype |
|
|
|
|
|
def test_vblora_nb_savable_params_only_topk_weights(self): |
|
|
mlp = self.get_mlp() |
|
|
vector_length = 2 |
|
|
num_vectors = 10 |
|
|
topk = 2 |
|
|
r = 4 |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1"], |
|
|
vector_length=vector_length, |
|
|
num_vectors=num_vectors, |
|
|
topk=topk, |
|
|
r=r, |
|
|
save_only_topk_weights=True, |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
|
|
|
mlp_vblora.lin3.requires_grad_(True) |
|
|
|
|
|
adapter_params, other_params = mlp_vblora.get_nb_savable_parameters() |
|
|
factor = 0.25 |
|
|
topk_indices_parameter = int( |
|
|
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features) |
|
|
/ vector_length |
|
|
* r |
|
|
* topk |
|
|
* factor |
|
|
) |
|
|
topk_weights_parameter = int( |
|
|
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features) |
|
|
/ vector_length |
|
|
* r |
|
|
* (topk - 1) |
|
|
) |
|
|
vector_bank_parameter = num_vectors * vector_length |
|
|
assert adapter_params == topk_indices_parameter + topk_weights_parameter + vector_bank_parameter |
|
|
assert other_params == (mlp.lin3.in_features + 1) * mlp.lin3.out_features |
|
|
|
|
|
def test_vblora_nb_savable_params_all_logits(self): |
|
|
mlp = self.get_mlp() |
|
|
vector_length = 2 |
|
|
num_vectors = 10 |
|
|
topk = 2 |
|
|
r = 4 |
|
|
config = VBLoRAConfig( |
|
|
target_modules=["lin0", "lin1"], |
|
|
vector_length=vector_length, |
|
|
num_vectors=num_vectors, |
|
|
topk=topk, |
|
|
r=r, |
|
|
save_only_topk_weights=False, |
|
|
) |
|
|
mlp_vblora = get_peft_model(mlp, config) |
|
|
|
|
|
mlp_vblora.lin3.requires_grad_(True) |
|
|
|
|
|
adapter_params, other_params = mlp_vblora.get_nb_savable_parameters() |
|
|
logits_parameter = int( |
|
|
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features) |
|
|
/ vector_length |
|
|
* r |
|
|
* num_vectors |
|
|
) |
|
|
vector_bank_parameter = num_vectors * vector_length |
|
|
assert adapter_params == logits_parameter + vector_bank_parameter |
|
|
assert other_params == (mlp.lin3.in_features + 1) * mlp.lin3.out_features |
|
|
|