|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, asdict, field |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
import click |
|
|
import json |
|
|
|
|
|
|
|
|
class SpecialTokensMapEnum(Enum): |
|
|
BOS_TOKEN = "bos_token" |
|
|
EOS_TOKEN = "eos_token" |
|
|
PAD_TOKEN = "pad_token" |
|
|
UNK_TOKEN = "unk_token" |
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class SpecialToken: |
|
|
id: int |
|
|
content: str |
|
|
lstrip: bool = False |
|
|
normalized: bool = False |
|
|
rstrip: bool = False |
|
|
single_word: bool = False |
|
|
special: bool = False |
|
|
special_token_map: list[SpecialTokensMapEnum] = field(default_factory=list) |
|
|
|
|
|
def to_added_tokens_decoder(self): |
|
|
data = asdict(self) |
|
|
token_id = str(data.pop("id")) |
|
|
data.pop("special_token_map") |
|
|
return {token_id: data} |
|
|
|
|
|
def to_added_tokens(self): |
|
|
data = asdict(self) |
|
|
data.pop("special_token_map") |
|
|
return data |
|
|
|
|
|
def to_special_tokens_map(self) -> dict[str, dict]: |
|
|
special_tokens_map = {} |
|
|
for special_token_map in self.special_token_map: |
|
|
data = asdict(self) |
|
|
data.pop("special_token_map") |
|
|
data.pop("special") |
|
|
data.pop("id") |
|
|
special_tokens_map[special_token_map.value] = data |
|
|
|
|
|
return special_tokens_map |
|
|
|
|
|
|
|
|
MODEL_MAX_LENGTH = 65536 |
|
|
|
|
|
DESIRED_MAPPING = [ |
|
|
SpecialToken(id=100256, content="<|extra_id_0|>"), |
|
|
SpecialToken( |
|
|
id=100257, |
|
|
content="<|endoftext|>", |
|
|
special=True, |
|
|
special_token_map=[ |
|
|
SpecialTokensMapEnum.BOS_TOKEN, |
|
|
SpecialTokensMapEnum.EOS_TOKEN, |
|
|
SpecialTokensMapEnum.UNK_TOKEN, |
|
|
]), |
|
|
SpecialToken(id=100258, content="<|fim_prefix|>", special=True), |
|
|
SpecialToken(id=100259, content="<|fim_middle|>", special=True), |
|
|
SpecialToken(id=100260, content="<|fim_suffix|>",special=True), |
|
|
SpecialToken(id=100261, content="|||PHONE_NUMBER|||"), |
|
|
SpecialToken(id=100262, content="|||EMAIL_ADDRESS|||"), |
|
|
SpecialToken(id=100263, content="|||IP_ADDRESS|||"), |
|
|
SpecialToken(id=100264, content="<|im_start|>", special=True), |
|
|
SpecialToken(id=100265, content="<|im_end|>", special=True), |
|
|
SpecialToken(id=100266, content="<|extra_id_1|>"), |
|
|
SpecialToken(id=100267, content="<|extra_id_2|>"), |
|
|
SpecialToken(id=100268, content="<|extra_id_3|>"), |
|
|
SpecialToken(id=100269, content="<|extra_id_4|>"), |
|
|
SpecialToken(id=100270, content="<|extra_id_5|>"), |
|
|
SpecialToken(id=100271, content="<|extra_id_6|>"), |
|
|
SpecialToken(id=100272, content="<|extra_id_7|>"), |
|
|
SpecialToken(id=100273, content="<|extra_id_8|>"), |
|
|
SpecialToken(id=100274, content="<|extra_id_9|>"), |
|
|
SpecialToken(id=100275, content="<|extra_id_10|>"), |
|
|
SpecialToken(id=100276, content="<|endofprompt|>", special=True), |
|
|
SpecialToken( |
|
|
id=100277, |
|
|
content="<|pad|>", |
|
|
special=True, |
|
|
special_token_map=[SpecialTokensMapEnum.PAD_TOKEN], |
|
|
), |
|
|
] |
|
|
|
|
|
SCRIPT_DIR = Path(__file__).parent |
|
|
TOKENIZER_CONFIG_FILE = SCRIPT_DIR / "tokenizer_config.json" |
|
|
TOKENIZER_FILE = SCRIPT_DIR / "tokenizer.json" |
|
|
VOCAB_FILE = SCRIPT_DIR / "vocab.json" |
|
|
SPECIAL_TOKENS_MAP_FILE = SCRIPT_DIR / "special_tokens_map.json" |
|
|
|
|
|
|
|
|
|
|
|
@click.group() |
|
|
def cli(): |
|
|
"""Dataset processing tools.""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def _get_mapped_special_token( |
|
|
special_tokens: list[SpecialToken], |
|
|
mapped_token: SpecialTokensMapEnum |
|
|
) -> SpecialToken: |
|
|
all_mapped_tokens = [token for token in special_tokens if mapped_token in token.special_token_map] |
|
|
if len(all_mapped_tokens) == 0: |
|
|
raise ValueError(f"Cannot find mapped token for {mapped_token}") |
|
|
if len(all_mapped_tokens) > 1: |
|
|
all_mapped_tokens_str = ", ".join([token.content for token in all_mapped_tokens]) |
|
|
raise ValueError(f"Found multiple mapped tokens for {mapped_token}: {all_mapped_tokens_str}") |
|
|
return all_mapped_tokens[0] |
|
|
|
|
|
|
|
|
def get_unk_token(special_tokens: list[SpecialToken]) -> SpecialToken: |
|
|
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.UNK_TOKEN) |
|
|
|
|
|
|
|
|
def get_bos_token(special_tokens: list[SpecialToken]) -> SpecialToken: |
|
|
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.BOS_TOKEN) |
|
|
|
|
|
|
|
|
def get_eos_token(special_tokens: list[SpecialToken]) -> SpecialToken: |
|
|
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.EOS_TOKEN) |
|
|
|
|
|
|
|
|
def get_pad_token(special_tokens: list[SpecialToken]) -> SpecialToken: |
|
|
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.PAD_TOKEN) |
|
|
|
|
|
|
|
|
@cli.command() |
|
|
def check(): |
|
|
"""Check if the current config matches the desired mapping.""" |
|
|
|
|
|
|
|
|
print("STEP 1: Checking tokenizer config file...") |
|
|
|
|
|
if not TOKENIZER_CONFIG_FILE.exists(): |
|
|
raise FileNotFoundError(f"Tokenizer config file not found: {TOKENIZER_CONFIG_FILE}") |
|
|
|
|
|
with open(TOKENIZER_CONFIG_FILE, "r") as f: |
|
|
tokenizer_config = json.load(f) |
|
|
|
|
|
added_tokens_decoder = tokenizer_config.get("added_tokens_decoder", {}) |
|
|
for token in DESIRED_MAPPING: |
|
|
str_token_id = str(token.id) |
|
|
if str_token_id not in added_tokens_decoder: |
|
|
raise ValueError(f"Token {token.id} not found in added tokens decoder") |
|
|
|
|
|
computed_added_tokens_decoder = token.to_added_tokens_decoder() |
|
|
if computed_added_tokens_decoder[str_token_id] != added_tokens_decoder[str_token_id]: |
|
|
raise ValueError(f"Token {token.id} has different content in added tokens decoder") |
|
|
|
|
|
print(f"Token {token.id} found in added tokens decoder; content matches") |
|
|
|
|
|
bos_token = get_bos_token(DESIRED_MAPPING) |
|
|
if bos_token.content != tokenizer_config["bos_token"]: |
|
|
raise ValueError(f"Bos token content mismatch: {bos_token.content} != {tokenizer_config['bos_token']}") |
|
|
else: |
|
|
print("Bos token content matches") |
|
|
|
|
|
eos_token = get_eos_token(DESIRED_MAPPING) |
|
|
if eos_token.content != tokenizer_config["eos_token"]: |
|
|
raise ValueError(f"Eos token content mismatch: {eos_token.content} != {tokenizer_config['eos_token']}") |
|
|
else: |
|
|
print("Eos token content matches") |
|
|
|
|
|
pad_token = get_pad_token(DESIRED_MAPPING) |
|
|
if pad_token.content != tokenizer_config["pad_token"]: |
|
|
raise ValueError(f"Pad token content mismatch: {pad_token.content} != {tokenizer_config['pad_token']}") |
|
|
else: |
|
|
print("Pad token content matches") |
|
|
|
|
|
unk_token = get_unk_token(DESIRED_MAPPING) |
|
|
if unk_token.content != tokenizer_config["unk_token"]: |
|
|
raise ValueError(f"Unk token content mismatch: {unk_token.content} != {tokenizer_config['unk_token']}") |
|
|
else: |
|
|
print("Unk token content matches") |
|
|
|
|
|
if tokenizer_config["model_max_length"] != MODEL_MAX_LENGTH: |
|
|
raise ValueError(f"Model max length mismatch: {tokenizer_config['model_max_length']} != {MODEL_MAX_LENGTH}") |
|
|
else: |
|
|
print("Model max length matches") |
|
|
|
|
|
|
|
|
|
|
|
print("STEP 2: Checking tokenizer file...") |
|
|
|
|
|
if not TOKENIZER_FILE.exists(): |
|
|
raise FileNotFoundError(f"Tokenizer file not found: {TOKENIZER_FILE}") |
|
|
|
|
|
with open(TOKENIZER_FILE, "r") as f: |
|
|
tokenizer = json.load(f) |
|
|
|
|
|
|
|
|
added_tokens_dict = {token["id"]: token for token in tokenizer.get("added_tokens", [])} |
|
|
for token in DESIRED_MAPPING: |
|
|
if token.id not in added_tokens_dict: |
|
|
raise ValueError(f"Token {token.id} not found in added tokens") |
|
|
|
|
|
computed_added_token = token.to_added_tokens() |
|
|
if computed_added_token != added_tokens_dict[token.id]: |
|
|
raise ValueError(f"Token {token.id} has different content in added tokens") |
|
|
print(f"Token {token.id} found in added tokens; content matches.") |
|
|
|
|
|
|
|
|
vocab = tokenizer.get("model", {}).get("vocab", {}) |
|
|
for token in DESIRED_MAPPING: |
|
|
if token.content not in vocab: |
|
|
raise ValueError(f"Token `{token.content}` not found in vocab") |
|
|
if token.id != vocab[token.content]: |
|
|
raise ValueError(f"Token `{token.content}`: vocab=`{vocab[token.content]}` provided=`{token.id}`") |
|
|
print(f"Token `{token.content}` found in vocab; id `{token.id}` matches.") |
|
|
|
|
|
seen_values: dict[int, list[str]] = {} |
|
|
for key, value in vocab.items(): |
|
|
seen_values.setdefault(value, []).append(key) |
|
|
|
|
|
broken_vocab = False |
|
|
for value, keys in seen_values.items(): |
|
|
if len(keys) > 1: |
|
|
broken_vocab = True |
|
|
print(f"Vocab value {value} is not unique; keys: {keys}") |
|
|
|
|
|
if broken_vocab: |
|
|
raise ValueError("Vocab values are not unique") |
|
|
|
|
|
else: |
|
|
print("Vocab values are unique") |
|
|
|
|
|
|
|
|
print("STEP 3: Checking vocab file...") |
|
|
|
|
|
if not VOCAB_FILE.exists(): |
|
|
raise FileNotFoundError(f"Vocab file not found: {VOCAB_FILE}") |
|
|
|
|
|
with open(VOCAB_FILE, "r") as f: |
|
|
vocab = json.load(f) |
|
|
|
|
|
for token in DESIRED_MAPPING: |
|
|
if token.content not in vocab: |
|
|
raise ValueError(f"Token `{token.content}` not found in vocab") |
|
|
if token.id != vocab[token.content]: |
|
|
raise ValueError(f"Token `{token.content}`: vocab=`{vocab[token.content]}` provided=`{token.id}`") |
|
|
print(f"Token `{token.content}` found in vocab; id `{token.id}` matches.") |
|
|
|
|
|
if len(set(vocab.values())) != len(vocab): |
|
|
raise ValueError("Vocab values are not unique") |
|
|
|
|
|
|
|
|
print("STEP 4: Checking special tokens map file...") |
|
|
|
|
|
if not SPECIAL_TOKENS_MAP_FILE.exists(): |
|
|
raise FileNotFoundError(f"Special tokens map file not found: {SPECIAL_TOKENS_MAP_FILE}") |
|
|
|
|
|
with open(SPECIAL_TOKENS_MAP_FILE, "r") as f: |
|
|
special_tokens_map = json.load(f) |
|
|
|
|
|
|
|
|
seen_special_tokens = set() |
|
|
for token in DESIRED_MAPPING: |
|
|
for key, value in token.to_special_tokens_map().items(): |
|
|
if key not in special_tokens_map: |
|
|
raise ValueError(f"Special token map {key} not found in special tokens map") |
|
|
if value != special_tokens_map[key]: |
|
|
raise ValueError(f"Special token map {key} content mismatch: {value} != {special_tokens_map[key]}") |
|
|
|
|
|
print(f"Special token map {key} content matches") |
|
|
seen_special_tokens.add(key) |
|
|
|
|
|
if len(seen_special_tokens) != len(special_tokens_map): |
|
|
raise ValueError("Special tokens map values are not unique") |
|
|
print("All special tokens map values match") |
|
|
|
|
|
|
|
|
@cli.command() |
|
|
def fix(): |
|
|
"""Fix the tokens in the tokenizer config, tokenizer file, vocab file, and special tokens map file.""" |
|
|
|
|
|
print("STEP 1: Fixing tokenizer config file...") |
|
|
with open(TOKENIZER_CONFIG_FILE, "r") as f: |
|
|
tokenizer_config = json.load(f) |
|
|
|
|
|
tokenizer_config["bos_token"] = get_bos_token(DESIRED_MAPPING).content |
|
|
tokenizer_config["eos_token"] = get_eos_token(DESIRED_MAPPING).content |
|
|
tokenizer_config["pad_token"] = get_pad_token(DESIRED_MAPPING).content |
|
|
tokenizer_config["unk_token"] = get_unk_token(DESIRED_MAPPING).content |
|
|
tokenizer_config["model_max_length"] = MODEL_MAX_LENGTH |
|
|
|
|
|
added_tokens_decoder = {} |
|
|
for token in DESIRED_MAPPING: |
|
|
added_tokens_decoder.update(token.to_added_tokens_decoder()) |
|
|
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder |
|
|
|
|
|
with open(TOKENIZER_CONFIG_FILE, "w") as f: |
|
|
json.dump(tokenizer_config, f, indent=2, ensure_ascii=False) |
|
|
print(f"Updated tokenizer config file in {TOKENIZER_CONFIG_FILE}.") |
|
|
|
|
|
|
|
|
print("STEP 2: Fixing tokenizer file...") |
|
|
with open(TOKENIZER_FILE, "r") as f: |
|
|
tokenizer = json.load(f) |
|
|
added_tokens = [] |
|
|
for token in DESIRED_MAPPING: |
|
|
added_tokens.append(token.to_added_tokens()) |
|
|
tokenizer["added_tokens"] = added_tokens |
|
|
|
|
|
for token in DESIRED_MAPPING: |
|
|
|
|
|
for key in list(tokenizer["model"]["vocab"].keys()): |
|
|
if tokenizer["model"]["vocab"][key] == token.id: |
|
|
tokenizer["model"]["vocab"].pop(key) |
|
|
|
|
|
|
|
|
tokenizer["model"]["vocab"][token.content] = token.id |
|
|
|
|
|
with open(TOKENIZER_FILE, "w") as f: |
|
|
json.dump(tokenizer, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"Updated tokenizer file in {TOKENIZER_FILE}.") |
|
|
|
|
|
print("STEP 3: Fixing vocab file...") |
|
|
with open(VOCAB_FILE, "r") as f: |
|
|
vocab = json.load(f) |
|
|
for token in DESIRED_MAPPING: |
|
|
|
|
|
for key in list(vocab.keys()): |
|
|
if vocab[key] == token.id: |
|
|
vocab.pop(key) |
|
|
|
|
|
|
|
|
vocab[token.content] = token.id |
|
|
with open(VOCAB_FILE, "w") as f: |
|
|
json.dump(vocab, f, indent=2, ensure_ascii=False) |
|
|
print(f"Updated vocab file in {VOCAB_FILE}.") |
|
|
|
|
|
print("STEP 4: Fixing special tokens map file...") |
|
|
with open(SPECIAL_TOKENS_MAP_FILE, "r") as f: |
|
|
special_tokens_map = json.load(f) |
|
|
|
|
|
for token in DESIRED_MAPPING: |
|
|
for key, value in token.to_special_tokens_map().items(): |
|
|
special_tokens_map[key] = value |
|
|
print(f"Updated special token map {key} content") |
|
|
|
|
|
with open(SPECIAL_TOKENS_MAP_FILE, "w") as f: |
|
|
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"Updated special tokens map file in {SPECIAL_TOKENS_MAP_FILE}.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
cli() |
|
|
|