File size: 22,154 Bytes
302920f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8e8743c8",
   "metadata": {},
   "source": [
    "# Using PEFT with custom models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c42c67e1",
   "metadata": {},
   "source": [
    "`peft` allows us to fine-tune models efficiently with LoRA. In this short notebook, we will demonstrate how to train a simple multilayer perceptron (MLP) using `peft`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce314af5",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b28b214d",
   "metadata": {},
   "source": [
    "Make sure that you have the latest version of `peft` installed. To ensure that, run this in your Python environment:\n",
    "    \n",
    "    python -m pip install --upgrade peft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d9da3d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import os\n",
    "\n",
    "# ignore bnb warnings\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "44075f54",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import peft\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f72acdfb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f6d9259a5d0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b127a78",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f265da76",
   "metadata": {},
   "source": [
    "We will create a toy dataset consisting of random data for a classification task. There is a little bit of signal in the data, so we should expect that the loss of the model can improve during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b355567e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.rand((1000, 20))\n",
    "y = (X.sum(1) > 10).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a60a869d",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = 800\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8859572e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(X[:n_train], y[:n_train]),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    ")\n",
    "eval_dataloader = torch.utils.data.DataLoader(\n",
    "    torch.utils.data.TensorDataset(X[n_train:], y[n_train:]),\n",
    "    batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97bddd2c",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db694a58",
   "metadata": {},
   "source": [
    "As a model, we use a simple multilayer perceptron (MLP). For demonstration purposes, we use a very large number of hidden units. This is totally overkill for this task but it helps to demonstrate the advantages of `peft`. In more realistic settings, models will also be quite large on average, so this is not far-fetched."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1b43cd8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, num_units_hidden=2000):\n",
    "        super().__init__()\n",
    "        self.seq = nn.Sequential(\n",
    "            nn.Linear(20, num_units_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_units_hidden, num_units_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_units_hidden, 2),\n",
    "            nn.LogSoftmax(dim=-1),\n",
    "        )\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.seq(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1277bf00",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02caf26a",
   "metadata": {},
   "source": [
    "Here are just a few training hyper-parameters and a simple function that performs the training and evaluation loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5d14c0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.002\n",
    "batch_size = 64\n",
    "max_epochs = 30\n",
    "device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "657d6b3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, criterion, train_dataloader, eval_dataloader, epochs):\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        train_loss = 0\n",
    "        for xb, yb in train_dataloader:\n",
    "            xb = xb.to(device)\n",
    "            yb = yb.to(device)\n",
    "            outputs = model(xb)\n",
    "            loss = criterion(outputs, yb)\n",
    "            train_loss += loss.detach().float()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        model.eval()\n",
    "        eval_loss = 0\n",
    "        for xb, yb in eval_dataloader:\n",
    "            xb = xb.to(device)\n",
    "            yb = yb.to(device)\n",
    "            with torch.no_grad():\n",
    "                outputs = model(xb)\n",
    "            loss = criterion(outputs, yb)\n",
    "            eval_loss += loss.detach().float()\n",
    "\n",
    "        eval_loss_total = (eval_loss / len(eval_dataloader)).item()\n",
    "        train_loss_total = (train_loss / len(train_dataloader)).item()\n",
    "        print(f\"{epoch=:<2}  {train_loss_total=:.4f}  {eval_loss_total=:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b382dcbe",
   "metadata": {},
   "source": [
    "### Training without peft"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b40d4873",
   "metadata": {},
   "source": [
    "Let's start without using `peft` to see what we can expect from the model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f059ced4",
   "metadata": {},
   "outputs": [],
   "source": [
    "module = MLP().to(device)\n",
    "optimizer = torch.optim.Adam(module.parameters(), lr=lr)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "17698863",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0   train_loss_total=0.7970  eval_loss_total=0.6472\n",
      "epoch=1   train_loss_total=0.5597  eval_loss_total=0.4898\n",
      "epoch=2   train_loss_total=0.3696  eval_loss_total=0.3323\n",
      "epoch=3   train_loss_total=0.2364  eval_loss_total=0.5454\n",
      "epoch=4   train_loss_total=0.2428  eval_loss_total=0.2843\n",
      "epoch=5   train_loss_total=0.1251  eval_loss_total=0.2514\n",
      "epoch=6   train_loss_total=0.0952  eval_loss_total=0.2068\n",
      "epoch=7   train_loss_total=0.0831  eval_loss_total=0.2395\n",
      "epoch=8   train_loss_total=0.0655  eval_loss_total=0.2524\n",
      "epoch=9   train_loss_total=0.0380  eval_loss_total=0.3650\n",
      "epoch=10  train_loss_total=0.0363  eval_loss_total=0.3495\n",
      "epoch=11  train_loss_total=0.0231  eval_loss_total=0.2360\n",
      "epoch=12  train_loss_total=0.0162  eval_loss_total=0.2276\n",
      "epoch=13  train_loss_total=0.0094  eval_loss_total=0.2716\n",
      "epoch=14  train_loss_total=0.0065  eval_loss_total=0.2237\n",
      "epoch=15  train_loss_total=0.0054  eval_loss_total=0.2366\n",
      "epoch=16  train_loss_total=0.0035  eval_loss_total=0.2673\n",
      "epoch=17  train_loss_total=0.0028  eval_loss_total=0.2630\n",
      "epoch=18  train_loss_total=0.0023  eval_loss_total=0.2835\n",
      "epoch=19  train_loss_total=0.0021  eval_loss_total=0.2727\n",
      "epoch=20  train_loss_total=0.0018  eval_loss_total=0.2597\n",
      "epoch=21  train_loss_total=0.0016  eval_loss_total=0.2553\n",
      "epoch=22  train_loss_total=0.0014  eval_loss_total=0.2712\n",
      "epoch=23  train_loss_total=0.0013  eval_loss_total=0.2637\n",
      "epoch=24  train_loss_total=0.0012  eval_loss_total=0.2733\n",
      "epoch=25  train_loss_total=0.0011  eval_loss_total=0.2738\n",
      "epoch=26  train_loss_total=0.0010  eval_loss_total=0.2477\n",
      "epoch=27  train_loss_total=0.0010  eval_loss_total=0.2584\n",
      "epoch=28  train_loss_total=0.0009  eval_loss_total=0.2844\n",
      "epoch=29  train_loss_total=0.0008  eval_loss_total=0.2633\n",
      "CPU times: user 1.31 s, sys: 236 ms, total: 1.54 s\n",
      "Wall time: 1.56 s\n"
     ]
    }
   ],
   "source": [
    "%time train(module, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cef0029",
   "metadata": {},
   "source": [
    "Okay, so we got an eval loss of ~0.26, which is much better than random."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f106078",
   "metadata": {},
   "source": [
    "### Training with peft"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dd47aa4",
   "metadata": {},
   "source": [
    "Now let's train with `peft`. First we check the names of the modules, so that we can configure `peft` to fine-tune the right modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "922db29b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('', __main__.MLP),\n",
       " ('seq', torch.nn.modules.container.Sequential),\n",
       " ('seq.0', torch.nn.modules.linear.Linear),\n",
       " ('seq.1', torch.nn.modules.activation.ReLU),\n",
       " ('seq.2', torch.nn.modules.linear.Linear),\n",
       " ('seq.3', torch.nn.modules.activation.ReLU),\n",
       " ('seq.4', torch.nn.modules.linear.Linear),\n",
       " ('seq.5', torch.nn.modules.activation.LogSoftmax)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(n, type(m)) for n, m in MLP().named_modules()]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5efb275d",
   "metadata": {},
   "source": [
    "Next we can define the LoRA config. There is nothing special going on here. We set the LoRA rank to 8 and select the layers `seq.0` and `seq.2` to be used for LoRA fine-tuning. As for `seq.4`, which is the output layer, we set it as `module_to_save`, which means it is also trained but no LoRA is applied."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf2c608d",
   "metadata": {},
   "source": [
    "*Note: Not all layers types can be fine-tuned with LoRA. At the moment, linear layers, embeddings, `Conv2D` and `transformers.pytorch_utils.Conv1D` are supported."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b342438f",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = peft.LoraConfig(\n",
    "    r=8,\n",
    "    target_modules=[\"seq.0\", \"seq.2\"],\n",
    "    modules_to_save=[\"seq.4\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "829b4e2d",
   "metadata": {},
   "source": [
    "Now let's create the `peft` model by passing our initial MLP, as well as the config we just defined, to `get_peft_model`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "602b6658",
   "metadata": {},
   "outputs": [],
   "source": [
    "module = MLP().to(device)\n",
    "module_copy = copy.deepcopy(module)  # we keep a copy of the original model for later\n",
    "peft_model = peft.get_peft_model(module, config)\n",
    "optimizer = torch.optim.Adam(peft_model.parameters(), lr=lr)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "peft_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2103737d",
   "metadata": {},
   "source": [
    "Checking the numbers, we see that only ~1% of parameters are actually trained, which is what we like to see.\n",
    "\n",
    "Now let's start the training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9200cbc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0   train_loss_total=0.6695  eval_loss_total=0.6388\n",
      "epoch=1   train_loss_total=0.5614  eval_loss_total=0.5456\n",
      "epoch=2   train_loss_total=0.3897  eval_loss_total=0.3035\n",
      "epoch=3   train_loss_total=0.2529  eval_loss_total=0.2510\n",
      "epoch=4   train_loss_total=0.1914  eval_loss_total=0.2191\n",
      "epoch=5   train_loss_total=0.1236  eval_loss_total=0.2586\n",
      "epoch=6   train_loss_total=0.1076  eval_loss_total=0.3205\n",
      "epoch=7   train_loss_total=0.1834  eval_loss_total=0.3951\n",
      "epoch=8   train_loss_total=0.1037  eval_loss_total=0.1646\n",
      "epoch=9   train_loss_total=0.0724  eval_loss_total=0.1409\n",
      "epoch=10  train_loss_total=0.0691  eval_loss_total=0.1725\n",
      "epoch=11  train_loss_total=0.0641  eval_loss_total=0.1423\n",
      "epoch=12  train_loss_total=0.0382  eval_loss_total=0.1490\n",
      "epoch=13  train_loss_total=0.0214  eval_loss_total=0.1517\n",
      "epoch=14  train_loss_total=0.0119  eval_loss_total=0.1717\n",
      "epoch=15  train_loss_total=0.0060  eval_loss_total=0.2366\n",
      "epoch=16  train_loss_total=0.0029  eval_loss_total=0.2069\n",
      "epoch=17  train_loss_total=0.0021  eval_loss_total=0.2082\n",
      "epoch=18  train_loss_total=0.0016  eval_loss_total=0.2119\n",
      "epoch=19  train_loss_total=0.0011  eval_loss_total=0.1984\n",
      "epoch=20  train_loss_total=0.0010  eval_loss_total=0.1821\n",
      "epoch=21  train_loss_total=0.0009  eval_loss_total=0.1892\n",
      "epoch=22  train_loss_total=0.0007  eval_loss_total=0.2062\n",
      "epoch=23  train_loss_total=0.0006  eval_loss_total=0.2408\n",
      "epoch=24  train_loss_total=0.0006  eval_loss_total=0.2038\n",
      "epoch=25  train_loss_total=0.0005  eval_loss_total=0.2374\n",
      "epoch=26  train_loss_total=0.0004  eval_loss_total=0.2139\n",
      "epoch=27  train_loss_total=0.0004  eval_loss_total=0.2085\n",
      "epoch=28  train_loss_total=0.0004  eval_loss_total=0.2395\n",
      "epoch=29  train_loss_total=0.0003  eval_loss_total=0.2100\n",
      "CPU times: user 1.41 s, sys: 48.9 ms, total: 1.46 s\n",
      "Wall time: 1.46 s\n"
     ]
    }
   ],
   "source": [
    "%time train(peft_model, optimizer, criterion, train_dataloader, eval_dataloader, epochs=max_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20f6f452",
   "metadata": {},
   "source": [
    "In the end, we see that the eval loss is very similar to the one we saw earlier when we trained without `peft`. This is quite nice to see, given that we are training a much smaller number of parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa55d1d4",
   "metadata": {},
   "source": [
    "#### Check which parameters were updated"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6e2146b",
   "metadata": {},
   "source": [
    "Finally, just to check that LoRA was applied as expected, we check what original weights were updated what weights stayed the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c7dcde21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New parameter model.seq.0.lora_A.default.weight |   160 parameters | updated\n",
      "New parameter model.seq.0.lora_B.default.weight | 16000 parameters | updated\n",
      "New parameter model.seq.2.lora_A.default.weight | 16000 parameters | updated\n",
      "New parameter model.seq.2.lora_B.default.weight | 16000 parameters | updated\n"
     ]
    }
   ],
   "source": [
    "for name, param in peft_model.base_model.named_parameters():\n",
    "    if \"lora\" not in name:\n",
    "        continue\n",
    "\n",
    "    print(f\"New parameter {name:<13} | {param.numel():>5} parameters | updated\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "022e6c41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter seq.0.weight  |   40000 parameters | not updated\n",
      "Parameter seq.0.bias    |    2000 parameters | not updated\n",
      "Parameter seq.2.weight  | 4000000 parameters | not updated\n",
      "Parameter seq.2.bias    |    2000 parameters | not updated\n",
      "Parameter seq.4.weight  |    4000 parameters | not updated\n",
      "Parameter seq.4.bias    |       2 parameters | not updated\n",
      "Parameter seq.4.weight  |    4000 parameters | updated\n",
      "Parameter seq.4.bias    |       2 parameters | updated\n"
     ]
    }
   ],
   "source": [
    "params_before = dict(module_copy.named_parameters())\n",
    "for name, param in peft_model.base_model.named_parameters():\n",
    "    if \"lora\" in name:\n",
    "        continue\n",
    "\n",
    "    name_before = (\n",
    "        name.partition(\".\")[-1].replace(\"base_layer.\", \"\").replace(\"original_\", \"\").replace(\"module.\", \"\").replace(\"modules_to_save.default.\", \"\")\n",
    "    )\n",
    "    param_before = params_before[name_before]\n",
    "    if torch.allclose(param, param_before):\n",
    "        print(f\"Parameter {name_before:<13} | {param.numel():>7} parameters | not updated\")\n",
    "    else:\n",
    "        print(f\"Parameter {name_before:<13} | {param.numel():>7} parameters | updated\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c09b43d",
   "metadata": {},
   "source": [
    "So we can see that apart from the new LoRA weights that were added, only the last layer was updated. Since the LoRA weights and the last layer have comparitively few parameters, this gives us a big boost in efficiency."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b46c6198",
   "metadata": {},
   "source": [
    "## Sharing the model through Hugging Face Hub"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6289e647",
   "metadata": {},
   "source": [
    "### Pushing the model to HF Hub"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06dcdfa0",
   "metadata": {},
   "source": [
    "With the `peft` model, it is also very easy to push a model the Hugging Face Hub. Below, we demonstrate how it works. It is assumed that you have a valid Hugging Face account and are logged in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "1b91a0af",
   "metadata": {},
   "outputs": [],
   "source": [
    "user = \"BenjaminB\"  # put your user name here\n",
    "model_name = \"peft-lora-with-custom-model\"\n",
    "model_id = f\"{user}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1430fffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model.push_to_hub(model_id);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "632bd799",
   "metadata": {},
   "source": [
    "As we can see, the adapter size is only 211 kB."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ff78c0c",
   "metadata": {},
   "source": [
    "### Loading the model from HF Hub"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5c7e87f",
   "metadata": {},
   "source": [
    "Now, it only takes one step to load the model from HF Hub. To do this, we can use `PeftModel.from_pretrained`, passing our base model and the model ID:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce0fcced",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded = peft.PeftModel.from_pretrained(module_copy, model_id)\n",
    "type(loaded)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd4b4eac",
   "metadata": {},
   "source": [
    "Let's check that the two models produce the same output:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2cf6ac4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "y_peft = peft_model(X.to(device))\n",
    "y_loaded = loaded(X.to(device))\n",
    "torch.allclose(y_peft, y_loaded)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeeb653f",
   "metadata": {},
   "source": [
    "### Clean up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c60355",
   "metadata": {},
   "source": [
    "Finally, as a clean up step, you may want to delete the repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b747038f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import delete_repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7e5ab237",
   "metadata": {},
   "outputs": [],
   "source": [
    "delete_repo(model_id)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}