[#7] training & fetching m-1-3 is ready
Browse files- config.yaml +6 -5
- explore/explore_fetch_tokenizer.py +4 -0
- idiomify/fetchers.py +1 -0
- main_train.py +7 -5
config.yaml
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
# for training an idiomifier
|
| 2 |
idiomifier:
|
| 3 |
-
ver: m-1-
|
| 4 |
-
desc:
|
| 5 |
bart: facebook/bart-base
|
| 6 |
lr: 0.0001
|
| 7 |
-
literal2idiomatic_ver: d-1-
|
| 8 |
-
idioms_ver: d-1-
|
| 9 |
-
|
|
|
|
| 10 |
batch_size: 40
|
| 11 |
shuffle: true
|
| 12 |
seed: 104
|
|
|
|
| 1 |
# for training an idiomifier
|
| 2 |
idiomifier:
|
| 3 |
+
ver: m-1-3
|
| 4 |
+
desc: Just overfitting on PIE dataset, but now with <idiom> & </idiom> special tokens.
|
| 5 |
bart: facebook/bart-base
|
| 6 |
lr: 0.0001
|
| 7 |
+
literal2idiomatic_ver: d-1-3
|
| 8 |
+
idioms_ver: d-1-3
|
| 9 |
+
tokenizer_ver: t-1-1
|
| 10 |
+
max_epochs: 3
|
| 11 |
batch_size: 40
|
| 12 |
shuffle: true
|
| 13 |
seed: 104
|
explore/explore_fetch_tokenizer.py
CHANGED
|
@@ -12,6 +12,9 @@ def main():
|
|
| 12 |
print(tokenizer.unk_token)
|
| 13 |
print(tokenizer.additional_special_tokens) # this should have been added
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
"""
|
| 17 |
<s>
|
|
@@ -22,6 +25,7 @@ def main():
|
|
| 22 |
<pad>
|
| 23 |
<unk>
|
| 24 |
['<idiom>', '</idiom>']
|
|
|
|
| 25 |
"""
|
| 26 |
|
| 27 |
if __name__ == '__main__':
|
|
|
|
| 12 |
print(tokenizer.unk_token)
|
| 13 |
print(tokenizer.additional_special_tokens) # this should have been added
|
| 14 |
|
| 15 |
+
# the size of the vocab
|
| 16 |
+
print(len(tokenizer))
|
| 17 |
+
|
| 18 |
|
| 19 |
"""
|
| 20 |
<s>
|
|
|
|
| 25 |
<pad>
|
| 26 |
<unk>
|
| 27 |
['<idiom>', '</idiom>']
|
| 28 |
+
50267
|
| 29 |
"""
|
| 30 |
|
| 31 |
if __name__ == '__main__':
|
idiomify/fetchers.py
CHANGED
|
@@ -60,6 +60,7 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
|
|
| 60 |
artifact_dir = artifact.download(root=idiomifier_dir(ver))
|
| 61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
| 62 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
|
|
|
| 63 |
model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
| 64 |
return model
|
| 65 |
|
|
|
|
| 60 |
artifact_dir = artifact.download(root=idiomifier_dir(ver))
|
| 61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
| 62 |
bart = AutoModelForSeq2SeqLM.from_config(AutoConfig.from_pretrained(config['bart']))
|
| 63 |
+
bart.resize_embeddings(config['vocab_size'])
|
| 64 |
model = Idiomifier.load_from_checkpoint(ckpt_path, bart=bart)
|
| 65 |
return model
|
| 66 |
|
main_train.py
CHANGED
|
@@ -5,9 +5,9 @@ import argparse
|
|
| 5 |
import pytorch_lightning as pl
|
| 6 |
from termcolor import colored
|
| 7 |
from pytorch_lightning.loggers import WandbLogger
|
| 8 |
-
from transformers import
|
| 9 |
from idiomify.datamodules import IdiomifyDataModule
|
| 10 |
-
from idiomify.fetchers import fetch_config
|
| 11 |
from idiomify.models import Idiomifier
|
| 12 |
from idiomify.paths import ROOT_DIR
|
| 13 |
|
|
@@ -23,12 +23,13 @@ def main():
|
|
| 23 |
config.update(vars(args))
|
| 24 |
if not config['upload']:
|
| 25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
| 26 |
-
# prepare
|
| 27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
| 28 |
-
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
| 29 |
-
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
| 30 |
# prepare the datamodule
|
| 31 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
|
|
|
|
|
|
|
|
|
| 32 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
| 33 |
logger = WandbLogger(log_model=False)
|
| 34 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
|
@@ -44,6 +45,7 @@ def main():
|
|
| 44 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
| 45 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
| 46 |
trainer.save_checkpoint(str(ckpt_path))
|
|
|
|
| 47 |
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
| 48 |
artifact.add_file(str(ckpt_path))
|
| 49 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
|
|
|
| 5 |
import pytorch_lightning as pl
|
| 6 |
from termcolor import colored
|
| 7 |
from pytorch_lightning.loggers import WandbLogger
|
| 8 |
+
from transformers import BartForConditionalGeneration
|
| 9 |
from idiomify.datamodules import IdiomifyDataModule
|
| 10 |
+
from idiomify.fetchers import fetch_config, fetch_tokenizer
|
| 11 |
from idiomify.models import Idiomifier
|
| 12 |
from idiomify.paths import ROOT_DIR
|
| 13 |
|
|
|
|
| 23 |
config.update(vars(args))
|
| 24 |
if not config['upload']:
|
| 25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
| 26 |
+
# prepare a pre-trained BART
|
| 27 |
bart = BartForConditionalGeneration.from_pretrained(config['bart'])
|
|
|
|
|
|
|
| 28 |
# prepare the datamodule
|
| 29 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
| 30 |
+
tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
|
| 31 |
+
bart.resize_token_embeddings(len(tokenizer)) # because new tokens are added, this process is necessary
|
| 32 |
+
model = Idiomifier(bart, config['lr'], tokenizer.bos_token_id, tokenizer.pad_token_id)
|
| 33 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
| 34 |
logger = WandbLogger(log_model=False)
|
| 35 |
trainer = pl.Trainer(max_epochs=config['max_epochs'],
|
|
|
|
| 45 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
| 46 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
| 47 |
trainer.save_checkpoint(str(ckpt_path))
|
| 48 |
+
config['vocab_size'] = len(tokenizer) # this will be needed to fetch a pretrained idiomifier later
|
| 49 |
artifact = wandb.Artifact(name="idiomifier", type="model", metadata=config)
|
| 50 |
artifact.add_file(str(ckpt_path))
|
| 51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|