Ramzes / tests /test_integrations.py
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from peft.utils.integrations import init_empty_weights, skip_init_on_device
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.lin1 = nn.Linear(20, 2, bias=bias)
def get_mlp():
return MLP()
class TestInitEmptyWeights:
def test_init_empty_weights_works(self):
# this is a very rudimentary test, as init_empty_weights is copied almost 1:1 from accelerate and is tested
# there
with init_empty_weights():
mlp = get_mlp()
expected = torch.device("meta")
assert all(p.device == expected for p in mlp.parameters())
def test_skip_init_on_device_works(self):
# when a function is decorated with skip_init_on_device, the parameters are not moved to meta device, even when
# inside the context
decorated_fn = skip_init_on_device(get_mlp)
with init_empty_weights():
mlp = decorated_fn()
expected = torch.device("cpu")
assert all(p.device == expected for p in mlp.parameters())
def test_skip_init_on_device_works_outside_context(self):
# same as before, but ensure that skip_init_on_device does not break when no init_empty_weights context is used
decorated_fn = skip_init_on_device(get_mlp)
mlp = decorated_fn()
expected = torch.device("cpu")
assert all(p.device == expected for p in mlp.parameters())
def test_skip_init_on_device_not_permanent(self):
# ensure that after skip_init_on_device has been used, init_empty_weights reverts to its original functionality
# with decorator => cpu
decorated_fn = skip_init_on_device(get_mlp)
with init_empty_weights():
mlp = decorated_fn()
expected = torch.device("cpu")
assert all(p.device == expected for p in mlp.parameters())
# without decorator => meta
with init_empty_weights():
mlp = get_mlp()
expected = torch.device("meta")
assert all(p.device == expected for p in mlp.parameters())
def test_skip_init_on_device_nested(self):
# ensure that skip_init_on_device works even if the decorated function is nested inside another decorated
# function
@skip_init_on_device
def outer_fn():
@skip_init_on_device
def inner_fn():
return get_mlp()
mlp0 = inner_fn()
mlp1 = get_mlp()
return mlp0, mlp1
with init_empty_weights():
mlp0, mlp1 = outer_fn()
expected = torch.device("cpu")
assert all(p.device == expected for p in mlp0.parameters())
assert all(p.device == expected for p in mlp1.parameters())