intelligent-tokenizer-v6-demo / core /intelligent_loss.py
ggunio's picture
Upload folder using huggingface_hub
ff85374 verified
"""
Intelligent Loss Functions for v6.2.0
Multi-objective loss with GPT-5 suggested improvements
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
import math
class IntelligentLoss(nn.Module):
"""
Comprehensive loss function for progressive splitting tokenizer
Combines multiple objectives with dynamic weighting
"""
def __init__(self, config: Optional[Dict] = None):
super().__init__()
# Default configuration
self.config = config or {}
# Special tokens (must match tokenizer)
self.PAD = 256
self.BOS = 257
self.EOS = 258
self.MASK = 259
# Loss components
self.reconstruction_loss = ReconstructionLoss(self.PAD)
self.compression_loss = CompressionLoss()
self.boundary_loss = BoundaryLoss()
self.language_loss = LanguageLoss()
self.consistency_loss = ConsistencyLoss()
# Dynamic weight adjustment
self.use_dynamic_weights = True
self.weight_history = {
'reconstruction': [],
'compression': [],
'boundary': [],
'language': [],
'consistency': []
}
def estimate_language_difficulty(self, targets: Dict) -> float:
"""Estimate language difficulty based on input characteristics"""
if 'input_ids' not in targets:
return 1.0
input_ids = targets['input_ids']
if input_ids.numel() == 0:
return 1.0
# Higher entropy = more complex language
unique_tokens = input_ids.unique().numel()
total_tokens = input_ids.numel()
diversity = min(1.0, (unique_tokens / total_tokens) * 2)
return diversity
def forward(self,
outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
"""
Compute combined loss with all objectives
Args:
outputs: Model outputs dictionary
targets: Target values dictionary
weights: Optional weight overrides
Returns:
Dictionary with total loss and individual components
"""
losses = {}
# 1. Reconstruction loss (primary objective)
if 'logits' in outputs and 'input_ids' in targets:
losses['reconstruction'] = self.reconstruction_loss(
outputs['logits'],
targets['input_ids'],
targets.get('attention_mask')
)
# 2. Compression loss (encourage optimal compression)
if 'compression_ratio' in outputs:
losses['compression'] = self.compression_loss(
outputs['compression_ratio'],
outputs.get('num_tokens')
)
# 3. Boundary loss (learn meaningful boundaries)
if 'boundaries' in outputs and 'boundary_targets' in targets:
losses['boundary'] = self.boundary_loss(
outputs['boundaries'],
targets['boundary_targets'],
targets.get('boundary_mask')
)
# 4. Language loss (language identification/clustering)
if 'language_clusters' in outputs and 'language_targets' in targets:
losses['language'] = self.language_loss(
outputs['language_clusters'],
targets['language_targets']
)
# 5. Consistency loss (encoder-decoder consistency)
if 'encoder_hidden' in outputs and 'decoder_hidden' in outputs:
losses['consistency'] = self.consistency_loss(
outputs['encoder_hidden'],
outputs['decoder_hidden']
)
# Apply weights (either provided or dynamic)
if weights is None and self.use_dynamic_weights:
weights = self.compute_dynamic_weights(losses)
elif weights is None:
weights = {
'reconstruction': 1.0,
'compression': 1.0,
'boundary': 1.0,
'language': 0.5,
'consistency': 0.5
}
# Weighted sum
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device)
for key, loss in losses.items():
weight = weights.get(key, 1.0)
total_loss = total_loss + weight * loss
losses[f'{key}_weighted'] = weight * loss
losses['total'] = total_loss
# Update weight history
for key in self.weight_history:
if key in losses:
self.weight_history[key].append(losses[key].item())
return losses
def compute_dynamic_weights(self, losses: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""
Dynamically adjust weights based on loss magnitudes and progress
GPT-5 suggestion: balance loss magnitudes for stable training
"""
weights = {}
eps = 1e-8 # GPT fix: prevent division by zero
# Get loss magnitudes with NaN protection
magnitudes = {}
for k, v in losses.items():
if torch.isnan(v) or torch.isinf(v):
magnitudes[k] = 1.0 # Default safe value
else:
magnitudes[k] = v.item()
# Compute relative scales (GPT fix: add epsilon)
avg_magnitude = max(eps, sum(magnitudes.values()) / len(magnitudes))
for key, magnitude in magnitudes.items():
# Inverse scaling to balance magnitudes (GPT fix: add epsilon)
weights[key] = avg_magnitude / max(eps, magnitude)
# Dynamic adjustment based on loss ratios
if 'reconstruction' in magnitudes and 'compression' in magnitudes:
recon_loss = magnitudes['reconstruction']
comp_loss = magnitudes['compression']
# If reconstruction loss is too high relative to compression
if recon_loss > comp_loss * 10:
# Drastically reduce compression pressure
weights['compression'] *= 0.1
weights['reconstruction'] *= 5.0
elif recon_loss > comp_loss * 5:
# Moderate adjustment
weights['compression'] *= 0.5
weights['reconstruction'] *= 2.0
elif recon_loss < comp_loss * 0.5:
# Good reconstruction, can push compression
weights['compression'] *= 2.0
weights['reconstruction'] *= 0.5
# Normalize weights to prevent explosion
total_weight = sum(weights.values())
if total_weight > 0:
weights = {k: min(10.0, v / total_weight * len(weights)) for k, v in weights.items()}
return weights
class ReconstructionLoss(nn.Module):
"""
Cross-entropy loss for sequence reconstruction
With label smoothing and focal loss options
"""
def __init__(self, pad_token: int = 256, label_smoothing: float = 0.1):
super().__init__()
self.pad_token = pad_token
self.label_smoothing = label_smoothing
self.focal_alpha = 0.25
self.focal_gamma = 2.0
self.use_focal = False
def forward(self,
logits: torch.Tensor,
targets: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute reconstruction loss
Args:
logits: [batch, seq_len, vocab_size]
targets: [batch, seq_len]
mask: [batch, seq_len] attention mask
"""
batch_size, seq_len, vocab_size = logits.shape
# Reshape for loss computation
logits_flat = logits.reshape(-1, vocab_size)
targets_flat = targets.reshape(-1)
if self.use_focal:
# Focal loss for hard examples
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss
if mask is not None:
mask_flat = mask.reshape(-1)
focal_loss = focal_loss * mask_flat
loss = focal_loss.sum() / mask_flat.sum()
else:
loss = focal_loss.mean()
else:
# Standard cross-entropy with label smoothing
if mask is not None:
mask_flat = mask.reshape(-1).bool() # GPT fix: ensure bool dtype
loss = F.cross_entropy(
logits_flat[mask_flat],
targets_flat[mask_flat],
ignore_index=self.pad_token,
label_smoothing=self.label_smoothing
)
else:
loss = F.cross_entropy(
logits_flat,
targets_flat,
ignore_index=self.pad_token,
label_smoothing=self.label_smoothing
)
return loss
class CompressionLoss(nn.Module):
"""
Aggressive compression loss - push for high compression
Must beat existing tokenizers (4 bytes/token = 4:1)
"""
def __init__(self):
super().__init__()
# Dynamic compression based on token count
# 1 token = 48:1, 2 = 24:1, 3 = 16:1, 4 = 12:1
self.min_ratio = 12.0 # 4 tokens (worst case, still 3x better than BPE)
self.target_ratio = 24.0 # 2 tokens (optimal balance)
self.max_ratio = 48.0 # 1 token (best compression)
def forward(self,
compression_ratio: torch.Tensor,
num_tokens: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute compression loss (GPT fix: fully vectorized)
Args:
compression_ratio: Current compression ratio (scalar or batch)
num_tokens: Number of tokens used (for additional penalty)
"""
# Ensure tensor (GPT fix: handle device properly)
if not torch.is_tensor(compression_ratio):
device = num_tokens.device if torch.is_tensor(num_tokens) else torch.device('cpu')
compression_ratio = torch.tensor(compression_ratio, dtype=torch.float32, device=device)
# Aggressive compression enforcement
# MUST achieve at least 16:1 to be viable
if compression_ratio < self.min_ratio:
# Moderate penalty for falling below minimum (reduced for stability)
under_loss = ((self.min_ratio - compression_ratio) / self.min_ratio) * 0.5
else:
under_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
# Reward getting close to target (24:1)
if self.min_ratio <= compression_ratio < self.target_ratio:
# Encourage reaching target
target_loss = ((self.target_ratio - compression_ratio) / self.target_ratio) * 0.5
elif compression_ratio >= self.target_ratio:
# Excellent compression - small reward for going higher
target_loss = -0.1 * torch.log(compression_ratio / self.target_ratio + 1.0)
else:
target_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
# Only mild penalty for extreme compression (>48:1)
if compression_ratio > self.max_ratio:
over_loss = ((compression_ratio - self.max_ratio) / self.max_ratio) * 0.2
else:
over_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
loss = under_loss + target_loss + over_loss
# Additional penalty based on token count (GPT fix: vectorized)
if num_tokens is not None:
if not torch.is_tensor(num_tokens):
num_tokens = torch.tensor(num_tokens, dtype=torch.float32, device=compression_ratio.device)
token_penalty = 0.1 * torch.clamp(num_tokens - 8, min=0.0) ** 2
loss = loss + token_penalty
return loss.mean() if loss.dim() > 0 else loss
class BoundaryLoss(nn.Module):
"""
Learn meaningful chunk boundaries
Combines multiple boundary objectives
"""
def __init__(self):
super().__init__()
self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
def forward(self,
predicted: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute boundary loss
Args:
predicted: [batch, seq_len, boundary_classes] predicted boundaries
target: [batch, seq_len, boundary_classes] target boundaries
mask: [batch, seq_len] valid positions mask
"""
# Binary cross-entropy for boundary prediction
loss = self.bce_loss(predicted, target.float())
if mask is not None:
# Apply mask
mask_expanded = mask.unsqueeze(-1).expand_as(loss)
loss = loss * mask_expanded
loss = loss.sum() / mask_expanded.sum()
else:
loss = loss.mean()
# Add regularization for boundary sparsity
# (boundaries should be relatively rare)
boundary_probs = torch.sigmoid(predicted)
sparsity_loss = 0.01 * boundary_probs.mean()
# Add smoothness regularization
# (boundaries should be somewhat smooth/continuous)
if predicted.size(1) > 1:
diff = predicted[:, 1:] - predicted[:, :-1]
smoothness_loss = 0.01 * (diff ** 2).mean()
else:
smoothness_loss = 0.0
total_loss = loss + sparsity_loss + smoothness_loss
return total_loss
class LanguageLoss(nn.Module):
"""
Language identification/clustering loss
Supports both classification and clustering objectives
"""
def __init__(self, num_languages: int = 128, temperature: float = 0.07):
super().__init__()
self.num_languages = num_languages
self.temperature = temperature
# For supervised language classification
self.ce_loss = nn.CrossEntropyLoss()
def forward(self,
predicted: torch.Tensor,
target: torch.Tensor,
mode: str = 'classification') -> torch.Tensor:
"""
Compute language loss
Args:
predicted: [batch, seq_len, num_languages] or [batch, num_languages]
target: Language labels or cluster assignments
mode: 'classification' or 'clustering'
"""
if mode == 'classification':
# Standard classification loss
if predicted.dim() == 3:
# Sequence-level predictions
batch_size, seq_len, _ = predicted.shape
predicted = predicted.reshape(-1, self.num_languages)
target = target.reshape(-1)
loss = self.ce_loss(predicted, target)
elif mode == 'clustering':
# Contrastive clustering loss (similar to SimCLR)
# Normalize embeddings
predicted = F.normalize(predicted, dim=-1)
# Compute similarity matrix
sim_matrix = torch.matmul(predicted, predicted.t()) / self.temperature
# Create labels (assuming batch contains similar samples)
batch_size = predicted.size(0)
labels = torch.arange(batch_size, device=predicted.device)
# Contrastive loss
loss = F.cross_entropy(sim_matrix, labels)
else:
raise ValueError(f"Unknown mode: {mode}")
return loss
class ConsistencyLoss(nn.Module):
"""
Ensure consistency between encoder and decoder representations
GPT-5 suggestion: helps with training stability
"""
def __init__(self, margin: float = 0.5):
super().__init__()
self.margin = margin
def forward(self,
encoder_hidden: torch.Tensor,
decoder_hidden: torch.Tensor) -> torch.Tensor:
"""
Compute consistency loss between encoder and decoder
Args:
encoder_hidden: [batch, seq_len, hidden_dim]
decoder_hidden: [batch, seq_len, hidden_dim]
"""
# Ensure same shape
if encoder_hidden.shape != decoder_hidden.shape:
# Align sequence lengths if different
min_len = min(encoder_hidden.size(1), decoder_hidden.size(1))
encoder_hidden = encoder_hidden[:, :min_len]
decoder_hidden = decoder_hidden[:, :min_len]
# L2 distance
l2_loss = F.mse_loss(encoder_hidden, decoder_hidden)
# Cosine similarity loss
encoder_norm = F.normalize(encoder_hidden, dim=-1)
decoder_norm = F.normalize(decoder_hidden, dim=-1)
cosine_sim = (encoder_norm * decoder_norm).sum(dim=-1)
cosine_loss = 1.0 - cosine_sim.mean()
# Combined loss
loss = l2_loss + 0.5 * cosine_loss
return loss
class AdaptiveLossScheduler:
"""
Dynamically adjust loss weights during training
Based on training progress and performance
"""
def __init__(self, config: Dict):
self.config = config
self.current_phase = 0
self.phase_epochs = [30, 60, 100] # Phase transition points
# Define phase-specific weights
self.phase_weights = [
# Phase 1: Boundary mastery
{
'reconstruction': 2.0,
'compression': 0.5,
'boundary': 3.0,
'language': 0.5,
'consistency': 0.5
},
# Phase 2: Compression focus
{
'reconstruction': 2.0,
'compression': 3.0,
'boundary': 1.0,
'language': 1.0,
'consistency': 1.0
},
# Phase 3: Balanced optimization
{
'reconstruction': 3.0,
'compression': 2.0,
'boundary': 1.0,
'language': 1.0,
'consistency': 1.5
}
]
def get_weights(self, epoch: int, metrics: Optional[Dict] = None) -> Dict[str, float]:
"""
Get current loss weights based on training phase
Args:
epoch: Current training epoch
metrics: Optional performance metrics for adaptive adjustment
"""
# Determine current phase
for i, phase_end in enumerate(self.phase_epochs):
if epoch <= phase_end:
self.current_phase = i
break
weights = self.phase_weights[self.current_phase].copy()
# Adaptive adjustments based on metrics
if metrics:
# If reconstruction is poor, increase its weight
if metrics.get('reconstruction_accuracy', 1.0) < 0.9:
weights['reconstruction'] *= 1.5
# If compression is off target, adjust weight
compression_ratio = metrics.get('compression_ratio', 16.0)
if compression_ratio < 8.0 or compression_ratio > 20.0:
weights['compression'] *= 1.5
return weights
if __name__ == "__main__":
# Test losses
print("Testing Intelligent Loss Functions")
# Create loss module
loss_fn = IntelligentLoss()
# Create dummy data
batch_size = 2
seq_len = 48
vocab_size = 260
hidden_dim = 1280
outputs = {
'logits': torch.randn(batch_size, seq_len, vocab_size),
'compression_ratio': torch.tensor(16.0),
'num_tokens': torch.tensor(3),
'boundaries': torch.randn(batch_size, seq_len, 4),
'language_clusters': torch.randn(batch_size, 128),
'encoder_hidden': torch.randn(batch_size, seq_len, hidden_dim),
'decoder_hidden': torch.randn(batch_size, seq_len, hidden_dim)
}
targets = {
'input_ids': torch.randint(0, 256, (batch_size, seq_len)),
'attention_mask': torch.ones(batch_size, seq_len),
'boundary_targets': torch.zeros(batch_size, seq_len, 4),
'language_targets': torch.randint(0, 128, (batch_size,))
}
# Compute losses
losses = loss_fn(outputs, targets)
print("\nLoss components:")
for key, value in losses.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.item():.4f}")
# Test adaptive scheduler
scheduler = AdaptiveLossScheduler({})
print("\nPhase weights:")
for epoch in [10, 40, 70]:
weights = scheduler.get_weights(epoch)
print(f" Epoch {epoch}: {weights}")