""" Intelligent Tokenizer v6.2.0 - 6-Layer Decoder with Multi-Level Cross-Attention Incorporates GPT-5 suggestions for KV cache optimization """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple import math class KVCacheOptimizedAttention(nn.Module): """ KV Cache Optimized Attention - GPT-5 suggestion 16Q → 2K/V for 8x memory reduction """ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_compression: int = 8): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads self.kv_heads = max(2, num_heads // kv_compression) # 16/8 = 2 KV heads self.head_dim = hidden_dim // num_heads # 80 # Query uses all heads self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads # Key/Value use fewer heads (GPT-5 suggestion) self.k_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads self.v_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads # Output projection self.o_proj = nn.Linear(hidden_dim, hidden_dim) # KV cache for inference self.register_buffer('cached_keys', None) self.register_buffer('cached_values', None) def forward(self, hidden_states: torch.Tensor, encoder_hidden: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]: """ Forward pass with KV cache optimization """ batch_size, seq_len = hidden_states.shape[:2] # Query projection (all heads) Q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim) Q = Q.transpose(1, 2) # [batch, heads, seq, dim] # Key/Value source (self or cross) kv_source = encoder_hidden if encoder_hidden is not None else hidden_states # Key/Value projection (fewer heads) K = self.k_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim) V = self.v_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim) K = K.transpose(1, 2) # [batch, kv_heads, seq, dim] V = V.transpose(1, 2) # Repeat KV heads to match Q heads (broadcast) K = K.repeat_interleave(self.num_heads // self.kv_heads, dim=1) V = V.repeat_interleave(self.num_heads // self.kv_heads, dim=1) # Cache management for incremental generation (GPT suggestion) if use_cache: # For incremental generation, only process new token if self.cached_keys is not None and hidden_states.size(1) == 1: # Append new K/V to cache K = torch.cat([self.cached_keys, K], dim=2) V = torch.cat([self.cached_values, V], dim=2) # Update cache self.cached_keys = K self.cached_values = V # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # Use additive mask (GPT suggestion) if attention_mask is not None: scores = scores + attention_mask # additive mask: -inf where masked, 0 elsewhere attn_weights = F.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, V) # Reshape and project attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim) output = self.o_proj(attn_output) return output, (K, V) if use_cache else None class SelectiveCrossAttention(nn.Module): """ Selective cross-attention - only attend to relevant encoder layers Reduces 24 → 8 cross-attentions for efficiency """ def __init__(self, hidden_dim: int = 1280, layer_id: int = 0): super().__init__() self.hidden_dim = hidden_dim self.layer_id = layer_id # Define which encoder layers this decoder layer should attend to self.encoder_connections = { 0: [0], # Decoder L0 → Encoder L0 (byte info) 1: [0], # Decoder L1 → Encoder L0 (byte info) 2: [1, 2], # Decoder L2 → Encoder L1,2 (language info) 3: [1, 2], # Decoder L3 → Encoder L1,2 (language info) 4: [3], # Decoder L4 → Encoder L3 (semantic info) 5: [3], # Decoder L5 → Encoder L3 (semantic info) } # Get connections for this layer self.connected_layers = self.encoder_connections.get(layer_id, [0]) # Create attention modules only for connected layers self.cross_attentions = nn.ModuleList([ KVCacheOptimizedAttention(hidden_dim, num_heads=16, kv_compression=8) for _ in self.connected_layers ]) # Lightweight fusion with weighted sum (GPT suggestion) self.fusion = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(0.1) ) # Learnable weights for connected layers only self.layer_weights = nn.Parameter(torch.ones(len(self.connected_layers)) / len(self.connected_layers)) def forward(self, decoder_hidden: torch.Tensor, encoder_all_hidden: List[torch.Tensor], attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Selectively attend to relevant encoder layers only """ # Only attend to connected encoder layers cross_outputs = [] for i, layer_idx in enumerate(self.connected_layers): if layer_idx < len(encoder_all_hidden): encoder_hidden = encoder_all_hidden[layer_idx] cross_out, _ = self.cross_attentions[i]( hidden_states=decoder_hidden, encoder_hidden=encoder_hidden, attention_mask=attention_mask ) cross_outputs.append(cross_out) # Weighted sum fusion for connected layers only if len(cross_outputs) > 1: weighted_outputs = torch.stack(cross_outputs, dim=0) # [N, batch, seq, hidden] weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1) fused = (weighted_outputs * weights).sum(dim=0) # [batch, seq, hidden] else: # Single connection - no fusion needed fused = cross_outputs[0] if cross_outputs else decoder_hidden # Apply lightweight fusion layer fused = self.fusion(fused) return fused class SwiGLU(nn.Module): """SwiGLU activation for better convergence (GPT suggestion)""" def __init__(self, dim: int, mult: float = 2.66): super().__init__() inner = int(round(dim * mult / 2)) * 2 # Even alignment self.w1 = nn.Linear(dim, inner // 2) self.w2 = nn.Linear(dim, inner // 2) self.w3 = nn.Linear(inner // 2, dim) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) class DecoderLayer(nn.Module): """ Single decoder layer with self-attention and selective cross-attention """ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, layer_id: int = 0): super().__init__() self.hidden_dim = hidden_dim self.layer_id = layer_id # Self-attention (with KV cache optimization) self.self_attn = KVCacheOptimizedAttention(hidden_dim, num_heads, kv_compression=8) self.self_attn_norm = nn.LayerNorm(hidden_dim) # Selective cross-attention to specific encoder layers self.cross_attn = SelectiveCrossAttention(hidden_dim, layer_id=layer_id) self.cross_attn_norm = nn.LayerNorm(hidden_dim) # Feed-forward network with SwiGLU (GPT suggestion) self.ffn = SwiGLU(hidden_dim, mult=2.66) self.ffn_norm = nn.LayerNorm(hidden_dim) # Dropout for residual connections self.dropout = nn.Dropout(0.1) def forward(self, hidden_states: torch.Tensor, encoder_all_hidden: List[torch.Tensor], self_attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]: """ Forward pass through decoder layer """ # Self-attention with residual residual = hidden_states hidden_states = self.self_attn_norm(hidden_states) self_attn_out, cache = self.self_attn( hidden_states, attention_mask=self_attention_mask, use_cache=use_cache ) hidden_states = residual + self.dropout(self_attn_out) # Cross-attention with residual residual = hidden_states hidden_states = self.cross_attn_norm(hidden_states) cross_attn_out = self.cross_attn( hidden_states, encoder_all_hidden, attention_mask=cross_attention_mask ) hidden_states = residual + self.dropout(cross_attn_out) # FFN with residual residual = hidden_states hidden_states = self.ffn_norm(hidden_states) ffn_out = self.ffn(hidden_states) hidden_states = residual + self.dropout(ffn_out) return hidden_states, cache class DecoderV62(nn.Module): """ 6-Layer Decoder with Multi-Level Cross-Attention Reduced from 8 layers but compensated with better cross-attention """ def __init__(self, config: Optional[Dict] = None): super().__init__() # Configuration self.hidden_dim = 1280 self.num_heads = 16 self.num_layers = 6 # Reduced from 8 self.vocab_size = 260 # 256 bytes + special tokens self.max_seq_len = 48 # Token constants (GPT suggestion - explicit constants) self.PAD = 256 self.BOS = 257 self.EOS = 258 self.MASK = 259 # Token embedding and position encoding self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_dim) self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_dim) # 6 decoder layers with layer-specific cross-attention self.layers = nn.ModuleList([ DecoderLayer(self.hidden_dim, self.num_heads, layer_id=i) for i in range(self.num_layers) ]) # Output projection self.output_norm = nn.LayerNorm(self.hidden_dim) self.output_projection = nn.Linear(self.hidden_dim, self.vocab_size) # Monitoring (GPT-5 suggestion) # Track importance of ENCODER layers (4) used by decoder self.register_buffer('layer_importance', torch.zeros(4)) # Track importance of 4 encoder layers def forward(self, encoder_all_hidden: List[torch.Tensor], decoder_input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = False, past_key_values: Optional[List] = None) -> Dict[str, torch.Tensor]: """ Forward pass through decoder Args: encoder_all_hidden: All encoder layer outputs (4 layers) decoder_input_ids: Input token IDs for teacher forcing attention_mask: Attention mask use_cache: Whether to cache KV for inference past_key_values: Cached KV from previous steps """ batch_size = encoder_all_hidden[0].size(0) device = encoder_all_hidden[0].device # If no decoder input, start with compressed representation if decoder_input_ids is None: # Use encoder's final compressed output as starting point hidden_states = encoder_all_hidden[-1] # [batch, M tokens, 1280] seq_len = hidden_states.size(1) else: # Teacher forcing mode: use provided tokens seq_len = decoder_input_ids.size(1) # Embeddings token_embeds = self.token_embedding(decoder_input_ids) position_ids = torch.arange(seq_len, device=device).expand(batch_size, -1) position_embeds = self.position_embedding(position_ids) hidden_states = token_embeds + position_embeds # Create causal mask for self-attention (additive mask - GPT suggestion) causal_mask = torch.full((1, 1, seq_len, seq_len), float('-inf'), device=device) causal_mask = torch.triu(causal_mask, diagonal=1) # [1, 1, seq, seq] # Pass through decoder layers all_hidden_states = [] all_caches = [] if use_cache else None for i, layer in enumerate(self.layers): # GPT final check: Create proper cross-attention mask for encoder hidden states if encoder_all_hidden is not None and len(encoder_all_hidden) > 0: S_enc = encoder_all_hidden[0].size(1) # Encoder sequence length # Create additive mask (0 = attend, -inf = mask) cross_mask = torch.zeros((batch_size, 1, 1, S_enc), device=hidden_states.device) else: cross_mask = None hidden_states, cache = layer( hidden_states, encoder_all_hidden, self_attention_mask=causal_mask, cross_attention_mask=cross_mask, # Use proper cross mask use_cache=use_cache ) all_hidden_states.append(hidden_states) if use_cache: all_caches.append(cache) # Final output projection hidden_states = self.output_norm(hidden_states) logits = self.output_projection(hidden_states) # Update monitoring: track encoder layer importance # (This would be computed based on cross-attention weights in practice) with torch.no_grad(): # Simplified: assume equal importance for now self.layer_importance = torch.tensor([0.25, 0.25, 0.25, 0.25]) outputs = { 'logits': logits, 'last_hidden_state': hidden_states, 'all_hidden_states': all_hidden_states, 'encoder_layer_importance': self.layer_importance } if use_cache: outputs['past_key_values'] = all_caches return outputs def generate(self, encoder_all_hidden: List[torch.Tensor], max_length: int = 48, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.95) -> torch.Tensor: """ Autoregressive generation """ batch_size = encoder_all_hidden[0].size(0) device = encoder_all_hidden[0].device # Start with BOS token generated = torch.full((batch_size, 1), self.BOS, device=device) # Generate tokens one by one past_key_values = None for _ in range(max_length - 1): # GPT optimization: Only pass last token for O(T) complexity if past_key_values is not None: decoder_input = generated[:, -1:] # Last token only else: decoder_input = generated # Full sequence for first step outputs = self.forward( encoder_all_hidden, decoder_input_ids=decoder_input, use_cache=True, past_key_values=past_key_values ) logits = outputs['logits'][:, -1, :] / temperature # Top-k filtering if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = float('-inf') # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) # Check for EOS if (next_token == self.EOS).all(): break past_key_values = outputs.get('past_key_values') return generated def get_memory_usage(self) -> Dict[str, float]: """ Calculate memory usage with KV cache optimization (GPT-5 metric) """ # Standard attention: 16 heads for K and V standard_kv_memory = 2 * 16 * self.max_seq_len * 80 * 4 # bytes # Optimized: 2 heads for K and V optimized_kv_memory = 2 * 2 * self.max_seq_len * 80 * 4 # bytes return { 'standard_kv_mb': standard_kv_memory / (1024 * 1024), 'optimized_kv_mb': optimized_kv_memory / (1024 * 1024), 'reduction_ratio': standard_kv_memory / optimized_kv_memory, 'total_params_m': sum(p.numel() for p in self.parameters()) / 1e6 } if __name__ == "__main__": # Test the decoder decoder = DecoderV62() # Simulate encoder outputs (4 layers, 6 tokens each) batch_size = 2 num_tokens = 6 # After progressive splitting hidden_dim = 1280 encoder_outputs = [ torch.randn(batch_size, num_tokens, hidden_dim) for _ in range(4) ] # Test with teacher forcing decoder_input = torch.randint(0, 256, (batch_size, 48)) output = decoder(encoder_outputs, decoder_input_ids=decoder_input) print(f"Decoder output shape: {output['logits'].shape}") print(f"Encoder layer importance: {output['encoder_layer_importance']}") # Test generation generated = decoder.generate(encoder_outputs, max_length=48) print(f"Generated shape: {generated.shape}") # Memory usage memory_stats = decoder.get_memory_usage() print(f"Memory optimization: {memory_stats['reduction_ratio']:.1f}x reduction") print(f"Total parameters: {memory_stats['total_params_m']:.1f}M")