""" Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder With GPT-5 suggested improvements """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple import math class RoPEPositionalEncoding(nn.Module): """ Rotary Position Embedding (RoPE) - GPT-5 suggestion Better for handling chunk boundaries and variable sequence lengths """ def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base # Precompute sinusoidal frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) # Precompute positional encodings t = torch.arange(max_seq_len).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) self.register_buffer('cos_cached', freqs.cos()) self.register_buffer('sin_cached', freqs.sin()) def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor: """ Apply RoPE to input tensor Handles chunk boundary corrections as suggested by GPT-5 """ if seq_len is None: seq_len = x.shape[1] # Get cached cos/sin values cos = self.cos_cached[:seq_len] sin = self.sin_cached[:seq_len] # Apply rotary embedding x_rot = self._apply_rotary_emb(x, cos, sin) return x_rot def _apply_rotary_emb(self, x, cos, sin): """Apply rotary embedding to input""" x1, x2 = x[..., ::2], x[..., 1::2] x_rot = torch.stack([ x1 * cos - x2 * sin, x1 * sin + x2 * cos ], dim=-1).flatten(-2) return x_rot class GatedCrossAttention(nn.Module): """ Gated Cross-Attention with MQA - GPT-5 suggestion Monitor gate values for quality assessment 16Q → 2K/V for 8x memory reduction """ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.kv_heads = kv_heads # Reduced KV heads (GPT suggestion) self.head_dim = hidden_dim // num_heads # 80 # Multi-Query Attention projections self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads self.o_proj = nn.Linear(hidden_dim, hidden_dim) # Gating mechanism (GPT-5 suggestion) self.gate = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.Sigmoid() ) # Gate monitoring (for analysis) self.register_buffer('gate_values', torch.zeros(1)) # Warmup factor (GPT suggestion) self.register_buffer('warmup_alpha', torch.tensor(1.0)) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass with gate monitoring Returns: (output, gate_values) """ batch_size, seq_len = query.shape[:2] # Multi-head attention projections Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim) K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim) V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim) # Transpose for attention computation Q = Q.transpose(1, 2) # [batch, heads, seq, dim] K = K.transpose(1, 2) # [batch, kv_heads, seq, dim] V = V.transpose(1, 2) # Repeat KV heads to match Q heads if necessary if self.kv_heads < self.num_heads: repeat_factor = self.num_heads // self.kv_heads K = K.repeat_interleave(repeat_factor, dim=1) V = V.repeat_interleave(repeat_factor, dim=1) # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, V) # Reshape back attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim) attn_output = self.o_proj(attn_output) # Gating mechanism gate_input = torch.cat([query, attn_output], dim=-1) gate_values = self.gate(gate_input) # Store gate values for monitoring (keep tensor shape consistent) self.gate_values[0] = gate_values.mean().detach() # Apply gate with warmup factor (GPT suggestion) gate_values = gate_values * self.warmup_alpha output = gate_values * attn_output + (1 - gate_values) * query return output, gate_values class ProgressiveSplittingLayer(nn.Module): """ Core innovation: 48 bytes → 1 token → N tokens → M tokens """ def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None): super().__init__() self.hidden_dim = hidden_dim self.config = config or {} # Dynamic splitting: 1~4 tokens for efficiency # 48 bytes / 4 tokens = 12:1 compression (still beats BPE's 4:1) self.min_tokens = 1 # 48:1 compression self.max_tokens = 4 # 12:1 compression (still 3x better than BPE) # Initial compression: 48 bytes → 1 super token self.byte_embed = nn.Embedding(260, 64) # Small embedding self.initial_compressor = nn.Sequential( nn.Linear(48 * 64, 2048), nn.LayerNorm(2048), nn.ReLU(), nn.Dropout(0.1), nn.Linear(2048, hidden_dim), nn.LayerNorm(hidden_dim) ) # Language-aware splitting: 1 → N tokens (config-based) self.language_splitter = nn.ModuleDict({ 'analyzer': nn.Sequential( nn.Linear(hidden_dim, 512), nn.ReLU(), nn.Linear(512, 256) # Language features ), 'split_predictor': nn.Linear(256, self.max_tokens), # Predict 1~4 tokens # Single unified expander that can produce any number of tokens 'dynamic_expander': nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 2), nn.LayerNorm(hidden_dim * 2), nn.GELU(), # Better than ReLU for transformers nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens) # Can produce up to 4 tokens ), # Token-wise importance predictor 'importance_predictor': nn.Sequential( nn.Linear(hidden_dim, 256), nn.ReLU(), nn.Linear(256, self.max_tokens), # Importance for each potential token nn.Softmax(dim=-1) ) }) # Boundary refinement: N → M tokens with linguistic awareness self.boundary_refiner = nn.ModuleDict({ 'scorer': nn.Sequential( nn.Linear(hidden_dim, 512), nn.ReLU(), nn.Linear(512, 1) ), 'morpheme_detector': nn.Conv1d(256, 64, 3), # 형태소 'word_detector': nn.Conv1d(256, 64, 5), # 단어 'phrase_detector': nn.Conv1d(256, 64, 7), # 구 'adjuster': nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=16, dim_feedforward=4 * hidden_dim, dropout=0.1, batch_first=True ) }) # Initialize split_predictor bias to prefer 1 token initially # This ensures untrained model starts with maximum compression with torch.no_grad(): self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0]) # High bias for 1 token, negative for others def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]: """ Progressive splitting forward pass Args: input_ids: Input byte sequence [batch, seq_len] temperature: Gumbel-Softmax temperature for annealing """ batch_size = input_ids.size(0) # Step 1: 48 bytes → 1 super token byte_embeddings = self.byte_embed(input_ids) # [batch, 48, 64] flattened = byte_embeddings.view(batch_size, -1) # [batch, 3072] super_token = self.initial_compressor(flattened) # [batch, 1280] super_token = super_token.unsqueeze(1) # [batch, 1, 1280] # Step 2: Language analysis and splitting (1 → N) lang_features = self.language_splitter['analyzer'](super_token) split_logits = self.language_splitter['split_predictor'](lang_features) split_weights = F.softmax(split_logits, dim=-1) # [batch, 1, 8] # Direct transformation from super token to initial representation # No hardcoded splits - let the model learn everything lang_tokens = super_token # Start with compressed representation # TRUE Adaptive expansion - Model learns optimal split (1~4 tokens) # Analyze content to decide how many tokens needed expansion_features = self.language_splitter['analyzer'](lang_tokens) # [batch, 1, 256] # Dynamic expansion: generate up to 4 tokens from super token expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1)) # [batch, hidden_dim*4] expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim) # [batch, 4, hidden_dim] # Predict how many tokens we actually need (1~4) split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1)) # [batch, 4] # Clamp logits to prevent extreme values that cause NaN split_logits = torch.clamp(split_logits, min=-10, max=10) # Ensure minimum temperature to prevent instability safe_temperature = max(temperature, 0.5) split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1) # [batch, 4] # Predict importance for each potential token position importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1)) # [batch, 4] # Dynamic token selection with importance-weighted allocation # Create cumulative mask for progressive token usage # If split_weights = [0.1, 0.2, 0.6, 0.1], we mainly use 3 tokens # Create progressive masks for 1, 2, 3, 4 tokens masks = [] for n in range(1, self.max_tokens + 1): mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device) mask[:, :n, :] = 1.0 masks.append(mask) # Apply importance-weighted masking # Important parts get more tokens, less important parts get fewer weighted_outputs = [] for i, mask in enumerate(masks): num_tokens = i + 1 # Weight by both split decision and importance token_weight = split_weights[:, i:i+1].unsqueeze(-1) # [batch, 1, 1] # Apply importance modulation for asymmetric splits if num_tokens > 1: # Redistribute tokens based on importance importance_adjusted = importance[:, :num_tokens].unsqueeze(-1) # [batch, n, 1] masked = expanded[:, :num_tokens] * importance_adjusted else: masked = expanded[:, :num_tokens] # Pad to max length if num_tokens < self.max_tokens: padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim, device=expanded.device) masked = torch.cat([masked, padding], dim=1) weighted_outputs.append(masked * token_weight) # Sum all weighted possibilities (differentiable selection) lang_tokens = sum(weighted_outputs) # Determine effective number of tokens (for monitoring) # Weighted average of token counts token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32) avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item() k = lang_tokens.size(1) # Step 3: Boundary refinement (N → M) # Calculate boundary scores for each token position boundary_scores = self.boundary_refiner['scorer'](lang_tokens) # [batch, N, 1] # Detect linguistic boundaries (morpheme, word, phrase) # Extract features for boundary detection if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3: batch_size, num_tokens, hidden_dim = lang_tokens.shape # For boundary detection, we need to consider the original byte sequence # But we're working with compressed tokens here # So we detect boundaries based on learned representations # Apply boundary adjustment with TransformerEncoderLayer # This learns to adjust token boundaries based on context refined_tokens = self.boundary_refiner['adjuster'](lang_tokens) # The adjuster should learn to: # 1. Respect UTF-8 boundaries (learned during training) # 2. Align with word/phrase boundaries (learned from language patterns) # 3. Maintain semantic coherence within each token else: refined_tokens = lang_tokens # Determine actual number of tokens based on highest probability # During inference, use argmax. During training, use weighted average. if self.training: # During training, use weighted average for differentiability actual_num_tokens = avg_tokens else: # During inference, select the split with highest probability split_decision = torch.argmax(split_weights, dim=-1) # [batch] actual_num_tokens = (split_decision.float().mean() + 1).item() # +1 because indices are 0-3 # Calculate compression ratio based on actual tokens used compression_ratio = 48.0 / max(1, actual_num_tokens) return { 'tokens': refined_tokens, 'num_tokens': actual_num_tokens, 'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device), 'gate_values': None, # Will be filled by cross-attention 'language_features': lang_features, 'split_weights': split_weights, 'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1), 'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None } class EncoderV62(nn.Module): """ 4-Layer Progressive Splitting Encoder with Cross-Attention All layers: 1280 dimensions """ def __init__(self, config: Optional[Dict] = None): super().__init__() # Store config for later use self.config = config or {} # Configuration self.hidden_dim = 1280 self.num_heads = 16 self.num_layers = 4 self.max_seq_len = 48 self.dropout = 0.1 # RoPE positional encoding (GPT-5 suggestion) self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len) # Layer 0: Progressive Splitting (48→1→N→M) - Pass config self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config) # Layers 1-3: Transformer encoders with cross-attention self.encoder_layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=self.hidden_dim, nhead=self.num_heads, dim_feedforward=4 * self.hidden_dim, # 5120 dropout=self.dropout, batch_first=True ) for _ in range(3) ]) # Cross-attention between layers with MQA (GPT-5 suggestion) self.cross_attentions = nn.ModuleList([ GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2) # 8x memory reduction for _ in range(3) ]) # Output heads for different tasks self.boundary_head = nn.Linear(self.hidden_dim, 4) self.language_head = nn.Linear(self.hidden_dim, 128) # Reduced from 512 (GPT suggestion) self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim) # Monitoring metrics (GPT-5 suggestion) self.register_buffer('compression_ratios', torch.zeros(1)) self.register_buffer('gate_averages', torch.zeros(3)) def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, temperature: float = 1.0) -> Dict[str, torch.Tensor]: """ Forward pass through the encoder Args: input_ids: Input byte sequence attention_mask: Optional attention mask temperature: Gumbel-Softmax temperature for annealing """ # Layer 0: Progressive splitting with temperature split_output = self.progressive_splitter(input_ids, temperature) x = split_output['tokens'] # [batch, M, 1280] # Apply RoPE x = self.rope(x, x.size(1)) # Store all hidden states for decoder all_hidden_states = [x] gate_values_list = [] # Layers 1-3 with cross-attention for i, (encoder_layer, cross_attn) in enumerate( zip(self.encoder_layers, self.cross_attentions) ): # Self-attention through transformer layer # GPT final check: Don't pass mask after progressive splitting changes sequence length x = encoder_layer(x) # No mask needed (no padding after compression) # Cross-attention with previous layer if i > 0: # Cross-attention with previous layer x, gate_values = cross_attn( query=x, key=all_hidden_states[-1], value=all_hidden_states[-1], mask=None # Mask not applicable after compression ) gate_values_list.append(gate_values) # Keep tensor shape consistent - store in existing buffer element self.gate_averages[i-1] = gate_values.mean().detach().item() # Fix indexing all_hidden_states.append(x) # Output projections boundaries = self.boundary_head(x) language_clusters = self.language_head(x) compressed = self.compression_head(x) # Update monitoring metrics # Ensure tensor is 1-dimensional for buffer assignment compression_ratio = split_output['compression_ratio'] if compression_ratio.dim() == 0: # Scalar tensor self.compression_ratios[0] = compression_ratio else: self.compression_ratios = compression_ratio return { 'last_hidden_state': x, 'all_hidden_states': all_hidden_states, 'boundaries': boundaries, 'language_clusters': language_clusters, 'compressed': compressed, 'compression_ratio': split_output['compression_ratio'], 'num_tokens': split_output['num_tokens'], 'splitting_probs': split_output.get('split_weights', None), # Add for diagnostics 'gate_values': gate_values_list, 'gate_averages': self.gate_averages, 'split_info': { 'language_features': split_output['language_features'], 'split_weights': split_output['split_weights'] } } def get_monitoring_stats(self) -> Dict[str, float]: """ Get monitoring statistics (GPT-5 suggestion) """ return { 'avg_compression_ratio': self.compression_ratios.item(), 'gate_layer1': self.gate_averages[0].item(), 'gate_layer2': self.gate_averages[1].item(), 'gate_layer3': self.gate_averages[2].item(), } def set_warmup_step(self, step: int, total_warmup: int = 1000): """ Set warmup alpha for all gates (GPT suggestion) Gradually increase gate influence from 0 to 1 """ alpha = min(1.0, step / total_warmup) for cross_attn in self.cross_attentions: cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device) def adaptive_compression_control(self, reconstruction_loss: float): """ Adaptive compression based on reconstruction quality No fixed phases - model learns optimal compression """ # If reconstruction is poor, model will learn to use more tokens # This happens automatically through gradient descent # No manual phase control needed pass # Let gradients handle it class DualSlidingWindowEncoder(EncoderV62): """ Extension with dual sliding window system Handles both chunk-level and token-level boundaries """ def __init__(self, config: Optional[Dict] = None): super().__init__(config) # Chunk-level sliding window self.chunk_window = nn.Conv1d( in_channels=1, out_channels=1, kernel_size=8, # 8-byte overlap stride=40, # 48-8=40 stride padding=4 ) # Token-level sliding window self.token_window = nn.MultiheadAttention( embed_dim=self.hidden_dim, num_heads=self.num_heads, batch_first=True ) def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor: """ Handle sequences longer than 48 bytes with sliding windows """ batch_size, seq_len = input_ids.shape if seq_len <= 48: return super().forward(input_ids) # Process in chunks with overlap chunks = [] for i in range(0, seq_len - 48 + 1, 40): # 8-byte overlap chunk = input_ids[:, i:i+48] chunk_output = super().forward(chunk) chunks.append(chunk_output['last_hidden_state']) # Combine chunks with attention combined = torch.cat(chunks, dim=1) attended, _ = self.token_window(combined, combined, combined) return { 'last_hidden_state': attended, 'num_chunks': len(chunks), 'total_compression': seq_len / attended.size(1) } if __name__ == "__main__": # Test the encoder encoder = EncoderV62() # Test input batch_size = 2 input_ids = torch.randint(0, 256, (batch_size, 48)) # Forward pass output = encoder(input_ids) print(f"Input shape: {input_ids.shape}") print(f"Output tokens: {output['num_tokens']}") print(f"Compression ratio: {output['compression_ratio']:.2f}:1") print(f"Gate averages: {output['gate_averages']}") print(f"Monitoring stats: {encoder.get_monitoring_stats()}")