Spaces:
Build error
Build error
Create token_weighter.py
Browse files
rhyme_with_ai/token_weighter.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TokenWeighter:
|
| 5 |
+
def __init__(self, tokenizer):
|
| 6 |
+
self.tokenizer_ = tokenizer
|
| 7 |
+
self.proba = self.get_token_proba()
|
| 8 |
+
|
| 9 |
+
def get_token_proba(self):
|
| 10 |
+
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
|
| 11 |
+
return valid_token_mask
|
| 12 |
+
|
| 13 |
+
def _filter_short_partial(self, vocab):
|
| 14 |
+
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
|
| 15 |
+
is_valid = np.zeros(len(vocab.keys()))
|
| 16 |
+
is_valid[valid_token_ids] = 1
|
| 17 |
+
return is_valid
|