Ramzes / tests /test_shira.py
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test file is for tests specific to SHiRA.
import os
import pytest
import torch
from accelerate.utils.imports import is_bf16_available
from torch import nn
from peft import PeftModel, ShiraConfig, get_peft_model
def custom_random_mask_function_with_custom_kwargs(custom_arg):
def mask_fn(base_layer, r):
"""
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the
seed is derived from custom_kwargs. Please use this as an example to create your own custom sparse masks that
may use custom_kwargs. Remember, for a pretrained weight with shape m, n, mask_fn must return only one mask
(shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and
dtype of mask must be same as base layer's weight's device and dtype.
"""
new_seed = custom_arg
shape = base_layer.weight.shape
num_shira_weights = r * (shape[0] + shape[1])
random_generator = torch.Generator()
random_generator.manual_seed(new_seed)
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
base_layer.weight.device
)
val = torch.ones_like(idx.type(base_layer.weight.dtype))
mask = torch.zeros_like(base_layer.weight.view(1, -1))
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
return mask
return mask_fn
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.relu = nn.ReLU()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 40, bias=bias) # lin1 and lin2 have same shape
self.lin2 = nn.Linear(40, 30, bias=bias)
self.lin3 = nn.Linear(30, 10, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.lin0(X)
X = self.relu(X)
X = self.lin1(X)
X = self.relu(X)
X = self.lin2(X)
X = self.relu(X)
X = self.lin3(X)
X = self.sm(X)
return X
class TestShira:
@pytest.fixture
def mlp(self):
torch.manual_seed(0)
model = MLP()
return model
def test_mlp_single_adapter_shapes(self, mlp):
# torch.manual_seed(0)
r = 2
config = ShiraConfig(r=r, target_modules=["lin1", "lin2"])
# creates a default SHiRA adapter
peft_model = get_peft_model(mlp, config)
shira_weight1_size = peft_model.base_model.model.lin1.shira_weight["default"].shape[0]
shira_weight2_size = peft_model.base_model.model.lin2.shira_weight["default"].shape[0]
shira_indices1_size = peft_model.base_model.model.lin1.shira_indices["default"].shape[1]
shira_indices2_size = peft_model.base_model.model.lin2.shira_indices["default"].shape[1]
base_weight1_size = peft_model.base_model.model.lin1.base_layer.weight.shape
base_weight2_size = peft_model.base_model.model.lin2.base_layer.weight.shape
delta_weight1_shape = peft_model.base_model.model.lin1.get_delta_weight("default").shape
delta_weight2_shape = peft_model.base_model.model.lin2.get_delta_weight("default").shape
assert shira_weight1_size == r * (base_weight1_size[0] + base_weight1_size[1])
assert shira_weight2_size == r * (base_weight2_size[0] + base_weight2_size[1])
assert shira_weight1_size == shira_indices1_size
assert shira_weight2_size == shira_indices2_size
assert delta_weight1_shape == base_weight1_size
assert delta_weight2_shape == base_weight2_size
return peft_model
def test_multiple_adapters_save_load(self, mlp, tmp_path):
# check saving and loading works with multiple adapters
# note, the random seeds in the below two configs are not the default values.
# so it will lead to different random sparse masks between saving and loading.
# our goal is to make sure that loaded indices are exactly the same as the saved indices regardless of what initial random mask gets generated.
# we will also make sure that parameters are saved and loaded correctly, and the output remains the same.
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=56)
# creates a default SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
config2 = ShiraConfig(r=3, target_modules=["lin1", "lin2", "lin3"], random_seed=67)
peft_model.add_adapter("second", config2)
assert torch.all(peft_model.base_model.model.lin1.shira_weight["first"] == 0)
assert torch.all(peft_model.base_model.model.lin2.shira_weight["first"] == 0)
assert torch.all(peft_model.base_model.model.lin1.shira_weight["second"] == 0)
assert torch.all(peft_model.base_model.model.lin2.shira_weight["second"] == 0)
assert torch.all(peft_model.base_model.model.lin3.shira_weight["second"] == 0)
shira_assign_val1_f = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["first"])
peft_model.base_model.model.lin1.shira_weight["first"] = shira_assign_val1_f
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["first"])
peft_model.base_model.model.lin2.shira_weight["first"] = shira_assign_val2_f
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
shira_assign_val1_s = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["second"])
peft_model.base_model.model.lin1.shira_weight["second"] = shira_assign_val1_s
shira_indices1_s = peft_model.base_model.model.lin1.shira_indices["second"]
shira_assign_val2_s = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["second"])
peft_model.base_model.model.lin2.shira_weight["second"] = shira_assign_val2_s
shira_indices2_s = peft_model.base_model.model.lin2.shira_indices["second"]
shira_assign_val3_s = torch.randn_like(peft_model.base_model.model.lin3.shira_weight["second"])
peft_model.base_model.model.lin3.shira_weight["second"] = shira_assign_val3_s
shira_indices3_s = peft_model.base_model.model.lin3.shira_indices["second"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
peft_model.set_adapter("second")
output_second = peft_model(input)
# sanity check
assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
assert os.path.exists(os.path.join(save_path, "second", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.load_adapter(os.path.join(save_path, "second"), "second")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
peft_model.set_adapter("second")
output_second_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.allclose(output_second, output_second_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
assert torch.all(shira_assign_val1_s == peft_model.base_model.model.lin1.shira_weight["second"])
assert torch.all(shira_assign_val2_s == peft_model.base_model.model.lin2.shira_weight["second"])
assert torch.all(shira_assign_val3_s == peft_model.base_model.model.lin3.shira_weight["second"])
assert torch.all(shira_indices1_s == peft_model.base_model.model.lin1.shira_indices["second"])
assert torch.all(shira_indices2_s == peft_model.base_model.model.lin2.shira_indices["second"])
assert torch.all(shira_indices3_s == peft_model.base_model.model.lin3.shira_indices["second"])
return peft_model
def test_save_load_custom_mask_function(self, mlp, tmp_path):
# we want to see if saving and loading works when a custom mask is involved
config = ShiraConfig(r=2, mask_type="custom", target_modules=["lin1", "lin2"], init_weights=False)
custom_arg = 120
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
config.mask_fn = custom_mask_fn
# create a custom mask SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
return peft_model
def test_save_load_default_random_mask_with_seed_function(self, mlp, tmp_path):
# we want to see if saving and loading works when a random mask is involved but the random seed is fixed.
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=567, init_weights=False)
# create a custom mask SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
return peft_model
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_shira_dtypes(self, dtype):
if dtype == torch.bfloat16:
# skip if bf16 is not supported on hardware, see #1872
if not is_bf16_available():
pytest.skip("bfloat16 not supported on this system, skipping the test")
model = MLP().to(dtype)
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"])
peft_model = get_peft_model(model, config)
inputs = torch.randn(5, 10).to(dtype)
output = peft_model(inputs) # should not raise
assert output.dtype == dtype