|
|
|
|
|
|
|
|
"""
|
|
|
Intelligent Tokenizer v6.0 - Inference Module
|
|
|
임베딩과 복원 기능
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import sys
|
|
|
import io
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
if sys.stdout.encoding != 'utf-8':
|
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
|
|
from core.boundary_aware_model import BoundaryAwareTokenizerModel
|
|
|
from src.core.byte_tokenizer_v6 import ByteTokenizerV6
|
|
|
|
|
|
|
|
|
class IntelligentTokenizer:
|
|
|
"""Intelligent Tokenizer for embedding and restoration"""
|
|
|
|
|
|
def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
|
|
|
"""
|
|
|
Initialize tokenizer
|
|
|
|
|
|
Args:
|
|
|
checkpoint_path: Path to model checkpoint
|
|
|
device: Device to use ('cuda', 'cpu', or None for auto)
|
|
|
"""
|
|
|
if device is None:
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
else:
|
|
|
self.device = torch.device(device)
|
|
|
|
|
|
print(f"Initializing Intelligent Tokenizer v6.0...")
|
|
|
print(f"Device: {self.device}")
|
|
|
|
|
|
|
|
|
checkpoint_path = Path(checkpoint_path)
|
|
|
if not checkpoint_path.exists():
|
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
|
|
|
|
|
|
|
|
self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
self.model = self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
self.tokenizer = ByteTokenizerV6()
|
|
|
self.max_chunk_size = 250
|
|
|
|
|
|
print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
|
|
|
print(f"Ready for inference!")
|
|
|
|
|
|
def embed(self, text: str) -> torch.Tensor:
|
|
|
"""
|
|
|
Convert text to embeddings
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
|
|
|
Returns:
|
|
|
Embedding tensor
|
|
|
"""
|
|
|
|
|
|
if len(text.encode('utf-8')) > self.max_chunk_size:
|
|
|
chunks = self._split_text_safely(text)
|
|
|
embeddings = []
|
|
|
|
|
|
for chunk in chunks:
|
|
|
emb = self._embed_single(chunk)
|
|
|
embeddings.append(emb)
|
|
|
|
|
|
|
|
|
return torch.cat(embeddings, dim=1)
|
|
|
else:
|
|
|
return self._embed_single(text)
|
|
|
|
|
|
def _embed_single(self, text: str) -> torch.Tensor:
|
|
|
"""Embed single chunk"""
|
|
|
|
|
|
encoded = self.tokenizer.encode(text)
|
|
|
byte_ids = encoded['input_ids']
|
|
|
input_ids = torch.tensor([byte_ids], device=self.device)
|
|
|
attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
encoder_outputs = self.model.encoder(input_ids, attention_mask)
|
|
|
embeddings = encoder_outputs['last_hidden_state']
|
|
|
|
|
|
return embeddings
|
|
|
|
|
|
def restore(self, text: str) -> Tuple[str, float]:
|
|
|
"""
|
|
|
Test restoration capability
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (restored_text, accuracy)
|
|
|
"""
|
|
|
|
|
|
if len(text.encode('utf-8')) > self.max_chunk_size:
|
|
|
chunks = self._split_text_safely(text)
|
|
|
restored_chunks = []
|
|
|
accuracies = []
|
|
|
|
|
|
for chunk in chunks:
|
|
|
restored, acc = self._restore_single(chunk)
|
|
|
restored_chunks.append(restored)
|
|
|
accuracies.append(acc)
|
|
|
|
|
|
return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
|
|
|
else:
|
|
|
return self._restore_single(text)
|
|
|
|
|
|
def _restore_single(self, text: str) -> Tuple[str, float]:
|
|
|
"""Restore single chunk"""
|
|
|
|
|
|
encoded = self.tokenizer.encode(text)
|
|
|
byte_ids = encoded['input_ids']
|
|
|
|
|
|
if len(byte_ids) <= 1:
|
|
|
return text, 1.0
|
|
|
|
|
|
input_ids = torch.tensor([byte_ids], device=self.device)
|
|
|
attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
decoder_input = input_ids[:, :-1]
|
|
|
labels = input_ids[:, 1:]
|
|
|
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
decoder_input_ids=decoder_input,
|
|
|
labels=labels,
|
|
|
use_cross_attention=True
|
|
|
)
|
|
|
|
|
|
|
|
|
predictions = torch.argmax(outputs['logits'], dim=-1)
|
|
|
accuracy = (predictions == labels).float().mean().item()
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
pred_list = predictions[0].cpu().tolist()
|
|
|
|
|
|
full_sequence = [self.tokenizer.BOS] + pred_list
|
|
|
|
|
|
|
|
|
filtered = [b for b in full_sequence if 0 <= b < 256]
|
|
|
if filtered:
|
|
|
restored_bytes = bytes(filtered)
|
|
|
restored_text = restored_bytes.decode('utf-8', errors='ignore')
|
|
|
else:
|
|
|
restored_text = ""
|
|
|
except Exception as e:
|
|
|
print(f"Restoration error: {e}")
|
|
|
restored_text = ""
|
|
|
|
|
|
return restored_text, accuracy
|
|
|
|
|
|
def compress(self, text: str) -> Dict:
|
|
|
"""
|
|
|
Get compression statistics
|
|
|
|
|
|
Args:
|
|
|
text: Input text
|
|
|
|
|
|
Returns:
|
|
|
Dict with compression info
|
|
|
"""
|
|
|
text_bytes = text.encode('utf-8')
|
|
|
embeddings = self.embed(text)
|
|
|
|
|
|
original_size = len(text_bytes)
|
|
|
compressed_size = embeddings.shape[1]
|
|
|
compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
|
|
|
|
|
|
return {
|
|
|
'original_bytes': original_size,
|
|
|
'compressed_tokens': compressed_size,
|
|
|
'compression_ratio': compression_ratio,
|
|
|
'embedding_shape': list(embeddings.shape)
|
|
|
}
|
|
|
|
|
|
def _split_text_safely(self, text: str) -> List[str]:
|
|
|
"""Split text safely at UTF-8 boundaries"""
|
|
|
chunks = []
|
|
|
text_bytes = text.encode('utf-8')
|
|
|
|
|
|
start = 0
|
|
|
while start < len(text_bytes):
|
|
|
end = min(start + self.max_chunk_size, len(text_bytes))
|
|
|
|
|
|
|
|
|
while end > start and end < len(text_bytes):
|
|
|
try:
|
|
|
chunk = text_bytes[start:end].decode('utf-8')
|
|
|
break
|
|
|
except UnicodeDecodeError:
|
|
|
end -= 1
|
|
|
|
|
|
if end > start:
|
|
|
chunk = text_bytes[start:end].decode('utf-8')
|
|
|
chunks.append(chunk)
|
|
|
start = end
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
def test_model():
|
|
|
"""Test model functionality"""
|
|
|
print("="*70)
|
|
|
print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
|
|
|
print("="*70)
|
|
|
|
|
|
|
|
|
tokenizer = IntelligentTokenizer()
|
|
|
|
|
|
|
|
|
test_samples = [
|
|
|
("English", "Hello, world!"),
|
|
|
("Korean", "안녕하세요. 반갑습니다."),
|
|
|
("Chinese", "今天天气很好"),
|
|
|
("Japanese", "こんにちは"),
|
|
|
("Arabic", "مرحبا بك"),
|
|
|
("Russian", "Привет, как дела?"),
|
|
|
("Emoji", "Hello 👋 World 🌍!"),
|
|
|
]
|
|
|
|
|
|
print("\n" + "="*70)
|
|
|
print("EMBEDDING & RESTORATION TESTS")
|
|
|
print("="*70)
|
|
|
|
|
|
total_accuracy = 0
|
|
|
successful = 0
|
|
|
|
|
|
for lang, text in test_samples:
|
|
|
print(f"\n[{lang}]")
|
|
|
print(f"Original: {text}")
|
|
|
|
|
|
|
|
|
embeddings = tokenizer.embed(text)
|
|
|
print(f"Embedding: {embeddings.shape}")
|
|
|
|
|
|
|
|
|
compression = tokenizer.compress(text)
|
|
|
print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens")
|
|
|
print(f"Ratio: {compression['compression_ratio']:.2f}x")
|
|
|
|
|
|
|
|
|
restored, accuracy = tokenizer.restore(text)
|
|
|
print(f"Restored: {restored}")
|
|
|
print(f"Accuracy: {accuracy:.1%}")
|
|
|
|
|
|
if accuracy > 0.7:
|
|
|
successful += 1
|
|
|
total_accuracy += accuracy
|
|
|
|
|
|
|
|
|
print("\n" + "="*70)
|
|
|
print("TEST SUMMARY")
|
|
|
print("="*70)
|
|
|
print(f"Tests passed: {successful}/{len(test_samples)}")
|
|
|
print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")
|
|
|
|
|
|
if successful == len(test_samples):
|
|
|
print("\n✅ ALL TESTS PASSED!")
|
|
|
return True
|
|
|
elif successful >= len(test_samples) * 0.7:
|
|
|
print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)")
|
|
|
return True
|
|
|
else:
|
|
|
print("\n❌ TESTS FAILED")
|
|
|
return False
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
success = test_model()
|
|
|
sys.exit(0 if success else 1) |