Ramzes / tests /test_vblora.py
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from accelerate.utils.imports import is_bf16_available
from safetensors import safe_open
from torch import nn
from peft import PeftModel, VBLoRAConfig, get_peft_model
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.relu = nn.ReLU()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape
self.lin2 = nn.Linear(20, 20, bias=bias)
self.lin3 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.lin0(X)
X = self.relu(X)
X = self.lin1(X)
X = self.relu(X)
X = self.lin2(X)
X = self.relu(X)
X = self.lin3(X)
X = self.sm(X)
return X
class TestVBLoRA:
def get_mlp(self):
model = MLP()
return model
def test_vblora_parameters(self):
mlp = self.get_mlp()
vector_length = 2
num_vectors = 10
config = VBLoRAConfig(
target_modules=["lin0", "lin1", "lin3"], vector_length=vector_length, num_vectors=num_vectors
)
mlp_vblora = get_peft_model(mlp, config)
vector_bank = mlp_vblora.vblora_vector_bank["default"]
vblora_lin0_logits_B = mlp_vblora.lin0.vblora_logits_B["default"]
assert vblora_lin0_logits_B.shape == (mlp.lin0.out_features // vector_length, config.r, num_vectors)
vblora_lin1_logits_A = mlp_vblora.lin1.vblora_logits_A["default"]
assert vblora_lin1_logits_A.shape == (config.r, mlp.lin1.in_features // vector_length, num_vectors)
vblora_lin3_logits_A = mlp_vblora.lin3.vblora_logits_A["default"]
assert vblora_lin3_logits_A.shape == (config.r, mlp.lin3.in_features // vector_length, num_vectors)
assert vector_bank.shape == (num_vectors, vector_length)
# test if the vector bank is shared across the layers
assert (
mlp_vblora.lin0.vblora_vector_bank["default"].data_ptr()
== mlp_vblora.lin3.vblora_vector_bank["default"].data_ptr()
)
assert mlp_vblora.lin1.vblora_vector_bank["default"].data_ptr() == vector_bank.data_ptr()
# should not raise
input = torch.randn(5, 10)
mlp_vblora(input)
def test_save_with_topk_weights(self, tmp_path):
torch.manual_seed(0)
mlp = self.get_mlp()
vector_length = 2
num_vectors = 10
topk = 2
config = VBLoRAConfig(
target_modules=["lin0", "lin3"],
topk=topk,
vector_length=vector_length,
num_vectors=num_vectors,
save_only_topk_weights=True,
)
mlp_vblora = get_peft_model(mlp, config)
save_path = tmp_path / "vblora"
mlp_vblora.save_pretrained(save_path)
assert os.path.exists(save_path / "adapter_model.safetensors")
adapter_model_dict = {}
with safe_open(save_path / "adapter_model.safetensors", framework="pt") as f:
for k in f.keys():
adapter_model_dict[k] = f.get_tensor(k)
assert "base_model.model.lin0.vblora_logits_A_topk_indices" in adapter_model_dict
assert "base_model.model.lin0.vblora_logits_A_topk_weights" in adapter_model_dict
assert "base_model.model.lin3.vblora_logits_B_topk_indices" in adapter_model_dict
assert "base_model.model.lin3.vblora_logits_B_topk_weights" in adapter_model_dict
assert "base_model.model.lin0.vblora_logits_A" not in adapter_model_dict
assert "base_model.model.lin3.vblora_logits_B" not in adapter_model_dict
assert adapter_model_dict["base_model.model.lin0.vblora_logits_B_topk_indices"].shape == (
mlp.lin0.out_features // vector_length,
config.r,
topk,
)
assert adapter_model_dict["base_model.model.lin0.vblora_logits_B_topk_weights"].shape == (
mlp.lin0.out_features // vector_length,
config.r,
topk - 1,
)
assert adapter_model_dict["base_model.model.lin3.vblora_logits_A_topk_indices"].shape == (
config.r,
mlp.lin3.in_features // vector_length,
topk,
)
assert adapter_model_dict["base_model.model.lin3.vblora_logits_A_topk_weights"].shape == (
config.r,
mlp.lin3.in_features // vector_length,
topk - 1,
)
@pytest.mark.parametrize("save_only_topk_weights", [True, False])
def test_save_load(self, save_only_topk_weights, tmp_path):
torch.manual_seed(0)
mlp = self.get_mlp()
config = VBLoRAConfig(
target_modules=["lin0", "lin1", "lin3"],
topk=2,
vector_length=2,
num_vectors=10,
save_only_topk_weights=save_only_topk_weights,
)
mlp_vblora = get_peft_model(mlp, config)
save_path = tmp_path / "vblora"
mlp_vblora.save_pretrained(save_path)
assert os.path.exists(save_path / "adapter_config.json")
del mlp
torch.manual_seed(0) # make sure the base model has the same weights
mlp = self.get_mlp()
mlp_vblora_loaded = PeftModel.from_pretrained(mlp, save_path)
input = torch.randn(5, 10)
output = mlp_vblora(input)
output_loaded = mlp_vblora_loaded(input)
assert torch.allclose(output, output_loaded, atol=1e-8, rtol=1e-5)
def test_resume_training_model_with_topk_weights(self, tmp_path):
torch.manual_seed(1)
mlp = self.get_mlp()
config = VBLoRAConfig(
target_modules=["lin0", "lin1", "lin3"],
topk=2,
vector_length=2,
num_vectors=10,
save_only_topk_weights=True,
)
mlp_vblora = get_peft_model(mlp, config)
save_path = tmp_path / "vblora"
mlp_vblora.save_pretrained(save_path)
input = torch.randn(5, 10)
mlp_vblora.train()
# should not raise
mlp_vblora(input)
del mlp
torch.manual_seed(1)
mlp = self.get_mlp()
mlp_vblora_loaded = PeftModel.from_pretrained(mlp, save_path)
mlp_vblora_loaded.train()
msg = "Found infinity values in VB-LoRA logits. Ensure training was not resumed from a `save_only_topk_weights` model."
with pytest.raises(RuntimeError, match=msg):
mlp_vblora_loaded(input)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_vblora_dtypes(self, dtype):
mlp = self.get_mlp()
if dtype == torch.bfloat16:
if not is_bf16_available():
pytest.skip("bfloat16 not supported on this system, skipping the test")
config = VBLoRAConfig(
target_modules=["lin0", "lin1", "lin3"], vector_length=2, num_vectors=10, save_only_topk_weights=False
)
mlp_vblora = get_peft_model(mlp.to(dtype), config)
inputs = torch.randn(5, 10).to(dtype)
output = mlp_vblora(inputs) # should not raise
assert output.dtype == dtype
def test_vblora_nb_savable_params_only_topk_weights(self):
mlp = self.get_mlp()
vector_length = 2
num_vectors = 10
topk = 2
r = 4
config = VBLoRAConfig(
target_modules=["lin0", "lin1"],
vector_length=vector_length,
num_vectors=num_vectors,
topk=topk,
r=r,
save_only_topk_weights=True,
)
mlp_vblora = get_peft_model(mlp, config)
mlp_vblora.lin3.requires_grad_(True) # set lin3 to trainable
adapter_params, other_params = mlp_vblora.get_nb_savable_parameters()
factor = 0.25 # dtype of index is uint8
topk_indices_parameter = int(
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features)
/ vector_length
* r
* topk
* factor
)
topk_weights_parameter = int(
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features)
/ vector_length
* r
* (topk - 1)
)
vector_bank_parameter = num_vectors * vector_length
assert adapter_params == topk_indices_parameter + topk_weights_parameter + vector_bank_parameter
assert other_params == (mlp.lin3.in_features + 1) * mlp.lin3.out_features
def test_vblora_nb_savable_params_all_logits(self):
mlp = self.get_mlp()
vector_length = 2
num_vectors = 10
topk = 2
r = 4
config = VBLoRAConfig(
target_modules=["lin0", "lin1"],
vector_length=vector_length,
num_vectors=num_vectors,
topk=topk,
r=r,
save_only_topk_weights=False,
)
mlp_vblora = get_peft_model(mlp, config)
mlp_vblora.lin3.requires_grad_(True) # set lin3 to trainable
adapter_params, other_params = mlp_vblora.get_nb_savable_parameters()
logits_parameter = int(
(mlp.lin0.out_features + mlp.lin0.in_features + mlp.lin1.out_features + mlp.lin1.in_features)
/ vector_length
* r
* num_vectors
)
vector_bank_parameter = num_vectors * vector_length
assert adapter_params == logits_parameter + vector_bank_parameter
assert other_params == (mlp.lin3.in_features + 1) * mlp.lin3.out_features