|
|
"""
|
|
|
Intelligent Tokenizer v6.2.0 - Byte Tokenizer with 46+2 Configuration
|
|
|
Handles chunking, sliding windows, and boundary adjustments
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
def _trim_utf8_boundary(byte_seq: List[int], limit: int) -> int:
|
|
|
"""
|
|
|
Trim byte sequence to valid UTF-8 boundary (GPT suggestion)
|
|
|
"""
|
|
|
end = min(limit, len(byte_seq))
|
|
|
while end > 0:
|
|
|
try:
|
|
|
bytes(byte_seq[:end]).decode('utf-8')
|
|
|
return end
|
|
|
except UnicodeDecodeError:
|
|
|
end -= 1
|
|
|
return limit
|
|
|
|
|
|
|
|
|
class ByteTokenizerV62:
|
|
|
"""
|
|
|
Pure byte-level tokenizer
|
|
|
46 content bytes + 2 special tokens (BOS/EOS) = 48 total
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
|
|
|
self.content_size = 46
|
|
|
self.max_seq_len = 48
|
|
|
self.chunk_overlap = 8
|
|
|
|
|
|
|
|
|
self.PAD = 256
|
|
|
self.BOS = 257
|
|
|
self.EOS = 258
|
|
|
self.MASK = 259
|
|
|
self.vocab_size = 260
|
|
|
|
|
|
def encode(self,
|
|
|
text: str,
|
|
|
add_special_tokens: bool = True,
|
|
|
return_chunks: bool = False) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Encode text to byte sequences
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
add_special_tokens: Whether to add BOS/EOS
|
|
|
return_chunks: Return multiple chunks for long sequences
|
|
|
"""
|
|
|
|
|
|
byte_sequence = list(text.encode('utf-8'))
|
|
|
|
|
|
if return_chunks and len(byte_sequence) > self.content_size:
|
|
|
|
|
|
return self._encode_with_chunks(byte_sequence, add_special_tokens)
|
|
|
|
|
|
|
|
|
if len(byte_sequence) > self.content_size:
|
|
|
cut_point = _trim_utf8_boundary(byte_sequence, self.content_size)
|
|
|
byte_sequence = byte_sequence[:cut_point]
|
|
|
|
|
|
|
|
|
if add_special_tokens:
|
|
|
byte_sequence = [self.BOS] + byte_sequence + [self.EOS]
|
|
|
|
|
|
|
|
|
if len(byte_sequence) < self.max_seq_len:
|
|
|
padding_length = self.max_seq_len - len(byte_sequence)
|
|
|
byte_sequence = byte_sequence + [self.PAD] * padding_length
|
|
|
|
|
|
input_ids = torch.tensor(byte_sequence, dtype=torch.long)
|
|
|
attention_mask = (input_ids != self.PAD)
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'length': len(byte_sequence),
|
|
|
'original_length': len(text.encode('utf-8'))
|
|
|
}
|
|
|
|
|
|
def _encode_with_chunks(self,
|
|
|
byte_sequence: List[int],
|
|
|
add_special_tokens: bool) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Encode long sequences with sliding window chunks
|
|
|
"""
|
|
|
chunks = []
|
|
|
positions = []
|
|
|
|
|
|
|
|
|
stride = self.content_size - self.chunk_overlap
|
|
|
|
|
|
for i in range(0, len(byte_sequence), stride):
|
|
|
|
|
|
chunk = byte_sequence[i:i + self.content_size]
|
|
|
|
|
|
|
|
|
if len(chunk) < self.content_size // 2:
|
|
|
if chunks:
|
|
|
last_chunk = chunks[-1]['input_ids'].tolist()
|
|
|
|
|
|
last_chunk = [b for b in last_chunk if b not in [self.PAD, self.BOS, self.EOS]]
|
|
|
|
|
|
merged = last_chunk + chunk + [self.EOS]
|
|
|
|
|
|
if len(merged) < self.max_seq_len:
|
|
|
merged += [self.PAD] * (self.max_seq_len - len(merged))
|
|
|
merged_ids = torch.tensor(merged[:self.max_seq_len], dtype=torch.long)
|
|
|
merged_mask = (merged_ids != self.PAD)
|
|
|
chunks[-1]['input_ids'] = merged_ids
|
|
|
chunks[-1]['attention_mask'] = merged_mask
|
|
|
break
|
|
|
|
|
|
|
|
|
if len(chunk) < self.content_size:
|
|
|
chunk += [self.PAD] * (self.content_size - len(chunk))
|
|
|
|
|
|
|
|
|
if add_special_tokens:
|
|
|
chunk_with_special = [self.BOS] + chunk + [self.EOS]
|
|
|
else:
|
|
|
chunk_with_special = chunk
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor(chunk_with_special, dtype=torch.long)
|
|
|
attention_mask = (input_ids != self.PAD)
|
|
|
|
|
|
chunks.append({
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'position': (i, min(i + self.content_size, len(byte_sequence)))
|
|
|
})
|
|
|
positions.append((i, min(i + self.content_size, len(byte_sequence))))
|
|
|
|
|
|
|
|
|
all_input_ids = torch.stack([c['input_ids'] for c in chunks])
|
|
|
all_attention_masks = torch.stack([c['attention_mask'] for c in chunks])
|
|
|
|
|
|
return {
|
|
|
'input_ids': all_input_ids,
|
|
|
'attention_mask': all_attention_masks,
|
|
|
'num_chunks': len(chunks),
|
|
|
'chunk_positions': positions,
|
|
|
'original_length': len(byte_sequence)
|
|
|
}
|
|
|
|
|
|
def reconstruct(self,
|
|
|
input_ids: torch.Tensor,
|
|
|
positions: List[Tuple[int, int]] = None,
|
|
|
skip_special_tokens: bool = True,
|
|
|
overlap: int = 8) -> str:
|
|
|
"""
|
|
|
Reconstruct text from multiple chunks (GPT suggestion)
|
|
|
|
|
|
Args:
|
|
|
input_ids: [num_chunks, seq_len] for multi-chunk
|
|
|
positions: List of (start, end) positions for each chunk
|
|
|
skip_special_tokens: Whether to skip special tokens
|
|
|
overlap: Overlap size between chunks
|
|
|
"""
|
|
|
if input_ids.dim() == 1:
|
|
|
|
|
|
return self.decode(input_ids, skip_special_tokens)
|
|
|
|
|
|
|
|
|
pieces = []
|
|
|
for i, chunk_ids in enumerate(input_ids):
|
|
|
chunk_ids = chunk_ids.cpu().numpy().tolist()
|
|
|
|
|
|
|
|
|
if skip_special_tokens:
|
|
|
chunk_ids = [
|
|
|
b for b in chunk_ids
|
|
|
if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
|
|
|
]
|
|
|
|
|
|
pieces.append(chunk_ids)
|
|
|
|
|
|
|
|
|
output = []
|
|
|
for i, chunk in enumerate(pieces):
|
|
|
if i == 0:
|
|
|
output.extend(chunk)
|
|
|
else:
|
|
|
|
|
|
output.extend(chunk[overlap:] if len(chunk) > overlap else chunk)
|
|
|
|
|
|
|
|
|
try:
|
|
|
text = bytes(output).decode('utf-8', errors='replace')
|
|
|
except:
|
|
|
text = ""
|
|
|
|
|
|
return text
|
|
|
|
|
|
def decode(self,
|
|
|
input_ids: torch.Tensor,
|
|
|
skip_special_tokens: bool = True) -> str:
|
|
|
"""
|
|
|
Decode byte sequences back to text
|
|
|
"""
|
|
|
if isinstance(input_ids, torch.Tensor):
|
|
|
input_ids = input_ids.cpu().numpy().tolist()
|
|
|
|
|
|
|
|
|
if isinstance(input_ids[0], list):
|
|
|
input_ids = input_ids[0]
|
|
|
|
|
|
|
|
|
if skip_special_tokens:
|
|
|
input_ids = [
|
|
|
b for b in input_ids
|
|
|
if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
|
|
|
]
|
|
|
|
|
|
|
|
|
try:
|
|
|
text = bytes(input_ids).decode('utf-8', errors='replace')
|
|
|
except:
|
|
|
text = ""
|
|
|
|
|
|
return text
|
|
|
|
|
|
def batch_encode(self,
|
|
|
texts: List[str],
|
|
|
add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Encode multiple texts as a batch
|
|
|
"""
|
|
|
encoded = [self.encode(text, add_special_tokens) for text in texts]
|
|
|
|
|
|
|
|
|
max_len = max(e['length'] for e in encoded)
|
|
|
max_len = min(max_len, self.max_seq_len)
|
|
|
|
|
|
|
|
|
batch_size = len(texts)
|
|
|
input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
|
|
|
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool)
|
|
|
|
|
|
for i, enc in enumerate(encoded):
|
|
|
seq_len = min(enc['length'], max_len)
|
|
|
if enc['input_ids'].dim() == 0:
|
|
|
enc['input_ids'] = enc['input_ids'].unsqueeze(0)
|
|
|
input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
|
|
|
attention_mask[i, :seq_len] = True
|
|
|
|
|
|
return {
|
|
|
'input_ids': input_ids,
|
|
|
'attention_mask': attention_mask,
|
|
|
'lengths': [e['length'] for e in encoded]
|
|
|
}
|
|
|
|
|
|
|
|
|
class ChunkBoundaryAdjuster(nn.Module):
|
|
|
"""
|
|
|
Neural network for adjusting chunk boundaries
|
|
|
Learns optimal splitting points
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 256):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.boundary_scorer = nn.Sequential(
|
|
|
nn.Linear(256, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(hidden_dim // 2, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.utf8_detector = nn.Sequential(
|
|
|
nn.Conv1d(1, 16, kernel_size=4, padding=2),
|
|
|
nn.ReLU(),
|
|
|
nn.Conv1d(16, 1, kernel_size=1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
def forward(self, byte_sequence: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Find optimal chunk boundaries
|
|
|
|
|
|
Args:
|
|
|
byte_sequence: [batch, seq_len, embedding_dim]
|
|
|
|
|
|
Returns:
|
|
|
boundary_scores: [batch, seq_len] - probability of boundary at each position
|
|
|
"""
|
|
|
batch_size, seq_len = byte_sequence.shape[:2]
|
|
|
|
|
|
|
|
|
boundary_scores = self.boundary_scorer(byte_sequence).squeeze(-1)
|
|
|
|
|
|
|
|
|
byte_values = byte_sequence[..., 0].unsqueeze(1)
|
|
|
utf8_scores = self.utf8_detector(byte_values).squeeze(1)
|
|
|
|
|
|
|
|
|
combined_scores = boundary_scores * utf8_scores
|
|
|
|
|
|
|
|
|
for i in range(0, seq_len, 46):
|
|
|
if i < seq_len:
|
|
|
|
|
|
combined_scores[:, i] = combined_scores[:, i] * 1.5
|
|
|
|
|
|
return combined_scores
|
|
|
|
|
|
|
|
|
class SlidingWindowProcessor(nn.Module):
|
|
|
"""
|
|
|
Process sequences with sliding windows at multiple scales
|
|
|
"""
|
|
|
|
|
|
def __init__(self, window_sizes: List[int] = [8, 16, 32, 46]):
|
|
|
super().__init__()
|
|
|
self.window_sizes = window_sizes
|
|
|
|
|
|
|
|
|
self.convs = nn.ModuleList([
|
|
|
nn.Conv1d(256, 128, kernel_size=ws, stride=ws//2, padding=ws//4)
|
|
|
for ws in window_sizes
|
|
|
])
|
|
|
|
|
|
|
|
|
self.fusion = nn.Sequential(
|
|
|
nn.Linear(128 * len(window_sizes), 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(256, 256)
|
|
|
)
|
|
|
|
|
|
def forward(self, byte_embeddings: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Apply multi-scale sliding windows
|
|
|
|
|
|
Args:
|
|
|
byte_embeddings: [batch, seq_len, embedding_dim]
|
|
|
|
|
|
Returns:
|
|
|
processed: [batch, seq_len, embedding_dim]
|
|
|
"""
|
|
|
|
|
|
x = byte_embeddings.transpose(1, 2)
|
|
|
|
|
|
|
|
|
multi_scale_features = []
|
|
|
for conv in self.convs:
|
|
|
features = conv(x)
|
|
|
|
|
|
pooled = F.adaptive_avg_pool1d(features, byte_embeddings.size(1))
|
|
|
multi_scale_features.append(pooled)
|
|
|
|
|
|
|
|
|
concat = torch.cat(multi_scale_features, dim=1)
|
|
|
concat = concat.transpose(1, 2)
|
|
|
|
|
|
|
|
|
fused = self.fusion(concat)
|
|
|
|
|
|
|
|
|
output = fused + byte_embeddings
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
class AdaptiveChunker:
|
|
|
"""
|
|
|
Adaptive chunking based on content complexity
|
|
|
Simple heuristic-based chunker for inference
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.min_chunk = 32
|
|
|
self.max_chunk = 46
|
|
|
self.target_chunk = 46
|
|
|
|
|
|
def determine_chunk_size(self, text: str) -> int:
|
|
|
"""
|
|
|
Determine optimal chunk size based on text characteristics
|
|
|
"""
|
|
|
byte_seq = text.encode('utf-8')
|
|
|
|
|
|
|
|
|
has_cjk = any(b >= 0x80 for b in byte_seq[:100])
|
|
|
has_arabic = any(0x0600 <= ord(c) <= 0x06FF for c in text[:100])
|
|
|
|
|
|
|
|
|
if has_cjk:
|
|
|
|
|
|
return self.min_chunk
|
|
|
elif has_arabic:
|
|
|
|
|
|
return 40
|
|
|
else:
|
|
|
|
|
|
return self.target_chunk
|
|
|
|
|
|
def chunk_text(self, text: str) -> List[str]:
|
|
|
"""
|
|
|
Split text into adaptive chunks
|
|
|
"""
|
|
|
chunk_size = self.determine_chunk_size(text)
|
|
|
byte_seq = text.encode('utf-8')
|
|
|
chunks = []
|
|
|
|
|
|
i = 0
|
|
|
while i < len(byte_seq):
|
|
|
|
|
|
end = min(i + chunk_size, len(byte_seq))
|
|
|
|
|
|
|
|
|
while end > i and end < len(byte_seq):
|
|
|
try:
|
|
|
_ = byte_seq[i:end].decode('utf-8')
|
|
|
break
|
|
|
except:
|
|
|
end -= 1
|
|
|
|
|
|
chunk_bytes = byte_seq[i:end]
|
|
|
chunks.append(chunk_bytes.decode('utf-8', errors='replace'))
|
|
|
i = end
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
tokenizer = ByteTokenizerV62()
|
|
|
|
|
|
|
|
|
test_texts = [
|
|
|
"Hello, world!",
|
|
|
"안녕하세요, 세계!",
|
|
|
"今天天气很好。",
|
|
|
"مرحبا بالعالم",
|
|
|
"A" * 100
|
|
|
]
|
|
|
|
|
|
for text in test_texts:
|
|
|
print(f"\nText: {text[:50]}...")
|
|
|
|
|
|
|
|
|
encoded = tokenizer.encode(text)
|
|
|
print(f" Encoded shape: {encoded['input_ids'].shape}")
|
|
|
print(f" Original length: {encoded['original_length']} bytes")
|
|
|
|
|
|
|
|
|
decoded = tokenizer.decode(encoded['input_ids'])
|
|
|
print(f" Decoded: {decoded[:50]}...")
|
|
|
|
|
|
|
|
|
if encoded['original_length'] > 46:
|
|
|
multi_encoded = tokenizer.encode(text, return_chunks=True)
|
|
|
print(f" Chunks: {multi_encoded['num_chunks']}")
|
|
|
|
|
|
|
|
|
batch = tokenizer.batch_encode(test_texts[:3])
|
|
|
print(f"\nBatch shape: {batch['input_ids'].shape}")
|
|
|
|
|
|
|
|
|
chunker = AdaptiveChunker()
|
|
|
for text in test_texts[:3]:
|
|
|
chunk_size = chunker.determine_chunk_size(text)
|
|
|
print(f"\n{text[:30]}... → Chunk size: {chunk_size}") |