soldni's picture
unescape utf
f62e5fc
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "click",
# ]
# ///
from dataclasses import dataclass, asdict, field
from enum import Enum
from pathlib import Path
import click
import json
class SpecialTokensMapEnum(Enum):
BOS_TOKEN = "bos_token"
EOS_TOKEN = "eos_token"
PAD_TOKEN = "pad_token"
UNK_TOKEN = "unk_token"
@dataclass(frozen=True)
class SpecialToken:
id: int
content: str
lstrip: bool = False
normalized: bool = False
rstrip: bool = False
single_word: bool = False
special: bool = False
special_token_map: list[SpecialTokensMapEnum] = field(default_factory=list)
def to_added_tokens_decoder(self):
data = asdict(self)
token_id = str(data.pop("id"))
data.pop("special_token_map")
return {token_id: data}
def to_added_tokens(self):
data = asdict(self)
data.pop("special_token_map")
return data
def to_special_tokens_map(self) -> dict[str, dict]:
special_tokens_map = {}
for special_token_map in self.special_token_map:
data = asdict(self)
data.pop("special_token_map")
data.pop("special")
data.pop("id")
special_tokens_map[special_token_map.value] = data
return special_tokens_map
MODEL_MAX_LENGTH = 65536
DESIRED_MAPPING = [
SpecialToken(id=100256, content="<|extra_id_0|>"),
SpecialToken(
id=100257,
content="<|endoftext|>",
special=True,
special_token_map=[
SpecialTokensMapEnum.BOS_TOKEN,
SpecialTokensMapEnum.EOS_TOKEN,
SpecialTokensMapEnum.UNK_TOKEN,
]),
SpecialToken(id=100258, content="<|fim_prefix|>", special=True),
SpecialToken(id=100259, content="<|fim_middle|>", special=True),
SpecialToken(id=100260, content="<|fim_suffix|>",special=True),
SpecialToken(id=100261, content="|||PHONE_NUMBER|||"),
SpecialToken(id=100262, content="|||EMAIL_ADDRESS|||"),
SpecialToken(id=100263, content="|||IP_ADDRESS|||"),
SpecialToken(id=100264, content="<|im_start|>", special=True),
SpecialToken(id=100265, content="<|im_end|>", special=True),
SpecialToken(id=100266, content="<|extra_id_1|>"),
SpecialToken(id=100267, content="<|extra_id_2|>"),
SpecialToken(id=100268, content="<|extra_id_3|>"),
SpecialToken(id=100269, content="<|extra_id_4|>"),
SpecialToken(id=100270, content="<|extra_id_5|>"),
SpecialToken(id=100271, content="<|extra_id_6|>"),
SpecialToken(id=100272, content="<|extra_id_7|>"),
SpecialToken(id=100273, content="<|extra_id_8|>"),
SpecialToken(id=100274, content="<|extra_id_9|>"),
SpecialToken(id=100275, content="<|extra_id_10|>"),
SpecialToken(id=100276, content="<|endofprompt|>", special=True),
SpecialToken(
id=100277,
content="<|pad|>",
special=True,
special_token_map=[SpecialTokensMapEnum.PAD_TOKEN],
),
]
SCRIPT_DIR = Path(__file__).parent
TOKENIZER_CONFIG_FILE = SCRIPT_DIR / "tokenizer_config.json"
TOKENIZER_FILE = SCRIPT_DIR / "tokenizer.json"
VOCAB_FILE = SCRIPT_DIR / "vocab.json"
SPECIAL_TOKENS_MAP_FILE = SCRIPT_DIR / "special_tokens_map.json"
@click.group()
def cli():
"""Dataset processing tools."""
pass
def _get_mapped_special_token(
special_tokens: list[SpecialToken],
mapped_token: SpecialTokensMapEnum
) -> SpecialToken:
all_mapped_tokens = [token for token in special_tokens if mapped_token in token.special_token_map]
if len(all_mapped_tokens) == 0:
raise ValueError(f"Cannot find mapped token for {mapped_token}")
if len(all_mapped_tokens) > 1:
all_mapped_tokens_str = ", ".join([token.content for token in all_mapped_tokens])
raise ValueError(f"Found multiple mapped tokens for {mapped_token}: {all_mapped_tokens_str}")
return all_mapped_tokens[0]
def get_unk_token(special_tokens: list[SpecialToken]) -> SpecialToken:
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.UNK_TOKEN)
def get_bos_token(special_tokens: list[SpecialToken]) -> SpecialToken:
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.BOS_TOKEN)
def get_eos_token(special_tokens: list[SpecialToken]) -> SpecialToken:
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.EOS_TOKEN)
def get_pad_token(special_tokens: list[SpecialToken]) -> SpecialToken:
return _get_mapped_special_token(special_tokens, SpecialTokensMapEnum.PAD_TOKEN)
@cli.command()
def check():
"""Check if the current config matches the desired mapping."""
# STEP 1: Check the Tokenizer Config File #
print("STEP 1: Checking tokenizer config file...")
if not TOKENIZER_CONFIG_FILE.exists():
raise FileNotFoundError(f"Tokenizer config file not found: {TOKENIZER_CONFIG_FILE}")
with open(TOKENIZER_CONFIG_FILE, "r") as f:
tokenizer_config = json.load(f)
added_tokens_decoder = tokenizer_config.get("added_tokens_decoder", {})
for token in DESIRED_MAPPING:
str_token_id = str(token.id)
if str_token_id not in added_tokens_decoder:
raise ValueError(f"Token {token.id} not found in added tokens decoder")
computed_added_tokens_decoder = token.to_added_tokens_decoder()
if computed_added_tokens_decoder[str_token_id] != added_tokens_decoder[str_token_id]:
raise ValueError(f"Token {token.id} has different content in added tokens decoder")
print(f"Token {token.id} found in added tokens decoder; content matches")
bos_token = get_bos_token(DESIRED_MAPPING)
if bos_token.content != tokenizer_config["bos_token"]:
raise ValueError(f"Bos token content mismatch: {bos_token.content} != {tokenizer_config['bos_token']}")
else:
print("Bos token content matches")
eos_token = get_eos_token(DESIRED_MAPPING)
if eos_token.content != tokenizer_config["eos_token"]:
raise ValueError(f"Eos token content mismatch: {eos_token.content} != {tokenizer_config['eos_token']}")
else:
print("Eos token content matches")
pad_token = get_pad_token(DESIRED_MAPPING)
if pad_token.content != tokenizer_config["pad_token"]:
raise ValueError(f"Pad token content mismatch: {pad_token.content} != {tokenizer_config['pad_token']}")
else:
print("Pad token content matches")
unk_token = get_unk_token(DESIRED_MAPPING)
if unk_token.content != tokenizer_config["unk_token"]:
raise ValueError(f"Unk token content mismatch: {unk_token.content} != {tokenizer_config['unk_token']}")
else:
print("Unk token content matches")
if tokenizer_config["model_max_length"] != MODEL_MAX_LENGTH:
raise ValueError(f"Model max length mismatch: {tokenizer_config['model_max_length']} != {MODEL_MAX_LENGTH}")
else:
print("Model max length matches")
# STEP 2: Check the Tokenizer File #
print("STEP 2: Checking tokenizer file...")
if not TOKENIZER_FILE.exists():
raise FileNotFoundError(f"Tokenizer file not found: {TOKENIZER_FILE}")
with open(TOKENIZER_FILE, "r") as f:
tokenizer = json.load(f)
# check if added_tokens matches
added_tokens_dict = {token["id"]: token for token in tokenizer.get("added_tokens", [])}
for token in DESIRED_MAPPING:
if token.id not in added_tokens_dict:
raise ValueError(f"Token {token.id} not found in added tokens")
computed_added_token = token.to_added_tokens()
if computed_added_token != added_tokens_dict[token.id]:
raise ValueError(f"Token {token.id} has different content in added tokens")
print(f"Token {token.id} found in added tokens; content matches.")
# check vocab
vocab = tokenizer.get("model", {}).get("vocab", {})
for token in DESIRED_MAPPING:
if token.content not in vocab:
raise ValueError(f"Token `{token.content}` not found in vocab")
if token.id != vocab[token.content]:
raise ValueError(f"Token `{token.content}`: vocab=`{vocab[token.content]}` provided=`{token.id}`")
print(f"Token `{token.content}` found in vocab; id `{token.id}` matches.")
seen_values: dict[int, list[str]] = {}
for key, value in vocab.items():
seen_values.setdefault(value, []).append(key)
broken_vocab = False
for value, keys in seen_values.items():
if len(keys) > 1:
broken_vocab = True
print(f"Vocab value {value} is not unique; keys: {keys}")
if broken_vocab:
raise ValueError("Vocab values are not unique")
else:
print("Vocab values are unique")
# STEP 3: Check the Vocab File #
print("STEP 3: Checking vocab file...")
if not VOCAB_FILE.exists():
raise FileNotFoundError(f"Vocab file not found: {VOCAB_FILE}")
with open(VOCAB_FILE, "r") as f:
vocab = json.load(f)
for token in DESIRED_MAPPING:
if token.content not in vocab:
raise ValueError(f"Token `{token.content}` not found in vocab")
if token.id != vocab[token.content]:
raise ValueError(f"Token `{token.content}`: vocab=`{vocab[token.content]}` provided=`{token.id}`")
print(f"Token `{token.content}` found in vocab; id `{token.id}` matches.")
if len(set(vocab.values())) != len(vocab):
raise ValueError("Vocab values are not unique")
# STEP 4: Check the Special Tokens Map File #
print("STEP 4: Checking special tokens map file...")
if not SPECIAL_TOKENS_MAP_FILE.exists():
raise FileNotFoundError(f"Special tokens map file not found: {SPECIAL_TOKENS_MAP_FILE}")
with open(SPECIAL_TOKENS_MAP_FILE, "r") as f:
special_tokens_map = json.load(f)
# This checks the special tokens map file.
seen_special_tokens = set()
for token in DESIRED_MAPPING:
for key, value in token.to_special_tokens_map().items():
if key not in special_tokens_map:
raise ValueError(f"Special token map {key} not found in special tokens map")
if value != special_tokens_map[key]:
raise ValueError(f"Special token map {key} content mismatch: {value} != {special_tokens_map[key]}")
print(f"Special token map {key} content matches")
seen_special_tokens.add(key)
if len(seen_special_tokens) != len(special_tokens_map):
raise ValueError("Special tokens map values are not unique")
print("All special tokens map values match")
@cli.command()
def fix():
"""Fix the tokens in the tokenizer config, tokenizer file, vocab file, and special tokens map file."""
print("STEP 1: Fixing tokenizer config file...")
with open(TOKENIZER_CONFIG_FILE, "r") as f:
tokenizer_config = json.load(f)
tokenizer_config["bos_token"] = get_bos_token(DESIRED_MAPPING).content
tokenizer_config["eos_token"] = get_eos_token(DESIRED_MAPPING).content
tokenizer_config["pad_token"] = get_pad_token(DESIRED_MAPPING).content
tokenizer_config["unk_token"] = get_unk_token(DESIRED_MAPPING).content
tokenizer_config["model_max_length"] = MODEL_MAX_LENGTH
added_tokens_decoder = {}
for token in DESIRED_MAPPING:
added_tokens_decoder.update(token.to_added_tokens_decoder())
tokenizer_config["added_tokens_decoder"] = added_tokens_decoder
with open(TOKENIZER_CONFIG_FILE, "w") as f:
json.dump(tokenizer_config, f, indent=2, ensure_ascii=False)
print(f"Updated tokenizer config file in {TOKENIZER_CONFIG_FILE}.")
print("STEP 2: Fixing tokenizer file...")
with open(TOKENIZER_FILE, "r") as f:
tokenizer = json.load(f)
added_tokens = []
for token in DESIRED_MAPPING:
added_tokens.append(token.to_added_tokens())
tokenizer["added_tokens"] = added_tokens
for token in DESIRED_MAPPING:
# check if vocab id is used already
for key in list(tokenizer["model"]["vocab"].keys()):
if tokenizer["model"]["vocab"][key] == token.id:
tokenizer["model"]["vocab"].pop(key)
# now that we know this is safe, add the token
tokenizer["model"]["vocab"][token.content] = token.id
with open(TOKENIZER_FILE, "w") as f:
json.dump(tokenizer, f, indent=2, ensure_ascii=False)
print(f"Updated tokenizer file in {TOKENIZER_FILE}.")
print("STEP 3: Fixing vocab file...")
with open(VOCAB_FILE, "r") as f:
vocab = json.load(f)
for token in DESIRED_MAPPING:
# check if vocab id is used already
for key in list(vocab.keys()):
if vocab[key] == token.id:
vocab.pop(key)
# now that we know this is safe, add the token
vocab[token.content] = token.id
with open(VOCAB_FILE, "w") as f:
json.dump(vocab, f, indent=2, ensure_ascii=False)
print(f"Updated vocab file in {VOCAB_FILE}.")
print("STEP 4: Fixing special tokens map file...")
with open(SPECIAL_TOKENS_MAP_FILE, "r") as f:
special_tokens_map = json.load(f)
for token in DESIRED_MAPPING:
for key, value in token.to_special_tokens_map().items():
special_tokens_map[key] = value
print(f"Updated special token map {key} content")
with open(SPECIAL_TOKENS_MAP_FILE, "w") as f:
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
print(f"Updated special tokens map file in {SPECIAL_TOKENS_MAP_FILE}.")
if __name__ == "__main__":
cli()