|
|
"""
|
|
|
Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder
|
|
|
With GPT-5 suggested improvements
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
import math
|
|
|
|
|
|
|
|
|
class RoPEPositionalEncoding(nn.Module):
|
|
|
"""
|
|
|
Rotary Position Embedding (RoPE) - GPT-5 suggestion
|
|
|
Better for handling chunk boundaries and variable sequence lengths
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
self.max_seq_len = max_seq_len
|
|
|
self.base = base
|
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
|
|
|
|
|
t = torch.arange(max_seq_len).type_as(self.inv_freq)
|
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
|
self.register_buffer('cos_cached', freqs.cos())
|
|
|
self.register_buffer('sin_cached', freqs.sin())
|
|
|
|
|
|
def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Apply RoPE to input tensor
|
|
|
Handles chunk boundary corrections as suggested by GPT-5
|
|
|
"""
|
|
|
if seq_len is None:
|
|
|
seq_len = x.shape[1]
|
|
|
|
|
|
|
|
|
cos = self.cos_cached[:seq_len]
|
|
|
sin = self.sin_cached[:seq_len]
|
|
|
|
|
|
|
|
|
x_rot = self._apply_rotary_emb(x, cos, sin)
|
|
|
|
|
|
return x_rot
|
|
|
|
|
|
def _apply_rotary_emb(self, x, cos, sin):
|
|
|
"""Apply rotary embedding to input"""
|
|
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
|
x_rot = torch.stack([
|
|
|
x1 * cos - x2 * sin,
|
|
|
x1 * sin + x2 * cos
|
|
|
], dim=-1).flatten(-2)
|
|
|
return x_rot
|
|
|
|
|
|
|
|
|
class GatedCrossAttention(nn.Module):
|
|
|
"""
|
|
|
Gated Cross-Attention with MQA - GPT-5 suggestion
|
|
|
Monitor gate values for quality assessment
|
|
|
16Q β 2K/V for 8x memory reduction
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.kv_heads = kv_heads
|
|
|
self.head_dim = hidden_dim // num_heads
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim)
|
|
|
self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim)
|
|
|
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
self.gate = nn.Sequential(
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.register_buffer('gate_values', torch.zeros(1))
|
|
|
|
|
|
|
|
|
self.register_buffer('warmup_alpha', torch.tensor(1.0))
|
|
|
|
|
|
def forward(self,
|
|
|
query: torch.Tensor,
|
|
|
key: torch.Tensor,
|
|
|
value: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Forward pass with gate monitoring
|
|
|
Returns: (output, gate_values)
|
|
|
"""
|
|
|
batch_size, seq_len = query.shape[:2]
|
|
|
|
|
|
|
|
|
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
|
K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim)
|
|
|
V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim)
|
|
|
|
|
|
|
|
|
Q = Q.transpose(1, 2)
|
|
|
K = K.transpose(1, 2)
|
|
|
V = V.transpose(1, 2)
|
|
|
|
|
|
|
|
|
if self.kv_heads < self.num_heads:
|
|
|
repeat_factor = self.num_heads // self.kv_heads
|
|
|
K = K.repeat_interleave(repeat_factor, dim=1)
|
|
|
V = V.repeat_interleave(repeat_factor, dim=1)
|
|
|
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
if mask is not None:
|
|
|
scores = scores.masked_fill(mask == 0, -1e9)
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_output = torch.matmul(attn_weights, V)
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
|
|
|
gate_input = torch.cat([query, attn_output], dim=-1)
|
|
|
gate_values = self.gate(gate_input)
|
|
|
|
|
|
|
|
|
self.gate_values[0] = gate_values.mean().detach()
|
|
|
|
|
|
|
|
|
gate_values = gate_values * self.warmup_alpha
|
|
|
output = gate_values * attn_output + (1 - gate_values) * query
|
|
|
|
|
|
return output, gate_values
|
|
|
|
|
|
|
|
|
|
|
|
class ProgressiveSplittingLayer(nn.Module):
|
|
|
"""
|
|
|
Core innovation: 48 bytes β 1 token β N tokens β M tokens
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.config = config or {}
|
|
|
|
|
|
|
|
|
|
|
|
self.min_tokens = 1
|
|
|
self.max_tokens = 4
|
|
|
|
|
|
|
|
|
self.byte_embed = nn.Embedding(260, 64)
|
|
|
self.initial_compressor = nn.Sequential(
|
|
|
nn.Linear(48 * 64, 2048),
|
|
|
nn.LayerNorm(2048),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(2048, hidden_dim),
|
|
|
nn.LayerNorm(hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.language_splitter = nn.ModuleDict({
|
|
|
'analyzer': nn.Sequential(
|
|
|
nn.Linear(hidden_dim, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(512, 256)
|
|
|
),
|
|
|
'split_predictor': nn.Linear(256, self.max_tokens),
|
|
|
|
|
|
'dynamic_expander': nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
|
nn.LayerNorm(hidden_dim * 2),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens)
|
|
|
),
|
|
|
|
|
|
'importance_predictor': nn.Sequential(
|
|
|
nn.Linear(hidden_dim, 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(256, self.max_tokens),
|
|
|
nn.Softmax(dim=-1)
|
|
|
)
|
|
|
})
|
|
|
|
|
|
|
|
|
self.boundary_refiner = nn.ModuleDict({
|
|
|
'scorer': nn.Sequential(
|
|
|
nn.Linear(hidden_dim, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(512, 1)
|
|
|
),
|
|
|
'morpheme_detector': nn.Conv1d(256, 64, 3),
|
|
|
'word_detector': nn.Conv1d(256, 64, 5),
|
|
|
'phrase_detector': nn.Conv1d(256, 64, 7),
|
|
|
'adjuster': nn.TransformerEncoderLayer(
|
|
|
d_model=hidden_dim,
|
|
|
nhead=16,
|
|
|
dim_feedforward=4 * hidden_dim,
|
|
|
dropout=0.1,
|
|
|
batch_first=True
|
|
|
)
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0])
|
|
|
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Progressive splitting forward pass
|
|
|
|
|
|
Args:
|
|
|
input_ids: Input byte sequence [batch, seq_len]
|
|
|
temperature: Gumbel-Softmax temperature for annealing
|
|
|
"""
|
|
|
batch_size = input_ids.size(0)
|
|
|
|
|
|
|
|
|
byte_embeddings = self.byte_embed(input_ids)
|
|
|
flattened = byte_embeddings.view(batch_size, -1)
|
|
|
super_token = self.initial_compressor(flattened)
|
|
|
super_token = super_token.unsqueeze(1)
|
|
|
|
|
|
|
|
|
lang_features = self.language_splitter['analyzer'](super_token)
|
|
|
split_logits = self.language_splitter['split_predictor'](lang_features)
|
|
|
split_weights = F.softmax(split_logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
lang_tokens = super_token
|
|
|
|
|
|
|
|
|
|
|
|
expansion_features = self.language_splitter['analyzer'](lang_tokens)
|
|
|
|
|
|
|
|
|
expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1))
|
|
|
expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim)
|
|
|
|
|
|
|
|
|
split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1))
|
|
|
|
|
|
split_logits = torch.clamp(split_logits, min=-10, max=10)
|
|
|
|
|
|
safe_temperature = max(temperature, 0.5)
|
|
|
split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1)
|
|
|
|
|
|
|
|
|
importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = []
|
|
|
for n in range(1, self.max_tokens + 1):
|
|
|
mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device)
|
|
|
mask[:, :n, :] = 1.0
|
|
|
masks.append(mask)
|
|
|
|
|
|
|
|
|
|
|
|
weighted_outputs = []
|
|
|
for i, mask in enumerate(masks):
|
|
|
num_tokens = i + 1
|
|
|
|
|
|
token_weight = split_weights[:, i:i+1].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
if num_tokens > 1:
|
|
|
|
|
|
importance_adjusted = importance[:, :num_tokens].unsqueeze(-1)
|
|
|
masked = expanded[:, :num_tokens] * importance_adjusted
|
|
|
else:
|
|
|
masked = expanded[:, :num_tokens]
|
|
|
|
|
|
|
|
|
if num_tokens < self.max_tokens:
|
|
|
padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim,
|
|
|
device=expanded.device)
|
|
|
masked = torch.cat([masked, padding], dim=1)
|
|
|
|
|
|
weighted_outputs.append(masked * token_weight)
|
|
|
|
|
|
|
|
|
lang_tokens = sum(weighted_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32)
|
|
|
avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item()
|
|
|
|
|
|
k = lang_tokens.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
boundary_scores = self.boundary_refiner['scorer'](lang_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3:
|
|
|
batch_size, num_tokens, hidden_dim = lang_tokens.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
refined_tokens = self.boundary_refiner['adjuster'](lang_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
refined_tokens = lang_tokens
|
|
|
|
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
actual_num_tokens = avg_tokens
|
|
|
else:
|
|
|
|
|
|
split_decision = torch.argmax(split_weights, dim=-1)
|
|
|
actual_num_tokens = (split_decision.float().mean() + 1).item()
|
|
|
|
|
|
|
|
|
compression_ratio = 48.0 / max(1, actual_num_tokens)
|
|
|
|
|
|
return {
|
|
|
'tokens': refined_tokens,
|
|
|
'num_tokens': actual_num_tokens,
|
|
|
'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device),
|
|
|
'gate_values': None,
|
|
|
'language_features': lang_features,
|
|
|
'split_weights': split_weights,
|
|
|
'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1),
|
|
|
'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None
|
|
|
}
|
|
|
|
|
|
|
|
|
class EncoderV62(nn.Module):
|
|
|
"""
|
|
|
4-Layer Progressive Splitting Encoder with Cross-Attention
|
|
|
All layers: 1280 dimensions
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.config = config or {}
|
|
|
|
|
|
|
|
|
self.hidden_dim = 1280
|
|
|
self.num_heads = 16
|
|
|
self.num_layers = 4
|
|
|
self.max_seq_len = 48
|
|
|
self.dropout = 0.1
|
|
|
|
|
|
|
|
|
self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len)
|
|
|
|
|
|
|
|
|
self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config)
|
|
|
|
|
|
|
|
|
self.encoder_layers = nn.ModuleList([
|
|
|
nn.TransformerEncoderLayer(
|
|
|
d_model=self.hidden_dim,
|
|
|
nhead=self.num_heads,
|
|
|
dim_feedforward=4 * self.hidden_dim,
|
|
|
dropout=self.dropout,
|
|
|
batch_first=True
|
|
|
) for _ in range(3)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.cross_attentions = nn.ModuleList([
|
|
|
GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2)
|
|
|
for _ in range(3)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.boundary_head = nn.Linear(self.hidden_dim, 4)
|
|
|
self.language_head = nn.Linear(self.hidden_dim, 128)
|
|
|
self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim)
|
|
|
|
|
|
|
|
|
self.register_buffer('compression_ratios', torch.zeros(1))
|
|
|
self.register_buffer('gate_averages', torch.zeros(3))
|
|
|
|
|
|
def forward(self,
|
|
|
input_ids: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Forward pass through the encoder
|
|
|
|
|
|
Args:
|
|
|
input_ids: Input byte sequence
|
|
|
attention_mask: Optional attention mask
|
|
|
temperature: Gumbel-Softmax temperature for annealing
|
|
|
"""
|
|
|
|
|
|
split_output = self.progressive_splitter(input_ids, temperature)
|
|
|
x = split_output['tokens']
|
|
|
|
|
|
|
|
|
x = self.rope(x, x.size(1))
|
|
|
|
|
|
|
|
|
all_hidden_states = [x]
|
|
|
gate_values_list = []
|
|
|
|
|
|
|
|
|
for i, (encoder_layer, cross_attn) in enumerate(
|
|
|
zip(self.encoder_layers, self.cross_attentions)
|
|
|
):
|
|
|
|
|
|
|
|
|
x = encoder_layer(x)
|
|
|
|
|
|
|
|
|
if i > 0:
|
|
|
|
|
|
x, gate_values = cross_attn(
|
|
|
query=x,
|
|
|
key=all_hidden_states[-1],
|
|
|
value=all_hidden_states[-1],
|
|
|
mask=None
|
|
|
)
|
|
|
gate_values_list.append(gate_values)
|
|
|
|
|
|
self.gate_averages[i-1] = gate_values.mean().detach().item()
|
|
|
|
|
|
all_hidden_states.append(x)
|
|
|
|
|
|
|
|
|
boundaries = self.boundary_head(x)
|
|
|
language_clusters = self.language_head(x)
|
|
|
compressed = self.compression_head(x)
|
|
|
|
|
|
|
|
|
|
|
|
compression_ratio = split_output['compression_ratio']
|
|
|
if compression_ratio.dim() == 0:
|
|
|
self.compression_ratios[0] = compression_ratio
|
|
|
else:
|
|
|
self.compression_ratios = compression_ratio
|
|
|
|
|
|
return {
|
|
|
'last_hidden_state': x,
|
|
|
'all_hidden_states': all_hidden_states,
|
|
|
'boundaries': boundaries,
|
|
|
'language_clusters': language_clusters,
|
|
|
'compressed': compressed,
|
|
|
'compression_ratio': split_output['compression_ratio'],
|
|
|
'num_tokens': split_output['num_tokens'],
|
|
|
'splitting_probs': split_output.get('split_weights', None),
|
|
|
'gate_values': gate_values_list,
|
|
|
'gate_averages': self.gate_averages,
|
|
|
'split_info': {
|
|
|
'language_features': split_output['language_features'],
|
|
|
'split_weights': split_output['split_weights']
|
|
|
}
|
|
|
}
|
|
|
|
|
|
def get_monitoring_stats(self) -> Dict[str, float]:
|
|
|
"""
|
|
|
Get monitoring statistics (GPT-5 suggestion)
|
|
|
"""
|
|
|
return {
|
|
|
'avg_compression_ratio': self.compression_ratios.item(),
|
|
|
'gate_layer1': self.gate_averages[0].item(),
|
|
|
'gate_layer2': self.gate_averages[1].item(),
|
|
|
'gate_layer3': self.gate_averages[2].item(),
|
|
|
}
|
|
|
|
|
|
def set_warmup_step(self, step: int, total_warmup: int = 1000):
|
|
|
"""
|
|
|
Set warmup alpha for all gates (GPT suggestion)
|
|
|
Gradually increase gate influence from 0 to 1
|
|
|
"""
|
|
|
alpha = min(1.0, step / total_warmup)
|
|
|
for cross_attn in self.cross_attentions:
|
|
|
cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device)
|
|
|
|
|
|
def adaptive_compression_control(self, reconstruction_loss: float):
|
|
|
"""
|
|
|
Adaptive compression based on reconstruction quality
|
|
|
No fixed phases - model learns optimal compression
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
class DualSlidingWindowEncoder(EncoderV62):
|
|
|
"""
|
|
|
Extension with dual sliding window system
|
|
|
Handles both chunk-level and token-level boundaries
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
|
|
self.chunk_window = nn.Conv1d(
|
|
|
in_channels=1,
|
|
|
out_channels=1,
|
|
|
kernel_size=8,
|
|
|
stride=40,
|
|
|
padding=4
|
|
|
)
|
|
|
|
|
|
|
|
|
self.token_window = nn.MultiheadAttention(
|
|
|
embed_dim=self.hidden_dim,
|
|
|
num_heads=self.num_heads,
|
|
|
batch_first=True
|
|
|
)
|
|
|
|
|
|
def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Handle sequences longer than 48 bytes with sliding windows
|
|
|
"""
|
|
|
batch_size, seq_len = input_ids.shape
|
|
|
|
|
|
if seq_len <= 48:
|
|
|
return super().forward(input_ids)
|
|
|
|
|
|
|
|
|
chunks = []
|
|
|
for i in range(0, seq_len - 48 + 1, 40):
|
|
|
chunk = input_ids[:, i:i+48]
|
|
|
chunk_output = super().forward(chunk)
|
|
|
chunks.append(chunk_output['last_hidden_state'])
|
|
|
|
|
|
|
|
|
combined = torch.cat(chunks, dim=1)
|
|
|
attended, _ = self.token_window(combined, combined, combined)
|
|
|
|
|
|
return {
|
|
|
'last_hidden_state': attended,
|
|
|
'num_chunks': len(chunks),
|
|
|
'total_compression': seq_len / attended.size(1)
|
|
|
}
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
encoder = EncoderV62()
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
input_ids = torch.randint(0, 256, (batch_size, 48))
|
|
|
|
|
|
|
|
|
output = encoder(input_ids)
|
|
|
|
|
|
print(f"Input shape: {input_ids.shape}")
|
|
|
print(f"Output tokens: {output['num_tokens']}")
|
|
|
print(f"Compression ratio: {output['compression_ratio']:.2f}:1")
|
|
|
print(f"Gate averages: {output['gate_averages']}")
|
|
|
print(f"Monitoring stats: {encoder.get_monitoring_stats()}") |