""" Intelligent Loss Functions for v6.2.0 Multi-objective loss with GPT-5 suggested improvements """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Tuple import math class IntelligentLoss(nn.Module): """ Comprehensive loss function for progressive splitting tokenizer Combines multiple objectives with dynamic weighting """ def __init__(self, config: Optional[Dict] = None): super().__init__() # Default configuration self.config = config or {} # Special tokens (must match tokenizer) self.PAD = 256 self.BOS = 257 self.EOS = 258 self.MASK = 259 # Loss components self.reconstruction_loss = ReconstructionLoss(self.PAD) self.compression_loss = CompressionLoss() self.boundary_loss = BoundaryLoss() self.language_loss = LanguageLoss() self.consistency_loss = ConsistencyLoss() # Dynamic weight adjustment self.use_dynamic_weights = True self.weight_history = { 'reconstruction': [], 'compression': [], 'boundary': [], 'language': [], 'consistency': [] } def estimate_language_difficulty(self, targets: Dict) -> float: """Estimate language difficulty based on input characteristics""" if 'input_ids' not in targets: return 1.0 input_ids = targets['input_ids'] if input_ids.numel() == 0: return 1.0 # Higher entropy = more complex language unique_tokens = input_ids.unique().numel() total_tokens = input_ids.numel() diversity = min(1.0, (unique_tokens / total_tokens) * 2) return diversity def forward(self, outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]: """ Compute combined loss with all objectives Args: outputs: Model outputs dictionary targets: Target values dictionary weights: Optional weight overrides Returns: Dictionary with total loss and individual components """ losses = {} # 1. Reconstruction loss (primary objective) if 'logits' in outputs and 'input_ids' in targets: losses['reconstruction'] = self.reconstruction_loss( outputs['logits'], targets['input_ids'], targets.get('attention_mask') ) # 2. Compression loss (encourage optimal compression) if 'compression_ratio' in outputs: losses['compression'] = self.compression_loss( outputs['compression_ratio'], outputs.get('num_tokens') ) # 3. Boundary loss (learn meaningful boundaries) if 'boundaries' in outputs and 'boundary_targets' in targets: losses['boundary'] = self.boundary_loss( outputs['boundaries'], targets['boundary_targets'], targets.get('boundary_mask') ) # 4. Language loss (language identification/clustering) if 'language_clusters' in outputs and 'language_targets' in targets: losses['language'] = self.language_loss( outputs['language_clusters'], targets['language_targets'] ) # 5. Consistency loss (encoder-decoder consistency) if 'encoder_hidden' in outputs and 'decoder_hidden' in outputs: losses['consistency'] = self.consistency_loss( outputs['encoder_hidden'], outputs['decoder_hidden'] ) # Apply weights (either provided or dynamic) if weights is None and self.use_dynamic_weights: weights = self.compute_dynamic_weights(losses) elif weights is None: weights = { 'reconstruction': 1.0, 'compression': 1.0, 'boundary': 1.0, 'language': 0.5, 'consistency': 0.5 } # Weighted sum total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device) for key, loss in losses.items(): weight = weights.get(key, 1.0) total_loss = total_loss + weight * loss losses[f'{key}_weighted'] = weight * loss losses['total'] = total_loss # Update weight history for key in self.weight_history: if key in losses: self.weight_history[key].append(losses[key].item()) return losses def compute_dynamic_weights(self, losses: Dict[str, torch.Tensor]) -> Dict[str, float]: """ Dynamically adjust weights based on loss magnitudes and progress GPT-5 suggestion: balance loss magnitudes for stable training """ weights = {} eps = 1e-8 # GPT fix: prevent division by zero # Get loss magnitudes with NaN protection magnitudes = {} for k, v in losses.items(): if torch.isnan(v) or torch.isinf(v): magnitudes[k] = 1.0 # Default safe value else: magnitudes[k] = v.item() # Compute relative scales (GPT fix: add epsilon) avg_magnitude = max(eps, sum(magnitudes.values()) / len(magnitudes)) for key, magnitude in magnitudes.items(): # Inverse scaling to balance magnitudes (GPT fix: add epsilon) weights[key] = avg_magnitude / max(eps, magnitude) # Dynamic adjustment based on loss ratios if 'reconstruction' in magnitudes and 'compression' in magnitudes: recon_loss = magnitudes['reconstruction'] comp_loss = magnitudes['compression'] # If reconstruction loss is too high relative to compression if recon_loss > comp_loss * 10: # Drastically reduce compression pressure weights['compression'] *= 0.1 weights['reconstruction'] *= 5.0 elif recon_loss > comp_loss * 5: # Moderate adjustment weights['compression'] *= 0.5 weights['reconstruction'] *= 2.0 elif recon_loss < comp_loss * 0.5: # Good reconstruction, can push compression weights['compression'] *= 2.0 weights['reconstruction'] *= 0.5 # Normalize weights to prevent explosion total_weight = sum(weights.values()) if total_weight > 0: weights = {k: min(10.0, v / total_weight * len(weights)) for k, v in weights.items()} return weights class ReconstructionLoss(nn.Module): """ Cross-entropy loss for sequence reconstruction With label smoothing and focal loss options """ def __init__(self, pad_token: int = 256, label_smoothing: float = 0.1): super().__init__() self.pad_token = pad_token self.label_smoothing = label_smoothing self.focal_alpha = 0.25 self.focal_gamma = 2.0 self.use_focal = False def forward(self, logits: torch.Tensor, targets: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute reconstruction loss Args: logits: [batch, seq_len, vocab_size] targets: [batch, seq_len] mask: [batch, seq_len] attention mask """ batch_size, seq_len, vocab_size = logits.shape # Reshape for loss computation logits_flat = logits.reshape(-1, vocab_size) targets_flat = targets.reshape(-1) if self.use_focal: # Focal loss for hard examples ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss if mask is not None: mask_flat = mask.reshape(-1) focal_loss = focal_loss * mask_flat loss = focal_loss.sum() / mask_flat.sum() else: loss = focal_loss.mean() else: # Standard cross-entropy with label smoothing if mask is not None: mask_flat = mask.reshape(-1).bool() # GPT fix: ensure bool dtype loss = F.cross_entropy( logits_flat[mask_flat], targets_flat[mask_flat], ignore_index=self.pad_token, label_smoothing=self.label_smoothing ) else: loss = F.cross_entropy( logits_flat, targets_flat, ignore_index=self.pad_token, label_smoothing=self.label_smoothing ) return loss class CompressionLoss(nn.Module): """ Aggressive compression loss - push for high compression Must beat existing tokenizers (4 bytes/token = 4:1) """ def __init__(self): super().__init__() # Dynamic compression based on token count # 1 token = 48:1, 2 = 24:1, 3 = 16:1, 4 = 12:1 self.min_ratio = 12.0 # 4 tokens (worst case, still 3x better than BPE) self.target_ratio = 24.0 # 2 tokens (optimal balance) self.max_ratio = 48.0 # 1 token (best compression) def forward(self, compression_ratio: torch.Tensor, num_tokens: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute compression loss (GPT fix: fully vectorized) Args: compression_ratio: Current compression ratio (scalar or batch) num_tokens: Number of tokens used (for additional penalty) """ # Ensure tensor (GPT fix: handle device properly) if not torch.is_tensor(compression_ratio): device = num_tokens.device if torch.is_tensor(num_tokens) else torch.device('cpu') compression_ratio = torch.tensor(compression_ratio, dtype=torch.float32, device=device) # Aggressive compression enforcement # MUST achieve at least 16:1 to be viable if compression_ratio < self.min_ratio: # Moderate penalty for falling below minimum (reduced for stability) under_loss = ((self.min_ratio - compression_ratio) / self.min_ratio) * 0.5 else: under_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device) # Reward getting close to target (24:1) if self.min_ratio <= compression_ratio < self.target_ratio: # Encourage reaching target target_loss = ((self.target_ratio - compression_ratio) / self.target_ratio) * 0.5 elif compression_ratio >= self.target_ratio: # Excellent compression - small reward for going higher target_loss = -0.1 * torch.log(compression_ratio / self.target_ratio + 1.0) else: target_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device) # Only mild penalty for extreme compression (>48:1) if compression_ratio > self.max_ratio: over_loss = ((compression_ratio - self.max_ratio) / self.max_ratio) * 0.2 else: over_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device) loss = under_loss + target_loss + over_loss # Additional penalty based on token count (GPT fix: vectorized) if num_tokens is not None: if not torch.is_tensor(num_tokens): num_tokens = torch.tensor(num_tokens, dtype=torch.float32, device=compression_ratio.device) token_penalty = 0.1 * torch.clamp(num_tokens - 8, min=0.0) ** 2 loss = loss + token_penalty return loss.mean() if loss.dim() > 0 else loss class BoundaryLoss(nn.Module): """ Learn meaningful chunk boundaries Combines multiple boundary objectives """ def __init__(self): super().__init__() self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') def forward(self, predicted: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute boundary loss Args: predicted: [batch, seq_len, boundary_classes] predicted boundaries target: [batch, seq_len, boundary_classes] target boundaries mask: [batch, seq_len] valid positions mask """ # Binary cross-entropy for boundary prediction loss = self.bce_loss(predicted, target.float()) if mask is not None: # Apply mask mask_expanded = mask.unsqueeze(-1).expand_as(loss) loss = loss * mask_expanded loss = loss.sum() / mask_expanded.sum() else: loss = loss.mean() # Add regularization for boundary sparsity # (boundaries should be relatively rare) boundary_probs = torch.sigmoid(predicted) sparsity_loss = 0.01 * boundary_probs.mean() # Add smoothness regularization # (boundaries should be somewhat smooth/continuous) if predicted.size(1) > 1: diff = predicted[:, 1:] - predicted[:, :-1] smoothness_loss = 0.01 * (diff ** 2).mean() else: smoothness_loss = 0.0 total_loss = loss + sparsity_loss + smoothness_loss return total_loss class LanguageLoss(nn.Module): """ Language identification/clustering loss Supports both classification and clustering objectives """ def __init__(self, num_languages: int = 128, temperature: float = 0.07): super().__init__() self.num_languages = num_languages self.temperature = temperature # For supervised language classification self.ce_loss = nn.CrossEntropyLoss() def forward(self, predicted: torch.Tensor, target: torch.Tensor, mode: str = 'classification') -> torch.Tensor: """ Compute language loss Args: predicted: [batch, seq_len, num_languages] or [batch, num_languages] target: Language labels or cluster assignments mode: 'classification' or 'clustering' """ if mode == 'classification': # Standard classification loss if predicted.dim() == 3: # Sequence-level predictions batch_size, seq_len, _ = predicted.shape predicted = predicted.reshape(-1, self.num_languages) target = target.reshape(-1) loss = self.ce_loss(predicted, target) elif mode == 'clustering': # Contrastive clustering loss (similar to SimCLR) # Normalize embeddings predicted = F.normalize(predicted, dim=-1) # Compute similarity matrix sim_matrix = torch.matmul(predicted, predicted.t()) / self.temperature # Create labels (assuming batch contains similar samples) batch_size = predicted.size(0) labels = torch.arange(batch_size, device=predicted.device) # Contrastive loss loss = F.cross_entropy(sim_matrix, labels) else: raise ValueError(f"Unknown mode: {mode}") return loss class ConsistencyLoss(nn.Module): """ Ensure consistency between encoder and decoder representations GPT-5 suggestion: helps with training stability """ def __init__(self, margin: float = 0.5): super().__init__() self.margin = margin def forward(self, encoder_hidden: torch.Tensor, decoder_hidden: torch.Tensor) -> torch.Tensor: """ Compute consistency loss between encoder and decoder Args: encoder_hidden: [batch, seq_len, hidden_dim] decoder_hidden: [batch, seq_len, hidden_dim] """ # Ensure same shape if encoder_hidden.shape != decoder_hidden.shape: # Align sequence lengths if different min_len = min(encoder_hidden.size(1), decoder_hidden.size(1)) encoder_hidden = encoder_hidden[:, :min_len] decoder_hidden = decoder_hidden[:, :min_len] # L2 distance l2_loss = F.mse_loss(encoder_hidden, decoder_hidden) # Cosine similarity loss encoder_norm = F.normalize(encoder_hidden, dim=-1) decoder_norm = F.normalize(decoder_hidden, dim=-1) cosine_sim = (encoder_norm * decoder_norm).sum(dim=-1) cosine_loss = 1.0 - cosine_sim.mean() # Combined loss loss = l2_loss + 0.5 * cosine_loss return loss class AdaptiveLossScheduler: """ Dynamically adjust loss weights during training Based on training progress and performance """ def __init__(self, config: Dict): self.config = config self.current_phase = 0 self.phase_epochs = [30, 60, 100] # Phase transition points # Define phase-specific weights self.phase_weights = [ # Phase 1: Boundary mastery { 'reconstruction': 2.0, 'compression': 0.5, 'boundary': 3.0, 'language': 0.5, 'consistency': 0.5 }, # Phase 2: Compression focus { 'reconstruction': 2.0, 'compression': 3.0, 'boundary': 1.0, 'language': 1.0, 'consistency': 1.0 }, # Phase 3: Balanced optimization { 'reconstruction': 3.0, 'compression': 2.0, 'boundary': 1.0, 'language': 1.0, 'consistency': 1.5 } ] def get_weights(self, epoch: int, metrics: Optional[Dict] = None) -> Dict[str, float]: """ Get current loss weights based on training phase Args: epoch: Current training epoch metrics: Optional performance metrics for adaptive adjustment """ # Determine current phase for i, phase_end in enumerate(self.phase_epochs): if epoch <= phase_end: self.current_phase = i break weights = self.phase_weights[self.current_phase].copy() # Adaptive adjustments based on metrics if metrics: # If reconstruction is poor, increase its weight if metrics.get('reconstruction_accuracy', 1.0) < 0.9: weights['reconstruction'] *= 1.5 # If compression is off target, adjust weight compression_ratio = metrics.get('compression_ratio', 16.0) if compression_ratio < 8.0 or compression_ratio > 20.0: weights['compression'] *= 1.5 return weights if __name__ == "__main__": # Test losses print("Testing Intelligent Loss Functions") # Create loss module loss_fn = IntelligentLoss() # Create dummy data batch_size = 2 seq_len = 48 vocab_size = 260 hidden_dim = 1280 outputs = { 'logits': torch.randn(batch_size, seq_len, vocab_size), 'compression_ratio': torch.tensor(16.0), 'num_tokens': torch.tensor(3), 'boundaries': torch.randn(batch_size, seq_len, 4), 'language_clusters': torch.randn(batch_size, 128), 'encoder_hidden': torch.randn(batch_size, seq_len, hidden_dim), 'decoder_hidden': torch.randn(batch_size, seq_len, hidden_dim) } targets = { 'input_ids': torch.randint(0, 256, (batch_size, seq_len)), 'attention_mask': torch.ones(batch_size, seq_len), 'boundary_targets': torch.zeros(batch_size, seq_len, 4), 'language_targets': torch.randint(0, 128, (batch_size,)) } # Compute losses losses = loss_fn(outputs, targets) print("\nLoss components:") for key, value in losses.items(): if isinstance(value, torch.Tensor): print(f" {key}: {value.item():.4f}") # Test adaptive scheduler scheduler = AdaptiveLossScheduler({}) print("\nPhase weights:") for epoch in [10, 40, 70]: weights = scheduler.get_weights(epoch) print(f" Epoch {epoch}: {weights}")