""" Train a tokenizer using the HuggingFace Tokenizers library. In the style of GPT-4 tokenizer. """ import os import time import argparse import torch from nanochat.tokenizer import RustBPETokenizer from nanochat.common import get_base_dir from nanochat.dataset import parquets_iter_batched # ----------------------------------------------------------------------------- # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') parser.add_argument('--max_chars', type=int, default=10_000_000_000, help='Maximum characters to train on (default: 10B)') parser.add_argument('--doc_cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--vocab_size', type=int, default=65536, help='Vocabulary size (default: 65536 = 2^16)') args = parser.parse_args() print(f"max_chars: {args.max_chars:,}") print(f"doc_cap: {args.doc_cap:,}") print(f"vocab_size: {args.vocab_size:,}") # ----------------------------------------------------------------------------- # Text iterator def text_iterator(): """ 1) Flatten the batches into a single iterator 2) Crop every document to args.doc_cap characters 3) Break when we've seen args.max_chars characters """ nchars = 0 for batch in parquets_iter_batched(split="train"): for doc in batch: doc_text = doc if len(doc_text) > args.doc_cap: doc_text = doc_text[:args.doc_cap] nchars += len(doc_text) yield doc_text if nchars > args.max_chars: return text_iter = text_iterator() # ----------------------------------------------------------------------------- # Train the tokenizer t0 = time.time() tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size) t1 = time.time() train_time = t1 - t0 print(f"Training time: {train_time:.2f}s") # ----------------------------------------------------------------------------- # Save the tokenizer to disk base_dir = get_base_dir() tokenizer_dir = os.path.join(base_dir, "tokenizer") tokenizer.save(tokenizer_dir) # ----------------------------------------------------------------------------- # Quick inline sanity check test_text = """Hello world! This is a test. Numbers: 123, 4567, 89 Contractions: I'm, you're, it's Special chars: @#$%^&*() Unicode: 你好世界 🌍""" encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) assert decoded == test_text # ----------------------------------------------------------------------------- # One more thing: we wish to cache a mapping from token id to number of bytes of that token # for efficient evaluation of bits per byte. Unlike the typical mean loss, this # allows us to report a loss that is invariant to the vocab size of the tokenizer. # The bits per byte on the validation set is then one of the primary metrics we care about. vocab_size = tokenizer.get_vocab_size() special_set = set(tokenizer.get_special_tokens()) token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)] token_bytes = [] for token_id in range(vocab_size): token_str = token_strings[token_id] # the Python string representation of this token if token_str in special_set: token_bytes.append(0) # special characters are not counted else: id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token token_bytes.append(id_bytes) token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu') token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") with open(token_bytes_path, "wb") as f: torch.save(token_bytes, f) print(f"Saved token_bytes to {token_bytes_path}") # Log to report from nanochat.report import get_report token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32) get_report().log(section="Tokenizer training", data=[ vars(args), # argparse command line arguments {"train_time": train_time}, {"num_special_tokens": len(special_set)}, { "token_bytes_min": int(token_bytes_nonzero.min().item()), "token_bytes_max": int(token_bytes_nonzero.max().item()), "token_bytes_mean": token_bytes_nonzero.mean().item(), "token_bytes_std": token_bytes_nonzero.std().item(), } ])