a couple of explore scripts for: adding two special tokens (<idiom>, </idiom>)
Browse files
explore/explore_bart_for_conditional_generation.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def main():
|
| 6 |
-
pass
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
if __name__ == '__main__':
|
| 10 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/explore_bart_tokenizer_add_special_tokens.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
| 6 |
+
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
| 7 |
+
num_added_tokens = tokenizer.add_special_tokens({
|
| 8 |
+
"additional_special_tokens": ["<idiom>", "</idiom>"], # beginning and end of an idiom
|
| 9 |
+
})
|
| 10 |
+
print(num_added_tokens)
|
| 11 |
+
print(tokenizer.additional_special_tokens) # more special tokens are added here
|
| 12 |
+
# and then you should resize the embedding table of your model
|
| 13 |
+
print(bart.model.shared.weight.shape) # before
|
| 14 |
+
bart.resize_token_embeddings(len(tokenizer))
|
| 15 |
+
print(bart.model.shared.weight.shape) # after
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == '__main__':
|
| 19 |
+
main()
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
2
|
| 23 |
+
['<idiom>', '</idiom>']
|
| 24 |
+
torch.Size([50265, 768])
|
| 25 |
+
torch.Size([50267, 768]) # you can see that 2 more embedding vectors have been added here.
|
| 26 |
+
later, you may want to save the tokenizer after you add the idiom special tokens.
|
| 27 |
+
"""
|
explore/explore_bart_tokenizer_special_tokens.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BartTokenizer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
| 6 |
+
print(tokenizer.bos_token)
|
| 7 |
+
print(tokenizer.cls_token)
|
| 8 |
+
print(tokenizer.eos_token)
|
| 9 |
+
print(tokenizer.sep_token)
|
| 10 |
+
print(tokenizer.mask_token)
|
| 11 |
+
print(tokenizer.pad_token)
|
| 12 |
+
print(tokenizer.unk_token)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
<s>
|
| 17 |
+
<s>
|
| 18 |
+
</s>
|
| 19 |
+
</s>
|
| 20 |
+
<mask>
|
| 21 |
+
<pad>
|
| 22 |
+
<unk>
|
| 23 |
+
|
| 24 |
+
right, so this is just like the symbols for BERT but in lowercase.
|
| 25 |
+
bos = cls
|
| 26 |
+
sep = eos
|
| 27 |
+
would it be okay to use <idiom> = <sep>?
|
| 28 |
+
no, sep implies that a sentence somehow ends.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == '__main__':
|
| 36 |
+
main()
|