ggunio's picture
Upload folder using huggingface_hub
ff85374 verified
"""
Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder
With GPT-5 suggested improvements
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import math
class RoPEPositionalEncoding(nn.Module):
"""
Rotary Position Embedding (RoPE) - GPT-5 suggestion
Better for handling chunk boundaries and variable sequence lengths
"""
def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute sinusoidal frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute positional encodings
t = torch.arange(max_seq_len).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
self.register_buffer('cos_cached', freqs.cos())
self.register_buffer('sin_cached', freqs.sin())
def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
"""
Apply RoPE to input tensor
Handles chunk boundary corrections as suggested by GPT-5
"""
if seq_len is None:
seq_len = x.shape[1]
# Get cached cos/sin values
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
# Apply rotary embedding
x_rot = self._apply_rotary_emb(x, cos, sin)
return x_rot
def _apply_rotary_emb(self, x, cos, sin):
"""Apply rotary embedding to input"""
x1, x2 = x[..., ::2], x[..., 1::2]
x_rot = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1).flatten(-2)
return x_rot
class GatedCrossAttention(nn.Module):
"""
Gated Cross-Attention with MQA - GPT-5 suggestion
Monitor gate values for quality assessment
16Q β†’ 2K/V for 8x memory reduction
"""
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.kv_heads = kv_heads # Reduced KV heads (GPT suggestion)
self.head_dim = hidden_dim // num_heads # 80
# Multi-Query Attention projections
self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
# Gating mechanism (GPT-5 suggestion)
self.gate = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Sigmoid()
)
# Gate monitoring (for analysis)
self.register_buffer('gate_values', torch.zeros(1))
# Warmup factor (GPT suggestion)
self.register_buffer('warmup_alpha', torch.tensor(1.0))
def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass with gate monitoring
Returns: (output, gate_values)
"""
batch_size, seq_len = query.shape[:2]
# Multi-head attention projections
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim)
K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim)
V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim)
# Transpose for attention computation
Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
V = V.transpose(1, 2)
# Repeat KV heads to match Q heads if necessary
if self.kv_heads < self.num_heads:
repeat_factor = self.num_heads // self.kv_heads
K = K.repeat_interleave(repeat_factor, dim=1)
V = V.repeat_interleave(repeat_factor, dim=1)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Reshape back
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
attn_output = self.o_proj(attn_output)
# Gating mechanism
gate_input = torch.cat([query, attn_output], dim=-1)
gate_values = self.gate(gate_input)
# Store gate values for monitoring (keep tensor shape consistent)
self.gate_values[0] = gate_values.mean().detach()
# Apply gate with warmup factor (GPT suggestion)
gate_values = gate_values * self.warmup_alpha
output = gate_values * attn_output + (1 - gate_values) * query
return output, gate_values
class ProgressiveSplittingLayer(nn.Module):
"""
Core innovation: 48 bytes β†’ 1 token β†’ N tokens β†’ M tokens
"""
def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None):
super().__init__()
self.hidden_dim = hidden_dim
self.config = config or {}
# Dynamic splitting: 1~4 tokens for efficiency
# 48 bytes / 4 tokens = 12:1 compression (still beats BPE's 4:1)
self.min_tokens = 1 # 48:1 compression
self.max_tokens = 4 # 12:1 compression (still 3x better than BPE)
# Initial compression: 48 bytes β†’ 1 super token
self.byte_embed = nn.Embedding(260, 64) # Small embedding
self.initial_compressor = nn.Sequential(
nn.Linear(48 * 64, 2048),
nn.LayerNorm(2048),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(2048, hidden_dim),
nn.LayerNorm(hidden_dim)
)
# Language-aware splitting: 1 β†’ N tokens (config-based)
self.language_splitter = nn.ModuleDict({
'analyzer': nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.ReLU(),
nn.Linear(512, 256) # Language features
),
'split_predictor': nn.Linear(256, self.max_tokens), # Predict 1~4 tokens
# Single unified expander that can produce any number of tokens
'dynamic_expander': nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.LayerNorm(hidden_dim * 2),
nn.GELU(), # Better than ReLU for transformers
nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens) # Can produce up to 4 tokens
),
# Token-wise importance predictor
'importance_predictor': nn.Sequential(
nn.Linear(hidden_dim, 256),
nn.ReLU(),
nn.Linear(256, self.max_tokens), # Importance for each potential token
nn.Softmax(dim=-1)
)
})
# Boundary refinement: N β†’ M tokens with linguistic awareness
self.boundary_refiner = nn.ModuleDict({
'scorer': nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.ReLU(),
nn.Linear(512, 1)
),
'morpheme_detector': nn.Conv1d(256, 64, 3), # ν˜•νƒœμ†Œ
'word_detector': nn.Conv1d(256, 64, 5), # 단어
'phrase_detector': nn.Conv1d(256, 64, 7), # ꡬ
'adjuster': nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=16,
dim_feedforward=4 * hidden_dim,
dropout=0.1,
batch_first=True
)
})
# Initialize split_predictor bias to prefer 1 token initially
# This ensures untrained model starts with maximum compression
with torch.no_grad():
self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0])
# High bias for 1 token, negative for others
def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]:
"""
Progressive splitting forward pass
Args:
input_ids: Input byte sequence [batch, seq_len]
temperature: Gumbel-Softmax temperature for annealing
"""
batch_size = input_ids.size(0)
# Step 1: 48 bytes β†’ 1 super token
byte_embeddings = self.byte_embed(input_ids) # [batch, 48, 64]
flattened = byte_embeddings.view(batch_size, -1) # [batch, 3072]
super_token = self.initial_compressor(flattened) # [batch, 1280]
super_token = super_token.unsqueeze(1) # [batch, 1, 1280]
# Step 2: Language analysis and splitting (1 β†’ N)
lang_features = self.language_splitter['analyzer'](super_token)
split_logits = self.language_splitter['split_predictor'](lang_features)
split_weights = F.softmax(split_logits, dim=-1) # [batch, 1, 8]
# Direct transformation from super token to initial representation
# No hardcoded splits - let the model learn everything
lang_tokens = super_token # Start with compressed representation
# TRUE Adaptive expansion - Model learns optimal split (1~4 tokens)
# Analyze content to decide how many tokens needed
expansion_features = self.language_splitter['analyzer'](lang_tokens) # [batch, 1, 256]
# Dynamic expansion: generate up to 4 tokens from super token
expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1)) # [batch, hidden_dim*4]
expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim) # [batch, 4, hidden_dim]
# Predict how many tokens we actually need (1~4)
split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1)) # [batch, 4]
# Clamp logits to prevent extreme values that cause NaN
split_logits = torch.clamp(split_logits, min=-10, max=10)
# Ensure minimum temperature to prevent instability
safe_temperature = max(temperature, 0.5)
split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1) # [batch, 4]
# Predict importance for each potential token position
importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1)) # [batch, 4]
# Dynamic token selection with importance-weighted allocation
# Create cumulative mask for progressive token usage
# If split_weights = [0.1, 0.2, 0.6, 0.1], we mainly use 3 tokens
# Create progressive masks for 1, 2, 3, 4 tokens
masks = []
for n in range(1, self.max_tokens + 1):
mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device)
mask[:, :n, :] = 1.0
masks.append(mask)
# Apply importance-weighted masking
# Important parts get more tokens, less important parts get fewer
weighted_outputs = []
for i, mask in enumerate(masks):
num_tokens = i + 1
# Weight by both split decision and importance
token_weight = split_weights[:, i:i+1].unsqueeze(-1) # [batch, 1, 1]
# Apply importance modulation for asymmetric splits
if num_tokens > 1:
# Redistribute tokens based on importance
importance_adjusted = importance[:, :num_tokens].unsqueeze(-1) # [batch, n, 1]
masked = expanded[:, :num_tokens] * importance_adjusted
else:
masked = expanded[:, :num_tokens]
# Pad to max length
if num_tokens < self.max_tokens:
padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim,
device=expanded.device)
masked = torch.cat([masked, padding], dim=1)
weighted_outputs.append(masked * token_weight)
# Sum all weighted possibilities (differentiable selection)
lang_tokens = sum(weighted_outputs)
# Determine effective number of tokens (for monitoring)
# Weighted average of token counts
token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32)
avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item()
k = lang_tokens.size(1)
# Step 3: Boundary refinement (N β†’ M)
# Calculate boundary scores for each token position
boundary_scores = self.boundary_refiner['scorer'](lang_tokens) # [batch, N, 1]
# Detect linguistic boundaries (morpheme, word, phrase)
# Extract features for boundary detection
if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3:
batch_size, num_tokens, hidden_dim = lang_tokens.shape
# For boundary detection, we need to consider the original byte sequence
# But we're working with compressed tokens here
# So we detect boundaries based on learned representations
# Apply boundary adjustment with TransformerEncoderLayer
# This learns to adjust token boundaries based on context
refined_tokens = self.boundary_refiner['adjuster'](lang_tokens)
# The adjuster should learn to:
# 1. Respect UTF-8 boundaries (learned during training)
# 2. Align with word/phrase boundaries (learned from language patterns)
# 3. Maintain semantic coherence within each token
else:
refined_tokens = lang_tokens
# Determine actual number of tokens based on highest probability
# During inference, use argmax. During training, use weighted average.
if self.training:
# During training, use weighted average for differentiability
actual_num_tokens = avg_tokens
else:
# During inference, select the split with highest probability
split_decision = torch.argmax(split_weights, dim=-1) # [batch]
actual_num_tokens = (split_decision.float().mean() + 1).item() # +1 because indices are 0-3
# Calculate compression ratio based on actual tokens used
compression_ratio = 48.0 / max(1, actual_num_tokens)
return {
'tokens': refined_tokens,
'num_tokens': actual_num_tokens,
'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device),
'gate_values': None, # Will be filled by cross-attention
'language_features': lang_features,
'split_weights': split_weights,
'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1),
'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None
}
class EncoderV62(nn.Module):
"""
4-Layer Progressive Splitting Encoder with Cross-Attention
All layers: 1280 dimensions
"""
def __init__(self, config: Optional[Dict] = None):
super().__init__()
# Store config for later use
self.config = config or {}
# Configuration
self.hidden_dim = 1280
self.num_heads = 16
self.num_layers = 4
self.max_seq_len = 48
self.dropout = 0.1
# RoPE positional encoding (GPT-5 suggestion)
self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len)
# Layer 0: Progressive Splitting (48→1→N→M) - Pass config
self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config)
# Layers 1-3: Transformer encoders with cross-attention
self.encoder_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=self.hidden_dim,
nhead=self.num_heads,
dim_feedforward=4 * self.hidden_dim, # 5120
dropout=self.dropout,
batch_first=True
) for _ in range(3)
])
# Cross-attention between layers with MQA (GPT-5 suggestion)
self.cross_attentions = nn.ModuleList([
GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2) # 8x memory reduction
for _ in range(3)
])
# Output heads for different tasks
self.boundary_head = nn.Linear(self.hidden_dim, 4)
self.language_head = nn.Linear(self.hidden_dim, 128) # Reduced from 512 (GPT suggestion)
self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim)
# Monitoring metrics (GPT-5 suggestion)
self.register_buffer('compression_ratios', torch.zeros(1))
self.register_buffer('gate_averages', torch.zeros(3))
def forward(self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
"""
Forward pass through the encoder
Args:
input_ids: Input byte sequence
attention_mask: Optional attention mask
temperature: Gumbel-Softmax temperature for annealing
"""
# Layer 0: Progressive splitting with temperature
split_output = self.progressive_splitter(input_ids, temperature)
x = split_output['tokens'] # [batch, M, 1280]
# Apply RoPE
x = self.rope(x, x.size(1))
# Store all hidden states for decoder
all_hidden_states = [x]
gate_values_list = []
# Layers 1-3 with cross-attention
for i, (encoder_layer, cross_attn) in enumerate(
zip(self.encoder_layers, self.cross_attentions)
):
# Self-attention through transformer layer
# GPT final check: Don't pass mask after progressive splitting changes sequence length
x = encoder_layer(x) # No mask needed (no padding after compression)
# Cross-attention with previous layer
if i > 0:
# Cross-attention with previous layer
x, gate_values = cross_attn(
query=x,
key=all_hidden_states[-1],
value=all_hidden_states[-1],
mask=None # Mask not applicable after compression
)
gate_values_list.append(gate_values)
# Keep tensor shape consistent - store in existing buffer element
self.gate_averages[i-1] = gate_values.mean().detach().item() # Fix indexing
all_hidden_states.append(x)
# Output projections
boundaries = self.boundary_head(x)
language_clusters = self.language_head(x)
compressed = self.compression_head(x)
# Update monitoring metrics
# Ensure tensor is 1-dimensional for buffer assignment
compression_ratio = split_output['compression_ratio']
if compression_ratio.dim() == 0: # Scalar tensor
self.compression_ratios[0] = compression_ratio
else:
self.compression_ratios = compression_ratio
return {
'last_hidden_state': x,
'all_hidden_states': all_hidden_states,
'boundaries': boundaries,
'language_clusters': language_clusters,
'compressed': compressed,
'compression_ratio': split_output['compression_ratio'],
'num_tokens': split_output['num_tokens'],
'splitting_probs': split_output.get('split_weights', None), # Add for diagnostics
'gate_values': gate_values_list,
'gate_averages': self.gate_averages,
'split_info': {
'language_features': split_output['language_features'],
'split_weights': split_output['split_weights']
}
}
def get_monitoring_stats(self) -> Dict[str, float]:
"""
Get monitoring statistics (GPT-5 suggestion)
"""
return {
'avg_compression_ratio': self.compression_ratios.item(),
'gate_layer1': self.gate_averages[0].item(),
'gate_layer2': self.gate_averages[1].item(),
'gate_layer3': self.gate_averages[2].item(),
}
def set_warmup_step(self, step: int, total_warmup: int = 1000):
"""
Set warmup alpha for all gates (GPT suggestion)
Gradually increase gate influence from 0 to 1
"""
alpha = min(1.0, step / total_warmup)
for cross_attn in self.cross_attentions:
cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device)
def adaptive_compression_control(self, reconstruction_loss: float):
"""
Adaptive compression based on reconstruction quality
No fixed phases - model learns optimal compression
"""
# If reconstruction is poor, model will learn to use more tokens
# This happens automatically through gradient descent
# No manual phase control needed
pass # Let gradients handle it
class DualSlidingWindowEncoder(EncoderV62):
"""
Extension with dual sliding window system
Handles both chunk-level and token-level boundaries
"""
def __init__(self, config: Optional[Dict] = None):
super().__init__(config)
# Chunk-level sliding window
self.chunk_window = nn.Conv1d(
in_channels=1,
out_channels=1,
kernel_size=8, # 8-byte overlap
stride=40, # 48-8=40 stride
padding=4
)
# Token-level sliding window
self.token_window = nn.MultiheadAttention(
embed_dim=self.hidden_dim,
num_heads=self.num_heads,
batch_first=True
)
def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Handle sequences longer than 48 bytes with sliding windows
"""
batch_size, seq_len = input_ids.shape
if seq_len <= 48:
return super().forward(input_ids)
# Process in chunks with overlap
chunks = []
for i in range(0, seq_len - 48 + 1, 40): # 8-byte overlap
chunk = input_ids[:, i:i+48]
chunk_output = super().forward(chunk)
chunks.append(chunk_output['last_hidden_state'])
# Combine chunks with attention
combined = torch.cat(chunks, dim=1)
attended, _ = self.token_window(combined, combined, combined)
return {
'last_hidden_state': attended,
'num_chunks': len(chunks),
'total_compression': seq_len / attended.size(1)
}
if __name__ == "__main__":
# Test the encoder
encoder = EncoderV62()
# Test input
batch_size = 2
input_ids = torch.randint(0, 256, (batch_size, 48))
# Forward pass
output = encoder(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output tokens: {output['num_tokens']}")
print(f"Compression ratio: {output['compression_ratio']:.2f}:1")
print(f"Gate averages: {output['gate_averages']}")
print(f"Monitoring stats: {encoder.get_monitoring_stats()}")