ggunio's picture
Upload core/unified_model.py with huggingface_hub
a6c6452 verified
raw
history blame
23.4 kB
"""
Unified Intelligent Tokenizer Model v6.0
์ˆœ์ˆ˜ ํ•™์Šต ๊ธฐ๋ฐ˜ - ๋ชจ๋“  ํ•ต์‹ฌ ์ฝ”๋“œ ํ†ตํ•ฉ
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List, Optional, Tuple, Union
class PositionalEncoding(nn.Module):
"""
Sinusoidal Positional Encoding (Transformer ์›๋ณธ ๋ฐฉ์‹)
ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ๊ณ ์ •๋œ sin/cos ํŒจํ„ด ์‚ฌ์šฉ
"""
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create sinusoidal position encodings
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # Even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # Odd dimensions
# Register as buffer (not trainable)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding to input
Args:
x: (batch_size, seq_len, d_model)
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class ByteTokenizer:
"""
Pure byte-level tokenizer - no language rules
"""
def __init__(self, max_seq_len: int = 512):
self.max_seq_len = max_seq_len
self.PAD = 256
self.BOS = 257
self.EOS = 258
self.MASK = 259
def encode(self, text: str, add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
# Convert to UTF-8 bytes
byte_seq = list(text.encode('utf-8'))
# Truncate if needed
if len(byte_seq) > self.max_seq_len - 2:
byte_seq = byte_seq[:self.max_seq_len - 2]
# Add special tokens
if add_special_tokens:
byte_seq = [self.BOS] + byte_seq + [self.EOS]
input_ids = torch.tensor(byte_seq, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'length': len(input_ids)
}
def encode_batch(self, texts: List[str]) -> Dict[str, torch.Tensor]:
encoded = [self.encode(text) for text in texts]
max_len = min(max(e['length'] for e in encoded), self.max_seq_len)
batch_size = len(texts)
input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.float32)
for i, enc in enumerate(encoded):
seq_len = min(enc['length'], max_len)
input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
attention_mask[i, :seq_len] = 1.0
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def decode(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.cpu().numpy().tolist()
if skip_special_tokens:
input_ids = [b for b in input_ids if b < 256]
try:
byte_array = bytes([min(b, 255) for b in input_ids if b != self.PAD])
return byte_array.decode('utf-8', errors='replace')
except:
return "".join([chr(b) if b < 128 else '?' for b in input_ids if b < 256])
class ByteEncoder(nn.Module):
"""
5-Layer Encoder with Positional Encoding
Layer dimensions: [384, 384, 512, 640, 768] - ์ˆ˜์ •๋จ
"""
def __init__(
self,
vocab_size: int = 260,
hidden_dims: List[int] = [384, 384, 512, 640, 768], # 512 ์ถ”๊ฐ€
num_heads: int = 8,
dropout: float = 0.1,
max_seq_len: int = 512
):
super().__init__()
# Byte embedding
self.byte_embedding = nn.Embedding(vocab_size, hidden_dims[0])
# Positional encoding (Sinusoidal)
self.pos_encoding = PositionalEncoding(hidden_dims[0], max_seq_len, dropout)
# 5 Transformer layers with dimension changes
self.layers = nn.ModuleList()
for i in range(len(hidden_dims)):
input_dim = hidden_dims[i-1] if i > 0 else hidden_dims[0]
output_dim = hidden_dims[i]
# Projection layer if dimension changes
if input_dim != output_dim:
proj = nn.Linear(input_dim, output_dim)
else:
proj = None
# Transformer encoder layer
layer = nn.TransformerEncoderLayer(
d_model=output_dim,
nhead=num_heads,
dim_feedforward=output_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True
)
self.layers.append(nn.ModuleDict({
'projection': proj,
'transformer': layer,
'norm': nn.LayerNorm(output_dim)
}))
self.dropout = nn.Dropout(dropout)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
# Embed bytes
x = self.byte_embedding(input_ids)
# Add positional encoding
x = self.pos_encoding(x)
# Prepare attention mask
if attention_mask is not None:
# Keep attention mask as is for TransformerEncoderLayer
# It expects shape (batch_size, seq_len) and handles masking internally
pass
# Process through 5 layers
all_hidden_states = []
for layer_dict in self.layers:
# Project if needed
if layer_dict['projection'] is not None:
x = layer_dict['projection'](x)
# Transformer layer - properly handle mask
if attention_mask is not None:
# TransformerEncoderLayer expects key_padding_mask (batch, seq)
# where True means "ignore this position"
key_padding_mask = (attention_mask == 0)
x = layer_dict['transformer'](x, src_key_padding_mask=key_padding_mask)
else:
x = layer_dict['transformer'](x)
x = layer_dict['norm'](x)
all_hidden_states.append(x)
# Pool for sequence representation
if attention_mask is not None:
# Masked mean pooling - attention_mask is (batch, seq)
mask = attention_mask.unsqueeze(-1) # (batch, seq, 1)
pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
else:
pooled = x.mean(dim=1)
return {
'last_hidden_state': x,
'pooled_output': pooled,
'all_hidden_states': all_hidden_states
}
class CrossAttention(nn.Module):
"""
Enhanced Cross-attention for relation learning between sequences
์ถ”๋ก  ๋ ˆ์ด์–ด ์—ฐ๊ฒฐ์„ ์œ„ํ•œ ๊ฐ•ํ™”๋œ ๊ด€๊ณ„ ํ•™์Šต
"""
def __init__(self, hidden_dim: int = 768, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
hidden_dim, num_heads, dropout, batch_first=True
)
# Enhanced relation classifier (8 types for richer relations)
# 0: identity, 1: similar, 2: different, 3: continuation
# 4: translation, 5: summary, 6: expansion, 7: contradiction
self.relation_head = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 8)
)
# Gating mechanism for adaptive fusion
self.gate = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Sigmoid()
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
query_mask: Optional[torch.Tensor] = None,
key_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
# Normalize inputs
query_norm = self.norm1(query)
key_norm = self.norm2(key)
# Fix key_mask dimension if needed
if key_mask is not None:
# Ensure key_mask matches key sequence length
if key_mask.dim() == 2 and key_mask.size(1) != key.size(1):
# Create new mask with correct dimensions
batch_size = key.size(0)
seq_len = key.size(1)
key_mask = torch.ones(batch_size, seq_len, dtype=key_mask.dtype, device=key_mask.device)
# Cross attention
attn_output, attn_weights = self.cross_attn(
query_norm, key_norm, key_norm,
key_padding_mask=(key_mask == 0) if key_mask is not None else None
)
# Residual connection
attn_output = attn_output + query
# Adaptive gating for fusion
gate_input = torch.cat([query.mean(dim=1), key.mean(dim=1)], dim=-1)
gate_weights = self.gate(gate_input).unsqueeze(1)
# Gated fusion: ์ ์‘์ ์œผ๋กœ cross-attention ๊ฒฐ๊ณผ ์กฐ์ ˆ
fused_output = gate_weights * attn_output + (1 - gate_weights) * query
# Pool for relation classification
query_pooled = query.mean(dim=1) if query_mask is None else \
(query * query_mask.unsqueeze(-1)).sum(1) / query_mask.sum(1, keepdim=True).clamp(min=1e-9)
key_pooled = key.mean(dim=1) if key_mask is None else \
(key * key_mask.unsqueeze(-1)).sum(1) / key_mask.sum(1, keepdim=True).clamp(min=1e-9)
# Classify relations with enhanced head
combined = torch.cat([query_pooled, key_pooled], dim=-1)
relation_logits = self.relation_head(combined)
return {
'cross_attention': fused_output, # Gated fusion output
'attention_weights': attn_weights,
'relation_logits': relation_logits,
'gate_weights': gate_weights.squeeze(1) # For analysis
}
class TransformerDecoder(nn.Module):
"""
Transformer Decoder with Positional Encoding
"""
def __init__(
self,
vocab_size: int = 260,
hidden_dim: int = 768,
num_heads: int = 8,
num_layers: int = 6,
dropout: float = 0.1,
max_seq_len: int = 512
):
super().__init__()
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
# Positional encoding
self.pos_encoding = PositionalEncoding(hidden_dim, max_seq_len, dropout)
# Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
# Output projection
self.output_projection = nn.Linear(hidden_dim, vocab_size)
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
def forward(
self,
encoder_hidden: torch.Tensor,
decoder_input_ids: Optional[torch.Tensor] = None,
encoder_mask: Optional[torch.Tensor] = None,
decoder_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
batch_size = encoder_hidden.size(0)
# Start with BOS if no input
if decoder_input_ids is None:
decoder_input_ids = torch.full((batch_size, 1), 257, device=encoder_hidden.device)
# Embed and add positional encoding
dec_seq_len = decoder_input_ids.size(1)
x = self.token_embedding(decoder_input_ids)
x = self.pos_encoding(x)
# Create causal mask
causal_mask = torch.triu(
torch.ones(dec_seq_len, dec_seq_len, device=x.device) * float('-inf'),
diagonal=1
)
# Decoder forward - handle variable-length encoder outputs
# The encoder may compress the sequence, so memory (encoder_hidden) might be shorter
# than the decoder sequence. This is expected and correct behavior.
enc_seq_len = encoder_hidden.size(1)
# Adjust encoder mask if needed
if encoder_mask is not None:
if encoder_mask.size(1) != enc_seq_len:
# Encoder compressed the sequence, create new mask for compressed length
# All compressed positions are valid (not masked)
memory_key_padding_mask = torch.zeros(
encoder_hidden.size(0), enc_seq_len,
dtype=torch.bool, device=encoder_hidden.device
)
else:
memory_key_padding_mask = (encoder_mask == 0)
else:
memory_key_padding_mask = None
# Decoder attends to compressed encoder states via cross-attention
# This naturally handles different sequence lengths
decoder_output = self.transformer(
tgt=x, # Decoder sequence (original length)
memory=encoder_hidden, # Encoder sequence (possibly compressed)
tgt_mask=causal_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_key_padding_mask=(decoder_mask == 0) if decoder_mask is not None else None
)
# Project to vocabulary
logits = self.output_projection(decoder_output)
return {
'logits': logits,
'hidden_states': decoder_output
}
@torch.no_grad()
def generate(
self,
encoder_hidden: torch.Tensor,
encoder_mask: Optional[torch.Tensor] = None,
max_length: int = 128,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95
) -> torch.Tensor:
batch_size = encoder_hidden.size(0)
device = encoder_hidden.device
# Start with BOS
decoder_input_ids = torch.full((batch_size, 1), 257, device=device)
for _ in range(max_length - 1):
# Forward pass
outputs = self.forward(encoder_hidden, decoder_input_ids, encoder_mask)
next_token_logits = outputs['logits'][:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
# Sample
probs = F.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, 1)
decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1)
# Stop at EOS
if (next_tokens == 258).all(): # EOS token
break
return decoder_input_ids
class IntelligentTokenizerModel(nn.Module):
"""
Complete Intelligent Tokenizer Model v6.0
ํ†ตํ•ฉ ๋ชจ๋ธ - Encoder + Decoder + Cross-Attention
"""
def __init__(
self,
vocab_size: int = 260,
encoder_dims: List[int] = [384, 384, 512, 640, 768], # 512 ์ถ”๊ฐ€
decoder_hidden: int = 768,
num_heads: int = 8,
num_decoder_layers: int = 6,
dropout: float = 0.1,
max_seq_len: int = 512
):
super().__init__()
# Components
self.tokenizer = ByteTokenizer(max_seq_len)
self.encoder = ByteEncoder(vocab_size, encoder_dims, num_heads, dropout, max_seq_len)
self.decoder = TransformerDecoder(vocab_size, decoder_hidden, num_heads, num_decoder_layers, dropout, max_seq_len)
self.cross_attention = CrossAttention(encoder_dims[-1], num_heads, dropout)
def forward(
self,
input_texts: Optional[List[str]] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cross_attention: bool = True
) -> Dict[str, torch.Tensor]:
# Tokenize if text input
if input_texts is not None:
tokenized = self.tokenizer.encode_batch(input_texts)
input_ids = tokenized['input_ids']
attention_mask = tokenized['attention_mask']
# ์‹œํ€€์Šค ๊ธธ์ด ์ฒดํฌ ๋ฐ ์กฐ์ •
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Encode
encoder_outputs = self.encoder(input_ids, attention_mask)
encoder_hidden = encoder_outputs['last_hidden_state'] # [batch, seq, 768]
# ์ฐจ์› ํ™•์ธ
assert encoder_hidden.size(-1) == 768, f"Encoder dim mismatch: {encoder_hidden.size(-1)}"
# Decode
decoder_outputs = self.decoder(
encoder_hidden,
decoder_input_ids,
attention_mask
)
decoder_hidden = decoder_outputs['hidden_states'] # [batch, seq, 768]
# Cross-Attention (๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์—์„œ ๊ด€๊ณ„ ํ•™์Šต)
cross_attn_outputs = None
relation_logits = None
if use_cross_attention and decoder_hidden is not None:
# ๋””์ฝ”๋” ์ถœ๋ ฅ๊ณผ ์ธ์ฝ”๋” ์ถœ๋ ฅ ๊ฐ„ ํฌ๋กœ์Šค์–ดํ…์…˜
cross_attn_outputs = self.cross_attention(
query=decoder_hidden, # ๋””์ฝ”๋”๊ฐ€ query
key=encoder_hidden, # ์ธ์ฝ”๋”๊ฐ€ key/value
query_mask=None, # decoder mask๋Š” causal์ด๋ฏ€๋กœ ๋ณ„๋„ ์ฒ˜๋ฆฌ
key_mask=attention_mask
)
# ๊ด€๊ณ„ ํ•™์Šต ๊ฒฐ๊ณผ
relation_logits = cross_attn_outputs['relation_logits']
# Cross-attention์œผ๋กœ ๊ฐ•ํ™”๋œ ๋””์ฝ”๋” ํ‘œํ˜„
enhanced_decoder = decoder_hidden + cross_attn_outputs['cross_attention']
# ์ตœ์ข… ๋กœ์ง“ ์žฌ๊ณ„์‚ฐ (cross-attention ์ ์šฉ ํ›„)
if hasattr(self.decoder, 'output_projection'):
decoder_outputs['logits'] = self.decoder.output_projection(enhanced_decoder)
# Calculate loss if labels provided
loss = None
if labels is not None:
# Reconstruction loss
loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.PAD)
recon_loss = loss_fct(
decoder_outputs['logits'].reshape(-1, decoder_outputs['logits'].size(-1)),
labels.reshape(-1)
)
# Relation loss (if cross-attention used)
relation_loss = 0
if relation_logits is not None:
# ์ž๊ธฐ ๊ด€๊ณ„๋Š” identity (class 0)์—ฌ์•ผ ํ•จ
batch_identity = torch.zeros(batch_size, dtype=torch.long, device=device)
relation_loss = F.cross_entropy(relation_logits, batch_identity) * 0.1
loss = recon_loss + relation_loss
return {
'loss': loss,
'logits': decoder_outputs['logits'],
'encoder_hidden_states': encoder_hidden,
'decoder_hidden_states': decoder_hidden,
'pooled_output': encoder_outputs['pooled_output'],
'cross_attention': cross_attn_outputs['cross_attention'] if cross_attn_outputs else None,
'relation_logits': relation_logits,
'all_encoder_states': encoder_outputs.get('all_hidden_states', None)
}
def encode_text(self, text: str) -> torch.Tensor:
"""Encode single text to representation"""
tokenized = self.tokenizer.encode(text)
# Move to same device as model
device = next(self.parameters()).device
input_ids = tokenized['input_ids'].unsqueeze(0).to(device)
attention_mask = tokenized['attention_mask'].unsqueeze(0).to(device)
with torch.no_grad():
outputs = self.encoder(input_ids, attention_mask)
return outputs['pooled_output'].squeeze(0)
def decode_representation(self, representation: torch.Tensor, max_length: int = 128) -> str:
"""Decode representation back to text"""
if representation.dim() == 1:
representation = representation.unsqueeze(0).unsqueeze(0)
elif representation.dim() == 2:
representation = representation.unsqueeze(1)
with torch.no_grad():
output_ids = self.decoder.generate(representation, max_length=max_length)
text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return text
def compute_relation(self, text1: str, text2: str) -> torch.Tensor:
"""Compute relation between two texts"""
# Encode both texts
enc1 = self.encode_text(text1).unsqueeze(0).unsqueeze(0)
enc2 = self.encode_text(text2).unsqueeze(0).unsqueeze(0)
# Compute cross-attention and relations
with torch.no_grad():
outputs = self.cross_attention(enc1, enc2)
return F.softmax(outputs['relation_logits'], dim=-1)