|
|
"""
|
|
|
Byte-Level Tokenizer V6 - Pure Learning Based
|
|
|
No vocabulary, no language rules - just bytes
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from typing import List, Dict, Union, Optional
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class ByteTokenizerV6:
|
|
|
"""
|
|
|
Pure byte-level tokenizer
|
|
|
- No vocabulary needed (bytes are 0-255)
|
|
|
- No language-specific rules
|
|
|
- Model learns all patterns from data
|
|
|
"""
|
|
|
|
|
|
def __init__(self, max_seq_len: int = 512):
|
|
|
"""Initialize byte tokenizer"""
|
|
|
|
|
|
self.max_seq_len = max_seq_len
|
|
|
|
|
|
|
|
|
self.PAD = 256
|
|
|
self.BOS = 257
|
|
|
self.EOS = 258
|
|
|
self.MASK = 259
|
|
|
|
|
|
|
|
|
self.vocab_size = 260
|
|
|
|
|
|
print(f"Byte tokenizer initialized (vocab_size={self.vocab_size})")
|
|
|
|
|
|
def encode(self, text: str, add_special_tokens: bool = True) -> Dict:
|
|
|
"""
|
|
|
Encode text to byte IDs
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
add_special_tokens: Whether to add BOS/EOS
|
|
|
|
|
|
Returns:
|
|
|
dict with 'input_ids', 'attention_mask', 'length'
|
|
|
"""
|
|
|
|
|
|
byte_sequence = list(text.encode('utf-8'))
|
|
|
|
|
|
|
|
|
max_len = self.max_seq_len - 2 if add_special_tokens else self.max_seq_len
|
|
|
if len(byte_sequence) > max_len:
|
|
|
byte_sequence = byte_sequence[:max_len]
|
|
|
|
|
|
|
|
|
if add_special_tokens:
|
|
|
input_ids = [self.BOS] + byte_sequence + [self.EOS]
|
|
|
else:
|
|
|
input_ids = byte_sequence
|
|
|
|
|
|
|
|
|
attention_mask = [1] * len(input_ids)
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'length': len(input_ids)
|
|
|
}
|
|
|
|
|
|
def encode_batch(self, texts: List[str], add_special_tokens: bool = True) -> Dict:
|
|
|
"""
|
|
|
Encode multiple texts with padding
|
|
|
|
|
|
Args:
|
|
|
texts: List of input texts
|
|
|
add_special_tokens: Whether to add special tokens
|
|
|
|
|
|
Returns:
|
|
|
Batched tensors with padding
|
|
|
"""
|
|
|
encoded_texts = []
|
|
|
max_length = 0
|
|
|
|
|
|
|
|
|
for text in texts:
|
|
|
encoded = self.encode(text, add_special_tokens)
|
|
|
encoded_texts.append(encoded)
|
|
|
max_length = max(max_length, encoded['length'])
|
|
|
|
|
|
|
|
|
max_length = min(max_length, self.max_seq_len)
|
|
|
|
|
|
|
|
|
batch_size = len(texts)
|
|
|
input_ids = np.full((batch_size, max_length), self.PAD, dtype=np.int64)
|
|
|
attention_mask = np.zeros((batch_size, max_length), dtype=np.float32)
|
|
|
|
|
|
|
|
|
for i, encoded in enumerate(encoded_texts):
|
|
|
seq_len = min(encoded['length'], max_length)
|
|
|
input_ids[i, :seq_len] = encoded['input_ids'][:seq_len]
|
|
|
attention_mask[i, :seq_len] = 1.0
|
|
|
|
|
|
return {
|
|
|
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
|
|
'attention_mask': torch.tensor(attention_mask, dtype=torch.float32),
|
|
|
'lengths': torch.tensor([e['length'] for e in encoded_texts], dtype=torch.long)
|
|
|
}
|
|
|
|
|
|
def decode(self, input_ids: Union[List[int], torch.Tensor, np.ndarray],
|
|
|
skip_special_tokens: bool = True) -> str:
|
|
|
"""
|
|
|
Decode byte IDs back to text
|
|
|
|
|
|
Args:
|
|
|
input_ids: Byte ID sequence
|
|
|
skip_special_tokens: Whether to skip special tokens
|
|
|
|
|
|
Returns:
|
|
|
Decoded text string
|
|
|
"""
|
|
|
|
|
|
if isinstance(input_ids, torch.Tensor):
|
|
|
input_ids = input_ids.cpu().numpy().tolist()
|
|
|
elif isinstance(input_ids, np.ndarray):
|
|
|
input_ids = input_ids.tolist()
|
|
|
|
|
|
|
|
|
if skip_special_tokens:
|
|
|
|
|
|
input_ids = [b for b in input_ids if 0 <= b <= 255]
|
|
|
else:
|
|
|
|
|
|
processed = []
|
|
|
for b in input_ids:
|
|
|
if b == self.PAD:
|
|
|
continue
|
|
|
elif b == self.BOS:
|
|
|
processed.append(ord('['))
|
|
|
elif b == self.EOS:
|
|
|
processed.append(ord(']'))
|
|
|
elif b == self.MASK:
|
|
|
processed.append(ord('*'))
|
|
|
elif 0 <= b <= 255:
|
|
|
processed.append(b)
|
|
|
input_ids = processed
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
byte_array = bytes(input_ids)
|
|
|
text = byte_array.decode('utf-8', errors='replace')
|
|
|
return text
|
|
|
except Exception as e:
|
|
|
|
|
|
return "".join([chr(b) if b < 128 else '?' for b in input_ids])
|
|
|
|
|
|
def decode_batch(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
|
|
|
"""
|
|
|
Decode a batch of byte sequences
|
|
|
|
|
|
Args:
|
|
|
input_ids: Batch of byte IDs (batch_size, seq_len)
|
|
|
skip_special_tokens: Whether to skip special tokens
|
|
|
|
|
|
Returns:
|
|
|
List of decoded texts
|
|
|
"""
|
|
|
texts = []
|
|
|
for i in range(input_ids.shape[0]):
|
|
|
text = self.decode(input_ids[i], skip_special_tokens)
|
|
|
texts.append(text)
|
|
|
return texts
|
|
|
|
|
|
def tokenize(self, text: str) -> List[int]:
|
|
|
"""
|
|
|
Simple tokenization to byte IDs (no special tokens)
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
|
|
|
Returns:
|
|
|
List of byte IDs
|
|
|
"""
|
|
|
return list(text.encode('utf-8'))
|
|
|
|
|
|
def detokenize(self, byte_ids: List[int]) -> str:
|
|
|
"""
|
|
|
Simple detokenization from byte IDs
|
|
|
|
|
|
Args:
|
|
|
byte_ids: List of byte IDs
|
|
|
|
|
|
Returns:
|
|
|
Decoded text
|
|
|
"""
|
|
|
try:
|
|
|
return bytes(byte_ids).decode('utf-8', errors='replace')
|
|
|
except:
|
|
|
return "".join([chr(b) if b < 128 else '?' for b in byte_ids])
|
|
|
|
|
|
def get_vocab_size(self) -> int:
|
|
|
"""Get vocabulary size"""
|
|
|
return self.vocab_size
|
|
|
|
|
|
def get_special_tokens(self) -> Dict[str, int]:
|
|
|
"""Get special token IDs"""
|
|
|
return {
|
|
|
'pad_id': self.PAD,
|
|
|
'bos_id': self.BOS,
|
|
|
'eos_id': self.EOS,
|
|
|
'mask_id': self.MASK
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
tokenizer = ByteTokenizerV6()
|
|
|
|
|
|
|
|
|
test_texts = [
|
|
|
"Hello World!",
|
|
|
"안녕하세요",
|
|
|
"你好世界",
|
|
|
"こんにちは",
|
|
|
"مرحبا بالعالم",
|
|
|
"Здравствуй мир"
|
|
|
]
|
|
|
|
|
|
print("=" * 50)
|
|
|
print("Single Text Encoding/Decoding Test")
|
|
|
print("=" * 50)
|
|
|
|
|
|
for text in test_texts:
|
|
|
print(f"\nOriginal: {text}")
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text)
|
|
|
print(f"Encoded length: {encoded['length']}")
|
|
|
print(f"First 10 bytes: {encoded['input_ids'][:10]}")
|
|
|
|
|
|
|
|
|
decoded = tokenizer.decode(encoded['input_ids'])
|
|
|
print(f"Decoded: {decoded}")
|
|
|
print(f"Match: {decoded == text}")
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
print("Batch Encoding/Decoding Test")
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
|
batch_result = tokenizer.encode_batch(test_texts)
|
|
|
print(f"Batch shape: {batch_result['input_ids'].shape}")
|
|
|
print(f"Attention mask shape: {batch_result['attention_mask'].shape}")
|
|
|
|
|
|
|
|
|
decoded_texts = tokenizer.decode_batch(batch_result['input_ids'])
|
|
|
print("\nBatch decoding results:")
|
|
|
for orig, dec in zip(test_texts, decoded_texts):
|
|
|
print(f"Original: {orig}")
|
|
|
print(f"Decoded: {dec}")
|
|
|
print(f"Match: {orig == dec}")
|
|
|
print() |