|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from accelerate.utils.imports import is_bf16_available |
|
|
from torch import nn |
|
|
|
|
|
from peft import PeftModel, ShiraConfig, get_peft_model |
|
|
|
|
|
|
|
|
def custom_random_mask_function_with_custom_kwargs(custom_arg): |
|
|
def mask_fn(base_layer, r): |
|
|
""" |
|
|
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the |
|
|
seed is derived from custom_kwargs. Please use this as an example to create your own custom sparse masks that |
|
|
may use custom_kwargs. Remember, for a pretrained weight with shape m, n, mask_fn must return only one mask |
|
|
(shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and |
|
|
dtype of mask must be same as base layer's weight's device and dtype. |
|
|
""" |
|
|
new_seed = custom_arg |
|
|
shape = base_layer.weight.shape |
|
|
num_shira_weights = r * (shape[0] + shape[1]) |
|
|
random_generator = torch.Generator() |
|
|
random_generator.manual_seed(new_seed) |
|
|
|
|
|
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to( |
|
|
base_layer.weight.device |
|
|
) |
|
|
val = torch.ones_like(idx.type(base_layer.weight.dtype)) |
|
|
mask = torch.zeros_like(base_layer.weight.view(1, -1)) |
|
|
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape) |
|
|
|
|
|
return mask |
|
|
|
|
|
return mask_fn |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, bias=True): |
|
|
super().__init__() |
|
|
self.relu = nn.ReLU() |
|
|
self.lin0 = nn.Linear(10, 20, bias=bias) |
|
|
self.lin1 = nn.Linear(20, 40, bias=bias) |
|
|
self.lin2 = nn.Linear(40, 30, bias=bias) |
|
|
self.lin3 = nn.Linear(30, 10, bias=bias) |
|
|
self.sm = nn.LogSoftmax(dim=-1) |
|
|
|
|
|
def forward(self, X): |
|
|
X = self.lin0(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin1(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin2(X) |
|
|
X = self.relu(X) |
|
|
X = self.lin3(X) |
|
|
X = self.sm(X) |
|
|
return X |
|
|
|
|
|
|
|
|
class TestShira: |
|
|
@pytest.fixture |
|
|
def mlp(self): |
|
|
torch.manual_seed(0) |
|
|
model = MLP() |
|
|
return model |
|
|
|
|
|
def test_mlp_single_adapter_shapes(self, mlp): |
|
|
|
|
|
|
|
|
r = 2 |
|
|
config = ShiraConfig(r=r, target_modules=["lin1", "lin2"]) |
|
|
|
|
|
peft_model = get_peft_model(mlp, config) |
|
|
|
|
|
shira_weight1_size = peft_model.base_model.model.lin1.shira_weight["default"].shape[0] |
|
|
shira_weight2_size = peft_model.base_model.model.lin2.shira_weight["default"].shape[0] |
|
|
shira_indices1_size = peft_model.base_model.model.lin1.shira_indices["default"].shape[1] |
|
|
shira_indices2_size = peft_model.base_model.model.lin2.shira_indices["default"].shape[1] |
|
|
|
|
|
base_weight1_size = peft_model.base_model.model.lin1.base_layer.weight.shape |
|
|
base_weight2_size = peft_model.base_model.model.lin2.base_layer.weight.shape |
|
|
|
|
|
delta_weight1_shape = peft_model.base_model.model.lin1.get_delta_weight("default").shape |
|
|
delta_weight2_shape = peft_model.base_model.model.lin2.get_delta_weight("default").shape |
|
|
|
|
|
assert shira_weight1_size == r * (base_weight1_size[0] + base_weight1_size[1]) |
|
|
assert shira_weight2_size == r * (base_weight2_size[0] + base_weight2_size[1]) |
|
|
|
|
|
assert shira_weight1_size == shira_indices1_size |
|
|
assert shira_weight2_size == shira_indices2_size |
|
|
|
|
|
assert delta_weight1_shape == base_weight1_size |
|
|
assert delta_weight2_shape == base_weight2_size |
|
|
|
|
|
return peft_model |
|
|
|
|
|
def test_multiple_adapters_save_load(self, mlp, tmp_path): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=56) |
|
|
|
|
|
peft_model = get_peft_model(mlp, config, adapter_name="first") |
|
|
config2 = ShiraConfig(r=3, target_modules=["lin1", "lin2", "lin3"], random_seed=67) |
|
|
peft_model.add_adapter("second", config2) |
|
|
|
|
|
assert torch.all(peft_model.base_model.model.lin1.shira_weight["first"] == 0) |
|
|
assert torch.all(peft_model.base_model.model.lin2.shira_weight["first"] == 0) |
|
|
assert torch.all(peft_model.base_model.model.lin1.shira_weight["second"] == 0) |
|
|
assert torch.all(peft_model.base_model.model.lin2.shira_weight["second"] == 0) |
|
|
assert torch.all(peft_model.base_model.model.lin3.shira_weight["second"] == 0) |
|
|
|
|
|
shira_assign_val1_f = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["first"]) |
|
|
peft_model.base_model.model.lin1.shira_weight["first"] = shira_assign_val1_f |
|
|
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
|
|
shira_assign_val2_f = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["first"]) |
|
|
peft_model.base_model.model.lin2.shira_weight["first"] = shira_assign_val2_f |
|
|
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
|
|
|
shira_assign_val1_s = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["second"]) |
|
|
peft_model.base_model.model.lin1.shira_weight["second"] = shira_assign_val1_s |
|
|
shira_indices1_s = peft_model.base_model.model.lin1.shira_indices["second"] |
|
|
shira_assign_val2_s = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["second"]) |
|
|
peft_model.base_model.model.lin2.shira_weight["second"] = shira_assign_val2_s |
|
|
shira_indices2_s = peft_model.base_model.model.lin2.shira_indices["second"] |
|
|
shira_assign_val3_s = torch.randn_like(peft_model.base_model.model.lin3.shira_weight["second"]) |
|
|
peft_model.base_model.model.lin3.shira_weight["second"] = shira_assign_val3_s |
|
|
shira_indices3_s = peft_model.base_model.model.lin3.shira_indices["second"] |
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
peft_model.set_adapter("first") |
|
|
output_first = peft_model(input) |
|
|
peft_model.set_adapter("second") |
|
|
output_second = peft_model(input) |
|
|
|
|
|
|
|
|
assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3) |
|
|
|
|
|
save_path = os.path.join(tmp_path, "shira") |
|
|
peft_model.save_pretrained(save_path) |
|
|
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
|
|
assert os.path.exists(os.path.join(save_path, "second", "adapter_config.json")) |
|
|
del peft_model |
|
|
|
|
|
torch.manual_seed(0) |
|
|
mlp = MLP() |
|
|
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
|
|
peft_model.load_adapter(os.path.join(save_path, "second"), "second") |
|
|
|
|
|
peft_model.set_adapter("first") |
|
|
output_first_loaded = peft_model(input) |
|
|
peft_model.set_adapter("second") |
|
|
output_second_loaded = peft_model(input) |
|
|
|
|
|
assert torch.allclose(output_first, output_first_loaded) |
|
|
assert torch.allclose(output_second, output_second_loaded) |
|
|
|
|
|
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
|
|
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
|
|
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
|
|
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
|
|
assert torch.all(shira_assign_val1_s == peft_model.base_model.model.lin1.shira_weight["second"]) |
|
|
assert torch.all(shira_assign_val2_s == peft_model.base_model.model.lin2.shira_weight["second"]) |
|
|
assert torch.all(shira_assign_val3_s == peft_model.base_model.model.lin3.shira_weight["second"]) |
|
|
assert torch.all(shira_indices1_s == peft_model.base_model.model.lin1.shira_indices["second"]) |
|
|
assert torch.all(shira_indices2_s == peft_model.base_model.model.lin2.shira_indices["second"]) |
|
|
assert torch.all(shira_indices3_s == peft_model.base_model.model.lin3.shira_indices["second"]) |
|
|
|
|
|
return peft_model |
|
|
|
|
|
def test_save_load_custom_mask_function(self, mlp, tmp_path): |
|
|
|
|
|
config = ShiraConfig(r=2, mask_type="custom", target_modules=["lin1", "lin2"], init_weights=False) |
|
|
custom_arg = 120 |
|
|
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg) |
|
|
config.mask_fn = custom_mask_fn |
|
|
|
|
|
|
|
|
peft_model = get_peft_model(mlp, config, adapter_name="first") |
|
|
|
|
|
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"] |
|
|
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
|
|
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"] |
|
|
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
peft_model.set_adapter("first") |
|
|
output_first = peft_model(input) |
|
|
|
|
|
save_path = os.path.join(tmp_path, "shira") |
|
|
peft_model.save_pretrained(save_path) |
|
|
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
|
|
del peft_model |
|
|
|
|
|
torch.manual_seed(0) |
|
|
mlp = MLP() |
|
|
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
|
|
|
|
|
peft_model.set_adapter("first") |
|
|
output_first_loaded = peft_model(input) |
|
|
|
|
|
assert torch.allclose(output_first, output_first_loaded) |
|
|
|
|
|
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
|
|
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
|
|
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
|
|
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
|
|
|
|
|
return peft_model |
|
|
|
|
|
def test_save_load_default_random_mask_with_seed_function(self, mlp, tmp_path): |
|
|
|
|
|
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=567, init_weights=False) |
|
|
|
|
|
|
|
|
peft_model = get_peft_model(mlp, config, adapter_name="first") |
|
|
|
|
|
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"] |
|
|
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"] |
|
|
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"] |
|
|
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"] |
|
|
|
|
|
input = torch.randn(5, 10) |
|
|
peft_model.set_adapter("first") |
|
|
output_first = peft_model(input) |
|
|
|
|
|
save_path = os.path.join(tmp_path, "shira") |
|
|
peft_model.save_pretrained(save_path) |
|
|
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json")) |
|
|
del peft_model |
|
|
|
|
|
torch.manual_seed(0) |
|
|
mlp = MLP() |
|
|
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first") |
|
|
|
|
|
peft_model.set_adapter("first") |
|
|
output_first_loaded = peft_model(input) |
|
|
|
|
|
assert torch.allclose(output_first, output_first_loaded) |
|
|
|
|
|
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"]) |
|
|
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"]) |
|
|
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"]) |
|
|
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"]) |
|
|
|
|
|
return peft_model |
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) |
|
|
def test_shira_dtypes(self, dtype): |
|
|
if dtype == torch.bfloat16: |
|
|
|
|
|
if not is_bf16_available(): |
|
|
pytest.skip("bfloat16 not supported on this system, skipping the test") |
|
|
|
|
|
model = MLP().to(dtype) |
|
|
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"]) |
|
|
peft_model = get_peft_model(model, config) |
|
|
inputs = torch.randn(5, 10).to(dtype) |
|
|
output = peft_model(inputs) |
|
|
assert output.dtype == dtype |
|
|
|