intelligent-tokenizer-v6-demo / core /byte_tokenizer_v6.py
ggunio's picture
Fix import error by adding core module files
c2e3f6e
"""
Byte-Level Tokenizer V6.1.2 - Compression-First Learning
No vocabulary, no language rules - just bytes
"""
import torch
from typing import List, Dict, Union, Optional
import numpy as np
class ByteTokenizerV6:
"""
Pure byte-level tokenizer
- No vocabulary needed (bytes are 0-255)
- No language-specific rules
- Model learns all patterns from data
"""
def __init__(self, max_seq_len: int = 64):
"""Initialize byte tokenizer"""
self.max_seq_len = max_seq_len
# Special tokens (beyond byte range 0-255)
self.PAD = 256
self.BOS = 257
self.EOS = 258
self.MASK = 259
# Total vocabulary size = 256 bytes + 4 special tokens
self.vocab_size = 260
print(f"Byte tokenizer initialized (vocab_size={self.vocab_size})")
def encode(self, text: str, add_special_tokens: bool = True) -> Dict:
"""
Encode text to byte IDs
Args:
text: Input text
add_special_tokens: Whether to add BOS/EOS
Returns:
dict with 'input_ids', 'attention_mask', 'length'
"""
# Convert text to UTF-8 bytes (pure bytes, no rules)
byte_sequence = list(text.encode('utf-8'))
# Truncate if necessary
max_len = self.max_seq_len - 2 if add_special_tokens else self.max_seq_len
if len(byte_sequence) > max_len:
byte_sequence = byte_sequence[:max_len]
# Add special tokens
if add_special_tokens:
input_ids = [self.BOS] + byte_sequence + [self.EOS]
else:
input_ids = byte_sequence
# Create attention mask (1 for real tokens, 0 for padding)
attention_mask = [1] * len(input_ids)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'length': len(input_ids)
}
def encode_batch(self, texts: List[str], add_special_tokens: bool = True) -> Dict:
"""
Encode multiple texts with padding
Args:
texts: List of input texts
add_special_tokens: Whether to add special tokens
Returns:
Batched tensors with padding
"""
encoded_texts = []
max_length = 0
# Encode each text
for text in texts:
encoded = self.encode(text, add_special_tokens)
encoded_texts.append(encoded)
max_length = max(max_length, encoded['length'])
# Limit to max sequence length
max_length = min(max_length, self.max_seq_len)
# Initialize batch tensors
batch_size = len(texts)
input_ids = np.full((batch_size, max_length), self.PAD, dtype=np.int64)
attention_mask = np.zeros((batch_size, max_length), dtype=np.float32)
# Fill batch tensors
for i, encoded in enumerate(encoded_texts):
seq_len = min(encoded['length'], max_length)
input_ids[i, :seq_len] = encoded['input_ids'][:seq_len]
attention_mask[i, :seq_len] = 1.0
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.float32),
'lengths': torch.tensor([e['length'] for e in encoded_texts], dtype=torch.long)
}
def decode(self, input_ids: Union[List[int], torch.Tensor, np.ndarray],
skip_special_tokens: bool = True) -> str:
"""
Decode byte IDs back to text
Args:
input_ids: Byte ID sequence
skip_special_tokens: Whether to skip special tokens
Returns:
Decoded text string
"""
# Convert to list if needed
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.cpu().numpy().tolist()
elif isinstance(input_ids, np.ndarray):
input_ids = input_ids.tolist()
# Filter special tokens if requested
if skip_special_tokens:
# Only keep actual bytes (0-255)
input_ids = [b for b in input_ids if 0 <= b <= 255]
else:
# Replace special tokens with readable markers
processed = []
for b in input_ids:
if b == self.PAD:
continue # Skip padding
elif b == self.BOS:
processed.append(ord('[')) # Use [ for BOS
elif b == self.EOS:
processed.append(ord(']')) # Use ] for EOS
elif b == self.MASK:
processed.append(ord('*')) # Use * for MASK
elif 0 <= b <= 255:
processed.append(b)
input_ids = processed
# Convert bytes to text
if not input_ids:
return ""
try:
# 유효한 UTF-8 시퀀스만 추출
valid_bytes = []
i = 0
while i < len(input_ids):
b = input_ids[i]
if b < 128: # ASCII
valid_bytes.append(b)
i += 1
elif 192 <= b < 224: # 2-byte UTF-8
if i + 1 < len(input_ids) and 128 <= input_ids[i+1] < 192:
valid_bytes.extend(input_ids[i:i+2])
i += 2
else:
i += 1 # Skip invalid
elif 224 <= b < 240: # 3-byte UTF-8
if i + 2 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+3, len(input_ids)))):
valid_bytes.extend(input_ids[i:i+3])
i += 3
else:
i += 1 # Skip invalid
elif 240 <= b < 248: # 4-byte UTF-8
if i + 3 < len(input_ids) and all(128 <= input_ids[j] < 192 for j in range(i+1, min(i+4, len(input_ids)))):
valid_bytes.extend(input_ids[i:i+4])
i += 4
else:
i += 1 # Skip invalid
else:
i += 1 # Skip invalid byte
# Decode valid bytes
if valid_bytes:
byte_array = bytes(valid_bytes)
text = byte_array.decode('utf-8', errors='replace') # replace로 변경
return text
else:
return ""
except Exception as e:
# Fallback: convert ASCII only
return "".join([chr(b) if b < 128 else '' for b in input_ids])
def decode_batch(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> List[str]:
"""
Decode a batch of byte sequences
Args:
input_ids: Batch of byte IDs (batch_size, seq_len)
skip_special_tokens: Whether to skip special tokens
Returns:
List of decoded texts
"""
texts = []
for i in range(input_ids.shape[0]):
text = self.decode(input_ids[i], skip_special_tokens)
texts.append(text)
return texts
def tokenize(self, text: str) -> List[int]:
"""
Simple tokenization to byte IDs (no special tokens)
Args:
text: Input text
Returns:
List of byte IDs
"""
return list(text.encode('utf-8'))
def detokenize(self, byte_ids: List[int]) -> str:
"""
Simple detokenization from byte IDs
Args:
byte_ids: List of byte IDs
Returns:
Decoded text
"""
try:
return bytes(byte_ids).decode('utf-8', errors='replace')
except:
return "".join([chr(b) if b < 128 else '?' for b in byte_ids])
def get_vocab_size(self) -> int:
"""Get vocabulary size"""
return self.vocab_size
def get_special_tokens(self) -> Dict[str, int]:
"""Get special token IDs"""
return {
'pad_id': self.PAD,
'bos_id': self.BOS,
'eos_id': self.EOS,
'mask_id': self.MASK
}
# Test code
if __name__ == "__main__":
# Initialize tokenizer
tokenizer = ByteTokenizerV6()
# Test texts in multiple languages
test_texts = [
"Hello World!",
"안녕하세요",
"你好世界",
"こんにちは",
"مرحبا بالعالم",
"Здравствуй мир"
]
print("=" * 50)
print("Single Text Encoding/Decoding Test")
print("=" * 50)
for text in test_texts:
print(f"\nOriginal: {text}")
# Encode
encoded = tokenizer.encode(text)
print(f"Encoded length: {encoded['length']}")
print(f"First 10 bytes: {encoded['input_ids'][:10]}")
# Decode
decoded = tokenizer.decode(encoded['input_ids'])
print(f"Decoded: {decoded}")
print(f"Match: {decoded == text}")
print("\n" + "=" * 50)
print("Batch Encoding/Decoding Test")
print("=" * 50)
# Batch test
batch_result = tokenizer.encode_batch(test_texts)
print(f"Batch shape: {batch_result['input_ids'].shape}")
print(f"Attention mask shape: {batch_result['attention_mask'].shape}")
# Decode batch
decoded_texts = tokenizer.decode_batch(batch_result['input_ids'])
print("\nBatch decoding results:")
for orig, dec in zip(test_texts, decoded_texts):
print(f"Original: {orig}")
print(f"Decoded: {dec}")
print(f"Match: {orig == dec}")
print()