|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from peft.utils.integrations import init_empty_weights, skip_init_on_device |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, bias=True): |
|
|
super().__init__() |
|
|
self.lin0 = nn.Linear(10, 20, bias=bias) |
|
|
self.relu = nn.ReLU() |
|
|
self.drop = nn.Dropout(0.5) |
|
|
self.lin1 = nn.Linear(20, 2, bias=bias) |
|
|
|
|
|
|
|
|
def get_mlp(): |
|
|
return MLP() |
|
|
|
|
|
|
|
|
class TestInitEmptyWeights: |
|
|
def test_init_empty_weights_works(self): |
|
|
|
|
|
|
|
|
with init_empty_weights(): |
|
|
mlp = get_mlp() |
|
|
|
|
|
expected = torch.device("meta") |
|
|
assert all(p.device == expected for p in mlp.parameters()) |
|
|
|
|
|
def test_skip_init_on_device_works(self): |
|
|
|
|
|
|
|
|
decorated_fn = skip_init_on_device(get_mlp) |
|
|
with init_empty_weights(): |
|
|
mlp = decorated_fn() |
|
|
|
|
|
expected = torch.device("cpu") |
|
|
assert all(p.device == expected for p in mlp.parameters()) |
|
|
|
|
|
def test_skip_init_on_device_works_outside_context(self): |
|
|
|
|
|
decorated_fn = skip_init_on_device(get_mlp) |
|
|
mlp = decorated_fn() |
|
|
expected = torch.device("cpu") |
|
|
assert all(p.device == expected for p in mlp.parameters()) |
|
|
|
|
|
def test_skip_init_on_device_not_permanent(self): |
|
|
|
|
|
|
|
|
|
|
|
decorated_fn = skip_init_on_device(get_mlp) |
|
|
with init_empty_weights(): |
|
|
mlp = decorated_fn() |
|
|
|
|
|
expected = torch.device("cpu") |
|
|
assert all(p.device == expected for p in mlp.parameters()) |
|
|
|
|
|
|
|
|
with init_empty_weights(): |
|
|
mlp = get_mlp() |
|
|
|
|
|
expected = torch.device("meta") |
|
|
assert all(p.device == expected for p in mlp.parameters()) |
|
|
|
|
|
def test_skip_init_on_device_nested(self): |
|
|
|
|
|
|
|
|
@skip_init_on_device |
|
|
def outer_fn(): |
|
|
@skip_init_on_device |
|
|
def inner_fn(): |
|
|
return get_mlp() |
|
|
|
|
|
mlp0 = inner_fn() |
|
|
mlp1 = get_mlp() |
|
|
return mlp0, mlp1 |
|
|
|
|
|
with init_empty_weights(): |
|
|
mlp0, mlp1 = outer_fn() |
|
|
|
|
|
expected = torch.device("cpu") |
|
|
assert all(p.device == expected for p in mlp0.parameters()) |
|
|
assert all(p.device == expected for p in mlp1.parameters()) |
|
|
|