File size: 1,630 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
from peft.tuners._buffer_dict import BufferDict
class TestBufferDict:
def test_init_from_dict_works(self):
bd = BufferDict(
{
"default": torch.randn(10, 2),
}
)
def test_update_from_other_bufferdict(self):
default_tensor = torch.randn(10, 2)
non_default_tensor = torch.randn(10, 2)
bd1 = BufferDict({"default": default_tensor})
bd2 = BufferDict({"non_default": non_default_tensor})
bd1.update(bd2)
assert set(bd1.keys()) == {"default", "non_default"}
assert torch.allclose(bd1["default"], default_tensor)
assert torch.allclose(bd1["non_default"], non_default_tensor)
def test_update_from_dict(self):
default_tensor = torch.randn(10, 2)
non_default_tensor = torch.randn(10, 2)
bd1 = BufferDict({"default": default_tensor})
d1 = {"non_default": non_default_tensor}
bd1.update(d1)
assert set(bd1.keys()) == {"default", "non_default"}
assert torch.allclose(bd1["default"], default_tensor)
assert torch.allclose(bd1["non_default"], non_default_tensor)
def test_update_from_dict_items(self):
default_tensor = torch.randn(10, 2)
non_default_tensor = torch.randn(10, 2)
bd1 = BufferDict({"default": default_tensor})
d1 = {"non_default": non_default_tensor}
bd1.update(d1.items())
assert set(bd1.keys()) == {"default", "non_default"}
assert torch.allclose(bd1["default"], default_tensor)
assert torch.allclose(bd1["non_default"], non_default_tensor)
|