|
|
""" |
|
|
Byte-Level Tokenizer V6.1.2 - Compression-First Learning |
|
|
No vocabulary, no language rules - just bytes |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from typing import List, Dict, Union, Optional |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class ByteTokenizerV6: |
|
|
""" |
|
|
Pure byte-level tokenizer |
|
|
- No vocabulary needed (bytes are 0-255) |
|
|
- No language-specific rules |
|
|
- Model learns all patterns from data |
|
|
""" |
|
|
|
|
|
def __init__(self, max_seq_len: int = 64): |
|
|
"""Initialize byte tokenizer""" |
|
|
|
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
|
|
|
self.PAD = 256 |
|
|
self.BOS = 257 |
|
|
self.EOS = 258 |
|
|
self.MASK = 259 |
|
|
|
|
|
|
|
|
self.vocab_size = 260 |
|
|
|
|
|
print(f"Byte tokenizer initialized (vocab_size={self.vocab_size})") |
|
|
|
|
|
def encode(self, text: str, add_special_tokens: bool = True) -> Dict: |
|
|
""" |
|
|
Encode text to byte IDs |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
add_special_tokens: Whether to add BOS/EOS |
|
|
|
|
|
Returns: |
|
|
dict with 'input_ids', 'attention_mask', 'length' |
|
|
""" |
|
|
|
|
|
byte_sequence = list(text.encode('utf-8')) |
|
|
|
|
|
|
|
|
max_len = self.max_seq_len - 2 if add_special_tokens else self.max_seq_len |
|
|
if len(byte_sequence) > max_len: |
|
|
byte_sequence = byte_sequence[:max_len] |
|
|
|
|
|
|
|
|
if add_special_tokens: |
|
|
input_ids = [self.BOS] + byte_sequence + [self.EOS] |
|
|
else: |
|
|
input_ids = byte_sequence |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(input_ids) |
|
|
|
|
|
return { |
|
|
'input_ids': input_ids, |
|
|
'attention_mask': attention_mask, |
|
|
'length': len(input_ids) |
|
|
} |
|
|
|
|
|
def encode_batch(self, texts: List[str], add_special_tokens: bool = True) -> Dict: |
|
|
""" |
|
|
Encode multiple texts with padding |
|
|
|
|
|
Args: |
|
|
texts: List of input texts |
|
|
add_special_tokens: Whether to add special tokens |
|
|
|
|
|
Returns: |
|
|
Batched tensors with padding |
|
|
""" |
|
|
encoded_texts = [] |
|
|
max_length = 0 |
|
|
|
|
|
|
|
|
for text in texts: |
|
|
encoded = self.encode(text, add_special_tokens) |
|
|
encoded_texts.append(encoded) |
|
|
max_length = max(max_length, encoded['length']) |
|
|
|
|
|
|
|
|
max_length = min(max_length, self.max_seq_len) |
|
|
|
|
|
|
|
|
batch_size = len(texts) |
|
|
input_ids = np.full((batch_size, max_length), self.PAD, dtype=np.int64) |
|
|
attention_mask = np.zeros((batch_size, max_length), dtype=np.float32) |
|
|
|
|
|
|
|
|
for i, encoded in enumerate(encoded_texts): |
|
|
seq_len = min(encoded['length'], max_length) |
|
|
input_ids[i, :seq_len] = encoded['input_ids'][:seq_len] |
|
|
attention_mask[i, :seq_len] = 1.0 |
|
|
|
|
|
return { |
|
|
'input_ids': torch.tensor(input_ids, dtype=torch.long), |
|
|
'attention_mask': torch.tensor(attention_mask, dtype=torch.float32), |
|
|
'lengths': torch.tensor([e['length'] for e in encoded_texts], dtype=torch.long) |
|
|
} |
|
|
|
|
|
def decode(self, input_ids: Union[List[int], torch.Tensor, np.ndarray], |
|
|
skip_special_tokens: bool = True) -> str: |
|
|
""" |
|
|
Decode byte IDs back to text |
|
|
|
|
|
Args: |
|
|
input_ids: Byte ID sequence |
|
|
skip_special_tokens: Whether to skip special tokens |
|
|
|
|
|
Returns: |
|
|
Decoded text string |
|
|
""" |
|
|
|
|
|
if isinstance(input_ids, torch.Tensor): |
|
|
input_ids = input_ids.cpu().numpy().tolist() |
|
|
elif isinstance(input_ids, np.ndarray): |
|
|
input_ids = input_ids.tolist() |
|
|
|
|
|
|
|
|
if skip_special_tokens: |
|
|
|
|
|
input_ids = [b for b in input_ids if 0 <= b <= 255] |
|
|
else: |
|
|
|
|
|
processed = [] |
|
|
for b in input_ids: |
|
|
if b == self.PAD: |
|
|
continue |
|
|
elif b == self.BOS: |
|
|
processed.append(ord('[')) |
|
|
elif b == self.EOS: |
|
|
processed.append(ord(']')) |
|
|
elif b == self.MASK: |
|
|
processed.append(ord('*')) |
|
|
elif 0 <= b <= 255: |
|
|
processed.append(b) |
|
|
input_ids = processed |
|
|
|
|
|
|
|
|
if not input_ids: |
|
|
return "" |
|
|
|
|
|
try: |
|
|
|
|
|
valid_bytes = [] |
|
|
i = 0 |
|
|
while i < len(input_ids): |
|
|
b = input_ids[i] |
|
|
if b < 128: |
|
|
valid_bytes.append(b) |
|
|
i += 1 |
|
|
elif 192 <= b < 224: |
|
|
if i + 1 < len(input_ids) and 128 <= input_ids[i+1] < 192: |
|
|
valid_bytes.extend(input_ids[i:i+2]) |
|
|
i += 2 |
|
|
else: |
|
|
i += 1 |
|
|
elif 224 <= b < 240: |
|
|
if i + 2 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+3, len(input_ids)))): |
|
|
valid_bytes.extend(input_ids[i:i+3]) |
|
|
i += 3 |
|
|
else: |
|
|
i += 1 |
|
|
elif 240 <= b < 248: |
|
|
if i + 3 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+4, len(input_ids)))): |
|
|
valid_bytes.extend(input_ids[i:i+4]) |
|
|
i += 4 |
|
|
else: |
|
|
i += 1 |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
|
|
|
if valid_bytes: |
|
|
byte_array = bytes(valid_bytes) |
|
|
text = byte_array.decode('utf-8', errors='replace') |
|
|
return text |
|
|
else: |
|
|
return "" |
|
|
except Exception as e: |
|
|
|
|
|
return "".join([chr(b) if b < 128 else '' for b in input_ids]) |
|
|
|
|
|
def decode_batch(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]: |
|
|
""" |
|
|
Decode a batch of byte sequences |
|
|
|
|
|
Args: |
|
|
input_ids: Batch of byte IDs (batch_size, seq_len) |
|
|
skip_special_tokens: Whether to skip special tokens |
|
|
|
|
|
Returns: |
|
|
List of decoded texts |
|
|
""" |
|
|
texts = [] |
|
|
for i in range(input_ids.shape[0]): |
|
|
text = self.decode(input_ids[i], skip_special_tokens) |
|
|
texts.append(text) |
|
|
return texts |
|
|
|
|
|
def tokenize(self, text: str) -> List[int]: |
|
|
""" |
|
|
Simple tokenization to byte IDs (no special tokens) |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
List of byte IDs |
|
|
""" |
|
|
return list(text.encode('utf-8')) |
|
|
|
|
|
def detokenize(self, byte_ids: List[int]) -> str: |
|
|
""" |
|
|
Simple detokenization from byte IDs |
|
|
|
|
|
Args: |
|
|
byte_ids: List of byte IDs |
|
|
|
|
|
Returns: |
|
|
Decoded text |
|
|
""" |
|
|
try: |
|
|
return bytes(byte_ids).decode('utf-8', errors='replace') |
|
|
except: |
|
|
return "".join([chr(b) if b < 128 else '?' for b in byte_ids]) |
|
|
|
|
|
def get_vocab_size(self) -> int: |
|
|
"""Get vocabulary size""" |
|
|
return self.vocab_size |
|
|
|
|
|
def get_special_tokens(self) -> Dict[str, int]: |
|
|
"""Get special token IDs""" |
|
|
return { |
|
|
'pad_id': self.PAD, |
|
|
'bos_id': self.BOS, |
|
|
'eos_id': self.EOS, |
|
|
'mask_id': self.MASK |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
tokenizer = ByteTokenizerV6() |
|
|
|
|
|
|
|
|
test_texts = [ |
|
|
"Hello World!", |
|
|
"안녕하세요", |
|
|
"你好世界", |
|
|
"こんにちは", |
|
|
"مرحبا بالعالم", |
|
|
"Здравствуй мир" |
|
|
] |
|
|
|
|
|
print("=" * 50) |
|
|
print("Single Text Encoding/Decoding Test") |
|
|
print("=" * 50) |
|
|
|
|
|
for text in test_texts: |
|
|
print(f"\nOriginal: {text}") |
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text) |
|
|
print(f"Encoded length: {encoded['length']}") |
|
|
print(f"First 10 bytes: {encoded['input_ids'][:10]}") |
|
|
|
|
|
|
|
|
decoded = tokenizer.decode(encoded['input_ids']) |
|
|
print(f"Decoded: {decoded}") |
|
|
print(f"Match: {decoded == text}") |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("Batch Encoding/Decoding Test") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
batch_result = tokenizer.encode_batch(test_texts) |
|
|
print(f"Batch shape: {batch_result['input_ids'].shape}") |
|
|
print(f"Attention mask shape: {batch_result['attention_mask'].shape}") |
|
|
|
|
|
|
|
|
decoded_texts = tokenizer.decode_batch(batch_result['input_ids']) |
|
|
print("\nBatch decoding results:") |
|
|
for orig, dec in zip(test_texts, decoded_texts): |
|
|
print(f"Original: {orig}") |
|
|
print(f"Decoded: {dec}") |
|
|
print(f"Match: {orig == dec}") |
|
|
print() |