|
|
"""
|
|
|
Intelligent Loss Functions for v6.2.0
|
|
|
Multi-objective loss with GPT-5 suggested improvements
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
import math
|
|
|
|
|
|
|
|
|
class IntelligentLoss(nn.Module):
|
|
|
"""
|
|
|
Comprehensive loss function for progressive splitting tokenizer
|
|
|
Combines multiple objectives with dynamic weighting
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.config = config or {}
|
|
|
|
|
|
|
|
|
self.PAD = 256
|
|
|
self.BOS = 257
|
|
|
self.EOS = 258
|
|
|
self.MASK = 259
|
|
|
|
|
|
|
|
|
self.reconstruction_loss = ReconstructionLoss(self.PAD)
|
|
|
self.compression_loss = CompressionLoss()
|
|
|
self.boundary_loss = BoundaryLoss()
|
|
|
self.language_loss = LanguageLoss()
|
|
|
self.consistency_loss = ConsistencyLoss()
|
|
|
|
|
|
|
|
|
self.use_dynamic_weights = True
|
|
|
self.weight_history = {
|
|
|
'reconstruction': [],
|
|
|
'compression': [],
|
|
|
'boundary': [],
|
|
|
'language': [],
|
|
|
'consistency': []
|
|
|
}
|
|
|
|
|
|
def estimate_language_difficulty(self, targets: Dict) -> float:
|
|
|
"""Estimate language difficulty based on input characteristics"""
|
|
|
if 'input_ids' not in targets:
|
|
|
return 1.0
|
|
|
|
|
|
input_ids = targets['input_ids']
|
|
|
if input_ids.numel() == 0:
|
|
|
return 1.0
|
|
|
|
|
|
|
|
|
unique_tokens = input_ids.unique().numel()
|
|
|
total_tokens = input_ids.numel()
|
|
|
diversity = min(1.0, (unique_tokens / total_tokens) * 2)
|
|
|
|
|
|
return diversity
|
|
|
|
|
|
def forward(self,
|
|
|
outputs: Dict[str, torch.Tensor],
|
|
|
targets: Dict[str, torch.Tensor],
|
|
|
weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Compute combined loss with all objectives
|
|
|
|
|
|
Args:
|
|
|
outputs: Model outputs dictionary
|
|
|
targets: Target values dictionary
|
|
|
weights: Optional weight overrides
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with total loss and individual components
|
|
|
"""
|
|
|
losses = {}
|
|
|
|
|
|
|
|
|
if 'logits' in outputs and 'input_ids' in targets:
|
|
|
losses['reconstruction'] = self.reconstruction_loss(
|
|
|
outputs['logits'],
|
|
|
targets['input_ids'],
|
|
|
targets.get('attention_mask')
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'compression_ratio' in outputs:
|
|
|
losses['compression'] = self.compression_loss(
|
|
|
outputs['compression_ratio'],
|
|
|
outputs.get('num_tokens')
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'boundaries' in outputs and 'boundary_targets' in targets:
|
|
|
losses['boundary'] = self.boundary_loss(
|
|
|
outputs['boundaries'],
|
|
|
targets['boundary_targets'],
|
|
|
targets.get('boundary_mask')
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'language_clusters' in outputs and 'language_targets' in targets:
|
|
|
losses['language'] = self.language_loss(
|
|
|
outputs['language_clusters'],
|
|
|
targets['language_targets']
|
|
|
)
|
|
|
|
|
|
|
|
|
if 'encoder_hidden' in outputs and 'decoder_hidden' in outputs:
|
|
|
losses['consistency'] = self.consistency_loss(
|
|
|
outputs['encoder_hidden'],
|
|
|
outputs['decoder_hidden']
|
|
|
)
|
|
|
|
|
|
|
|
|
if weights is None and self.use_dynamic_weights:
|
|
|
weights = self.compute_dynamic_weights(losses)
|
|
|
elif weights is None:
|
|
|
weights = {
|
|
|
'reconstruction': 1.0,
|
|
|
'compression': 1.0,
|
|
|
'boundary': 1.0,
|
|
|
'language': 0.5,
|
|
|
'consistency': 0.5
|
|
|
}
|
|
|
|
|
|
|
|
|
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device)
|
|
|
for key, loss in losses.items():
|
|
|
weight = weights.get(key, 1.0)
|
|
|
total_loss = total_loss + weight * loss
|
|
|
losses[f'{key}_weighted'] = weight * loss
|
|
|
|
|
|
losses['total'] = total_loss
|
|
|
|
|
|
|
|
|
for key in self.weight_history:
|
|
|
if key in losses:
|
|
|
self.weight_history[key].append(losses[key].item())
|
|
|
|
|
|
return losses
|
|
|
|
|
|
def compute_dynamic_weights(self, losses: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
|
"""
|
|
|
Dynamically adjust weights based on loss magnitudes and progress
|
|
|
GPT-5 suggestion: balance loss magnitudes for stable training
|
|
|
"""
|
|
|
weights = {}
|
|
|
eps = 1e-8
|
|
|
|
|
|
|
|
|
magnitudes = {}
|
|
|
for k, v in losses.items():
|
|
|
if torch.isnan(v) or torch.isinf(v):
|
|
|
magnitudes[k] = 1.0
|
|
|
else:
|
|
|
magnitudes[k] = v.item()
|
|
|
|
|
|
|
|
|
avg_magnitude = max(eps, sum(magnitudes.values()) / len(magnitudes))
|
|
|
|
|
|
for key, magnitude in magnitudes.items():
|
|
|
|
|
|
weights[key] = avg_magnitude / max(eps, magnitude)
|
|
|
|
|
|
|
|
|
if 'reconstruction' in magnitudes and 'compression' in magnitudes:
|
|
|
recon_loss = magnitudes['reconstruction']
|
|
|
comp_loss = magnitudes['compression']
|
|
|
|
|
|
|
|
|
if recon_loss > comp_loss * 10:
|
|
|
|
|
|
weights['compression'] *= 0.1
|
|
|
weights['reconstruction'] *= 5.0
|
|
|
elif recon_loss > comp_loss * 5:
|
|
|
|
|
|
weights['compression'] *= 0.5
|
|
|
weights['reconstruction'] *= 2.0
|
|
|
elif recon_loss < comp_loss * 0.5:
|
|
|
|
|
|
weights['compression'] *= 2.0
|
|
|
weights['reconstruction'] *= 0.5
|
|
|
|
|
|
|
|
|
total_weight = sum(weights.values())
|
|
|
if total_weight > 0:
|
|
|
weights = {k: min(10.0, v / total_weight * len(weights)) for k, v in weights.items()}
|
|
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
class ReconstructionLoss(nn.Module):
|
|
|
"""
|
|
|
Cross-entropy loss for sequence reconstruction
|
|
|
With label smoothing and focal loss options
|
|
|
"""
|
|
|
|
|
|
def __init__(self, pad_token: int = 256, label_smoothing: float = 0.1):
|
|
|
super().__init__()
|
|
|
self.pad_token = pad_token
|
|
|
self.label_smoothing = label_smoothing
|
|
|
self.focal_alpha = 0.25
|
|
|
self.focal_gamma = 2.0
|
|
|
self.use_focal = False
|
|
|
|
|
|
def forward(self,
|
|
|
logits: torch.Tensor,
|
|
|
targets: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute reconstruction loss
|
|
|
|
|
|
Args:
|
|
|
logits: [batch, seq_len, vocab_size]
|
|
|
targets: [batch, seq_len]
|
|
|
mask: [batch, seq_len] attention mask
|
|
|
"""
|
|
|
batch_size, seq_len, vocab_size = logits.shape
|
|
|
|
|
|
|
|
|
logits_flat = logits.reshape(-1, vocab_size)
|
|
|
targets_flat = targets.reshape(-1)
|
|
|
|
|
|
if self.use_focal:
|
|
|
|
|
|
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
|
|
|
pt = torch.exp(-ce_loss)
|
|
|
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss
|
|
|
|
|
|
if mask is not None:
|
|
|
mask_flat = mask.reshape(-1)
|
|
|
focal_loss = focal_loss * mask_flat
|
|
|
loss = focal_loss.sum() / mask_flat.sum()
|
|
|
else:
|
|
|
loss = focal_loss.mean()
|
|
|
else:
|
|
|
|
|
|
if mask is not None:
|
|
|
mask_flat = mask.reshape(-1).bool()
|
|
|
loss = F.cross_entropy(
|
|
|
logits_flat[mask_flat],
|
|
|
targets_flat[mask_flat],
|
|
|
ignore_index=self.pad_token,
|
|
|
label_smoothing=self.label_smoothing
|
|
|
)
|
|
|
else:
|
|
|
loss = F.cross_entropy(
|
|
|
logits_flat,
|
|
|
targets_flat,
|
|
|
ignore_index=self.pad_token,
|
|
|
label_smoothing=self.label_smoothing
|
|
|
)
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
class CompressionLoss(nn.Module):
|
|
|
"""
|
|
|
Aggressive compression loss - push for high compression
|
|
|
Must beat existing tokenizers (4 bytes/token = 4:1)
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.min_ratio = 12.0
|
|
|
self.target_ratio = 24.0
|
|
|
self.max_ratio = 48.0
|
|
|
|
|
|
def forward(self,
|
|
|
compression_ratio: torch.Tensor,
|
|
|
num_tokens: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute compression loss (GPT fix: fully vectorized)
|
|
|
|
|
|
Args:
|
|
|
compression_ratio: Current compression ratio (scalar or batch)
|
|
|
num_tokens: Number of tokens used (for additional penalty)
|
|
|
"""
|
|
|
|
|
|
if not torch.is_tensor(compression_ratio):
|
|
|
device = num_tokens.device if torch.is_tensor(num_tokens) else torch.device('cpu')
|
|
|
compression_ratio = torch.tensor(compression_ratio, dtype=torch.float32, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
if compression_ratio < self.min_ratio:
|
|
|
|
|
|
under_loss = ((self.min_ratio - compression_ratio) / self.min_ratio) * 0.5
|
|
|
else:
|
|
|
under_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
|
|
|
|
|
|
|
|
if self.min_ratio <= compression_ratio < self.target_ratio:
|
|
|
|
|
|
target_loss = ((self.target_ratio - compression_ratio) / self.target_ratio) * 0.5
|
|
|
elif compression_ratio >= self.target_ratio:
|
|
|
|
|
|
target_loss = -0.1 * torch.log(compression_ratio / self.target_ratio + 1.0)
|
|
|
else:
|
|
|
target_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
|
|
|
|
|
|
|
|
if compression_ratio > self.max_ratio:
|
|
|
over_loss = ((compression_ratio - self.max_ratio) / self.max_ratio) * 0.2
|
|
|
else:
|
|
|
over_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
|
|
|
|
|
loss = under_loss + target_loss + over_loss
|
|
|
|
|
|
|
|
|
if num_tokens is not None:
|
|
|
if not torch.is_tensor(num_tokens):
|
|
|
num_tokens = torch.tensor(num_tokens, dtype=torch.float32, device=compression_ratio.device)
|
|
|
token_penalty = 0.1 * torch.clamp(num_tokens - 8, min=0.0) ** 2
|
|
|
loss = loss + token_penalty
|
|
|
|
|
|
return loss.mean() if loss.dim() > 0 else loss
|
|
|
|
|
|
|
|
|
class BoundaryLoss(nn.Module):
|
|
|
"""
|
|
|
Learn meaningful chunk boundaries
|
|
|
Combines multiple boundary objectives
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
|
|
|
|
|
|
def forward(self,
|
|
|
predicted: torch.Tensor,
|
|
|
target: torch.Tensor,
|
|
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute boundary loss
|
|
|
|
|
|
Args:
|
|
|
predicted: [batch, seq_len, boundary_classes] predicted boundaries
|
|
|
target: [batch, seq_len, boundary_classes] target boundaries
|
|
|
mask: [batch, seq_len] valid positions mask
|
|
|
"""
|
|
|
|
|
|
loss = self.bce_loss(predicted, target.float())
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|
mask_expanded = mask.unsqueeze(-1).expand_as(loss)
|
|
|
loss = loss * mask_expanded
|
|
|
loss = loss.sum() / mask_expanded.sum()
|
|
|
else:
|
|
|
loss = loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
boundary_probs = torch.sigmoid(predicted)
|
|
|
sparsity_loss = 0.01 * boundary_probs.mean()
|
|
|
|
|
|
|
|
|
|
|
|
if predicted.size(1) > 1:
|
|
|
diff = predicted[:, 1:] - predicted[:, :-1]
|
|
|
smoothness_loss = 0.01 * (diff ** 2).mean()
|
|
|
else:
|
|
|
smoothness_loss = 0.0
|
|
|
|
|
|
total_loss = loss + sparsity_loss + smoothness_loss
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
class LanguageLoss(nn.Module):
|
|
|
"""
|
|
|
Language identification/clustering loss
|
|
|
Supports both classification and clustering objectives
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_languages: int = 128, temperature: float = 0.07):
|
|
|
super().__init__()
|
|
|
self.num_languages = num_languages
|
|
|
self.temperature = temperature
|
|
|
|
|
|
|
|
|
self.ce_loss = nn.CrossEntropyLoss()
|
|
|
|
|
|
def forward(self,
|
|
|
predicted: torch.Tensor,
|
|
|
target: torch.Tensor,
|
|
|
mode: str = 'classification') -> torch.Tensor:
|
|
|
"""
|
|
|
Compute language loss
|
|
|
|
|
|
Args:
|
|
|
predicted: [batch, seq_len, num_languages] or [batch, num_languages]
|
|
|
target: Language labels or cluster assignments
|
|
|
mode: 'classification' or 'clustering'
|
|
|
"""
|
|
|
if mode == 'classification':
|
|
|
|
|
|
if predicted.dim() == 3:
|
|
|
|
|
|
batch_size, seq_len, _ = predicted.shape
|
|
|
predicted = predicted.reshape(-1, self.num_languages)
|
|
|
target = target.reshape(-1)
|
|
|
|
|
|
loss = self.ce_loss(predicted, target)
|
|
|
|
|
|
elif mode == 'clustering':
|
|
|
|
|
|
|
|
|
predicted = F.normalize(predicted, dim=-1)
|
|
|
|
|
|
|
|
|
sim_matrix = torch.matmul(predicted, predicted.t()) / self.temperature
|
|
|
|
|
|
|
|
|
batch_size = predicted.size(0)
|
|
|
labels = torch.arange(batch_size, device=predicted.device)
|
|
|
|
|
|
|
|
|
loss = F.cross_entropy(sim_matrix, labels)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown mode: {mode}")
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
class ConsistencyLoss(nn.Module):
|
|
|
"""
|
|
|
Ensure consistency between encoder and decoder representations
|
|
|
GPT-5 suggestion: helps with training stability
|
|
|
"""
|
|
|
|
|
|
def __init__(self, margin: float = 0.5):
|
|
|
super().__init__()
|
|
|
self.margin = margin
|
|
|
|
|
|
def forward(self,
|
|
|
encoder_hidden: torch.Tensor,
|
|
|
decoder_hidden: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Compute consistency loss between encoder and decoder
|
|
|
|
|
|
Args:
|
|
|
encoder_hidden: [batch, seq_len, hidden_dim]
|
|
|
decoder_hidden: [batch, seq_len, hidden_dim]
|
|
|
"""
|
|
|
|
|
|
if encoder_hidden.shape != decoder_hidden.shape:
|
|
|
|
|
|
min_len = min(encoder_hidden.size(1), decoder_hidden.size(1))
|
|
|
encoder_hidden = encoder_hidden[:, :min_len]
|
|
|
decoder_hidden = decoder_hidden[:, :min_len]
|
|
|
|
|
|
|
|
|
l2_loss = F.mse_loss(encoder_hidden, decoder_hidden)
|
|
|
|
|
|
|
|
|
encoder_norm = F.normalize(encoder_hidden, dim=-1)
|
|
|
decoder_norm = F.normalize(decoder_hidden, dim=-1)
|
|
|
cosine_sim = (encoder_norm * decoder_norm).sum(dim=-1)
|
|
|
cosine_loss = 1.0 - cosine_sim.mean()
|
|
|
|
|
|
|
|
|
loss = l2_loss + 0.5 * cosine_loss
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
class AdaptiveLossScheduler:
|
|
|
"""
|
|
|
Dynamically adjust loss weights during training
|
|
|
Based on training progress and performance
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Dict):
|
|
|
self.config = config
|
|
|
self.current_phase = 0
|
|
|
self.phase_epochs = [30, 60, 100]
|
|
|
|
|
|
|
|
|
self.phase_weights = [
|
|
|
|
|
|
{
|
|
|
'reconstruction': 2.0,
|
|
|
'compression': 0.5,
|
|
|
'boundary': 3.0,
|
|
|
'language': 0.5,
|
|
|
'consistency': 0.5
|
|
|
},
|
|
|
|
|
|
{
|
|
|
'reconstruction': 2.0,
|
|
|
'compression': 3.0,
|
|
|
'boundary': 1.0,
|
|
|
'language': 1.0,
|
|
|
'consistency': 1.0
|
|
|
},
|
|
|
|
|
|
{
|
|
|
'reconstruction': 3.0,
|
|
|
'compression': 2.0,
|
|
|
'boundary': 1.0,
|
|
|
'language': 1.0,
|
|
|
'consistency': 1.5
|
|
|
}
|
|
|
]
|
|
|
|
|
|
def get_weights(self, epoch: int, metrics: Optional[Dict] = None) -> Dict[str, float]:
|
|
|
"""
|
|
|
Get current loss weights based on training phase
|
|
|
|
|
|
Args:
|
|
|
epoch: Current training epoch
|
|
|
metrics: Optional performance metrics for adaptive adjustment
|
|
|
"""
|
|
|
|
|
|
for i, phase_end in enumerate(self.phase_epochs):
|
|
|
if epoch <= phase_end:
|
|
|
self.current_phase = i
|
|
|
break
|
|
|
|
|
|
weights = self.phase_weights[self.current_phase].copy()
|
|
|
|
|
|
|
|
|
if metrics:
|
|
|
|
|
|
if metrics.get('reconstruction_accuracy', 1.0) < 0.9:
|
|
|
weights['reconstruction'] *= 1.5
|
|
|
|
|
|
|
|
|
compression_ratio = metrics.get('compression_ratio', 16.0)
|
|
|
if compression_ratio < 8.0 or compression_ratio > 20.0:
|
|
|
weights['compression'] *= 1.5
|
|
|
|
|
|
return weights
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print("Testing Intelligent Loss Functions")
|
|
|
|
|
|
|
|
|
loss_fn = IntelligentLoss()
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
seq_len = 48
|
|
|
vocab_size = 260
|
|
|
hidden_dim = 1280
|
|
|
|
|
|
outputs = {
|
|
|
'logits': torch.randn(batch_size, seq_len, vocab_size),
|
|
|
'compression_ratio': torch.tensor(16.0),
|
|
|
'num_tokens': torch.tensor(3),
|
|
|
'boundaries': torch.randn(batch_size, seq_len, 4),
|
|
|
'language_clusters': torch.randn(batch_size, 128),
|
|
|
'encoder_hidden': torch.randn(batch_size, seq_len, hidden_dim),
|
|
|
'decoder_hidden': torch.randn(batch_size, seq_len, hidden_dim)
|
|
|
}
|
|
|
|
|
|
targets = {
|
|
|
'input_ids': torch.randint(0, 256, (batch_size, seq_len)),
|
|
|
'attention_mask': torch.ones(batch_size, seq_len),
|
|
|
'boundary_targets': torch.zeros(batch_size, seq_len, 4),
|
|
|
'language_targets': torch.randint(0, 128, (batch_size,))
|
|
|
}
|
|
|
|
|
|
|
|
|
losses = loss_fn(outputs, targets)
|
|
|
|
|
|
print("\nLoss components:")
|
|
|
for key, value in losses.items():
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
print(f" {key}: {value.item():.4f}")
|
|
|
|
|
|
|
|
|
scheduler = AdaptiveLossScheduler({})
|
|
|
|
|
|
print("\nPhase weights:")
|
|
|
for epoch in [10, 40, 70]:
|
|
|
weights = scheduler.get_weights(epoch)
|
|
|
print(f" Epoch {epoch}: {weights}") |