File size: 22,154 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 |
{
"cells": [
{
"cell_type": "markdown",
"id": "8e8743c8",
"metadata": {},
"source": [
"# Using PEFT with custom models"
]
},
{
"cell_type": "markdown",
"id": "c42c67e1",
"metadata": {},
"source": [
"`peft` allows us to fine-tune models efficiently with LoRA. In this short notebook, we will demonstrate how to train a simple multilayer perceptron (MLP) using `peft`."
]
},
{
"cell_type": "markdown",
"id": "ce314af5",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "markdown",
"id": "b28b214d",
"metadata": {},
"source": [
"Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:\n",
" \n",
" python -m pip install --upgrade peft"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4d9da3d9",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import os\n",
"\n",
"# ignore bnb warnings\n",
"os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "44075f54",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import peft\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f72acdfb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f6d9259a5d0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.manual_seed(0)"
]
},
{
"cell_type": "markdown",
"id": "2b127a78",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "markdown",
"id": "f265da76",
"metadata": {},
"source": [
"We will create a toy dataset consisting of random data for a classification task. There is a little bit of signal in the data, so we should expect that the loss of the model can improve during training."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b355567e",
"metadata": {},
"outputs": [],
"source": [
"X = torch.rand((1000, 20))\n",
"y = (X.sum(1) > 10).long()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a60a869d",
"metadata": {},
"outputs": [],
"source": [
"n_train = 800\n",
"batch_size = 64"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8859572e",
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = torch.utils.data.DataLoader(\n",
" torch.utils.data.TensorDataset(X[:n_train], y[:n_train]),\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
")\n",
"eval_dataloader = torch.utils.data.DataLoader(\n",
" torch.utils.data.TensorDataset(X[n_train:], y[n_train:]),\n",
" batch_size=batch_size,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "97bddd2c",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "markdown",
"id": "db694a58",
"metadata": {},
"source": [
"As a model, we use a simple multilayer perceptron (MLP). For demonstration purposes, we use a very large number of hidden units. This is totally overkill for this task but it helps to demonstrate the advantages of `peft`. In more realistic settings, models will also be quite large on average, so this is not far-fetched."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1b43cd8f",
"metadata": {},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self, num_units_hidden=2000):\n",
" super().__init__()\n",
" self.seq = nn.Sequential(\n",
" nn.Linear(20, num_units_hidden),\n",
" nn.ReLU(),\n",
" nn.Linear(num_units_hidden, num_units_hidden),\n",
" nn.ReLU(),\n",
" nn.Linear(num_units_hidden, 2),\n",
" nn.LogSoftmax(dim=-1),\n",
" )\n",
"\n",
" def forward(self, X):\n",
" return self.seq(X)"
]
},
{
"cell_type": "markdown",
"id": "1277bf00",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"id": "02caf26a",
"metadata": {},
"source": [
"Here are just a few training hyper-parameters and a simple function that performs the training and evaluation loop."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5d14c0c4",
"metadata": {},
"outputs": [],
"source": [
"lr = 0.002\n",
"batch_size = 64\n",
"max_epochs = 30\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "657d6b3e",
"metadata": {},
"outputs": [],
"source": [
"def train(model, optimizer, criterion, train_dataloader, eval_dataloader, epochs):\n",
" for epoch in range(epochs):\n",
" model.train()\n",
" train_loss = 0\n",
" for xb, yb in train_dataloader:\n",
" xb = xb.to(device)\n",
" yb = yb.to(device)\n",
" outputs = model(xb)\n",
" loss = criterion(outputs, yb)\n",
" train_loss += loss.detach().float()\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" model.eval()\n",
" eval_loss = 0\n",
" for xb, yb in eval_dataloader:\n",
" xb = xb.to(device)\n",
" yb = yb.to(device)\n",
" with torch.no_grad():\n",
" outputs = model(xb)\n",
" loss = criterion(outputs, yb)\n",
" eval_loss += loss.detach().float()\n",
"\n",
" eval_loss_total = (eval_loss / len(eval_dataloader)).item()\n",
" train_loss_total = (train_loss / len(train_dataloader)).item()\n",
" print(f\"{epoch=:<2} {train_loss_total=:.4f} {eval_loss_total=:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "b382dcbe",
"metadata": {},
"source": [
"### Training without peft"
]
},
{
"cell_type": "markdown",
"id": "b40d4873",
"metadata": {},
"source": [
"Let's start without using `peft` to see what we can expect from the model training."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f059ced4",
"metadata": {},
"outputs": [],
"source": [
"module = MLP().to(device)\n",
"optimizer = torch.optim.Adam(module.parameters(), lr=lr)\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "17698863",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=0 train_loss_total=0.7970 eval_loss_total=0.6472\n",
"epoch=1 train_loss_total=0.5597 eval_loss_total=0.4898\n",
"epoch=2 train_loss_total=0.3696 eval_loss_total=0.3323\n",
"epoch=3 train_loss_total=0.2364 eval_loss_total=0.5454\n",
"epoch=4 train_loss_total=0.2428 eval_loss_total=0.2843\n",
"epoch=5 train_loss_total=0.1251 eval_loss_total=0.2514\n",
"epoch=6 train_loss_total=0.0952 eval_loss_total=0.2068\n",
"epoch=7 train_loss_total=0.0831 eval_loss_total=0.2395\n",
"epoch=8 train_loss_total=0.0655 eval_loss_total=0.2524\n",
"epoch=9 train_loss_total=0.0380 eval_loss_total=0.3650\n",
"epoch=10 train_loss_total=0.0363 eval_loss_total=0.3495\n",
"epoch=11 train_loss_total=0.0231 eval_loss_total=0.2360\n",
"epoch=12 train_loss_total=0.0162 eval_loss_total=0.2276\n",
"epoch=13 train_loss_total=0.0094 eval_loss_total=0.2716\n",
"epoch=14 train_loss_total=0.0065 eval_loss_total=0.2237\n",
"epoch=15 train_loss_total=0.0054 eval_loss_total=0.2366\n",
"epoch=16 train_loss_total=0.0035 eval_loss_total=0.2673\n",
"epoch=17 train_loss_total=0.0028 eval_loss_total=0.2630\n",
"epoch=18 train_loss_total=0.0023 eval_loss_total=0.2835\n",
"epoch=19 train_loss_total=0.0021 eval_loss_total=0.2727\n",
"epoch=20 train_loss_total=0.0018 eval_loss_total=0.2597\n",
"epoch=21 train_loss_total=0.0016 eval_loss_total=0.2553\n",
"epoch=22 train_loss_total=0.0014 eval_loss_total=0.2712\n",
"epoch=23 train_loss_total=0.0013 eval_loss_total=0.2637\n",
"epoch=24 train_loss_total=0.0012 eval_loss_total=0.2733\n",
"epoch=25 train_loss_total=0.0011 eval_loss_total=0.2738\n",
"epoch=26 train_loss_total=0.0010 eval_loss_total=0.2477\n",
"epoch=27 train_loss_total=0.0010 eval_loss_total=0.2584\n",
"epoch=28 train_loss_total=0.0009 eval_loss_total=0.2844\n",
"epoch=29 train_loss_total=0.0008 eval_loss_total=0.2633\n",
"CPU times: user 1.31 s, sys: 236 ms, total: 1.54 s\n",
"Wall time: 1.56 s\n"
]
}
],
"source": [
"%time train(module, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)"
]
},
{
"cell_type": "markdown",
"id": "4cef0029",
"metadata": {},
"source": [
"Okay, so we got an eval loss of ~0.26, which is much better than random."
]
},
{
"cell_type": "markdown",
"id": "4f106078",
"metadata": {},
"source": [
"### Training with peft"
]
},
{
"cell_type": "markdown",
"id": "8dd47aa4",
"metadata": {},
"source": [
"Now let's train with `peft`. First we check the names of the modules, so that we can configure `peft` to fine-tune the right modules."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "922db29b",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[('', __main__.MLP),\n",
" ('seq', torch.nn.modules.container.Sequential),\n",
" ('seq.0', torch.nn.modules.linear.Linear),\n",
" ('seq.1', torch.nn.modules.activation.ReLU),\n",
" ('seq.2', torch.nn.modules.linear.Linear),\n",
" ('seq.3', torch.nn.modules.activation.ReLU),\n",
" ('seq.4', torch.nn.modules.linear.Linear),\n",
" ('seq.5', torch.nn.modules.activation.LogSoftmax)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(n, type(m)) for n, m in MLP().named_modules()]"
]
},
{
"cell_type": "markdown",
"id": "5efb275d",
"metadata": {},
"source": [
"Next we can define the LoRA config. There is nothing special going on here. We set the LoRA rank to 8 and select the layers `seq.0` and `seq.2` to be used for LoRA fine-tuning. As for `seq.4`, which is the output layer, we set it as `module_to_save`, which means it is also trained but no LoRA is applied."
]
},
{
"cell_type": "markdown",
"id": "cf2c608d",
"metadata": {},
"source": [
"*Note: Not all layers types can be fine-tuned with LoRA. At the moment, linear layers, embeddings, `Conv2D` and `transformers.pytorch_utils.Conv1D` are supported."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b342438f",
"metadata": {},
"outputs": [],
"source": [
"config = peft.LoraConfig(\n",
" r=8,\n",
" target_modules=[\"seq.0\", \"seq.2\"],\n",
" modules_to_save=[\"seq.4\"],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "829b4e2d",
"metadata": {},
"source": [
"Now let's create the `peft` model by passing our initial MLP, as well as the config we just defined, to `get_peft_model`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "602b6658",
"metadata": {},
"outputs": [],
"source": [
"module = MLP().to(device)\n",
"module_copy = copy.deepcopy(module) # we keep a copy of the original model for later\n",
"peft_model = peft.get_peft_model(module, config)\n",
"optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)\n",
"criterion = nn.CrossEntropyLoss()\n",
"peft_model.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"id": "2103737d",
"metadata": {},
"source": [
"Checking the numbers, we see that only ~1% of parameters are actually trained, which is what we like to see.\n",
"\n",
"Now let's start the training:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9200cbc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch=0 train_loss_total=0.6695 eval_loss_total=0.6388\n",
"epoch=1 train_loss_total=0.5614 eval_loss_total=0.5456\n",
"epoch=2 train_loss_total=0.3897 eval_loss_total=0.3035\n",
"epoch=3 train_loss_total=0.2529 eval_loss_total=0.2510\n",
"epoch=4 train_loss_total=0.1914 eval_loss_total=0.2191\n",
"epoch=5 train_loss_total=0.1236 eval_loss_total=0.2586\n",
"epoch=6 train_loss_total=0.1076 eval_loss_total=0.3205\n",
"epoch=7 train_loss_total=0.1834 eval_loss_total=0.3951\n",
"epoch=8 train_loss_total=0.1037 eval_loss_total=0.1646\n",
"epoch=9 train_loss_total=0.0724 eval_loss_total=0.1409\n",
"epoch=10 train_loss_total=0.0691 eval_loss_total=0.1725\n",
"epoch=11 train_loss_total=0.0641 eval_loss_total=0.1423\n",
"epoch=12 train_loss_total=0.0382 eval_loss_total=0.1490\n",
"epoch=13 train_loss_total=0.0214 eval_loss_total=0.1517\n",
"epoch=14 train_loss_total=0.0119 eval_loss_total=0.1717\n",
"epoch=15 train_loss_total=0.0060 eval_loss_total=0.2366\n",
"epoch=16 train_loss_total=0.0029 eval_loss_total=0.2069\n",
"epoch=17 train_loss_total=0.0021 eval_loss_total=0.2082\n",
"epoch=18 train_loss_total=0.0016 eval_loss_total=0.2119\n",
"epoch=19 train_loss_total=0.0011 eval_loss_total=0.1984\n",
"epoch=20 train_loss_total=0.0010 eval_loss_total=0.1821\n",
"epoch=21 train_loss_total=0.0009 eval_loss_total=0.1892\n",
"epoch=22 train_loss_total=0.0007 eval_loss_total=0.2062\n",
"epoch=23 train_loss_total=0.0006 eval_loss_total=0.2408\n",
"epoch=24 train_loss_total=0.0006 eval_loss_total=0.2038\n",
"epoch=25 train_loss_total=0.0005 eval_loss_total=0.2374\n",
"epoch=26 train_loss_total=0.0004 eval_loss_total=0.2139\n",
"epoch=27 train_loss_total=0.0004 eval_loss_total=0.2085\n",
"epoch=28 train_loss_total=0.0004 eval_loss_total=0.2395\n",
"epoch=29 train_loss_total=0.0003 eval_loss_total=0.2100\n",
"CPU times: user 1.41 s, sys: 48.9 ms, total: 1.46 s\n",
"Wall time: 1.46 s\n"
]
}
],
"source": [
"%time train(peft_model, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)"
]
},
{
"cell_type": "markdown",
"id": "20f6f452",
"metadata": {},
"source": [
"In the end, we see that the eval loss is very similar to the one we saw earlier when we trained without `peft`. This is quite nice to see, given that we are training a much smaller number of parameters."
]
},
{
"cell_type": "markdown",
"id": "fa55d1d4",
"metadata": {},
"source": [
"#### Check which parameters were updated"
]
},
{
"cell_type": "markdown",
"id": "a6e2146b",
"metadata": {},
"source": [
"Finally, just to check that LoRA was applied as expected, we check what original weights were updated what weights stayed the same."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c7dcde21",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New parameter model.seq.0.lora_A.default.weight | 160 parameters | updated\n",
"New parameter model.seq.0.lora_B.default.weight | 16000 parameters | updated\n",
"New parameter model.seq.2.lora_A.default.weight | 16000 parameters | updated\n",
"New parameter model.seq.2.lora_B.default.weight | 16000 parameters | updated\n"
]
}
],
"source": [
"for name, param in peft_model.base_model.named_parameters():\n",
" if \"lora\" not in name:\n",
" continue\n",
"\n",
" print(f\"New parameter {name:<13} | {param.numel():>5} parameters | updated\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "022e6c41",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter seq.0.weight | 40000 parameters | not updated\n",
"Parameter seq.0.bias | 2000 parameters | not updated\n",
"Parameter seq.2.weight | 4000000 parameters | not updated\n",
"Parameter seq.2.bias | 2000 parameters | not updated\n",
"Parameter seq.4.weight | 4000 parameters | not updated\n",
"Parameter seq.4.bias | 2 parameters | not updated\n",
"Parameter seq.4.weight | 4000 parameters | updated\n",
"Parameter seq.4.bias | 2 parameters | updated\n"
]
}
],
"source": [
"params_before = dict(module_copy.named_parameters())\n",
"for name, param in peft_model.base_model.named_parameters():\n",
" if \"lora\" in name:\n",
" continue\n",
"\n",
" name_before = (\n",
" name.partition(\".\")[-1].replace(\"base_layer.\", \"\").replace(\"original_\", \"\").replace(\"module.\", \"\").replace(\"modules_to_save.default.\", \"\")\n",
" )\n",
" param_before = params_before[name_before]\n",
" if torch.allclose(param, param_before):\n",
" print(f\"Parameter {name_before:<13} | {param.numel():>7} parameters | not updated\")\n",
" else:\n",
" print(f\"Parameter {name_before:<13} | {param.numel():>7} parameters | updated\")"
]
},
{
"cell_type": "markdown",
"id": "4c09b43d",
"metadata": {},
"source": [
"So we can see that apart from the new LoRA weights that were added, only the last layer was updated. Since the LoRA weights and the last layer have comparitively few parameters, this gives us a big boost in efficiency."
]
},
{
"cell_type": "markdown",
"id": "b46c6198",
"metadata": {},
"source": [
"## Sharing the model through Hugging Face Hub"
]
},
{
"cell_type": "markdown",
"id": "6289e647",
"metadata": {},
"source": [
"### Pushing the model to HF Hub"
]
},
{
"cell_type": "markdown",
"id": "06dcdfa0",
"metadata": {},
"source": [
"With the `peft` model, it is also very easy to push a model the Hugging Face Hub. Below, we demonstrate how it works. It is assumed that you have a valid Hugging Face account and are logged in:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "1b91a0af",
"metadata": {},
"outputs": [],
"source": [
"user = \"BenjaminB\" # put your user name here\n",
"model_name = \"peft-lora-with-custom-model\"\n",
"model_id = f\"{user}/{model_name}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1430fffd",
"metadata": {},
"outputs": [],
"source": [
"peft_model.push_to_hub(model_id);"
]
},
{
"cell_type": "markdown",
"id": "632bd799",
"metadata": {},
"source": [
"As we can see, the adapter size is only 211 kB."
]
},
{
"cell_type": "markdown",
"id": "4ff78c0c",
"metadata": {},
"source": [
"### Loading the model from HF Hub"
]
},
{
"cell_type": "markdown",
"id": "e5c7e87f",
"metadata": {},
"source": [
"Now, it only takes one step to load the model from HF Hub. To do this, we can use `PeftModel.from_pretrained`, passing our base model and the model ID:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce0fcced",
"metadata": {},
"outputs": [],
"source": [
"loaded = peft.PeftModel.from_pretrained(module_copy, model_id)\n",
"type(loaded)"
]
},
{
"cell_type": "markdown",
"id": "cd4b4eac",
"metadata": {},
"source": [
"Let's check that the two models produce the same output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2cf6ac4",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"y_peft = peft_model(X.to(device))\n",
"y_loaded = loaded(X.to(device))\n",
"torch.allclose(y_peft, y_loaded)"
]
},
{
"cell_type": "markdown",
"id": "eeeb653f",
"metadata": {},
"source": [
"### Clean up"
]
},
{
"cell_type": "markdown",
"id": "61c60355",
"metadata": {},
"source": [
"Finally, as a clean up step, you may want to delete the repo."
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "b747038f",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import delete_repo"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7e5ab237",
"metadata": {},
"outputs": [],
"source": [
"delete_repo(model_id)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|