Spaces:
Running
Running
Upload 4 files
Browse files- coarse_transformer.ipynb +632 -0
- fine_transformer.ipynb +183 -0
- musiclm.ipynb +304 -0
- semantic_transformer.ipynb +851 -0
coarse_transformer.ipynb
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Coarse Transformer"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"### Libraries:"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 1,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
| 25 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
| 26 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
| 27 |
+
"from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
|
| 28 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
|
| 29 |
+
"import gc\n",
|
| 30 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
| 31 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 2,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
| 41 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"audio_output_dir = './audio'\n",
|
| 44 |
+
"batch_size = 1\n",
|
| 45 |
+
"data_max_length = 320 * 32\n",
|
| 46 |
+
"num_train_steps = 1000"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 3,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [
|
| 54 |
+
{
|
| 55 |
+
"name": "stdout",
|
| 56 |
+
"output_type": "stream",
|
| 57 |
+
"text": [
|
| 58 |
+
"spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"source": [
|
| 63 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
| 64 |
+
" dim = 512,\n",
|
| 65 |
+
" depth = 6,\n",
|
| 66 |
+
" heads = 8,\n",
|
| 67 |
+
" dim_head = 64,\n",
|
| 68 |
+
" spec_n_fft = 128,\n",
|
| 69 |
+
" spec_win_length = 24,\n",
|
| 70 |
+
" spec_aug_stretch_factor = 0.8\n",
|
| 71 |
+
")\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"text_transformer = TextTransformer(\n",
|
| 74 |
+
" dim = 512,\n",
|
| 75 |
+
" depth = 6,\n",
|
| 76 |
+
" heads = 8,\n",
|
| 77 |
+
" dim_head = 64\n",
|
| 78 |
+
")\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"mulan = MuLaN(\n",
|
| 81 |
+
" audio_transformer = audio_transformer,\n",
|
| 82 |
+
" text_transformer = text_transformer\n",
|
| 83 |
+
")\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
| 86 |
+
" mulan = mulan, \n",
|
| 87 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
| 88 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
| 89 |
+
")\n",
|
| 90 |
+
"wavs = torch.randn(2, 1024)\n",
|
| 91 |
+
"conds = quantizer(wavs = wavs, namespace = 'semantic')"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": 4,
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [
|
| 99 |
+
{
|
| 100 |
+
"name": "stdout",
|
| 101 |
+
"output_type": "stream",
|
| 102 |
+
"text": [
|
| 103 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
| 104 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
| 105 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
| 106 |
+
"0: loss: 90.55248260498047\n",
|
| 107 |
+
"0: valid loss 28.765926361083984\n",
|
| 108 |
+
"0: saving model to results\n",
|
| 109 |
+
"1: loss: 39.71841812133789\n",
|
| 110 |
+
"2: loss: 89.22168731689453\n",
|
| 111 |
+
"3: loss: 64.72769927978516\n",
|
| 112 |
+
"4: loss: 46.61131286621094\n",
|
| 113 |
+
"5: loss: 71.61656951904297\n",
|
| 114 |
+
"6: loss: 51.03081130981445\n",
|
| 115 |
+
"7: loss: 41.790443420410156\n",
|
| 116 |
+
"8: loss: 53.92983627319336\n",
|
| 117 |
+
"9: loss: 34.468536376953125\n",
|
| 118 |
+
"10: loss: 33.230533599853516\n",
|
| 119 |
+
"11: loss: 39.82740020751953\n",
|
| 120 |
+
"12: loss: 25.284324645996094\n",
|
| 121 |
+
"13: loss: 28.97213363647461\n",
|
| 122 |
+
"14: loss: 30.330350875854492\n",
|
| 123 |
+
"15: loss: 29.048341751098633\n",
|
| 124 |
+
"16: loss: 22.92132568359375\n",
|
| 125 |
+
"17: loss: 19.784038543701172\n",
|
| 126 |
+
"18: loss: 24.917173385620117\n",
|
| 127 |
+
"19: loss: 21.861900329589844\n",
|
| 128 |
+
"20: loss: 21.64893913269043\n",
|
| 129 |
+
"21: loss: 19.426795959472656\n",
|
| 130 |
+
"22: loss: 16.47875213623047\n",
|
| 131 |
+
"23: loss: 14.150989532470703\n",
|
| 132 |
+
"24: loss: 16.4312686920166\n",
|
| 133 |
+
"25: loss: 10.732200622558594\n",
|
| 134 |
+
"26: loss: 9.64625358581543\n",
|
| 135 |
+
"27: loss: 13.40906047821045\n",
|
| 136 |
+
"28: loss: 8.942117691040039\n",
|
| 137 |
+
"29: loss: 14.944022178649902\n",
|
| 138 |
+
"30: loss: 17.149667739868164\n",
|
| 139 |
+
"31: loss: 8.965814590454102\n",
|
| 140 |
+
"32: loss: 10.492903709411621\n",
|
| 141 |
+
"33: loss: 11.236382484436035\n",
|
| 142 |
+
"34: loss: 10.356119155883789\n",
|
| 143 |
+
"35: loss: 9.816141128540039\n",
|
| 144 |
+
"36: loss: 11.789191246032715\n",
|
| 145 |
+
"37: loss: 10.450325012207031\n",
|
| 146 |
+
"38: loss: 18.911396026611328\n",
|
| 147 |
+
"39: loss: 8.278931617736816\n",
|
| 148 |
+
"40: loss: 10.884782791137695\n",
|
| 149 |
+
"41: loss: 8.885784149169922\n",
|
| 150 |
+
"42: loss: 9.226049423217773\n",
|
| 151 |
+
"43: loss: 10.362125396728516\n",
|
| 152 |
+
"44: loss: 4.0845770835876465\n",
|
| 153 |
+
"45: loss: 9.664544105529785\n",
|
| 154 |
+
"46: loss: 9.46312427520752\n",
|
| 155 |
+
"47: loss: 9.138323783874512\n",
|
| 156 |
+
"48: loss: 7.396448135375977\n",
|
| 157 |
+
"49: loss: 7.293612480163574\n",
|
| 158 |
+
"50: loss: 10.331693649291992\n",
|
| 159 |
+
"51: loss: 7.775559425354004\n",
|
| 160 |
+
"52: loss: 7.011277198791504\n",
|
| 161 |
+
"53: loss: 6.324047565460205\n",
|
| 162 |
+
"54: loss: 5.501199245452881\n",
|
| 163 |
+
"55: loss: 4.69442081451416\n",
|
| 164 |
+
"56: loss: 4.073971748352051\n",
|
| 165 |
+
"57: loss: 4.142904758453369\n",
|
| 166 |
+
"58: loss: 4.585968017578125\n",
|
| 167 |
+
"59: loss: 4.700481414794922\n",
|
| 168 |
+
"60: loss: 5.152374267578125\n",
|
| 169 |
+
"61: loss: 8.181085586547852\n",
|
| 170 |
+
"62: loss: 6.7371416091918945\n",
|
| 171 |
+
"63: loss: 10.67423152923584\n",
|
| 172 |
+
"64: loss: 5.926950454711914\n",
|
| 173 |
+
"65: loss: 5.470860004425049\n",
|
| 174 |
+
"66: loss: 4.630016803741455\n",
|
| 175 |
+
"67: loss: 5.366561412811279\n",
|
| 176 |
+
"68: loss: 11.271105766296387\n",
|
| 177 |
+
"69: loss: 6.516841411590576\n",
|
| 178 |
+
"70: loss: 7.9438066482543945\n",
|
| 179 |
+
"71: loss: 5.358776092529297\n",
|
| 180 |
+
"72: loss: 5.713461875915527\n",
|
| 181 |
+
"73: loss: 7.075550556182861\n",
|
| 182 |
+
"74: loss: 5.229584217071533\n",
|
| 183 |
+
"75: loss: 5.103419303894043\n",
|
| 184 |
+
"76: loss: 4.516308307647705\n",
|
| 185 |
+
"77: loss: 7.4682488441467285\n",
|
| 186 |
+
"78: loss: 7.275866508483887\n",
|
| 187 |
+
"79: loss: 5.846785545349121\n",
|
| 188 |
+
"80: loss: 5.688624382019043\n",
|
| 189 |
+
"81: loss: 5.150119781494141\n",
|
| 190 |
+
"82: loss: 4.671944618225098\n",
|
| 191 |
+
"83: loss: 8.293455123901367\n",
|
| 192 |
+
"84: loss: 7.202897071838379\n",
|
| 193 |
+
"85: loss: 4.38778018951416\n",
|
| 194 |
+
"86: loss: 4.410329818725586\n",
|
| 195 |
+
"87: loss: 4.341781139373779\n",
|
| 196 |
+
"88: loss: 4.000961780548096\n",
|
| 197 |
+
"89: loss: 4.009156703948975\n",
|
| 198 |
+
"90: loss: 3.562082052230835\n",
|
| 199 |
+
"91: loss: 3.641108989715576\n",
|
| 200 |
+
"92: loss: 5.916473388671875\n",
|
| 201 |
+
"93: loss: 4.046755790710449\n",
|
| 202 |
+
"94: loss: 6.699942111968994\n",
|
| 203 |
+
"95: loss: 6.139719009399414\n",
|
| 204 |
+
"96: loss: 10.71791934967041\n",
|
| 205 |
+
"97: loss: 4.094853401184082\n",
|
| 206 |
+
"98: loss: 6.08973503112793\n",
|
| 207 |
+
"99: loss: 9.11803150177002\n",
|
| 208 |
+
"100: loss: 8.486052513122559\n",
|
| 209 |
+
"100: valid loss 4.0021281242370605\n",
|
| 210 |
+
"101: loss: 4.0021281242370605\n",
|
| 211 |
+
"102: loss: 3.346961736679077\n",
|
| 212 |
+
"103: loss: 3.15854549407959\n",
|
| 213 |
+
"104: loss: 2.5357956886291504\n",
|
| 214 |
+
"105: loss: 5.492861270904541\n",
|
| 215 |
+
"106: loss: 2.7623958587646484\n",
|
| 216 |
+
"107: loss: 2.9482226371765137\n",
|
| 217 |
+
"108: loss: 6.3801493644714355\n",
|
| 218 |
+
"109: loss: 4.1293463706970215\n",
|
| 219 |
+
"110: loss: 3.566096067428589\n",
|
| 220 |
+
"111: loss: 3.569946527481079\n",
|
| 221 |
+
"112: loss: 3.762925624847412\n",
|
| 222 |
+
"113: loss: 6.147146701812744\n",
|
| 223 |
+
"114: loss: 5.933719635009766\n",
|
| 224 |
+
"115: loss: 6.800720691680908\n",
|
| 225 |
+
"116: loss: 2.86614990234375\n",
|
| 226 |
+
"117: loss: 3.0812878608703613\n",
|
| 227 |
+
"118: loss: 3.110222101211548\n",
|
| 228 |
+
"119: loss: 4.000320911407471\n",
|
| 229 |
+
"120: loss: 3.2422871589660645\n",
|
| 230 |
+
"121: loss: 3.7775020599365234\n",
|
| 231 |
+
"122: loss: 3.595900774002075\n",
|
| 232 |
+
"123: loss: 2.73819637298584\n",
|
| 233 |
+
"124: loss: 3.4981672763824463\n",
|
| 234 |
+
"125: loss: 5.3726325035095215\n",
|
| 235 |
+
"126: loss: 3.0014798641204834\n",
|
| 236 |
+
"127: loss: 3.5963802337646484\n",
|
| 237 |
+
"128: loss: 2.8306686878204346\n",
|
| 238 |
+
"129: loss: 2.5162878036499023\n",
|
| 239 |
+
"130: loss: 2.685560941696167\n",
|
| 240 |
+
"131: loss: 6.374442100524902\n",
|
| 241 |
+
"132: loss: 7.788975715637207\n",
|
| 242 |
+
"133: loss: 2.897576332092285\n",
|
| 243 |
+
"134: loss: 3.333127737045288\n",
|
| 244 |
+
"135: loss: 3.436774253845215\n",
|
| 245 |
+
"136: loss: 4.979071617126465\n",
|
| 246 |
+
"137: loss: 4.120012283325195\n",
|
| 247 |
+
"138: loss: 3.7855355739593506\n",
|
| 248 |
+
"139: loss: 4.324587345123291\n",
|
| 249 |
+
"140: loss: 3.4336843490600586\n",
|
| 250 |
+
"141: loss: 2.6801435947418213\n",
|
| 251 |
+
"142: loss: 3.359581470489502\n",
|
| 252 |
+
"143: loss: 5.4692182540893555\n",
|
| 253 |
+
"144: loss: 5.773078918457031\n",
|
| 254 |
+
"145: loss: 4.27987813949585\n",
|
| 255 |
+
"146: loss: 7.247451305389404\n",
|
| 256 |
+
"147: loss: 6.170166492462158\n",
|
| 257 |
+
"148: loss: 4.961609840393066\n",
|
| 258 |
+
"149: loss: 4.028770923614502\n",
|
| 259 |
+
"150: loss: 2.90120005607605\n",
|
| 260 |
+
"151: loss: 1.9893661737442017\n",
|
| 261 |
+
"152: loss: 1.652574062347412\n",
|
| 262 |
+
"153: loss: 2.374600887298584\n",
|
| 263 |
+
"154: loss: 2.1045265197753906\n",
|
| 264 |
+
"155: loss: 6.417508125305176\n",
|
| 265 |
+
"156: loss: 5.273669719696045\n",
|
| 266 |
+
"157: loss: 6.238985538482666\n",
|
| 267 |
+
"158: loss: 3.8025736808776855\n",
|
| 268 |
+
"159: loss: 6.6854705810546875\n",
|
| 269 |
+
"160: loss: 2.5476467609405518\n",
|
| 270 |
+
"161: loss: 6.810393810272217\n",
|
| 271 |
+
"162: loss: 2.2033159732818604\n",
|
| 272 |
+
"163: loss: 1.9863100051879883\n",
|
| 273 |
+
"164: loss: 4.976431369781494\n",
|
| 274 |
+
"165: loss: 3.899188756942749\n",
|
| 275 |
+
"166: loss: 4.68454647064209\n",
|
| 276 |
+
"167: loss: 2.4539690017700195\n",
|
| 277 |
+
"168: loss: 6.830282688140869\n",
|
| 278 |
+
"169: loss: 1.7942843437194824\n",
|
| 279 |
+
"170: loss: 1.242318868637085\n",
|
| 280 |
+
"171: loss: 5.012855052947998\n",
|
| 281 |
+
"172: loss: 1.6154134273529053\n",
|
| 282 |
+
"173: loss: 1.5895756483078003\n",
|
| 283 |
+
"174: loss: 5.240614891052246\n",
|
| 284 |
+
"175: loss: 1.8958660364151\n",
|
| 285 |
+
"176: loss: 2.1411402225494385\n",
|
| 286 |
+
"177: loss: 5.932228088378906\n",
|
| 287 |
+
"178: loss: 2.7539122104644775\n",
|
| 288 |
+
"179: loss: 6.218499660491943\n",
|
| 289 |
+
"180: loss: 2.991704225540161\n",
|
| 290 |
+
"181: loss: 3.378645896911621\n",
|
| 291 |
+
"182: loss: 2.719741106033325\n",
|
| 292 |
+
"183: loss: 2.5844321250915527\n",
|
| 293 |
+
"184: loss: 5.851257801055908\n",
|
| 294 |
+
"185: loss: 2.239989995956421\n",
|
| 295 |
+
"186: loss: 5.5589141845703125\n",
|
| 296 |
+
"187: loss: 3.11521053314209\n",
|
| 297 |
+
"188: loss: 2.5269265174865723\n",
|
| 298 |
+
"189: loss: 2.181260824203491\n",
|
| 299 |
+
"190: loss: 1.8941911458969116\n",
|
| 300 |
+
"191: loss: 5.106175422668457\n",
|
| 301 |
+
"192: loss: 3.5514838695526123\n",
|
| 302 |
+
"193: loss: 3.233003854751587\n",
|
| 303 |
+
"194: loss: 2.55694317817688\n",
|
| 304 |
+
"195: loss: 6.5134053230285645\n",
|
| 305 |
+
"196: loss: 6.311967372894287\n",
|
| 306 |
+
"197: loss: 2.3541362285614014\n",
|
| 307 |
+
"198: loss: 6.195401668548584\n",
|
| 308 |
+
"199: loss: 3.013007879257202\n",
|
| 309 |
+
"200: loss: 2.53104567527771\n",
|
| 310 |
+
"200: valid loss 1.895339846611023\n",
|
| 311 |
+
"201: loss: 7.572109699249268\n",
|
| 312 |
+
"202: loss: 1.946860909461975\n",
|
| 313 |
+
"203: loss: 1.6077873706817627\n",
|
| 314 |
+
"204: loss: 1.5050052404403687\n",
|
| 315 |
+
"205: loss: 1.1216596364974976\n",
|
| 316 |
+
"206: loss: 1.017206072807312\n",
|
| 317 |
+
"207: loss: 7.081823825836182\n",
|
| 318 |
+
"208: loss: 1.1608872413635254\n",
|
| 319 |
+
"209: loss: 0.728882908821106\n",
|
| 320 |
+
"210: loss: 0.514722466468811\n",
|
| 321 |
+
"211: loss: 0.6075964570045471\n",
|
| 322 |
+
"212: loss: 0.7593868970870972\n",
|
| 323 |
+
"213: loss: 0.6465023159980774\n",
|
| 324 |
+
"214: loss: 8.1160888671875\n",
|
| 325 |
+
"215: loss: 0.8256340622901917\n",
|
| 326 |
+
"216: loss: 0.5982277393341064\n",
|
| 327 |
+
"217: loss: 7.202335834503174\n",
|
| 328 |
+
"218: loss: 4.8967790603637695\n",
|
| 329 |
+
"219: loss: 2.037604331970215\n",
|
| 330 |
+
"220: loss: 1.7443571090698242\n",
|
| 331 |
+
"221: loss: 0.8838777542114258\n",
|
| 332 |
+
"222: loss: 0.7871264219284058\n",
|
| 333 |
+
"223: loss: 5.985363483428955\n",
|
| 334 |
+
"224: loss: 3.6808922290802\n",
|
| 335 |
+
"225: loss: 4.453125476837158\n",
|
| 336 |
+
"226: loss: 4.137350559234619\n",
|
| 337 |
+
"227: loss: 1.5606231689453125\n",
|
| 338 |
+
"228: loss: 5.764791488647461\n",
|
| 339 |
+
"229: loss: 1.2394036054611206\n",
|
| 340 |
+
"230: loss: 1.1438194513320923\n",
|
| 341 |
+
"231: loss: 0.5560073852539062\n",
|
| 342 |
+
"232: loss: 5.746810436248779\n",
|
| 343 |
+
"233: loss: 4.34252405166626\n",
|
| 344 |
+
"234: loss: 6.079676628112793\n",
|
| 345 |
+
"235: loss: 4.213600158691406\n",
|
| 346 |
+
"236: loss: 1.1661522388458252\n",
|
| 347 |
+
"237: loss: 7.770791053771973\n",
|
| 348 |
+
"238: loss: 3.6331183910369873\n",
|
| 349 |
+
"239: loss: 6.657710552215576\n",
|
| 350 |
+
"240: loss: 4.314018249511719\n",
|
| 351 |
+
"241: loss: 3.964081048965454\n",
|
| 352 |
+
"242: loss: 3.4643802642822266\n",
|
| 353 |
+
"243: loss: 3.2389814853668213\n",
|
| 354 |
+
"244: loss: 5.009263515472412\n",
|
| 355 |
+
"245: loss: 5.4173903465271\n",
|
| 356 |
+
"246: loss: 3.464853048324585\n",
|
| 357 |
+
"247: loss: 2.690930128097534\n",
|
| 358 |
+
"248: loss: 5.482550621032715\n",
|
| 359 |
+
"249: loss: 1.500435709953308\n",
|
| 360 |
+
"250: loss: 1.207865834236145\n",
|
| 361 |
+
"251: loss: 6.162202835083008\n",
|
| 362 |
+
"252: loss: 0.5159206986427307\n",
|
| 363 |
+
"253: loss: 0.352285772562027\n",
|
| 364 |
+
"254: loss: 0.28347644209861755\n",
|
| 365 |
+
"255: loss: 0.2998739182949066\n",
|
| 366 |
+
"256: loss: 7.412589073181152\n",
|
| 367 |
+
"257: loss: 1.0271281003952026\n",
|
| 368 |
+
"258: loss: 0.5622831583023071\n",
|
| 369 |
+
"259: loss: 6.975170135498047\n",
|
| 370 |
+
"260: loss: 0.050237879157066345\n",
|
| 371 |
+
"261: loss: 9.500787734985352\n",
|
| 372 |
+
"262: loss: 1.1100494861602783\n",
|
| 373 |
+
"263: loss: 10.5401029586792\n",
|
| 374 |
+
"264: loss: 7.637964725494385\n",
|
| 375 |
+
"265: loss: 1.5384433269500732\n",
|
| 376 |
+
"266: loss: 0.6748937368392944\n",
|
| 377 |
+
"267: loss: 0.38336750864982605\n",
|
| 378 |
+
"268: loss: 0.1832476705312729\n",
|
| 379 |
+
"269: loss: 7.080984115600586\n",
|
| 380 |
+
"270: loss: 6.806582927703857\n",
|
| 381 |
+
"271: loss: 6.216980457305908\n",
|
| 382 |
+
"272: loss: 8.122699737548828\n",
|
| 383 |
+
"273: loss: 2.344430685043335\n",
|
| 384 |
+
"274: loss: 5.185897350311279\n",
|
| 385 |
+
"275: loss: 5.136538982391357\n",
|
| 386 |
+
"276: loss: 4.847122669219971\n",
|
| 387 |
+
"277: loss: 3.447641372680664\n",
|
| 388 |
+
"278: loss: 1.9696052074432373\n",
|
| 389 |
+
"279: loss: 6.129249095916748\n",
|
| 390 |
+
"280: loss: 1.4744977951049805\n",
|
| 391 |
+
"281: loss: 4.836997032165527\n",
|
| 392 |
+
"282: loss: 4.361396789550781\n",
|
| 393 |
+
"283: loss: 4.975046157836914\n",
|
| 394 |
+
"284: loss: 5.6431074142456055\n",
|
| 395 |
+
"285: loss: 8.127538681030273\n",
|
| 396 |
+
"286: loss: 7.203218460083008\n",
|
| 397 |
+
"287: loss: 2.408040761947632\n",
|
| 398 |
+
"288: loss: 1.7607803344726562\n",
|
| 399 |
+
"289: loss: 1.1752283573150635\n",
|
| 400 |
+
"290: loss: 5.39897346496582\n",
|
| 401 |
+
"291: loss: 0.8753417134284973\n",
|
| 402 |
+
"292: loss: 6.104700088500977\n",
|
| 403 |
+
"293: loss: 0.8714774250984192\n",
|
| 404 |
+
"294: loss: 5.633414268493652\n",
|
| 405 |
+
"295: loss: 1.0734435319900513\n",
|
| 406 |
+
"296: loss: 0.5978174209594727\n",
|
| 407 |
+
"297: loss: 0.6240620613098145\n",
|
| 408 |
+
"298: loss: 0.3799970746040344\n",
|
| 409 |
+
"299: loss: 5.793654441833496\n",
|
| 410 |
+
"300: loss: 4.920631408691406\n",
|
| 411 |
+
"300: valid loss 0.5733768343925476\n",
|
| 412 |
+
"301: loss: 0.5733768343925476\n",
|
| 413 |
+
"302: loss: 0.35356906056404114\n",
|
| 414 |
+
"303: loss: 6.0288190841674805\n",
|
| 415 |
+
"304: loss: 0.17994554340839386\n",
|
| 416 |
+
"305: loss: 6.07096004486084\n",
|
| 417 |
+
"306: loss: 0.798763632774353\n",
|
| 418 |
+
"307: loss: 0.30721110105514526\n",
|
| 419 |
+
"308: loss: 0.35866862535476685\n",
|
| 420 |
+
"309: loss: 6.664376258850098\n",
|
| 421 |
+
"310: loss: 10.371112823486328\n",
|
| 422 |
+
"311: loss: 1.5442111492156982\n",
|
| 423 |
+
"312: loss: 0.5046924948692322\n",
|
| 424 |
+
"313: loss: 0.02138896845281124\n",
|
| 425 |
+
"314: loss: 11.088417053222656\n",
|
| 426 |
+
"315: loss: 0.2801823616027832\n",
|
| 427 |
+
"316: loss: 1.6325680017471313\n",
|
| 428 |
+
"317: loss: 1.042490005493164\n",
|
| 429 |
+
"318: loss: 0.19980621337890625\n",
|
| 430 |
+
"319: loss: 6.208798408508301\n",
|
| 431 |
+
"320: loss: 2.2923152446746826\n",
|
| 432 |
+
"321: loss: 1.5293265581130981\n",
|
| 433 |
+
"322: loss: 5.384918212890625\n",
|
| 434 |
+
"323: loss: 0.5806372165679932\n",
|
| 435 |
+
"324: loss: 0.11083264648914337\n",
|
| 436 |
+
"325: loss: 6.474861145019531\n",
|
| 437 |
+
"326: loss: 6.7361063957214355\n",
|
| 438 |
+
"327: loss: 6.07684850692749\n",
|
| 439 |
+
"328: loss: 0.1449495404958725\n",
|
| 440 |
+
"329: loss: 0.24492450058460236\n",
|
| 441 |
+
"330: loss: 0.0179277490824461\n",
|
| 442 |
+
"331: loss: 5.866001605987549\n",
|
| 443 |
+
"332: loss: 0.14012691378593445\n",
|
| 444 |
+
"333: loss: 0.14467062056064606\n",
|
| 445 |
+
"334: loss: 0.01395170483738184\n",
|
| 446 |
+
"335: loss: 0.04150881618261337\n",
|
| 447 |
+
"336: loss: 0.07648518681526184\n",
|
| 448 |
+
"337: loss: 9.367613792419434\n",
|
| 449 |
+
"338: loss: 8.372873306274414\n",
|
| 450 |
+
"339: loss: 0.6273093223571777\n",
|
| 451 |
+
"340: loss: 0.11360179632902145\n",
|
| 452 |
+
"341: loss: 0.02351052314043045\n",
|
| 453 |
+
"342: loss: 0.06904540210962296\n",
|
| 454 |
+
"343: loss: 0.02174321562051773\n",
|
| 455 |
+
"344: loss: 0.11702124029397964\n",
|
| 456 |
+
"345: loss: 0.061455100774765015\n",
|
| 457 |
+
"346: loss: 0.03193430230021477\n",
|
| 458 |
+
"347: loss: 0.33268794417381287\n",
|
| 459 |
+
"348: loss: 0.053275030106306076\n",
|
| 460 |
+
"349: loss: 0.009291582740843296\n",
|
| 461 |
+
"350: loss: 0.18401774764060974\n",
|
| 462 |
+
"351: loss: 0.30571281909942627\n",
|
| 463 |
+
"352: loss: 17.913070678710938\n",
|
| 464 |
+
"353: loss: 0.2126859426498413\n",
|
| 465 |
+
"354: loss: 0.6229326128959656\n",
|
| 466 |
+
"355: loss: 11.214807510375977\n",
|
| 467 |
+
"356: loss: 0.15888328850269318\n",
|
| 468 |
+
"357: loss: 0.662460446357727\n",
|
| 469 |
+
"358: loss: 7.345875263214111\n",
|
| 470 |
+
"359: loss: 7.803595066070557\n",
|
| 471 |
+
"360: loss: 1.2322083711624146\n",
|
| 472 |
+
"361: loss: 0.7014895081520081\n",
|
| 473 |
+
"362: loss: 0.10298460721969604\n",
|
| 474 |
+
"363: loss: 8.574231147766113\n",
|
| 475 |
+
"364: loss: 0.03108447603881359\n",
|
| 476 |
+
"365: loss: 0.6616091728210449\n",
|
| 477 |
+
"366: loss: 4.938299655914307\n",
|
| 478 |
+
"367: loss: 5.479018688201904\n",
|
| 479 |
+
"368: loss: 6.740688800811768\n",
|
| 480 |
+
"369: loss: 3.110865831375122\n",
|
| 481 |
+
"370: loss: 4.795236587524414\n",
|
| 482 |
+
"371: loss: 1.8502461910247803\n",
|
| 483 |
+
"372: loss: 3.737464427947998\n",
|
| 484 |
+
"373: loss: 1.9333598613739014\n",
|
| 485 |
+
"374: loss: 7.145735740661621\n",
|
| 486 |
+
"375: loss: 1.3372946977615356\n",
|
| 487 |
+
"376: loss: 5.683573246002197\n",
|
| 488 |
+
"377: loss: 1.204305648803711\n",
|
| 489 |
+
"378: loss: 0.9289284348487854\n",
|
| 490 |
+
"379: loss: 5.174688339233398\n",
|
| 491 |
+
"380: loss: 1.458616852760315\n",
|
| 492 |
+
"381: loss: 0.9457168579101562\n",
|
| 493 |
+
"382: loss: 0.4627819359302521\n",
|
| 494 |
+
"383: loss: 0.2658665180206299\n",
|
| 495 |
+
"384: loss: 4.429558753967285\n",
|
| 496 |
+
"385: loss: 1.2449607849121094\n",
|
| 497 |
+
"386: loss: 1.3288488388061523\n",
|
| 498 |
+
"387: loss: 6.628821849822998\n",
|
| 499 |
+
"388: loss: 0.4825551211833954\n",
|
| 500 |
+
"389: loss: 0.6510865688323975\n",
|
| 501 |
+
"390: loss: 0.36395493149757385\n",
|
| 502 |
+
"391: loss: 0.18036174774169922\n",
|
| 503 |
+
"392: loss: 0.3237663209438324\n",
|
| 504 |
+
"393: loss: 6.840792655944824\n",
|
| 505 |
+
"394: loss: 1.6587960720062256\n",
|
| 506 |
+
"395: loss: 7.458000659942627\n",
|
| 507 |
+
"396: loss: 0.8729283809661865\n",
|
| 508 |
+
"397: loss: 0.6731876134872437\n",
|
| 509 |
+
"398: loss: 0.1747300773859024\n",
|
| 510 |
+
"399: loss: 0.5882076621055603\n",
|
| 511 |
+
"400: loss: 0.6982569098472595\n",
|
| 512 |
+
"400: valid loss 0.4763210713863373\n",
|
| 513 |
+
"401: loss: 0.4763210713863373\n",
|
| 514 |
+
"402: loss: 0.46096739172935486\n",
|
| 515 |
+
"403: loss: 4.166454792022705\n",
|
| 516 |
+
"404: loss: 0.44991931319236755\n",
|
| 517 |
+
"405: loss: 4.830379009246826\n",
|
| 518 |
+
"406: loss: 0.5408239364624023\n",
|
| 519 |
+
"407: loss: 0.2607786953449249\n",
|
| 520 |
+
"408: loss: 0.13067474961280823\n",
|
| 521 |
+
"409: loss: 4.062631130218506\n",
|
| 522 |
+
"410: loss: 5.5028300285339355\n",
|
| 523 |
+
"411: loss: 1.2942296266555786\n",
|
| 524 |
+
"412: loss: 1.4390389919281006\n",
|
| 525 |
+
"413: loss: 5.374651908874512\n",
|
| 526 |
+
"414: loss: 1.2929461002349854\n",
|
| 527 |
+
"415: loss: 0.643798291683197\n",
|
| 528 |
+
"416: loss: 0.6353816986083984\n",
|
| 529 |
+
"417: loss: 5.8032636642456055\n",
|
| 530 |
+
"418: loss: 3.3737053871154785\n",
|
| 531 |
+
"419: loss: 1.8712362051010132\n",
|
| 532 |
+
"420: loss: 1.0622261762619019\n",
|
| 533 |
+
"421: loss: 0.8681365847587585\n",
|
| 534 |
+
"422: loss: 0.6761938333511353\n",
|
| 535 |
+
"423: loss: 4.074782371520996\n",
|
| 536 |
+
"424: loss: 0.4106965661048889\n"
|
| 537 |
+
]
|
| 538 |
+
},
|
| 539 |
+
{
|
| 540 |
+
"ename": "KeyboardInterrupt",
|
| 541 |
+
"evalue": "",
|
| 542 |
+
"output_type": "error",
|
| 543 |
+
"traceback": [
|
| 544 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 545 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 546 |
+
"Cell \u001b[1;32mIn[4], line 49\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m coarse_transformer, trainer, wav2vec, soundstream\n\u001b[0;32m 47\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 49\u001b[0m \u001b[43mtrain_coarse_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 547 |
+
"Cell \u001b[1;32mIn[4], line 43\u001b[0m, in \u001b[0;36mtrain_coarse_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 23\u001b[0m coarse_transformer \u001b[38;5;241m=\u001b[39m CoarseTransformer(\n\u001b[0;32m 24\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 25\u001b[0m codebook_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 29\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 30\u001b[0m )\n\u001b[0;32m 32\u001b[0m trainer \u001b[38;5;241m=\u001b[39m CoarseTransformerTrainer(\n\u001b[0;32m 33\u001b[0m transformer\u001b[38;5;241m=\u001b[39mcoarse_transformer,\n\u001b[0;32m 34\u001b[0m codec\u001b[38;5;241m=\u001b[39msoundstream,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 40\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 41\u001b[0m )\n\u001b[1;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 44\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(coarse_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcoarse_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 45\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave coarse_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 548 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1302\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 1299\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 1301\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1302\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1303\u001b[0m log_fn(logs)\n\u001b[0;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
| 549 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1244\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1238\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[0;32m 1239\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_wrapper(\n\u001b[0;32m 1240\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_kwargs,\n\u001b[0;32m 1241\u001b[0m return_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 1242\u001b[0m )\n\u001b[1;32m-> 1244\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad_accum_every\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1246\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n\u001b[0;32m 1248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_grad_norm):\n",
|
| 550 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\accelerate\\accelerator.py:2151\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[1;34m(self, loss, **kwargs)\u001b[0m\n\u001b[0;32m 2149\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[0;32m 2150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2151\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 551 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m 524\u001b[0m )\n\u001b[1;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 552 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 553 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[0;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
|
| 554 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 555 |
+
]
|
| 556 |
+
}
|
| 557 |
+
],
|
| 558 |
+
"source": [
|
| 559 |
+
"def train_coarse_transformer():\n",
|
| 560 |
+
" wav2vec = HubertWithKmeans(\n",
|
| 561 |
+
" checkpoint_path=checkpoint_path,\n",
|
| 562 |
+
" kmeans_path=kmeans_path\n",
|
| 563 |
+
" )\n",
|
| 564 |
+
" soundstream = MusicLMSoundStream(\n",
|
| 565 |
+
" codebook_size=1024, # Add this line to specify the codebook size\n",
|
| 566 |
+
" strides=(3, 4, 5, 8),\n",
|
| 567 |
+
" target_sample_hz=24000,\n",
|
| 568 |
+
" rq_num_quantizers=8\n",
|
| 569 |
+
" )\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" if torch.cuda.is_available():\n",
|
| 572 |
+
" coarse_transformer = CoarseTransformer(\n",
|
| 573 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 574 |
+
" codebook_size=1024,\n",
|
| 575 |
+
" num_coarse_quantizers=4,\n",
|
| 576 |
+
" dim=1024,\n",
|
| 577 |
+
" depth=6,\n",
|
| 578 |
+
" audio_text_condition=True\n",
|
| 579 |
+
" ).cuda()\n",
|
| 580 |
+
" else:\n",
|
| 581 |
+
" coarse_transformer = CoarseTransformer(\n",
|
| 582 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 583 |
+
" codebook_size=1024,\n",
|
| 584 |
+
" num_coarse_quantizers=4,\n",
|
| 585 |
+
" dim=1024,\n",
|
| 586 |
+
" depth=6,\n",
|
| 587 |
+
" audio_text_condition=True\n",
|
| 588 |
+
" )\n",
|
| 589 |
+
"\n",
|
| 590 |
+
" trainer = CoarseTransformerTrainer(\n",
|
| 591 |
+
" transformer=coarse_transformer,\n",
|
| 592 |
+
" codec=soundstream,\n",
|
| 593 |
+
" wav2vec=wav2vec,\n",
|
| 594 |
+
" audio_conditioner=quantizer,\n",
|
| 595 |
+
" folder=audio_output_dir,\n",
|
| 596 |
+
" batch_size=batch_size,\n",
|
| 597 |
+
" data_max_length=data_max_length,\n",
|
| 598 |
+
" num_train_steps=num_train_steps\n",
|
| 599 |
+
" )\n",
|
| 600 |
+
"\n",
|
| 601 |
+
" trainer.train()\n",
|
| 602 |
+
" torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')\n",
|
| 603 |
+
" print(\"save coarse_transformer.pth\")\n",
|
| 604 |
+
" del coarse_transformer, trainer, wav2vec, soundstream\n",
|
| 605 |
+
" gc.collect()\n",
|
| 606 |
+
"\n",
|
| 607 |
+
"train_coarse_transformer()"
|
| 608 |
+
]
|
| 609 |
+
}
|
| 610 |
+
],
|
| 611 |
+
"metadata": {
|
| 612 |
+
"kernelspec": {
|
| 613 |
+
"display_name": "myenv",
|
| 614 |
+
"language": "python",
|
| 615 |
+
"name": "python3"
|
| 616 |
+
},
|
| 617 |
+
"language_info": {
|
| 618 |
+
"codemirror_mode": {
|
| 619 |
+
"name": "ipython",
|
| 620 |
+
"version": 3
|
| 621 |
+
},
|
| 622 |
+
"file_extension": ".py",
|
| 623 |
+
"mimetype": "text/x-python",
|
| 624 |
+
"name": "python",
|
| 625 |
+
"nbconvert_exporter": "python",
|
| 626 |
+
"pygments_lexer": "ipython3",
|
| 627 |
+
"version": "3.11.2"
|
| 628 |
+
}
|
| 629 |
+
},
|
| 630 |
+
"nbformat": 4,
|
| 631 |
+
"nbformat_minor": 2
|
| 632 |
+
}
|
fine_transformer.ipynb
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Fine Transformer"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"### Libraries:"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 1,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
| 25 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
| 26 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
| 27 |
+
"from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
|
| 28 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
|
| 29 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
| 30 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
|
| 31 |
+
"import gc\n",
|
| 32 |
+
"from nltk.tokenize import word_tokenize\n",
|
| 33 |
+
"import nltk\n",
|
| 34 |
+
"import librosa\n",
|
| 35 |
+
"import numpy as np\n",
|
| 36 |
+
"import pickle"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": 2,
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [
|
| 44 |
+
{
|
| 45 |
+
"name": "stderr",
|
| 46 |
+
"output_type": "stream",
|
| 47 |
+
"text": [
|
| 48 |
+
"[nltk_data] Downloading package punkt to\n",
|
| 49 |
+
"[nltk_data] C:\\Users\\hp\\AppData\\Roaming\\nltk_data...\n",
|
| 50 |
+
"[nltk_data] Package punkt is already up-to-date!\n"
|
| 51 |
+
]
|
| 52 |
+
}
|
| 53 |
+
],
|
| 54 |
+
"source": [
|
| 55 |
+
"nltk.download('punkt')\n",
|
| 56 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
| 57 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"audio_output_dir = './audio'\n",
|
| 60 |
+
"batch_size = 1\n",
|
| 61 |
+
"data_max_length = 320 * 32\n",
|
| 62 |
+
"num_train_steps = 1000"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": 3,
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [
|
| 70 |
+
{
|
| 71 |
+
"name": "stdout",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
| 75 |
+
"spectrogram yielded shape of (65, 841), but had to be cropped to (64, 832) to be patchified for transformer\n",
|
| 76 |
+
"0: loss: 103.04938507080078\n",
|
| 77 |
+
"0: valid loss 11.681041717529297\n",
|
| 78 |
+
"0: saving model to results\n",
|
| 79 |
+
"training complete\n",
|
| 80 |
+
"save fine_transformer.pth\n"
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"source": [
|
| 85 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
| 86 |
+
" dim = 512,\n",
|
| 87 |
+
" depth = 6,\n",
|
| 88 |
+
" heads = 8,\n",
|
| 89 |
+
" dim_head = 64,\n",
|
| 90 |
+
" spec_n_fft = 128,\n",
|
| 91 |
+
" spec_win_length = 24,\n",
|
| 92 |
+
" spec_aug_stretch_factor = 0.8\n",
|
| 93 |
+
")\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"text_transformer = TextTransformer(\n",
|
| 96 |
+
" dim = 512,\n",
|
| 97 |
+
" depth = 6,\n",
|
| 98 |
+
" heads = 8,\n",
|
| 99 |
+
" dim_head = 64\n",
|
| 100 |
+
")\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"mulan = MuLaN(\n",
|
| 103 |
+
" audio_transformer = audio_transformer,\n",
|
| 104 |
+
" text_transformer = text_transformer\n",
|
| 105 |
+
")\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
| 108 |
+
" mulan = mulan, \n",
|
| 109 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
| 110 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
| 111 |
+
")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"def train_fine_transformer():\n",
|
| 115 |
+
" soundstream = MusicLMSoundStream(\n",
|
| 116 |
+
" codebook_size=1024, \n",
|
| 117 |
+
" strides=(3, 4, 5, 8),\n",
|
| 118 |
+
" target_sample_hz=24000,\n",
|
| 119 |
+
" rq_num_quantizers=8\n",
|
| 120 |
+
" )\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" if torch.cuda.is_available():\n",
|
| 123 |
+
" fine_transformer = FineTransformer(\n",
|
| 124 |
+
" num_coarse_quantizers = 4,\n",
|
| 125 |
+
" num_fine_quantizers = 4,\n",
|
| 126 |
+
" codebook_size = 1024,\n",
|
| 127 |
+
" dim = 1024,\n",
|
| 128 |
+
" depth = 6,\n",
|
| 129 |
+
" audio_text_condition = True\n",
|
| 130 |
+
" ).cuda()\n",
|
| 131 |
+
" else:\n",
|
| 132 |
+
" fine_transformer = FineTransformer(\n",
|
| 133 |
+
" num_coarse_quantizers = 4,\n",
|
| 134 |
+
" num_fine_quantizers = 4,\n",
|
| 135 |
+
" codebook_size = 1024,\n",
|
| 136 |
+
" dim = 1024,\n",
|
| 137 |
+
" depth = 6,\n",
|
| 138 |
+
" audio_text_condition = True\n",
|
| 139 |
+
" )\n",
|
| 140 |
+
"\n",
|
| 141 |
+
" trainer = FineTransformerTrainer(\n",
|
| 142 |
+
" transformer=fine_transformer,\n",
|
| 143 |
+
" codec=soundstream,\n",
|
| 144 |
+
" folder=audio_output_dir,\n",
|
| 145 |
+
" batch_size=batch_size,\n",
|
| 146 |
+
" data_max_length=data_max_length,\n",
|
| 147 |
+
" num_train_steps=num_train_steps,\n",
|
| 148 |
+
" audio_conditioner = quantizer\n",
|
| 149 |
+
" )\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" trainer.train()\n",
|
| 152 |
+
" torch.save(fine_transformer.state_dict(), 'fine_transformer.pth')\n",
|
| 153 |
+
" print(\"save fine_transformer.pth\")\n",
|
| 154 |
+
" del fine_transformer, trainer, soundstream\n",
|
| 155 |
+
" gc.collect()\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"train_fine_transformer()"
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
],
|
| 162 |
+
"metadata": {
|
| 163 |
+
"kernelspec": {
|
| 164 |
+
"display_name": "myenv",
|
| 165 |
+
"language": "python",
|
| 166 |
+
"name": "python3"
|
| 167 |
+
},
|
| 168 |
+
"language_info": {
|
| 169 |
+
"codemirror_mode": {
|
| 170 |
+
"name": "ipython",
|
| 171 |
+
"version": 3
|
| 172 |
+
},
|
| 173 |
+
"file_extension": ".py",
|
| 174 |
+
"mimetype": "text/x-python",
|
| 175 |
+
"name": "python",
|
| 176 |
+
"nbconvert_exporter": "python",
|
| 177 |
+
"pygments_lexer": "ipython3",
|
| 178 |
+
"version": "3.11.2"
|
| 179 |
+
}
|
| 180 |
+
},
|
| 181 |
+
"nbformat": 4,
|
| 182 |
+
"nbformat_minor": 2
|
| 183 |
+
}
|
musiclm.ipynb
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# AudioLM"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"### Libraries:"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 2,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [
|
| 22 |
+
{
|
| 23 |
+
"name": "stderr",
|
| 24 |
+
"output_type": "stream",
|
| 25 |
+
"text": [
|
| 26 |
+
"2024-07-26 16:06:09 | WARNING | xformers | WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:\n",
|
| 27 |
+
" PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cpu)\n",
|
| 28 |
+
" Python 3.11.6 (you have 3.11.2)\n",
|
| 29 |
+
" Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)\n",
|
| 30 |
+
" Memory-efficient attention, SwiGLU, sparse and more won't be available.\n",
|
| 31 |
+
" Set XFORMERS_MORE_DETAILS=1 for more details\n",
|
| 32 |
+
"2024-07-26 16:06:09 | WARNING | xformers | Triton is not available, some optimizations will not be enabled.\n",
|
| 33 |
+
"This is just a warning: triton is not available\n"
|
| 34 |
+
]
|
| 35 |
+
}
|
| 36 |
+
],
|
| 37 |
+
"source": [
|
| 38 |
+
"import torch\n",
|
| 39 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
| 40 |
+
"from audiolm_pytorch import SemanticTransformer\n",
|
| 41 |
+
"from audiolm_pytorch import CoarseTransformer\n",
|
| 42 |
+
"from audiolm_pytorch import FineTransformer\n",
|
| 43 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM\n",
|
| 44 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
| 45 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 3,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
| 55 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 5,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"soundstream = AudioLMSoundStream(\n",
|
| 67 |
+
" codebook_size=1024, # Add this line to specify the codebook size\n",
|
| 68 |
+
" strides=(2, 4, 5, 8),\n",
|
| 69 |
+
" target_sample_hz=16000,\n",
|
| 70 |
+
" rq_num_quantizers=8\n",
|
| 71 |
+
")\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"if torch.cuda.is_available():\n",
|
| 75 |
+
" semantic_transformer = SemanticTransformer(\n",
|
| 76 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 77 |
+
" dim=1024,\n",
|
| 78 |
+
" depth=6,\n",
|
| 79 |
+
" audio_text_condition=True\n",
|
| 80 |
+
" ).cuda()\n",
|
| 81 |
+
"\n",
|
| 82 |
+
" coarse_transformer = CoarseTransformer(\n",
|
| 83 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 84 |
+
" codebook_size=1024,\n",
|
| 85 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
| 86 |
+
" dim=1024,\n",
|
| 87 |
+
" depth=6,\n",
|
| 88 |
+
" audio_text_condition=True\n",
|
| 89 |
+
" ).cuda()\n",
|
| 90 |
+
"\n",
|
| 91 |
+
" fine_transformer = FineTransformer(\n",
|
| 92 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
| 93 |
+
" num_fine_quantizers=4,\n",
|
| 94 |
+
" codebook_size=1024,\n",
|
| 95 |
+
" dim=1024,\n",
|
| 96 |
+
" depth=6,\n",
|
| 97 |
+
" audio_text_condition=True\n",
|
| 98 |
+
" ).cuda()\n",
|
| 99 |
+
"else:\n",
|
| 100 |
+
" semantic_transformer = SemanticTransformer(\n",
|
| 101 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 102 |
+
" dim=1024,\n",
|
| 103 |
+
" depth=6,\n",
|
| 104 |
+
" audio_text_condition=True\n",
|
| 105 |
+
" )\n",
|
| 106 |
+
"\n",
|
| 107 |
+
" coarse_transformer = CoarseTransformer(\n",
|
| 108 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 109 |
+
" codebook_size=1024,\n",
|
| 110 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
| 111 |
+
" dim=1024,\n",
|
| 112 |
+
" depth=6,\n",
|
| 113 |
+
" audio_text_condition=True\n",
|
| 114 |
+
" )\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" fine_transformer = FineTransformer(\n",
|
| 117 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
| 118 |
+
" num_fine_quantizers=4,\n",
|
| 119 |
+
" codebook_size=1024,\n",
|
| 120 |
+
" dim=1024,\n",
|
| 121 |
+
" depth=6,\n",
|
| 122 |
+
" audio_text_condition=True\n",
|
| 123 |
+
" )\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))\n",
|
| 126 |
+
"coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))\n",
|
| 127 |
+
"fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"audiolm = AudioLM(\n",
|
| 130 |
+
" wav2vec=wav2vec,\n",
|
| 131 |
+
" codec=soundstream,\n",
|
| 132 |
+
" semantic_transformer=semantic_transformer,\n",
|
| 133 |
+
" coarse_transformer=coarse_transformer,\n",
|
| 134 |
+
" fine_transformer=fine_transformer\n",
|
| 135 |
+
")\n"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "markdown",
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"source": [
|
| 142 |
+
"# MuLaN"
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"execution_count": 6,
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"outputs": [],
|
| 150 |
+
"source": [
|
| 151 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
| 152 |
+
" dim = 512,\n",
|
| 153 |
+
" depth = 6,\n",
|
| 154 |
+
" heads = 8,\n",
|
| 155 |
+
" dim_head = 64,\n",
|
| 156 |
+
" spec_n_fft = 128,\n",
|
| 157 |
+
" spec_win_length = 24,\n",
|
| 158 |
+
" spec_aug_stretch_factor = 0.8\n",
|
| 159 |
+
")\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"text_transformer = TextTransformer(\n",
|
| 162 |
+
" dim = 512,\n",
|
| 163 |
+
" depth = 6,\n",
|
| 164 |
+
" heads = 8,\n",
|
| 165 |
+
" dim_head = 64\n",
|
| 166 |
+
")\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"mulan = MuLaN(\n",
|
| 169 |
+
" audio_transformer = audio_transformer,\n",
|
| 170 |
+
" text_transformer = text_transformer\n",
|
| 171 |
+
")\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
| 174 |
+
" mulan = mulan, \n",
|
| 175 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
| 176 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
| 177 |
+
")\n"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "markdown",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"source": [
|
| 184 |
+
"# MusicLM"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "code",
|
| 189 |
+
"execution_count": 7,
|
| 190 |
+
"metadata": {},
|
| 191 |
+
"outputs": [],
|
| 192 |
+
"source": [
|
| 193 |
+
"from musiclm_pytorch import MusicLM\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"if torch.cuda.is_available():\n",
|
| 196 |
+
" musiclm = MusicLM(\n",
|
| 197 |
+
" audio_lm = audiolm,\n",
|
| 198 |
+
" mulan_embed_quantizer = quantizer\n",
|
| 199 |
+
" ).cuda()\n",
|
| 200 |
+
"else:\n",
|
| 201 |
+
" musiclm = MusicLM(\n",
|
| 202 |
+
" audio_lm = audiolm,\n",
|
| 203 |
+
" mulan_embed_quantizer = quantizer\n",
|
| 204 |
+
" )"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "markdown",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"source": [
|
| 211 |
+
"# Inference:"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": 10,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [
|
| 219 |
+
{
|
| 220 |
+
"name": "stdout",
|
| 221 |
+
"output_type": "stream",
|
| 222 |
+
"text": [
|
| 223 |
+
" 31 / 403\r"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"ename": "KeyboardInterrupt",
|
| 228 |
+
"evalue": "",
|
| 229 |
+
"output_type": "error",
|
| 230 |
+
"traceback": [
|
| 231 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 232 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 233 |
+
"Cell \u001b[1;32mIn[10], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mMusiclm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\n\u001b[0;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcrazy EDM, heavy bang\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m display_audio(res, \u001b[38;5;241m32000\u001b[39m)\n",
|
| 234 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\genmodel.py:161\u001b[0m, in \u001b[0;36mBaseGenModel.generate\u001b[1;34m(self, descriptions, progress, return_tokens)\u001b[0m\n\u001b[0;32m 159\u001b[0m attributes, prompt_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_tokens_and_attributes(descriptions, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m 160\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m prompt_tokens \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m--> 161\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_tokens:\n\u001b[0;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerate_audio(tokens), tokens\n",
|
| 235 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\musicgen.py:256\u001b[0m, in \u001b[0;36mMusicGen._generate_tokens\u001b[1;34m(self, attributes, prompt_tokens, progress)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mduration \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_duration:\n\u001b[0;32m 254\u001b[0m \u001b[38;5;66;03m# generate by sampling from LM, simple case.\u001b[39;00m\n\u001b[0;32m 255\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast:\n\u001b[1;32m--> 256\u001b[0m gen_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_gen_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_gen_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgeneration_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 260\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 261\u001b[0m \u001b[38;5;66;03m# now this gets a bit messier, we need to handle prompts,\u001b[39;00m\n\u001b[0;32m 262\u001b[0m \u001b[38;5;66;03m# melody conditioning etc.\u001b[39;00m\n\u001b[0;32m 263\u001b[0m ref_wavs \u001b[38;5;241m=\u001b[39m [attr\u001b[38;5;241m.\u001b[39mwav[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mself_wav\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m attr \u001b[38;5;129;01min\u001b[39;00m attributes]\n",
|
| 236 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 237 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:510\u001b[0m, in \u001b[0;36mLMModel.generate\u001b[1;34m(self, prompt, conditions, num_samples, max_gen_len, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg, remove_prompts, check, callback, **kwargs)\u001b[0m\n\u001b[0;32m 508\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (curr_sequence \u001b[38;5;241m==\u001b[39m unknown_token)\u001b[38;5;241m.\u001b[39many()\n\u001b[0;32m 509\u001b[0m \u001b[38;5;66;03m# sample next token from the model, next token shape is [B, K, 1]\u001b[39;00m\n\u001b[1;32m--> 510\u001b[0m next_token \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample_next_token\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 511\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurr_sequence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfg_conditions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munconditional_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_sampling\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 512\u001b[0m \u001b[43m \u001b[49m\u001b[43mcfg_coef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg_coef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 513\u001b[0m \u001b[38;5;66;03m# ensure the tokens that should be masked are properly set to special_token_id\u001b[39;00m\n\u001b[0;32m 514\u001b[0m \u001b[38;5;66;03m# as the model never output special_token_id\u001b[39;00m\n\u001b[0;32m 515\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m mask[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, offset:offset\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mexpand(B, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
| 238 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:369\u001b[0m, in \u001b[0;36mLMModel._sample_next_token\u001b[1;34m(self, sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg)\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 367\u001b[0m \u001b[38;5;66;03m# Preparing for CFG, predicting both conditional and unconditional logits.\u001b[39;00m\n\u001b[0;32m 368\u001b[0m sequence \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([sequence, sequence], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m--> 369\u001b[0m all_logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 370\u001b[0m \u001b[43m \u001b[49m\u001b[43msequence\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 371\u001b[0m \u001b[43m \u001b[49m\u001b[43mconditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcondition_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcondition_tensors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 373\u001b[0m cond_logits, uncond_logits \u001b[38;5;241m=\u001b[39m all_logits\u001b[38;5;241m.\u001b[39msplit(B, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# [B, K, T, card]\u001b[39;00m\n",
|
| 239 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 240 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 241 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:257\u001b[0m, in \u001b[0;36mLMModel.forward\u001b[1;34m(self, sequence, conditions, condition_tensors, stage)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conditions, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt pass both conditions and condition_tensors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 255\u001b[0m input_, cross_attention_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuser(input_, condition_tensors)\n\u001b[1;32m--> 257\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn_mask_per_stage\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstage\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 259\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm:\n\u001b[0;32m 260\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm(out)\n",
|
| 242 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 243 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 244 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:708\u001b[0m, in \u001b[0;36mStreamingTransformer.forward\u001b[1;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[0;32m 705\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositional_scale \u001b[38;5;241m*\u001b[39m pos_emb\n\u001b[0;32m 707\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m--> 708\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 710\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_streaming:\n\u001b[0;32m 711\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_streaming_state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moffsets\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m offsets \u001b[38;5;241m+\u001b[39m T\n",
|
| 245 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:665\u001b[0m, in \u001b[0;36mStreamingTransformer._apply_layer\u001b[1;34m(self, layer, *args, **kwargs)\u001b[0m\n\u001b[0;32m 663\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheckpointing\n\u001b[0;32m 664\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m--> 665\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 666\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 667\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch_checkpoint(layer, \u001b[38;5;241m*\u001b[39margs, use_reentrant\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
| 246 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 247 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 248 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:563\u001b[0m, in \u001b[0;36mStreamingTransformerLayer.forward\u001b[1;34m(self, src, src_mask, src_key_padding_mask, cross_attention_src)\u001b[0m\n\u001b[0;32m 559\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_1(\n\u001b[0;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sa_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(x), src_mask, src_key_padding_mask))\n\u001b[0;32m 561\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cross_attention_src \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 562\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_cross(\n\u001b[1;32m--> 563\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cross_attention_block\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 564\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_cross\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 565\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_2(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x)))\n\u001b[0;32m 566\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
| 249 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:546\u001b[0m, in \u001b[0;36mStreamingTransformerLayer._cross_attention_block\u001b[1;34m(self, src, cross_attention_src)\u001b[0m\n\u001b[0;32m 544\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcross_attention \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 545\u001b[0m \u001b[38;5;66;03m# queries are from src, keys and values from cross_attention_src.\u001b[39;00m\n\u001b[1;32m--> 546\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 547\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 548\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout_cross(x)\n",
|
| 250 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 251 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 252 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:356\u001b[0m, in \u001b[0;36mStreamingMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m 354\u001b[0m q \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(query, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[:dim], bias_q)\n\u001b[0;32m 355\u001b[0m \u001b[38;5;66;03m# todo: when streaming, we could actually save k, v and check the shape actually match.\u001b[39;00m\n\u001b[1;32m--> 356\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias_k\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 357\u001b[0m v \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m dim:], bias_v)\n\u001b[0;32m 358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqk_layer_norm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n",
|
| 253 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 254 |
+
]
|
| 255 |
+
}
|
| 256 |
+
],
|
| 257 |
+
"source": [
|
| 258 |
+
"music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4)"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": null,
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"torch.save(music, 'generated_music.pt')"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
{
|
| 271 |
+
"cell_type": "code",
|
| 272 |
+
"execution_count": null,
|
| 273 |
+
"metadata": {},
|
| 274 |
+
"outputs": [],
|
| 275 |
+
"source": [
|
| 276 |
+
"import torchaudio\n",
|
| 277 |
+
"output_path = \"out.wav\"\n",
|
| 278 |
+
"sample_rate = 44100\n",
|
| 279 |
+
"torchaudio.save(output_path, music.cpu() , sample_rate)"
|
| 280 |
+
]
|
| 281 |
+
}
|
| 282 |
+
],
|
| 283 |
+
"metadata": {
|
| 284 |
+
"kernelspec": {
|
| 285 |
+
"display_name": "myenv",
|
| 286 |
+
"language": "python",
|
| 287 |
+
"name": "python3"
|
| 288 |
+
},
|
| 289 |
+
"language_info": {
|
| 290 |
+
"codemirror_mode": {
|
| 291 |
+
"name": "ipython",
|
| 292 |
+
"version": 3
|
| 293 |
+
},
|
| 294 |
+
"file_extension": ".py",
|
| 295 |
+
"mimetype": "text/x-python",
|
| 296 |
+
"name": "python",
|
| 297 |
+
"nbconvert_exporter": "python",
|
| 298 |
+
"pygments_lexer": "ipython3",
|
| 299 |
+
"version": "3.11.2"
|
| 300 |
+
}
|
| 301 |
+
},
|
| 302 |
+
"nbformat": 4,
|
| 303 |
+
"nbformat_minor": 2
|
| 304 |
+
}
|
semantic_transformer.ipynb
ADDED
|
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Semantic Transformer"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"### Libraries"
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"execution_count": 1,
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [],
|
| 22 |
+
"source": [
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"import multiprocessing\n",
|
| 25 |
+
"from audiolm_pytorch import HubertWithKmeans, MusicLMSoundStream\n",
|
| 26 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
| 27 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
| 28 |
+
"from audiolm_pytorch import FineTransformer, FineTransformerTrainer\n",
|
| 29 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
| 30 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
|
| 31 |
+
"import gc "
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 2,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
| 41 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
| 42 |
+
"audio_output_dir = './audio'\n",
|
| 43 |
+
"batch_size = 1\n",
|
| 44 |
+
"data_max_length = 320 * 32\n",
|
| 45 |
+
"num_train_steps = 1000"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 3,
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [
|
| 53 |
+
{
|
| 54 |
+
"name": "stdout",
|
| 55 |
+
"output_type": "stream",
|
| 56 |
+
"text": [
|
| 57 |
+
"spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n",
|
| 58 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
| 59 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
| 60 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
| 61 |
+
"0: loss: 6.5572309494018555\n",
|
| 62 |
+
"0: valid loss 6.723005294799805\n",
|
| 63 |
+
"0: saving model to results\n",
|
| 64 |
+
"1: loss: 6.5375285148620605\n",
|
| 65 |
+
"2: loss: 5.515031337738037\n",
|
| 66 |
+
"3: loss: 0.6989991664886475\n",
|
| 67 |
+
"4: loss: 0.016623886302113533\n",
|
| 68 |
+
"5: loss: 6.3969268798828125\n",
|
| 69 |
+
"6: loss: 0.8643577098846436\n",
|
| 70 |
+
"7: loss: 0.008508207276463509\n",
|
| 71 |
+
"8: loss: 0.00020680516900029033\n",
|
| 72 |
+
"9: loss: 8.900370597839355\n",
|
| 73 |
+
"10: loss: 0.00010900969209615141\n",
|
| 74 |
+
"11: loss: 0.0001591881300555542\n",
|
| 75 |
+
"12: loss: 8.055902481079102\n",
|
| 76 |
+
"13: loss: 0.0009496303973719478\n",
|
| 77 |
+
"14: loss: 0.0027423782739788294\n",
|
| 78 |
+
"15: loss: 0.0009589337860234082\n",
|
| 79 |
+
"16: loss: 7.296541690826416\n",
|
| 80 |
+
"17: loss: 0.0005210856324993074\n",
|
| 81 |
+
"18: loss: 0.0008424322586506605\n",
|
| 82 |
+
"19: loss: 5.571179389953613\n",
|
| 83 |
+
"20: loss: 0.003094581188634038\n",
|
| 84 |
+
"21: loss: 0.0019461463671177626\n",
|
| 85 |
+
"22: loss: 5.488490104675293\n",
|
| 86 |
+
"23: loss: 4.800296783447266\n",
|
| 87 |
+
"24: loss: 4.962136268615723\n",
|
| 88 |
+
"25: loss: 5.943732738494873\n",
|
| 89 |
+
"26: loss: 0.006312617566436529\n",
|
| 90 |
+
"27: loss: 4.396454334259033\n",
|
| 91 |
+
"28: loss: 0.012498963624238968\n",
|
| 92 |
+
"29: loss: 0.0049488842487335205\n",
|
| 93 |
+
"30: loss: 0.0011625693878158927\n",
|
| 94 |
+
"31: loss: 3.445856809616089\n",
|
| 95 |
+
"32: loss: 0.000534387887455523\n",
|
| 96 |
+
"33: loss: 0.000711498549208045\n",
|
| 97 |
+
"34: loss: 0.0009514373959973454\n",
|
| 98 |
+
"35: loss: 0.001239188713952899\n",
|
| 99 |
+
"36: loss: 8.732012748718262\n",
|
| 100 |
+
"37: loss: 0.0009216524777002633\n",
|
| 101 |
+
"38: loss: 0.0006809335318394005\n",
|
| 102 |
+
"39: loss: 0.000797786982730031\n",
|
| 103 |
+
"40: loss: 4.916833400726318\n",
|
| 104 |
+
"41: loss: 0.0010107718408107758\n",
|
| 105 |
+
"42: loss: 0.0008451942121610045\n",
|
| 106 |
+
"43: loss: 3.160980701446533\n",
|
| 107 |
+
"44: loss: 0.0008387335110455751\n",
|
| 108 |
+
"45: loss: 0.0010360947344452143\n",
|
| 109 |
+
"46: loss: 0.001215349417179823\n",
|
| 110 |
+
"47: loss: 5.990973949432373\n",
|
| 111 |
+
"48: loss: 0.0017369053093716502\n",
|
| 112 |
+
"49: loss: 6.410669803619385\n",
|
| 113 |
+
"50: loss: 0.003450337564572692\n",
|
| 114 |
+
"51: loss: 0.003860922297462821\n",
|
| 115 |
+
"52: loss: 0.002359303878620267\n",
|
| 116 |
+
"53: loss: 0.001058467198163271\n",
|
| 117 |
+
"54: loss: 0.00047752217506058514\n",
|
| 118 |
+
"55: loss: 0.00025489379186183214\n",
|
| 119 |
+
"56: loss: 0.00016276698443107307\n",
|
| 120 |
+
"57: loss: 7.828070163726807\n",
|
| 121 |
+
"58: loss: 0.00011652028479147702\n",
|
| 122 |
+
"59: loss: 4.505963325500488\n",
|
| 123 |
+
"60: loss: 0.00013153781765140593\n",
|
| 124 |
+
"61: loss: 0.00015024915046524256\n",
|
| 125 |
+
"62: loss: 0.00017777853645384312\n",
|
| 126 |
+
"63: loss: 8.09732437133789\n",
|
| 127 |
+
"64: loss: 0.00041875039460137486\n",
|
| 128 |
+
"65: loss: 0.0009824583539739251\n",
|
| 129 |
+
"66: loss: 0.001990197692066431\n",
|
| 130 |
+
"67: loss: 5.392111778259277\n",
|
| 131 |
+
"68: loss: 0.0017270153621211648\n",
|
| 132 |
+
"69: loss: 0.0010434042196720839\n",
|
| 133 |
+
"70: loss: 0.0005951145431026816\n",
|
| 134 |
+
"71: loss: 0.00037293724017217755\n",
|
| 135 |
+
"72: loss: 0.00025969729176722467\n",
|
| 136 |
+
"73: loss: 7.013213157653809\n",
|
| 137 |
+
"74: loss: 3.807203531265259\n",
|
| 138 |
+
"75: loss: 0.00026780215557664633\n",
|
| 139 |
+
"76: loss: 0.00031897667213343084\n",
|
| 140 |
+
"77: loss: 0.0003657388442661613\n",
|
| 141 |
+
"78: loss: 5.076975345611572\n",
|
| 142 |
+
"79: loss: 0.001055362867191434\n",
|
| 143 |
+
"80: loss: 0.0010116726625710726\n",
|
| 144 |
+
"81: loss: 0.0017484871204942465\n",
|
| 145 |
+
"82: loss: 0.0018696936313062906\n",
|
| 146 |
+
"83: loss: 5.30266809463501\n",
|
| 147 |
+
"84: loss: 5.457505226135254\n",
|
| 148 |
+
"85: loss: 0.0012204349040985107\n",
|
| 149 |
+
"86: loss: 3.2936503887176514\n",
|
| 150 |
+
"87: loss: 0.0020471797324717045\n",
|
| 151 |
+
"88: loss: 0.0026046710554510355\n",
|
| 152 |
+
"89: loss: 0.0026721167378127575\n",
|
| 153 |
+
"90: loss: 0.0024667021352797747\n",
|
| 154 |
+
"91: loss: 5.0201215744018555\n",
|
| 155 |
+
"92: loss: 4.591504096984863\n",
|
| 156 |
+
"93: loss: 0.0025711969938129187\n",
|
| 157 |
+
"94: loss: 0.002706416416913271\n",
|
| 158 |
+
"95: loss: 0.0024713831953704357\n",
|
| 159 |
+
"96: loss: 0.002004373585805297\n",
|
| 160 |
+
"97: loss: 0.001489203074015677\n",
|
| 161 |
+
"98: loss: 0.0010426173685118556\n",
|
| 162 |
+
"99: loss: 8.796974182128906\n",
|
| 163 |
+
"100: loss: 0.0005365900578908622\n",
|
| 164 |
+
"100: valid loss 5.255128860473633\n",
|
| 165 |
+
"101: loss: 0.0004417159070726484\n",
|
| 166 |
+
"102: loss: 4.595282554626465\n",
|
| 167 |
+
"103: loss: 0.000659952696878463\n",
|
| 168 |
+
"104: loss: 0.0008260122267529368\n",
|
| 169 |
+
"105: loss: 0.0009083786280825734\n",
|
| 170 |
+
"106: loss: 4.042155742645264\n",
|
| 171 |
+
"107: loss: 4.17121696472168\n",
|
| 172 |
+
"108: loss: 0.0007671767962165177\n",
|
| 173 |
+
"109: loss: 4.022541522979736\n",
|
| 174 |
+
"110: loss: 3.5455234050750732\n",
|
| 175 |
+
"111: loss: 0.001035561435855925\n",
|
| 176 |
+
"112: loss: 0.0012967187212780118\n",
|
| 177 |
+
"113: loss: 7.237168312072754\n",
|
| 178 |
+
"114: loss: 3.522667407989502\n",
|
| 179 |
+
"115: loss: 0.004003542009741068\n",
|
| 180 |
+
"116: loss: 0.0040553268045187\n",
|
| 181 |
+
"117: loss: 0.0029700316954404116\n",
|
| 182 |
+
"118: loss: 0.0019125432008877397\n",
|
| 183 |
+
"119: loss: 3.4947195053100586\n",
|
| 184 |
+
"120: loss: 0.001095975050702691\n",
|
| 185 |
+
"121: loss: 0.0009612821158953011\n",
|
| 186 |
+
"122: loss: 0.000824352668132633\n",
|
| 187 |
+
"123: loss: 3.3077425956726074\n",
|
| 188 |
+
"124: loss: 0.0007418203167617321\n",
|
| 189 |
+
"125: loss: 0.0007488489500246942\n",
|
| 190 |
+
"126: loss: 0.0007235489320009947\n",
|
| 191 |
+
"127: loss: 3.426555633544922\n",
|
| 192 |
+
"128: loss: 0.0006980476318858564\n",
|
| 193 |
+
"129: loss: 0.0006986281368881464\n",
|
| 194 |
+
"130: loss: 0.0006706370622850955\n",
|
| 195 |
+
"131: loss: 0.0006185953388921916\n",
|
| 196 |
+
"132: loss: 4.421964645385742\n",
|
| 197 |
+
"133: loss: 0.0006264401017688215\n",
|
| 198 |
+
"134: loss: 0.0006876335828565061\n",
|
| 199 |
+
"135: loss: 0.0007215599762275815\n",
|
| 200 |
+
"136: loss: 0.0007203654968179762\n",
|
| 201 |
+
"137: loss: 0.0006922150496393442\n",
|
| 202 |
+
"138: loss: 0.0006356032681651413\n",
|
| 203 |
+
"139: loss: 3.7695367336273193\n",
|
| 204 |
+
"140: loss: 0.0006305422284640372\n",
|
| 205 |
+
"141: loss: 0.0006744156708009541\n",
|
| 206 |
+
"142: loss: 0.0006895355763845146\n",
|
| 207 |
+
"143: loss: 3.770907402038574\n",
|
| 208 |
+
"144: loss: 0.000908360059838742\n",
|
| 209 |
+
"145: loss: 0.0011299465550109744\n",
|
| 210 |
+
"146: loss: 0.0012696339981630445\n",
|
| 211 |
+
"147: loss: 0.0012722468236461282\n",
|
| 212 |
+
"148: loss: 3.8808021545410156\n",
|
| 213 |
+
"149: loss: 3.783026695251465\n",
|
| 214 |
+
"150: loss: 0.002035590121522546\n",
|
| 215 |
+
"151: loss: 0.0026034933980554342\n",
|
| 216 |
+
"152: loss: 0.0024936539120972157\n",
|
| 217 |
+
"153: loss: 0.0018582777120172977\n",
|
| 218 |
+
"154: loss: 2.8572535514831543\n",
|
| 219 |
+
"155: loss: 0.001062657218426466\n",
|
| 220 |
+
"156: loss: 0.0008821044466458261\n",
|
| 221 |
+
"157: loss: 0.0007058316841721535\n",
|
| 222 |
+
"158: loss: 0.0005539683043025434\n",
|
| 223 |
+
"159: loss: 5.476413726806641\n",
|
| 224 |
+
"160: loss: 0.00043070572428405285\n",
|
| 225 |
+
"161: loss: 0.00042034301441162825\n",
|
| 226 |
+
"162: loss: 0.0004015824815724045\n",
|
| 227 |
+
"163: loss: 0.0003759717510547489\n",
|
| 228 |
+
"164: loss: 0.00034577338374219835\n",
|
| 229 |
+
"165: loss: 3.9209775924682617\n",
|
| 230 |
+
"166: loss: 0.0003425567992962897\n",
|
| 231 |
+
"167: loss: 0.00036322552477940917\n",
|
| 232 |
+
"168: loss: 0.00037287475424818695\n",
|
| 233 |
+
"169: loss: 7.9045209884643555\n",
|
| 234 |
+
"170: loss: 0.0004473228473216295\n",
|
| 235 |
+
"171: loss: 0.0005134259699843824\n",
|
| 236 |
+
"172: loss: 2.9501657485961914\n",
|
| 237 |
+
"173: loss: 0.0008285943185910583\n",
|
| 238 |
+
"174: loss: 0.00113466486800462\n",
|
| 239 |
+
"175: loss: 0.0013167448341846466\n",
|
| 240 |
+
"176: loss: 0.0014080735854804516\n",
|
| 241 |
+
"177: loss: 4.0473408699035645\n",
|
| 242 |
+
"178: loss: 0.0016744763124734163\n",
|
| 243 |
+
"179: loss: 0.0016492144204676151\n",
|
| 244 |
+
"180: loss: 4.165207386016846\n",
|
| 245 |
+
"181: loss: 0.0017677460564300418\n",
|
| 246 |
+
"182: loss: 0.0018474040552973747\n",
|
| 247 |
+
"183: loss: 0.0017496442887932062\n",
|
| 248 |
+
"184: loss: 3.3882288932800293\n",
|
| 249 |
+
"185: loss: 0.0018872346263378859\n",
|
| 250 |
+
"186: loss: 3.0333187580108643\n",
|
| 251 |
+
"187: loss: 0.0028638774529099464\n",
|
| 252 |
+
"188: loss: 3.709534168243408\n",
|
| 253 |
+
"189: loss: 6.904417991638184\n",
|
| 254 |
+
"190: loss: 0.006619434338063002\n",
|
| 255 |
+
"191: loss: 0.00595641415566206\n",
|
| 256 |
+
"192: loss: 5.050801753997803\n",
|
| 257 |
+
"193: loss: 3.7556490898132324\n",
|
| 258 |
+
"194: loss: 0.002467694692313671\n",
|
| 259 |
+
"195: loss: 0.002025420544669032\n",
|
| 260 |
+
"196: loss: 0.001494809053838253\n",
|
| 261 |
+
"197: loss: 0.0010330628138035536\n",
|
| 262 |
+
"198: loss: 0.0006917425780557096\n",
|
| 263 |
+
"199: loss: 0.0004644835426006466\n",
|
| 264 |
+
"200: loss: 4.2029547691345215\n",
|
| 265 |
+
"200: valid loss 0.00025821171584539115\n",
|
| 266 |
+
"201: loss: 4.2771501541137695\n",
|
| 267 |
+
"202: loss: 3.7102839946746826\n",
|
| 268 |
+
"203: loss: 3.7408058643341064\n",
|
| 269 |
+
"204: loss: 0.0003981325135100633\n",
|
| 270 |
+
"205: loss: 0.0005507581518031657\n",
|
| 271 |
+
"206: loss: 4.065332889556885\n",
|
| 272 |
+
"207: loss: 0.0011804178357124329\n",
|
| 273 |
+
"208: loss: 0.0017080714460462332\n",
|
| 274 |
+
"209: loss: 0.0021062048617750406\n",
|
| 275 |
+
"210: loss: 0.0021494715474545956\n",
|
| 276 |
+
"211: loss: 6.465389251708984\n",
|
| 277 |
+
"212: loss: 0.0029505854472517967\n",
|
| 278 |
+
"213: loss: 3.367213010787964\n",
|
| 279 |
+
"214: loss: 0.27502918243408203\n",
|
| 280 |
+
"215: loss: 0.9933775663375854\n",
|
| 281 |
+
"216: loss: 0.810478925704956\n",
|
| 282 |
+
"217: loss: 0.4562891721725464\n",
|
| 283 |
+
"218: loss: 0.24387648701667786\n",
|
| 284 |
+
"219: loss: 0.11290910840034485\n",
|
| 285 |
+
"220: loss: 0.019248925149440765\n",
|
| 286 |
+
"221: loss: 0.0021138046868145466\n",
|
| 287 |
+
"222: loss: 5.169565200805664\n",
|
| 288 |
+
"223: loss: 0.0008601757581345737\n",
|
| 289 |
+
"224: loss: 3.9269232749938965\n",
|
| 290 |
+
"225: loss: 0.0007863161154091358\n",
|
| 291 |
+
"226: loss: 0.00024547570501454175\n",
|
| 292 |
+
"227: loss: 4.449281215667725\n",
|
| 293 |
+
"228: loss: 0.00019524114031810313\n",
|
| 294 |
+
"229: loss: 5.162830829620361\n",
|
| 295 |
+
"230: loss: 0.0005567128537222743\n",
|
| 296 |
+
"231: loss: 4.195521831512451\n",
|
| 297 |
+
"232: loss: 3.7389187812805176\n",
|
| 298 |
+
"233: loss: 5.919421672821045\n",
|
| 299 |
+
"234: loss: 6.7034173011779785\n",
|
| 300 |
+
"235: loss: 5.353506088256836\n",
|
| 301 |
+
"236: loss: 2.4018566608428955\n",
|
| 302 |
+
"237: loss: 3.7457311153411865\n",
|
| 303 |
+
"238: loss: 0.17652225494384766\n",
|
| 304 |
+
"239: loss: 4.564880847930908\n",
|
| 305 |
+
"240: loss: 0.027039170265197754\n",
|
| 306 |
+
"241: loss: 0.005270962603390217\n",
|
| 307 |
+
"242: loss: 0.0015485308831557631\n",
|
| 308 |
+
"243: loss: 0.0010360399028286338\n",
|
| 309 |
+
"244: loss: 0.0007773903198540211\n",
|
| 310 |
+
"245: loss: 6.206174850463867\n",
|
| 311 |
+
"246: loss: 6.409456253051758\n",
|
| 312 |
+
"247: loss: 0.04051050543785095\n",
|
| 313 |
+
"248: loss: 0.0017684113699942827\n",
|
| 314 |
+
"249: loss: 0.00044090740266256034\n",
|
| 315 |
+
"250: loss: 5.761023044586182\n",
|
| 316 |
+
"251: loss: 0.00016311556100845337\n",
|
| 317 |
+
"252: loss: 0.0001715785765554756\n",
|
| 318 |
+
"253: loss: 0.00019523760420270264\n",
|
| 319 |
+
"254: loss: 0.00023307953961193562\n",
|
| 320 |
+
"255: loss: 0.00028373271925374866\n",
|
| 321 |
+
"256: loss: 4.927147388458252\n",
|
| 322 |
+
"257: loss: 4.228280544281006\n",
|
| 323 |
+
"258: loss: 0.0011933923233300447\n",
|
| 324 |
+
"259: loss: 0.005215882323682308\n",
|
| 325 |
+
"260: loss: 0.0013388781808316708\n",
|
| 326 |
+
"261: loss: 4.206026554107666\n",
|
| 327 |
+
"262: loss: 0.0034830207005143166\n",
|
| 328 |
+
"263: loss: 4.173500061035156\n",
|
| 329 |
+
"264: loss: 0.007450783159583807\n",
|
| 330 |
+
"265: loss: 4.5892510414123535\n",
|
| 331 |
+
"266: loss: 0.006880312692373991\n",
|
| 332 |
+
"267: loss: 4.572935104370117\n",
|
| 333 |
+
"268: loss: 0.002904222346842289\n",
|
| 334 |
+
"269: loss: 3.2348222732543945\n",
|
| 335 |
+
"270: loss: 4.376621723175049\n",
|
| 336 |
+
"271: loss: 3.573988914489746\n",
|
| 337 |
+
"272: loss: 0.0010127610294148326\n",
|
| 338 |
+
"273: loss: 9.308874130249023\n",
|
| 339 |
+
"274: loss: 4.688360214233398\n",
|
| 340 |
+
"275: loss: 3.9581832885742188\n",
|
| 341 |
+
"276: loss: 0.01065391581505537\n",
|
| 342 |
+
"277: loss: 0.0067514535039663315\n",
|
| 343 |
+
"278: loss: 0.003611961379647255\n",
|
| 344 |
+
"279: loss: 0.001811509020626545\n",
|
| 345 |
+
"280: loss: 0.0009013370145112276\n",
|
| 346 |
+
"281: loss: 4.266546726226807\n",
|
| 347 |
+
"282: loss: 5.132745742797852\n",
|
| 348 |
+
"283: loss: 0.000957090116571635\n",
|
| 349 |
+
"284: loss: 0.0015025322791188955\n",
|
| 350 |
+
"285: loss: 6.258731842041016\n",
|
| 351 |
+
"286: loss: 5.029386043548584\n",
|
| 352 |
+
"287: loss: 0.007954631000757217\n",
|
| 353 |
+
"288: loss: 0.0050008054822683334\n",
|
| 354 |
+
"289: loss: 0.001655810745432973\n",
|
| 355 |
+
"290: loss: 5.501289367675781\n",
|
| 356 |
+
"291: loss: 4.655749797821045\n",
|
| 357 |
+
"292: loss: 4.383106231689453\n",
|
| 358 |
+
"293: loss: 0.000304496381431818\n",
|
| 359 |
+
"294: loss: 0.0003326725563965738\n",
|
| 360 |
+
"295: loss: 0.00035310350358486176\n",
|
| 361 |
+
"296: loss: 5.683162212371826\n",
|
| 362 |
+
"297: loss: 0.0004622728156391531\n",
|
| 363 |
+
"298: loss: 4.067113399505615\n",
|
| 364 |
+
"299: loss: 0.0008154112147167325\n",
|
| 365 |
+
"300: loss: 0.00108420941978693\n",
|
| 366 |
+
"300: valid loss 0.0013179074740037322\n",
|
| 367 |
+
"301: loss: 0.0013179074740037322\n",
|
| 368 |
+
"302: loss: 4.358561992645264\n",
|
| 369 |
+
"303: loss: 5.026749610900879\n",
|
| 370 |
+
"304: loss: 0.002862808993086219\n",
|
| 371 |
+
"305: loss: 0.003396229352802038\n",
|
| 372 |
+
"306: loss: 5.530904293060303\n",
|
| 373 |
+
"307: loss: 0.0035779180470854044\n",
|
| 374 |
+
"308: loss: 0.003205555956810713\n",
|
| 375 |
+
"309: loss: 4.112671852111816\n",
|
| 376 |
+
"310: loss: 3.6920313835144043\n",
|
| 377 |
+
"311: loss: 0.0026951604522764683\n",
|
| 378 |
+
"312: loss: 0.0026851999573409557\n",
|
| 379 |
+
"313: loss: 3.3092551231384277\n",
|
| 380 |
+
"314: loss: 0.0024079573340713978\n",
|
| 381 |
+
"315: loss: 0.0022026696242392063\n",
|
| 382 |
+
"316: loss: 0.0018284200923517346\n",
|
| 383 |
+
"317: loss: 0.0014258958399295807\n",
|
| 384 |
+
"318: loss: 0.0010761057492345572\n",
|
| 385 |
+
"319: loss: 0.0008039181702770293\n",
|
| 386 |
+
"320: loss: 0.0006038622814230621\n",
|
| 387 |
+
"321: loss: 0.00046244796249084175\n",
|
| 388 |
+
"322: loss: 5.89370059967041\n",
|
| 389 |
+
"323: loss: 0.00031747910543344915\n",
|
| 390 |
+
"324: loss: 0.00028221303364261985\n",
|
| 391 |
+
"325: loss: 0.00025451104738749564\n",
|
| 392 |
+
"326: loss: 0.00023175252135843039\n",
|
| 393 |
+
"327: loss: 0.00021364034910220653\n",
|
| 394 |
+
"328: loss: 3.906613826751709\n",
|
| 395 |
+
"329: loss: 3.844726085662842\n",
|
| 396 |
+
"330: loss: 0.00023705456987954676\n",
|
| 397 |
+
"331: loss: 0.0002663657069206238\n",
|
| 398 |
+
"332: loss: 0.0002947220054920763\n",
|
| 399 |
+
"333: loss: 6.28004264831543\n",
|
| 400 |
+
"334: loss: 0.0003821635036729276\n",
|
| 401 |
+
"335: loss: 3.633335828781128\n",
|
| 402 |
+
"336: loss: 0.0005681345355696976\n",
|
| 403 |
+
"337: loss: 6.994467735290527\n",
|
| 404 |
+
"338: loss: 7.915759086608887\n",
|
| 405 |
+
"339: loss: 0.0026061832904815674\n",
|
| 406 |
+
"340: loss: 0.0048998151905834675\n",
|
| 407 |
+
"341: loss: 0.004243680741637945\n",
|
| 408 |
+
"342: loss: 0.0025005636271089315\n",
|
| 409 |
+
"343: loss: 4.005818843841553\n",
|
| 410 |
+
"344: loss: 0.0011636920971795917\n",
|
| 411 |
+
"345: loss: 0.0009634271846152842\n",
|
| 412 |
+
"346: loss: 0.0008427661377936602\n",
|
| 413 |
+
"347: loss: 0.0007607618463225663\n",
|
| 414 |
+
"348: loss: 0.0006956492434255779\n",
|
| 415 |
+
"349: loss: 4.547393798828125\n",
|
| 416 |
+
"350: loss: 0.0006480301963165402\n",
|
| 417 |
+
"351: loss: 0.0006520788883790374\n",
|
| 418 |
+
"352: loss: 0.0006446384941227734\n",
|
| 419 |
+
"353: loss: 4.283820629119873\n",
|
| 420 |
+
"354: loss: 0.0007140468223951757\n",
|
| 421 |
+
"355: loss: 0.000788742327131331\n",
|
| 422 |
+
"356: loss: 0.0008332571596838534\n",
|
| 423 |
+
"357: loss: 0.0008390303701162338\n",
|
| 424 |
+
"358: loss: 0.000806896947324276\n",
|
| 425 |
+
"359: loss: 4.646646976470947\n",
|
| 426 |
+
"360: loss: 0.0021708165295422077\n",
|
| 427 |
+
"361: loss: 0.0009108624653890729\n",
|
| 428 |
+
"362: loss: 3.9582133293151855\n",
|
| 429 |
+
"363: loss: 3.3569955825805664\n",
|
| 430 |
+
"364: loss: 0.002499263733625412\n",
|
| 431 |
+
"365: loss: 4.646510601043701\n",
|
| 432 |
+
"366: loss: 0.0032457842025905848\n",
|
| 433 |
+
"367: loss: 0.0033331059385091066\n",
|
| 434 |
+
"368: loss: 0.00275675137527287\n",
|
| 435 |
+
"369: loss: 0.0020243506878614426\n",
|
| 436 |
+
"370: loss: 4.458893775939941\n",
|
| 437 |
+
"371: loss: 5.930361270904541\n",
|
| 438 |
+
"372: loss: 4.287806510925293\n",
|
| 439 |
+
"373: loss: 3.365216016769409\n",
|
| 440 |
+
"374: loss: 0.011499284766614437\n",
|
| 441 |
+
"375: loss: 0.0031067240051925182\n",
|
| 442 |
+
"376: loss: 0.003569819498807192\n",
|
| 443 |
+
"377: loss: 0.0032246895134449005\n",
|
| 444 |
+
"378: loss: 0.0023426800034940243\n",
|
| 445 |
+
"379: loss: 0.0016774036921560764\n",
|
| 446 |
+
"380: loss: 0.0010665183654055\n",
|
| 447 |
+
"381: loss: 0.0007539619691669941\n",
|
| 448 |
+
"382: loss: 3.873556137084961\n",
|
| 449 |
+
"383: loss: 0.08063449710607529\n",
|
| 450 |
+
"384: loss: 0.0005400768714025617\n",
|
| 451 |
+
"385: loss: 0.000518861401360482\n",
|
| 452 |
+
"386: loss: 0.00048329788842238486\n",
|
| 453 |
+
"387: loss: 4.2107648849487305\n",
|
| 454 |
+
"388: loss: 4.465734481811523\n",
|
| 455 |
+
"389: loss: 0.000529197626747191\n",
|
| 456 |
+
"390: loss: 3.872891664505005\n",
|
| 457 |
+
"391: loss: 5.214785099029541\n",
|
| 458 |
+
"392: loss: 4.345657825469971\n",
|
| 459 |
+
"393: loss: 0.0016826370265334845\n",
|
| 460 |
+
"394: loss: 0.0024580529425293207\n",
|
| 461 |
+
"395: loss: 0.002994671929627657\n",
|
| 462 |
+
"396: loss: 0.002981696743518114\n",
|
| 463 |
+
"397: loss: 0.002537172520533204\n",
|
| 464 |
+
"398: loss: 0.001975367311388254\n",
|
| 465 |
+
"399: loss: 0.0014994062948971987\n",
|
| 466 |
+
"400: loss: 0.0011500928085297346\n",
|
| 467 |
+
"400: valid loss 0.0009022268350236118\n",
|
| 468 |
+
"401: loss: 5.212808132171631\n",
|
| 469 |
+
"402: loss: 0.0008533270447514951\n",
|
| 470 |
+
"403: loss: 0.0008498210809193552\n",
|
| 471 |
+
"404: loss: 0.0008541711140424013\n",
|
| 472 |
+
"405: loss: 3.912627696990967\n",
|
| 473 |
+
"406: loss: 0.0008917151135392487\n",
|
| 474 |
+
"407: loss: 0.0009278871002607048\n",
|
| 475 |
+
"408: loss: 3.4623196125030518\n",
|
| 476 |
+
"409: loss: 0.0011483340058475733\n",
|
| 477 |
+
"410: loss: 0.0014651089441031218\n",
|
| 478 |
+
"411: loss: 3.501060962677002\n",
|
| 479 |
+
"412: loss: 4.905694484710693\n",
|
| 480 |
+
"413: loss: 0.0025538327172398567\n",
|
| 481 |
+
"414: loss: 0.0019650040194392204\n",
|
| 482 |
+
"415: loss: 0.001453581964597106\n",
|
| 483 |
+
"416: loss: 4.282127857208252\n",
|
| 484 |
+
"417: loss: 0.001117513864301145\n",
|
| 485 |
+
"418: loss: 3.2745401859283447\n",
|
| 486 |
+
"419: loss: 3.0665171146392822\n",
|
| 487 |
+
"420: loss: 0.001583368401043117\n",
|
| 488 |
+
"421: loss: 0.0018978181760758162\n",
|
| 489 |
+
"422: loss: 5.070369720458984\n",
|
| 490 |
+
"423: loss: 0.0025998111814260483\n",
|
| 491 |
+
"424: loss: 0.0028609540313482285\n",
|
| 492 |
+
"425: loss: 2.7316229343414307\n",
|
| 493 |
+
"426: loss: 0.003324385266751051\n",
|
| 494 |
+
"427: loss: 0.00243724649772048\n",
|
| 495 |
+
"428: loss: 0.0020084292627871037\n",
|
| 496 |
+
"429: loss: 0.001639676047489047\n",
|
| 497 |
+
"430: loss: 0.0012756038922816515\n",
|
| 498 |
+
"431: loss: 0.0010202551493421197\n",
|
| 499 |
+
"432: loss: 0.0008382818195968866\n",
|
| 500 |
+
"433: loss: 3.9101459980010986\n",
|
| 501 |
+
"434: loss: 3.4464950561523438\n",
|
| 502 |
+
"435: loss: 4.598957538604736\n",
|
| 503 |
+
"436: loss: 6.656869888305664\n",
|
| 504 |
+
"437: loss: 2.557544469833374\n",
|
| 505 |
+
"438: loss: 1.769715666770935\n",
|
| 506 |
+
"439: loss: 0.8786362409591675\n",
|
| 507 |
+
"440: loss: 0.09529905021190643\n",
|
| 508 |
+
"441: loss: 3.9526867866516113\n",
|
| 509 |
+
"442: loss: 3.4567954540252686\n",
|
| 510 |
+
"443: loss: 0.28547608852386475\n",
|
| 511 |
+
"444: loss: 0.1331639289855957\n",
|
| 512 |
+
"445: loss: 0.01748904585838318\n",
|
| 513 |
+
"446: loss: 3.7364015579223633\n",
|
| 514 |
+
"447: loss: 1.6454107761383057\n",
|
| 515 |
+
"448: loss: 0.007931341417133808\n",
|
| 516 |
+
"449: loss: 0.0017749288817867637\n",
|
| 517 |
+
"450: loss: 3.6518070697784424\n",
|
| 518 |
+
"451: loss: 3.056483507156372\n",
|
| 519 |
+
"452: loss: 0.0008364453678950667\n",
|
| 520 |
+
"453: loss: 0.0009152528364211321\n",
|
| 521 |
+
"454: loss: 0.0009797721868380904\n",
|
| 522 |
+
"455: loss: 4.194733142852783\n",
|
| 523 |
+
"456: loss: 0.0013897174503654242\n",
|
| 524 |
+
"457: loss: 0.0018761098617687821\n",
|
| 525 |
+
"458: loss: 0.0020015202462673187\n",
|
| 526 |
+
"459: loss: 9.263550758361816\n",
|
| 527 |
+
"460: loss: 0.0025061527267098427\n",
|
| 528 |
+
"461: loss: 0.003998400643467903\n",
|
| 529 |
+
"462: loss: 0.0031979954801499844\n",
|
| 530 |
+
"463: loss: 0.0009064731420949101\n",
|
| 531 |
+
"464: loss: 3.1668450832366943\n",
|
| 532 |
+
"465: loss: 6.006053924560547\n",
|
| 533 |
+
"466: loss: 0.0006406777538359165\n",
|
| 534 |
+
"467: loss: 0.0009267539135180414\n",
|
| 535 |
+
"468: loss: 0.0012060123262926936\n",
|
| 536 |
+
"469: loss: 0.0013315295800566673\n",
|
| 537 |
+
"470: loss: 3.5539376735687256\n",
|
| 538 |
+
"471: loss: 3.4590916633605957\n",
|
| 539 |
+
"472: loss: 0.0017678193980827928\n",
|
| 540 |
+
"473: loss: 0.00218581547960639\n",
|
| 541 |
+
"474: loss: 0.0025737383402884007\n",
|
| 542 |
+
"475: loss: 2.97592830657959\n",
|
| 543 |
+
"476: loss: 0.0032222135923802853\n",
|
| 544 |
+
"477: loss: 0.0020487091969698668\n",
|
| 545 |
+
"478: loss: 3.0420033931732178\n",
|
| 546 |
+
"479: loss: 0.001554043497890234\n",
|
| 547 |
+
"480: loss: 0.001528518507257104\n",
|
| 548 |
+
"481: loss: 0.001422215485945344\n",
|
| 549 |
+
"482: loss: 0.0012641653884202242\n",
|
| 550 |
+
"483: loss: 0.0010866222437471151\n",
|
| 551 |
+
"484: loss: 7.149199962615967\n",
|
| 552 |
+
"485: loss: 0.0010687584290280938\n",
|
| 553 |
+
"486: loss: 0.0012197017204016447\n",
|
| 554 |
+
"487: loss: 0.001343191834166646\n",
|
| 555 |
+
"488: loss: 0.0013996028574183583\n",
|
| 556 |
+
"489: loss: 0.001371717662550509\n",
|
| 557 |
+
"490: loss: 3.68569278717041\n",
|
| 558 |
+
"491: loss: 0.0014253916451707482\n",
|
| 559 |
+
"492: loss: 0.001504680491052568\n",
|
| 560 |
+
"493: loss: 0.0014929386088624597\n",
|
| 561 |
+
"494: loss: 0.0013759569264948368\n",
|
| 562 |
+
"495: loss: 3.385620355606079\n",
|
| 563 |
+
"496: loss: 0.0012212302535772324\n",
|
| 564 |
+
"497: loss: 0.0011952322674915195\n",
|
| 565 |
+
"498: loss: 3.1083197593688965\n",
|
| 566 |
+
"499: loss: 8.146794319152832\n",
|
| 567 |
+
"500: loss: 3.8151681423187256\n",
|
| 568 |
+
"500: valid loss 3.2241313457489014\n",
|
| 569 |
+
"501: loss: 0.002565972041338682\n",
|
| 570 |
+
"502: loss: 4.1275224685668945\n",
|
| 571 |
+
"503: loss: 0.004586916882544756\n",
|
| 572 |
+
"504: loss: 3.6200292110443115\n",
|
| 573 |
+
"505: loss: 0.004917770624160767\n",
|
| 574 |
+
"506: loss: 0.0035543786361813545\n",
|
| 575 |
+
"507: loss: 0.002198878675699234\n",
|
| 576 |
+
"508: loss: 3.9696688652038574\n",
|
| 577 |
+
"509: loss: 0.0012150105321779847\n",
|
| 578 |
+
"510: loss: 3.0237858295440674\n",
|
| 579 |
+
"511: loss: 0.0016711285570636392\n",
|
| 580 |
+
"512: loss: 0.0017911652103066444\n",
|
| 581 |
+
"513: loss: 0.001645330572500825\n",
|
| 582 |
+
"514: loss: 3.3689823150634766\n",
|
| 583 |
+
"515: loss: 0.0014145843451842666\n",
|
| 584 |
+
"516: loss: 0.0013438486494123936\n",
|
| 585 |
+
"517: loss: 0.0011701782932505012\n",
|
| 586 |
+
"518: loss: 0.0009688445716165006\n",
|
| 587 |
+
"519: loss: 0.0007915324531495571\n",
|
| 588 |
+
"520: loss: 4.113221645355225\n",
|
| 589 |
+
"521: loss: 0.0006360645638778806\n",
|
| 590 |
+
"522: loss: 0.0006149905384518206\n",
|
| 591 |
+
"523: loss: 8.360527038574219\n",
|
| 592 |
+
"524: loss: 0.0006234433385543525\n",
|
| 593 |
+
"525: loss: 0.0006739232921972871\n",
|
| 594 |
+
"526: loss: 0.0007281479192897677\n",
|
| 595 |
+
"527: loss: 0.000767726160120219\n",
|
| 596 |
+
"528: loss: 0.000772368221078068\n",
|
| 597 |
+
"529: loss: 0.0007228502072393894\n",
|
| 598 |
+
"530: loss: 0.0006368369213305414\n",
|
| 599 |
+
"531: loss: 3.732311725616455\n",
|
| 600 |
+
"532: loss: 5.932078838348389\n",
|
| 601 |
+
"533: loss: 3.5892159938812256\n",
|
| 602 |
+
"534: loss: 5.249965667724609\n",
|
| 603 |
+
"535: loss: 7.211183071136475\n",
|
| 604 |
+
"536: loss: 4.0714263916015625\n",
|
| 605 |
+
"537: loss: 3.1499719619750977\n",
|
| 606 |
+
"538: loss: 0.1844794750213623\n",
|
| 607 |
+
"539: loss: 3.4192230701446533\n",
|
| 608 |
+
"540: loss: 0.011980107054114342\n",
|
| 609 |
+
"541: loss: 0.010612019337713718\n",
|
| 610 |
+
"542: loss: 0.0045662750490009785\n",
|
| 611 |
+
"543: loss: 0.005457601509988308\n",
|
| 612 |
+
"544: loss: 0.015783555805683136\n",
|
| 613 |
+
"545: loss: 0.0013816619757562876\n",
|
| 614 |
+
"546: loss: 8.18481731414795\n",
|
| 615 |
+
"547: loss: 0.0006438567652367055\n",
|
| 616 |
+
"548: loss: 0.000572906865272671\n",
|
| 617 |
+
"549: loss: 10.10994815826416\n",
|
| 618 |
+
"550: loss: 0.003346000798046589\n",
|
| 619 |
+
"551: loss: 0.0006713962065987289\n",
|
| 620 |
+
"552: loss: 0.00026078836526721716\n",
|
| 621 |
+
"553: loss: 11.756505012512207\n",
|
| 622 |
+
"554: loss: 7.101832389831543\n",
|
| 623 |
+
"555: loss: 0.00021459207346197218\n",
|
| 624 |
+
"556: loss: 0.00025998923229053617\n",
|
| 625 |
+
"557: loss: 0.0003112201811745763\n",
|
| 626 |
+
"558: loss: 14.851192474365234\n",
|
| 627 |
+
"559: loss: 0.0004224810691084713\n",
|
| 628 |
+
"560: loss: 0.00047494613681919873\n",
|
| 629 |
+
"561: loss: 0.000519308028742671\n",
|
| 630 |
+
"562: loss: 0.0005509845213964581\n",
|
| 631 |
+
"563: loss: 0.0005668219528160989\n",
|
| 632 |
+
"564: loss: 14.569344520568848\n",
|
| 633 |
+
"565: loss: 6.4913740158081055\n",
|
| 634 |
+
"566: loss: 0.0008433411712758243\n",
|
| 635 |
+
"567: loss: 8.495502471923828\n",
|
| 636 |
+
"568: loss: 0.0019402098841965199\n",
|
| 637 |
+
"569: loss: 0.0035519124940037727\n",
|
| 638 |
+
"570: loss: 0.006841914728283882\n",
|
| 639 |
+
"571: loss: 4.089066982269287\n",
|
| 640 |
+
"572: loss: 5.491721153259277\n",
|
| 641 |
+
"573: loss: 3.87937331199646\n",
|
| 642 |
+
"574: loss: 0.03460773825645447\n",
|
| 643 |
+
"575: loss: 0.015647828578948975\n",
|
| 644 |
+
"576: loss: 0.002720448188483715\n",
|
| 645 |
+
"577: loss: 6.188972473144531\n",
|
| 646 |
+
"578: loss: 0.0008381525985896587\n",
|
| 647 |
+
"579: loss: 0.0008579537970945239\n",
|
| 648 |
+
"580: loss: 0.0008331844583153725\n",
|
| 649 |
+
"581: loss: 7.444668769836426\n",
|
| 650 |
+
"582: loss: 0.0013645365834236145\n",
|
| 651 |
+
"583: loss: 0.0018909723730757833\n",
|
| 652 |
+
"584: loss: 4.148159503936768\n",
|
| 653 |
+
"585: loss: 6.465692043304443\n",
|
| 654 |
+
"586: loss: 0.0040971520356833935\n",
|
| 655 |
+
"587: loss: 0.015496809035539627\n",
|
| 656 |
+
"588: loss: 0.0011185817420482635\n",
|
| 657 |
+
"589: loss: 0.00048535081441514194\n",
|
| 658 |
+
"590: loss: 0.0002821610542014241\n",
|
| 659 |
+
"591: loss: 0.00022055530280340463\n",
|
| 660 |
+
"592: loss: 0.0002070294285658747\n",
|
| 661 |
+
"593: loss: 0.00021876658138353378\n",
|
| 662 |
+
"594: loss: 0.00024527875939384103\n",
|
| 663 |
+
"595: loss: 0.00028197691426612437\n",
|
| 664 |
+
"596: loss: 0.00031235843198373914\n",
|
| 665 |
+
"597: loss: 0.00032129406463354826\n",
|
| 666 |
+
"598: loss: 0.000305092049529776\n",
|
| 667 |
+
"599: loss: 6.581624507904053\n",
|
| 668 |
+
"600: loss: 0.0004181505355518311\n",
|
| 669 |
+
"600: valid loss 0.001562803634442389\n",
|
| 670 |
+
"601: loss: 0.001562803634442389\n",
|
| 671 |
+
"602: loss: 0.0008329854463227093\n",
|
| 672 |
+
"603: loss: 8.43118953704834\n",
|
| 673 |
+
"604: loss: 0.00018880203424487263\n",
|
| 674 |
+
"605: loss: 6.225329399108887\n",
|
| 675 |
+
"606: loss: 0.0001953585451701656\n",
|
| 676 |
+
"607: loss: 0.00031005332130007446\n",
|
| 677 |
+
"608: loss: 6.243394374847412\n",
|
| 678 |
+
"609: loss: 0.002007008297368884\n",
|
| 679 |
+
"610: loss: 0.2842656672000885\n",
|
| 680 |
+
"611: loss: 0.002102950122207403\n",
|
| 681 |
+
"612: loss: 0.0013235295191407204\n",
|
| 682 |
+
"613: loss: 0.0012432391522452235\n",
|
| 683 |
+
"614: loss: 0.0011076040100306273\n",
|
| 684 |
+
"615: loss: 0.0009366637095808983\n",
|
| 685 |
+
"616: loss: 0.0007713991799391806\n",
|
| 686 |
+
"617: loss: 0.0006266268319450319\n",
|
| 687 |
+
"618: loss: 0.0005072436179034412\n",
|
| 688 |
+
"619: loss: 0.00041213506483472884\n",
|
| 689 |
+
"620: loss: 0.0003370844351593405\n",
|
| 690 |
+
"621: loss: 0.0002783465606626123\n",
|
| 691 |
+
"622: loss: 6.750359535217285\n",
|
| 692 |
+
"623: loss: 4.032569408416748\n",
|
| 693 |
+
"624: loss: 4.749107360839844\n",
|
| 694 |
+
"625: loss: 5.599199295043945\n",
|
| 695 |
+
"626: loss: 4.851316452026367\n",
|
| 696 |
+
"627: loss: 0.0012356003280729055\n",
|
| 697 |
+
"628: loss: 0.0019876735750585794\n",
|
| 698 |
+
"629: loss: 0.0022025934886187315\n",
|
| 699 |
+
"630: loss: 0.09389199316501617\n",
|
| 700 |
+
"631: loss: 0.0011942394776269794\n",
|
| 701 |
+
"632: loss: 0.0008771757711656392\n",
|
| 702 |
+
"633: loss: 0.000724500569049269\n",
|
| 703 |
+
"634: loss: 4.850365161895752\n",
|
| 704 |
+
"635: loss: 6.96458101272583\n",
|
| 705 |
+
"636: loss: 3.944305658340454\n",
|
| 706 |
+
"637: loss: 1.573992133140564\n",
|
| 707 |
+
"638: loss: 0.006376080680638552\n",
|
| 708 |
+
"639: loss: 0.004621799103915691\n",
|
| 709 |
+
"640: loss: 0.008686978369951248\n",
|
| 710 |
+
"641: loss: 0.002786734839901328\n",
|
| 711 |
+
"642: loss: 0.0012673415476456285\n",
|
| 712 |
+
"643: loss: 0.0008905518334358931\n"
|
| 713 |
+
]
|
| 714 |
+
},
|
| 715 |
+
{
|
| 716 |
+
"ename": "KeyboardInterrupt",
|
| 717 |
+
"evalue": "",
|
| 718 |
+
"output_type": "error",
|
| 719 |
+
"traceback": [
|
| 720 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 721 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 722 |
+
"Cell \u001b[1;32mIn[3], line 78\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m semantic_transformer, trainer, wav2vec\n\u001b[0;32m 73\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 78\u001b[0m \u001b[43mtrain_semantic_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 723 |
+
"Cell \u001b[1;32mIn[3], line 69\u001b[0m, in \u001b[0;36mtrain_semantic_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 52\u001b[0m semantic_transformer \u001b[38;5;241m=\u001b[39m SemanticTransformer(\n\u001b[0;32m 53\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 54\u001b[0m dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[0;32m 55\u001b[0m depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m6\u001b[39m,\n\u001b[0;32m 56\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 57\u001b[0m )\n\u001b[0;32m 59\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SemanticTransformerTrainer(\n\u001b[0;32m 60\u001b[0m transformer\u001b[38;5;241m=\u001b[39msemantic_transformer,\n\u001b[0;32m 61\u001b[0m wav2vec\u001b[38;5;241m=\u001b[39mwav2vec,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 66\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 67\u001b[0m )\n\u001b[1;32m---> 69\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 70\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(semantic_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msemantic_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave semantic_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
| 724 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1000\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 997\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 999\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1000\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1001\u001b[0m log_fn(logs)\n\u001b[0;32m 1003\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
| 725 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:944\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 941\u001b[0m data_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_tuple_to_kwargs(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdl_iter))\n\u001b[0;32m 943\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[1;32m--> 944\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_wrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdata_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_loss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 946\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mbackward(loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every)\n\u001b[0;32m 948\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n",
|
| 726 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 727 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 728 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\audiolm_pytorch.py:1480\u001b[0m, in \u001b[0;36mSemanticTransformerWrapper.forward\u001b[1;34m(self, semantic_token_ids, raw_wave, text, text_embeds, return_loss, **kwargs)\u001b[0m\n\u001b[0;32m 1478\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(raw_wave)\n\u001b[0;32m 1479\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text_embeds)\n\u001b[1;32m-> 1480\u001b[0m text_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio_conditioner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mraw_wave\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msemantic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(semantic_token_ids):\n\u001b[0;32m 1483\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwav2vec), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mVQWav2Vec must be be provided if given raw wave for training\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
|
| 729 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 730 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 731 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:872\u001b[0m, in \u001b[0;36mMuLaNEmbedQuantizer.forward\u001b[1;34m(self, wavs, texts, namespace)\u001b[0m\n\u001b[0;32m 869\u001b[0m \u001b[38;5;66;03m# sound and language live in joint embedding space because of contrastive learning\u001b[39;00m\n\u001b[0;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(wavs):\n\u001b[1;32m--> 872\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulan\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_audio_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 873\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exists(texts):\n\u001b[0;32m 874\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulan\u001b[38;5;241m.\u001b[39mget_text_latents(texts)\n",
|
| 732 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:732\u001b[0m, in \u001b[0;36mMuLaN.get_audio_latents\u001b[1;34m(self, wavs, return_all_layers)\u001b[0m\n\u001b[0;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_audio_latents\u001b[39m(\n\u001b[0;32m 728\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 729\u001b[0m wavs,\n\u001b[0;32m 730\u001b[0m return_all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 731\u001b[0m ):\n\u001b[1;32m--> 732\u001b[0m audio_embeds, audio_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 733\u001b[0m audio_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_to_latents(audio_embeds)\n\u001b[0;32m 734\u001b[0m out \u001b[38;5;241m=\u001b[39m l2norm(audio_latents)\n",
|
| 733 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 734 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 735 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:525\u001b[0m, in \u001b[0;36mAudioSpectrogramTransformer.forward\u001b[1;34m(self, x, force_no_patch_dropout, return_all_layers)\u001b[0m\n\u001b[0;32m 521\u001b[0m rel_pos_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdynamic_pos_bias_mlp(rel_dist\u001b[38;5;241m.\u001b[39mfloat())\n\u001b[0;32m 523\u001b[0m \u001b[38;5;66;03m# attention, what else\u001b[39;00m\n\u001b[1;32m--> 525\u001b[0m x, all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;66;03m# final global average and norm (most recent papers show this is superior to CLS token)\u001b[39;00m\n\u001b[0;32m 529\u001b[0m x \u001b[38;5;241m=\u001b[39m reduce(x, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb n d -> b d\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
| 736 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 737 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 738 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:247\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, x, rel_pos_bias, mask, return_all_layers)\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attn, ff \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m 246\u001b[0m x \u001b[38;5;241m=\u001b[39m attn(x, rel_pos_bias \u001b[38;5;241m=\u001b[39m rel_pos_bias, mask \u001b[38;5;241m=\u001b[39m mask) \u001b[38;5;241m+\u001b[39m x\n\u001b[1;32m--> 247\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m x\n\u001b[0;32m 248\u001b[0m layers\u001b[38;5;241m.\u001b[39mappend(x)\n\u001b[0;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_all_layers:\n",
|
| 739 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 740 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 741 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
|
| 742 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 743 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
| 744 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 745 |
+
]
|
| 746 |
+
}
|
| 747 |
+
],
|
| 748 |
+
"source": [
|
| 749 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
| 750 |
+
" dim = 512,\n",
|
| 751 |
+
" depth = 6,\n",
|
| 752 |
+
" heads = 8,\n",
|
| 753 |
+
" dim_head = 64,\n",
|
| 754 |
+
" spec_n_fft = 128,\n",
|
| 755 |
+
" spec_win_length = 24,\n",
|
| 756 |
+
" spec_aug_stretch_factor = 0.8\n",
|
| 757 |
+
")\n",
|
| 758 |
+
"\n",
|
| 759 |
+
"text_transformer = TextTransformer(\n",
|
| 760 |
+
" dim = 512,\n",
|
| 761 |
+
" depth = 6,\n",
|
| 762 |
+
" heads = 8,\n",
|
| 763 |
+
" dim_head = 64\n",
|
| 764 |
+
")\n",
|
| 765 |
+
"\n",
|
| 766 |
+
"mulan = MuLaN(\n",
|
| 767 |
+
" audio_transformer = audio_transformer,\n",
|
| 768 |
+
" text_transformer = text_transformer\n",
|
| 769 |
+
")\n",
|
| 770 |
+
"\n",
|
| 771 |
+
"# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)\n",
|
| 772 |
+
"\n",
|
| 773 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
| 774 |
+
" mulan = mulan, # pass in trained mulan from above\n",
|
| 775 |
+
" conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024\n",
|
| 776 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
| 777 |
+
")\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"# now say you want the conditioning embeddings for semantic transformer\n",
|
| 780 |
+
"\n",
|
| 781 |
+
"wavs = torch.randn(2, 1024)\n",
|
| 782 |
+
"conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers\n",
|
| 783 |
+
"\n",
|
| 784 |
+
"# SemanticTransformer\n",
|
| 785 |
+
"def train_semantic_transformer():\n",
|
| 786 |
+
" wav2vec = HubertWithKmeans(\n",
|
| 787 |
+
" checkpoint_path=checkpoint_path,\n",
|
| 788 |
+
" kmeans_path=kmeans_path\n",
|
| 789 |
+
" )\n",
|
| 790 |
+
"\n",
|
| 791 |
+
"\n",
|
| 792 |
+
" if torch.cuda.is_available():\n",
|
| 793 |
+
" semantic_transformer = SemanticTransformer(\n",
|
| 794 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 795 |
+
" dim=1024,\n",
|
| 796 |
+
" depth=6,\n",
|
| 797 |
+
" audio_text_condition=True\n",
|
| 798 |
+
" ).cuda()\n",
|
| 799 |
+
" else:\n",
|
| 800 |
+
" semantic_transformer = SemanticTransformer(\n",
|
| 801 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
| 802 |
+
" dim=1024,\n",
|
| 803 |
+
" depth=6,\n",
|
| 804 |
+
" audio_text_condition=True\n",
|
| 805 |
+
" )\n",
|
| 806 |
+
"\n",
|
| 807 |
+
" trainer = SemanticTransformerTrainer(\n",
|
| 808 |
+
" transformer=semantic_transformer,\n",
|
| 809 |
+
" wav2vec=wav2vec,\n",
|
| 810 |
+
" audio_conditioner=quantizer,\n",
|
| 811 |
+
" folder=audio_output_dir,\n",
|
| 812 |
+
" batch_size=batch_size,\n",
|
| 813 |
+
" data_max_length=data_max_length,\n",
|
| 814 |
+
" num_train_steps=num_train_steps\n",
|
| 815 |
+
" )\n",
|
| 816 |
+
"\n",
|
| 817 |
+
" trainer.train()\n",
|
| 818 |
+
" torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')\n",
|
| 819 |
+
" print(\"save semantic_transformer.pth\")\n",
|
| 820 |
+
" del semantic_transformer, trainer, wav2vec\n",
|
| 821 |
+
" gc.collect()\n",
|
| 822 |
+
"\n",
|
| 823 |
+
"\n",
|
| 824 |
+
"\n",
|
| 825 |
+
"\n",
|
| 826 |
+
"train_semantic_transformer()"
|
| 827 |
+
]
|
| 828 |
+
}
|
| 829 |
+
],
|
| 830 |
+
"metadata": {
|
| 831 |
+
"kernelspec": {
|
| 832 |
+
"display_name": "myenv",
|
| 833 |
+
"language": "python",
|
| 834 |
+
"name": "python3"
|
| 835 |
+
},
|
| 836 |
+
"language_info": {
|
| 837 |
+
"codemirror_mode": {
|
| 838 |
+
"name": "ipython",
|
| 839 |
+
"version": 3
|
| 840 |
+
},
|
| 841 |
+
"file_extension": ".py",
|
| 842 |
+
"mimetype": "text/x-python",
|
| 843 |
+
"name": "python",
|
| 844 |
+
"nbconvert_exporter": "python",
|
| 845 |
+
"pygments_lexer": "ipython3",
|
| 846 |
+
"version": "3.11.2"
|
| 847 |
+
}
|
| 848 |
+
},
|
| 849 |
+
"nbformat": 4,
|
| 850 |
+
"nbformat_minor": 2
|
| 851 |
+
}
|