|
|
"""
|
|
|
Intelligent Tokenizer v6.2.0 - Unified Model
|
|
|
Integrates encoder, decoder, and tokenizer with all GPT improvements
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
import math
|
|
|
|
|
|
|
|
|
try:
|
|
|
from .encoder import EncoderV62
|
|
|
from .decoder import DecoderV62
|
|
|
from .tokenizer import ByteTokenizerV62
|
|
|
except ImportError:
|
|
|
|
|
|
from encoder import EncoderV62
|
|
|
from decoder import DecoderV62
|
|
|
from tokenizer import ByteTokenizerV62
|
|
|
|
|
|
|
|
|
class IntelligentTokenizerV62(nn.Module):
|
|
|
"""
|
|
|
Complete v6.2.0 model with progressive splitting and optimizations
|
|
|
|
|
|
Key features:
|
|
|
- 48-byte chunks (46+2 with BOS/EOS)
|
|
|
- Progressive splitting: 48β1βNβM tokens
|
|
|
- Multi-level cross-attention
|
|
|
- KV cache optimization (8x reduction)
|
|
|
- All GPT-5 improvements integrated
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.config = config or {}
|
|
|
|
|
|
|
|
|
self.tokenizer = ByteTokenizerV62(config)
|
|
|
self.encoder = EncoderV62(config)
|
|
|
self.decoder = DecoderV62(config)
|
|
|
|
|
|
|
|
|
self.compression_weight = 0.1
|
|
|
self.reconstruction_weight = 0.1
|
|
|
self.boundary_weight = 0.1
|
|
|
|
|
|
|
|
|
self.register_buffer('training_step', torch.tensor(0))
|
|
|
self.register_buffer('current_epoch', torch.tensor(0))
|
|
|
|
|
|
def forward(self,
|
|
|
input_ids: torch.Tensor = None,
|
|
|
attention_mask: torch.Tensor = None,
|
|
|
labels: torch.Tensor = None,
|
|
|
text: str = None,
|
|
|
return_loss: bool = True,
|
|
|
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Unified forward pass
|
|
|
|
|
|
Args:
|
|
|
input_ids: Pre-tokenized input (optional)
|
|
|
attention_mask: Attention mask (optional)
|
|
|
labels: Target labels for training (optional)
|
|
|
text: Raw text input (alternative to input_ids)
|
|
|
return_loss: Whether to compute loss
|
|
|
temperature: Temperature for Gumbel-Softmax in encoder
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with model outputs
|
|
|
"""
|
|
|
|
|
|
if text is not None:
|
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
|
|
|
|
if isinstance(input_ids, str):
|
|
|
text = input_ids
|
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
|
if input_ids is not None and torch.is_tensor(input_ids):
|
|
|
input_ids = input_ids.to(device)
|
|
|
if attention_mask is not None and torch.is_tensor(attention_mask):
|
|
|
attention_mask = attention_mask.to(device)
|
|
|
if labels is not None and torch.is_tensor(labels):
|
|
|
labels = labels.to(device)
|
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
temperature=temperature
|
|
|
)
|
|
|
|
|
|
|
|
|
if labels is not None:
|
|
|
|
|
|
|
|
|
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
|
|
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
|
|
|
decoder_outputs = self.decoder(
|
|
|
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
|
|
decoder_input_ids=decoder_input,
|
|
|
attention_mask=decoder_mask
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
if return_loss and input_ids is not None:
|
|
|
labels = input_ids
|
|
|
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
|
|
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
|
|
|
decoder_outputs = self.decoder(
|
|
|
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
|
|
decoder_input_ids=decoder_input,
|
|
|
attention_mask=decoder_mask
|
|
|
)
|
|
|
else:
|
|
|
decoder_outputs = self.decoder(
|
|
|
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
|
|
decoder_input_ids=None,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
outputs = {}
|
|
|
for key, value in encoder_outputs.items():
|
|
|
outputs[f'enc_{key}'] = value
|
|
|
for key, value in decoder_outputs.items():
|
|
|
outputs[f'dec_{key}'] = value
|
|
|
|
|
|
|
|
|
if return_loss and labels is not None:
|
|
|
loss = self.compute_loss(outputs, labels, attention_mask)
|
|
|
outputs['loss'] = loss
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def compute_loss(self,
|
|
|
outputs: Dict[str, torch.Tensor],
|
|
|
labels: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute combined loss with multiple objectives
|
|
|
|
|
|
Components:
|
|
|
1. Reconstruction loss (cross-entropy)
|
|
|
2. Compression loss (encourage higher compression)
|
|
|
3. Boundary loss (boundary prediction accuracy)
|
|
|
"""
|
|
|
losses = {}
|
|
|
|
|
|
|
|
|
if 'dec_logits' in outputs:
|
|
|
logits = outputs['dec_logits']
|
|
|
|
|
|
|
|
|
target_labels = labels[:, 1:] if labels.dim() > 1 else labels[1:]
|
|
|
target_mask = attention_mask[:, 1:] if attention_mask is not None and attention_mask.dim() > 1 else None
|
|
|
|
|
|
|
|
|
batch_size, seq_len, vocab_size = logits.shape
|
|
|
logits_flat = logits.reshape(-1, vocab_size)
|
|
|
labels_flat = target_labels.reshape(-1)
|
|
|
|
|
|
|
|
|
if target_mask is not None:
|
|
|
mask_flat = target_mask.reshape(-1).bool()
|
|
|
reconstruction_loss = F.cross_entropy(
|
|
|
logits_flat[mask_flat],
|
|
|
labels_flat[mask_flat],
|
|
|
ignore_index=self.tokenizer.PAD,
|
|
|
label_smoothing=0.1
|
|
|
)
|
|
|
else:
|
|
|
reconstruction_loss = F.cross_entropy(
|
|
|
logits_flat,
|
|
|
labels_flat,
|
|
|
ignore_index=self.tokenizer.PAD,
|
|
|
label_smoothing=0.1
|
|
|
)
|
|
|
|
|
|
losses['reconstruction'] = reconstruction_loss * self.reconstruction_weight
|
|
|
|
|
|
|
|
|
if 'enc_compression_ratio' in outputs:
|
|
|
|
|
|
target_ratio = 24.0
|
|
|
current_ratio = outputs['enc_compression_ratio']
|
|
|
|
|
|
|
|
|
if isinstance(current_ratio, (int, float)):
|
|
|
current_ratio_tensor = labels.new_tensor(current_ratio, dtype=torch.float32)
|
|
|
else:
|
|
|
current_ratio_tensor = current_ratio.float()
|
|
|
target_ratio_tensor = labels.new_tensor(target_ratio, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
compression_loss = F.smooth_l1_loss(
|
|
|
current_ratio_tensor,
|
|
|
target_ratio_tensor,
|
|
|
beta=2.0
|
|
|
)
|
|
|
|
|
|
losses['compression'] = compression_loss * self.compression_weight
|
|
|
|
|
|
|
|
|
if 'enc_boundaries' in outputs and outputs['enc_boundaries'] is not None:
|
|
|
boundary_scores = outputs['enc_boundaries']
|
|
|
|
|
|
|
|
|
|
|
|
boundary_probs = torch.sigmoid(boundary_scores)
|
|
|
|
|
|
|
|
|
sparsity_loss = boundary_probs.mean() * 0.1
|
|
|
|
|
|
|
|
|
if boundary_scores.size(1) > 1:
|
|
|
diff = boundary_scores[:, 1:] - boundary_scores[:, :-1]
|
|
|
smoothness_loss = (diff ** 2).mean() * 0.01
|
|
|
else:
|
|
|
smoothness_loss = 0.0
|
|
|
|
|
|
boundary_loss = sparsity_loss + smoothness_loss
|
|
|
|
|
|
losses['boundary'] = boundary_loss * self.boundary_weight
|
|
|
|
|
|
|
|
|
total_loss = sum(losses.values())
|
|
|
|
|
|
|
|
|
self.last_losses = losses
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
def generate(self,
|
|
|
text: str = None,
|
|
|
input_ids: torch.Tensor = None,
|
|
|
max_length: int = 256,
|
|
|
temperature: float = 0.1,
|
|
|
top_k: int = 10,
|
|
|
top_p: float = 0.95) -> str:
|
|
|
"""
|
|
|
Generate/reconstruct text
|
|
|
|
|
|
Args:
|
|
|
text: Input text to encode and reconstruct
|
|
|
input_ids: Pre-encoded input
|
|
|
max_length: Maximum generation length
|
|
|
temperature: Sampling temperature
|
|
|
top_k: Top-k sampling
|
|
|
top_p: Top-p (nucleus) sampling
|
|
|
|
|
|
Returns:
|
|
|
Reconstructed/generated text
|
|
|
"""
|
|
|
|
|
|
chunk_positions = None
|
|
|
if text is not None:
|
|
|
|
|
|
if len(text.encode('utf-8')) > self.tokenizer.content_size:
|
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True, return_chunks=True)
|
|
|
chunk_positions = encoded.get('chunk_positions', None)
|
|
|
else:
|
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
|
|
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
else:
|
|
|
attention_mask = (input_ids != self.tokenizer.PAD).bool()
|
|
|
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
|
input_ids = input_ids.to(device)
|
|
|
attention_mask = attention_mask.to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
encoder_outputs = self.encoder(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'all_hidden_states' in encoder_outputs:
|
|
|
encoder_all_hidden = encoder_outputs['all_hidden_states']
|
|
|
else:
|
|
|
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
|
|
|
encoder_all_hidden = [compressed] * 4
|
|
|
|
|
|
|
|
|
batch_size = input_ids.size(0)
|
|
|
|
|
|
|
|
|
generated_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device)
|
|
|
|
|
|
for step in range(max_length - 1):
|
|
|
with torch.no_grad():
|
|
|
|
|
|
decoder_outputs = self.decoder(
|
|
|
encoder_all_hidden=encoder_all_hidden,
|
|
|
decoder_input_ids=generated_ids,
|
|
|
attention_mask=torch.ones_like(generated_ids),
|
|
|
use_cache=False
|
|
|
)
|
|
|
|
|
|
|
|
|
logits = decoder_outputs['logits'][:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
if top_k > 0:
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
|
|
|
|
|
|
if (next_token == self.tokenizer.EOS).all():
|
|
|
break
|
|
|
|
|
|
|
|
|
if generated_ids.dim() > 2 and chunk_positions is not None:
|
|
|
|
|
|
text = self.tokenizer.reconstruct(
|
|
|
generated_ids,
|
|
|
positions=chunk_positions,
|
|
|
overlap=self.tokenizer.chunk_overlap
|
|
|
)
|
|
|
elif generated_ids.dim() > 2:
|
|
|
|
|
|
text = self.tokenizer.reconstruct(generated_ids)
|
|
|
else:
|
|
|
|
|
|
text = self.tokenizer.decode(generated_ids[0] if generated_ids.dim() > 1 else generated_ids)
|
|
|
|
|
|
return text
|
|
|
|
|
|
def compress(self, text: str) -> Dict[str, Union[torch.Tensor, float]]:
|
|
|
"""
|
|
|
Compress text and return compression statistics
|
|
|
|
|
|
Args:
|
|
|
text: Input text to compress
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with compressed representation and statistics
|
|
|
"""
|
|
|
|
|
|
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
|
|
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
|
|
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
|
|
|
|
device = next(self.parameters()).device
|
|
|
input_ids = input_ids.to(device)
|
|
|
attention_mask = attention_mask.to(device)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
encoder_outputs = self.encoder(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
'compressed': encoder_outputs['compressed'],
|
|
|
'num_tokens': encoder_outputs['num_tokens'],
|
|
|
'compression_ratio': encoder_outputs['compression_ratio'],
|
|
|
'original_bytes': len(text.encode('utf-8')),
|
|
|
'compressed_size': encoder_outputs['num_tokens'] * 2
|
|
|
}
|
|
|
|
|
|
def update_training_state(self, epoch: int, step: int = 0, reconstruction_loss: float = None):
|
|
|
"""
|
|
|
Update training state - adaptive, not phase-based
|
|
|
|
|
|
Args:
|
|
|
epoch: Current epoch
|
|
|
step: Current training step
|
|
|
reconstruction_loss: Current reconstruction quality
|
|
|
"""
|
|
|
self.current_epoch = torch.tensor(epoch)
|
|
|
self.training_step = torch.tensor(step)
|
|
|
|
|
|
|
|
|
self.encoder.set_warmup_step(step)
|
|
|
|
|
|
|
|
|
if reconstruction_loss is not None:
|
|
|
|
|
|
if reconstruction_loss > 1.0:
|
|
|
self.reconstruction_weight = 1.0
|
|
|
self.compression_weight = 0.1
|
|
|
else:
|
|
|
|
|
|
self.reconstruction_weight = 0.5
|
|
|
self.compression_weight = 0.1
|
|
|
|
|
|
|
|
|
self.boundary_weight = 0.1
|
|
|
|
|
|
|
|
|
self.encoder.adaptive_compression_control(reconstruction_loss)
|
|
|
else:
|
|
|
|
|
|
self.reconstruction_weight = 0.5
|
|
|
self.compression_weight = 0.1
|
|
|
self.boundary_weight = 0.1
|
|
|
|
|
|
def get_model_stats(self) -> Dict[str, float]:
|
|
|
"""
|
|
|
Get model statistics for monitoring
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with various model statistics
|
|
|
"""
|
|
|
stats = {}
|
|
|
|
|
|
|
|
|
encoder_stats = self.encoder.get_monitoring_stats()
|
|
|
stats.update({f'encoder_{k}': v for k, v in encoder_stats.items()})
|
|
|
|
|
|
|
|
|
decoder_memory = self.decoder.get_memory_usage()
|
|
|
stats.update({f'decoder_{k}': v for k, v in decoder_memory.items()})
|
|
|
|
|
|
|
|
|
if hasattr(self, 'last_losses'):
|
|
|
for k, v in self.last_losses.items():
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
stats[f'loss_{k}'] = v.item() if v.numel() == 1 else v.mean().item()
|
|
|
else:
|
|
|
stats[f'loss_{k}'] = float(v)
|
|
|
|
|
|
|
|
|
stats['current_epoch'] = self.current_epoch.item()
|
|
|
stats['training_step'] = self.training_step.item()
|
|
|
|
|
|
return stats
|
|
|
|
|
|
def save_checkpoint(self, path: str):
|
|
|
"""
|
|
|
Save model checkpoint
|
|
|
|
|
|
Args:
|
|
|
path: Path to save checkpoint
|
|
|
"""
|
|
|
checkpoint = {
|
|
|
'model_state_dict': self.state_dict(),
|
|
|
'config': self.config,
|
|
|
'epoch': self.current_epoch.item(),
|
|
|
'step': self.training_step.item(),
|
|
|
'stats': self.get_model_stats()
|
|
|
}
|
|
|
torch.save(checkpoint, path)
|
|
|
print(f"Checkpoint saved to {path}")
|
|
|
|
|
|
@classmethod
|
|
|
def from_checkpoint(cls, path: str, device: str = 'cuda'):
|
|
|
"""
|
|
|
Load model from checkpoint
|
|
|
|
|
|
Args:
|
|
|
path: Path to checkpoint
|
|
|
device: Device to load model on
|
|
|
|
|
|
Returns:
|
|
|
Loaded model instance
|
|
|
"""
|
|
|
checkpoint = torch.load(path, map_location=device)
|
|
|
|
|
|
|
|
|
model = cls(checkpoint.get('config', {}))
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
if 'epoch' in checkpoint:
|
|
|
model.current_epoch = torch.tensor(checkpoint['epoch'])
|
|
|
if 'step' in checkpoint:
|
|
|
model.training_step = torch.tensor(checkpoint['step'])
|
|
|
|
|
|
print(f"Model loaded from {path} (Epoch {checkpoint.get('epoch', 0)})")
|
|
|
return model
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("Testing Intelligent Tokenizer v6.2.0")
|
|
|
|
|
|
|
|
|
model = IntelligentTokenizerV62()
|
|
|
print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
|
|
|
|
|
|
|
|
|
test_texts = [
|
|
|
"Hello, world!",
|
|
|
"μλ
νμΈμ, λ§λμ λ°κ°μ΅λλ€. μ€λ λ μ¨κ° μ’λ€μ!",
|
|
|
"δ»ε€©ε€©ζ°εΎε₯½γ",
|
|
|
]
|
|
|
|
|
|
for text in test_texts:
|
|
|
print(f"\nInput: {text}")
|
|
|
|
|
|
|
|
|
compression = model.compress(text)
|
|
|
print(f" Compression ratio: {compression['compression_ratio']:.1f}:1")
|
|
|
print(f" Tokens: {compression['num_tokens']}")
|
|
|
|
|
|
|
|
|
reconstructed = model.generate(text, temperature=0.1)
|
|
|
print(f" Reconstructed: {reconstructed}")
|
|
|
|
|
|
|
|
|
stats = model.get_model_stats()
|
|
|
print(f"\nModel Statistics:")
|
|
|
for key, value in stats.items():
|
|
|
if isinstance(value, float):
|
|
|
print(f" {key}: {value:.4f}")
|
|
|
else:
|
|
|
print(f" {key}: {value}") |