ggunio's picture
Fix: Increase max_length to 256 for proper text reconstruction
da2970e verified
"""
Intelligent Tokenizer v6.2.0 - Unified Model
Integrates encoder, decoder, and tokenizer with all GPT improvements
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Union
import math
# Import our components
try:
from .encoder import EncoderV62
from .decoder import DecoderV62
from .tokenizer import ByteTokenizerV62
except ImportError:
# For standalone testing
from encoder import EncoderV62
from decoder import DecoderV62
from tokenizer import ByteTokenizerV62
class IntelligentTokenizerV62(nn.Module):
"""
Complete v6.2.0 model with progressive splitting and optimizations
Key features:
- 48-byte chunks (46+2 with BOS/EOS)
- Progressive splitting: 48→1→N→M tokens
- Multi-level cross-attention
- KV cache optimization (8x reduction)
- All GPT-5 improvements integrated
"""
def __init__(self, config: Optional[Dict] = None):
super().__init__()
# Default configuration
self.config = config or {}
# Model components
self.tokenizer = ByteTokenizerV62(config)
self.encoder = EncoderV62(config)
self.decoder = DecoderV62(config)
# Training configuration
self.compression_weight = 0.1
self.reconstruction_weight = 0.1
self.boundary_weight = 0.1
# Monitoring
self.register_buffer('training_step', torch.tensor(0))
self.register_buffer('current_epoch', torch.tensor(0))
def forward(self,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
text: str = None,
return_loss: bool = True,
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
"""
Unified forward pass
Args:
input_ids: Pre-tokenized input (optional)
attention_mask: Attention mask (optional)
labels: Target labels for training (optional)
text: Raw text input (alternative to input_ids)
return_loss: Whether to compute loss
temperature: Temperature for Gumbel-Softmax in encoder
Returns:
Dictionary with model outputs
"""
# Handle text input
if text is not None:
encoded = self.tokenizer.encode(text, add_special_tokens=True)
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
# Handle string passed as input_ids (common mistake)
if isinstance(input_ids, str):
text = input_ids
encoded = self.tokenizer.encode(text, add_special_tokens=True)
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
# Ensure tensors are on the right device
device = next(self.parameters()).device
if input_ids is not None and torch.is_tensor(input_ids):
input_ids = input_ids.to(device)
if attention_mask is not None and torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(device)
if labels is not None and torch.is_tensor(labels):
labels = labels.to(device)
# Encoder forward pass with temperature for Gumbel annealing
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
temperature=temperature
)
# Decoder forward pass
if labels is not None:
# Training mode with teacher forcing (GPT suggestion: shift by 1)
# Input: labels[:-1], Target: labels[1:]
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
decoder_outputs = self.decoder(
encoder_all_hidden=encoder_outputs['all_hidden_states'],
decoder_input_ids=decoder_input,
attention_mask=decoder_mask
)
else:
# Inference mode (without teacher forcing)
# For now, fallback to using input as labels for stable training
# TODO: Implement proper autoregressive generation
if return_loss and input_ids is not None:
labels = input_ids # Use input as both input and target
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
decoder_outputs = self.decoder(
encoder_all_hidden=encoder_outputs['all_hidden_states'],
decoder_input_ids=decoder_input,
attention_mask=decoder_mask
)
else:
decoder_outputs = self.decoder(
encoder_all_hidden=encoder_outputs['all_hidden_states'],
decoder_input_ids=None,
attention_mask=attention_mask
)
# Combine outputs with prefix to avoid key collision (GPT suggestion)
outputs = {}
for key, value in encoder_outputs.items():
outputs[f'enc_{key}'] = value
for key, value in decoder_outputs.items():
outputs[f'dec_{key}'] = value
# Compute loss if requested
if return_loss and labels is not None:
loss = self.compute_loss(outputs, labels, attention_mask)
outputs['loss'] = loss
return outputs
def compute_loss(self,
outputs: Dict[str, torch.Tensor],
labels: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute combined loss with multiple objectives
Components:
1. Reconstruction loss (cross-entropy)
2. Compression loss (encourage higher compression)
3. Boundary loss (boundary prediction accuracy)
"""
losses = {}
# 1. Reconstruction loss (GPT suggestion: use shifted targets)
if 'dec_logits' in outputs:
logits = outputs['dec_logits']
# Shift targets for next-token prediction
target_labels = labels[:, 1:] if labels.dim() > 1 else labels[1:]
target_mask = attention_mask[:, 1:] if attention_mask is not None and attention_mask.dim() > 1 else None
# Reshape for cross-entropy
batch_size, seq_len, vocab_size = logits.shape
logits_flat = logits.reshape(-1, vocab_size)
labels_flat = target_labels.reshape(-1)
# Mask out padding (GPT suggestion: use bool mask)
if target_mask is not None:
mask_flat = target_mask.reshape(-1).bool()
reconstruction_loss = F.cross_entropy(
logits_flat[mask_flat],
labels_flat[mask_flat],
ignore_index=self.tokenizer.PAD,
label_smoothing=0.1 # Added label smoothing
)
else:
reconstruction_loss = F.cross_entropy(
logits_flat,
labels_flat,
ignore_index=self.tokenizer.PAD,
label_smoothing=0.1
)
losses['reconstruction'] = reconstruction_loss * self.reconstruction_weight
# 2. Compression loss (GPT suggestion: use proper device tensor creation)
if 'enc_compression_ratio' in outputs:
# Target compression ratio (e.g., 24:1 as per config)
target_ratio = 24.0
current_ratio = outputs['enc_compression_ratio']
# Create tensors on same device (GPT suggestion)
if isinstance(current_ratio, (int, float)):
current_ratio_tensor = labels.new_tensor(current_ratio, dtype=torch.float32)
else:
current_ratio_tensor = current_ratio.float()
target_ratio_tensor = labels.new_tensor(target_ratio, dtype=torch.float32)
# Penalize deviation from target (use smooth L1 to avoid explosion)
compression_loss = F.smooth_l1_loss(
current_ratio_tensor,
target_ratio_tensor,
beta=2.0 # Transition point from L2 to L1
)
losses['compression'] = compression_loss * self.compression_weight
# 3. Boundary loss (GPT suggestion: more meaningful boundary learning)
if 'enc_boundaries' in outputs and outputs['enc_boundaries'] is not None:
boundary_scores = outputs['enc_boundaries']
# Boundary sparsity + smoothness (GPT suggestion)
# Encourage sparse but clear boundaries
boundary_probs = torch.sigmoid(boundary_scores)
# Sparsity loss (boundaries should be rare)
sparsity_loss = boundary_probs.mean() * 0.1
# Smoothness loss (adjacent boundaries should be different)
if boundary_scores.size(1) > 1:
diff = boundary_scores[:, 1:] - boundary_scores[:, :-1]
smoothness_loss = (diff ** 2).mean() * 0.01
else:
smoothness_loss = 0.0
boundary_loss = sparsity_loss + smoothness_loss
losses['boundary'] = boundary_loss * self.boundary_weight
# Combine all losses
total_loss = sum(losses.values())
# Store individual losses for monitoring
self.last_losses = losses
return total_loss
def generate(self,
text: str = None,
input_ids: torch.Tensor = None,
max_length: int = 256,
temperature: float = 0.1,
top_k: int = 10,
top_p: float = 0.95) -> str:
"""
Generate/reconstruct text
Args:
text: Input text to encode and reconstruct
input_ids: Pre-encoded input
max_length: Maximum generation length
temperature: Sampling temperature
top_k: Top-k sampling
top_p: Top-p (nucleus) sampling
Returns:
Reconstructed/generated text
"""
# Encode input if text is provided (GPT suggestion: handle multi-chunk properly)
chunk_positions = None
if text is not None:
# Check if text needs chunking
if len(text.encode('utf-8')) > self.tokenizer.content_size:
encoded = self.tokenizer.encode(text, add_special_tokens=True, return_chunks=True)
chunk_positions = encoded.get('chunk_positions', None)
else:
encoded = self.tokenizer.encode(text, add_special_tokens=True)
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
else:
attention_mask = (input_ids != self.tokenizer.PAD).bool() # GPT suggestion: bool mask
# Move to device
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Encode
with torch.no_grad():
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Prepare all hidden states for decoder
if 'all_hidden_states' in encoder_outputs:
encoder_all_hidden = encoder_outputs['all_hidden_states']
else:
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
encoder_all_hidden = [compressed] * 4
# Autoregressive generation (fixed version)
batch_size = input_ids.size(0)
# Start with BOS token
generated_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device)
for step in range(max_length - 1):
with torch.no_grad():
# Decode current sequence
decoder_outputs = self.decoder(
encoder_all_hidden=encoder_all_hidden,
decoder_input_ids=generated_ids,
attention_mask=torch.ones_like(generated_ids),
use_cache=False
)
# Get next token prediction
logits = decoder_outputs['logits'][:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# Sample next token
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to generated sequence
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# Check for EOS
if (next_token == self.tokenizer.EOS).all():
break
# Decode to text (GPT suggestion: proper multi-chunk reconstruction)
if generated_ids.dim() > 2 and chunk_positions is not None:
# Multi-chunk output with positions
text = self.tokenizer.reconstruct(
generated_ids,
positions=chunk_positions,
overlap=self.tokenizer.chunk_overlap
)
elif generated_ids.dim() > 2:
# Multi-chunk without positions (fallback)
text = self.tokenizer.reconstruct(generated_ids)
else:
# Single sequence
text = self.tokenizer.decode(generated_ids[0] if generated_ids.dim() > 1 else generated_ids)
return text
def compress(self, text: str) -> Dict[str, Union[torch.Tensor, float]]:
"""
Compress text and return compression statistics
Args:
text: Input text to compress
Returns:
Dictionary with compressed representation and statistics
"""
# Encode text
encoded = self.tokenizer.encode(text, add_special_tokens=True)
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
# Move to device
device = next(self.parameters()).device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Get compressed representation
with torch.no_grad():
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
return {
'compressed': encoder_outputs['compressed'],
'num_tokens': encoder_outputs['num_tokens'],
'compression_ratio': encoder_outputs['compression_ratio'],
'original_bytes': len(text.encode('utf-8')),
'compressed_size': encoder_outputs['num_tokens'] * 2 # Approximate bytes
}
def update_training_state(self, epoch: int, step: int = 0, reconstruction_loss: float = None):
"""
Update training state - adaptive, not phase-based
Args:
epoch: Current epoch
step: Current training step
reconstruction_loss: Current reconstruction quality
"""
self.current_epoch = torch.tensor(epoch)
self.training_step = torch.tensor(step)
# Update encoder warmup (gates only)
self.encoder.set_warmup_step(step)
# Adaptive weight adjustment based on performance
if reconstruction_loss is not None:
# If reconstruction is poor, increase its weight
if reconstruction_loss > 1.0:
self.reconstruction_weight = 1.0
self.compression_weight = 0.1 # Less compression focus
else:
# Good reconstruction, can focus on compression
self.reconstruction_weight = 0.5
self.compression_weight = 0.1
# Boundary weight stays moderate
self.boundary_weight = 0.1
# Let encoder know about reconstruction quality
self.encoder.adaptive_compression_control(reconstruction_loss)
else:
# Default balanced weights
self.reconstruction_weight = 0.5
self.compression_weight = 0.1
self.boundary_weight = 0.1
def get_model_stats(self) -> Dict[str, float]:
"""
Get model statistics for monitoring
Returns:
Dictionary with various model statistics
"""
stats = {}
# Encoder stats (GPT suggestion: already prefixed)
encoder_stats = self.encoder.get_monitoring_stats()
stats.update({f'encoder_{k}': v for k, v in encoder_stats.items()})
# Decoder memory stats
decoder_memory = self.decoder.get_memory_usage()
stats.update({f'decoder_{k}': v for k, v in decoder_memory.items()})
# Loss stats (if available) - check for tensor items
if hasattr(self, 'last_losses'):
for k, v in self.last_losses.items():
if isinstance(v, torch.Tensor):
stats[f'loss_{k}'] = v.item() if v.numel() == 1 else v.mean().item()
else:
stats[f'loss_{k}'] = float(v)
# Training info
stats['current_epoch'] = self.current_epoch.item()
stats['training_step'] = self.training_step.item()
return stats
def save_checkpoint(self, path: str):
"""
Save model checkpoint
Args:
path: Path to save checkpoint
"""
checkpoint = {
'model_state_dict': self.state_dict(),
'config': self.config,
'epoch': self.current_epoch.item(),
'step': self.training_step.item(),
'stats': self.get_model_stats()
}
torch.save(checkpoint, path)
print(f"Checkpoint saved to {path}")
@classmethod
def from_checkpoint(cls, path: str, device: str = 'cuda'):
"""
Load model from checkpoint
Args:
path: Path to checkpoint
device: Device to load model on
Returns:
Loaded model instance
"""
checkpoint = torch.load(path, map_location=device)
# Create model with saved config
model = cls(checkpoint.get('config', {}))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
# Restore training state
if 'epoch' in checkpoint:
model.current_epoch = torch.tensor(checkpoint['epoch'])
if 'step' in checkpoint:
model.training_step = torch.tensor(checkpoint['step'])
print(f"Model loaded from {path} (Epoch {checkpoint.get('epoch', 0)})")
return model
if __name__ == "__main__":
# Test unified model
print("Testing Intelligent Tokenizer v6.2.0")
# Create model
model = IntelligentTokenizerV62()
print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
# Test texts
test_texts = [
"Hello, world!",
"μ•ˆλ…•ν•˜μ„Έμš”, λ§Œλ‚˜μ„œ λ°˜κ°‘μŠ΅λ‹ˆλ‹€. 였늘 날씨가 μ’‹λ„€μš”!",
"δ»Šε€©ε€©ζ°”εΎˆε₯½γ€‚",
]
for text in test_texts:
print(f"\nInput: {text}")
# Compress
compression = model.compress(text)
print(f" Compression ratio: {compression['compression_ratio']:.1f}:1")
print(f" Tokens: {compression['num_tokens']}")
# Generate (reconstruct)
reconstructed = model.generate(text, temperature=0.1)
print(f" Reconstructed: {reconstructed}")
# Get model stats
stats = model.get_model_stats()
print(f"\nModel Statistics:")
for key, value in stats.items():
if isinstance(value, float):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")