#!/usr/bin/env python # -*- coding: utf-8 -*- """ Intelligent Tokenizer v6.0 - Inference Module 임베딩과 복원 기능 """ import torch import sys import io from pathlib import Path from typing import Dict, List, Optional, Tuple # UTF-8 인코딩 설정 if sys.stdout.encoding != 'utf-8': sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') sys.path.append(str(Path(__file__).parent)) from core.boundary_aware_model import BoundaryAwareTokenizerModel from src.core.byte_tokenizer_v6 import ByteTokenizerV6 class IntelligentTokenizer: """Intelligent Tokenizer for embedding and restoration""" def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None): """ Initialize tokenizer Args: checkpoint_path: Path to model checkpoint device: Device to use ('cuda', 'cpu', or None for auto) """ if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) print(f"Initializing Intelligent Tokenizer v6.0...") print(f"Device: {self.device}") # Load checkpoint checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) # Initialize model self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config']) self.model.load_state_dict(checkpoint['model_state_dict']) self.model = self.model.to(self.device) self.model.eval() # Initialize tokenizer self.tokenizer = ByteTokenizerV6() self.max_chunk_size = 250 # Safe margin for 256 byte limit print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}") print(f"Ready for inference!") def embed(self, text: str) -> torch.Tensor: """ Convert text to embeddings Args: text: Input text Returns: Embedding tensor """ # Handle long text by chunking if len(text.encode('utf-8')) > self.max_chunk_size: chunks = self._split_text_safely(text) embeddings = [] for chunk in chunks: emb = self._embed_single(chunk) embeddings.append(emb) # Concatenate embeddings return torch.cat(embeddings, dim=1) else: return self._embed_single(text) def _embed_single(self, text: str) -> torch.Tensor: """Embed single chunk""" # Encode text encoded = self.tokenizer.encode(text) byte_ids = encoded['input_ids'] input_ids = torch.tensor([byte_ids], device=self.device) attention_mask = torch.tensor([encoded['attention_mask']], device=self.device) with torch.no_grad(): # Get embeddings encoder_outputs = self.model.encoder(input_ids, attention_mask) embeddings = encoder_outputs['last_hidden_state'] return embeddings def restore(self, text: str) -> Tuple[str, float]: """ Test restoration capability Args: text: Input text Returns: Tuple of (restored_text, accuracy) """ # Handle long text if len(text.encode('utf-8')) > self.max_chunk_size: chunks = self._split_text_safely(text) restored_chunks = [] accuracies = [] for chunk in chunks: restored, acc = self._restore_single(chunk) restored_chunks.append(restored) accuracies.append(acc) return ''.join(restored_chunks), sum(accuracies) / len(accuracies) else: return self._restore_single(text) def _restore_single(self, text: str) -> Tuple[str, float]: """Restore single chunk""" # Encode text encoded = self.tokenizer.encode(text) byte_ids = encoded['input_ids'] if len(byte_ids) <= 1: return text, 1.0 input_ids = torch.tensor([byte_ids], device=self.device) attention_mask = torch.tensor([encoded['attention_mask']], device=self.device) with torch.no_grad(): # Teacher forcing for restoration test decoder_input = input_ids[:, :-1] labels = input_ids[:, 1:] outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input, labels=labels, use_cross_attention=True ) # Get predictions predictions = torch.argmax(outputs['logits'], dim=-1) accuracy = (predictions == labels).float().mean().item() # Decode predictions try: # Remove special tokens and convert to bytes pred_list = predictions[0].cpu().tolist() # Add BOS at beginning for full sequence full_sequence = [self.tokenizer.BOS] + pred_list # Filter valid bytes filtered = [b for b in full_sequence if 0 <= b < 256] if filtered: restored_bytes = bytes(filtered) restored_text = restored_bytes.decode('utf-8', errors='ignore') else: restored_text = "" except Exception as e: print(f"Restoration error: {e}") restored_text = "" return restored_text, accuracy def compress(self, text: str) -> Dict: """ Get compression statistics Args: text: Input text Returns: Dict with compression info """ text_bytes = text.encode('utf-8') embeddings = self.embed(text) original_size = len(text_bytes) compressed_size = embeddings.shape[1] compression_ratio = original_size / compressed_size if compressed_size > 0 else 0 return { 'original_bytes': original_size, 'compressed_tokens': compressed_size, 'compression_ratio': compression_ratio, 'embedding_shape': list(embeddings.shape) } def _split_text_safely(self, text: str) -> List[str]: """Split text safely at UTF-8 boundaries""" chunks = [] text_bytes = text.encode('utf-8') start = 0 while start < len(text_bytes): end = min(start + self.max_chunk_size, len(text_bytes)) # Find valid UTF-8 boundary while end > start and end < len(text_bytes): try: chunk = text_bytes[start:end].decode('utf-8') break except UnicodeDecodeError: end -= 1 if end > start: chunk = text_bytes[start:end].decode('utf-8') chunks.append(chunk) start = end else: break return chunks def test_model(): """Test model functionality""" print("="*70) print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST") print("="*70) # Initialize tokenizer tokenizer = IntelligentTokenizer() # Test samples test_samples = [ ("English", "Hello, world!"), ("Korean", "안녕하세요. 반갑습니다."), ("Chinese", "今天天气很好"), ("Japanese", "こんにちは"), ("Arabic", "مرحبا بك"), ("Russian", "Привет, как дела?"), ("Emoji", "Hello 👋 World 🌍!"), ] print("\n" + "="*70) print("EMBEDDING & RESTORATION TESTS") print("="*70) total_accuracy = 0 successful = 0 for lang, text in test_samples: print(f"\n[{lang}]") print(f"Original: {text}") # Test embedding embeddings = tokenizer.embed(text) print(f"Embedding: {embeddings.shape}") # Test compression compression = tokenizer.compress(text) print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens") print(f"Ratio: {compression['compression_ratio']:.2f}x") # Test restoration restored, accuracy = tokenizer.restore(text) print(f"Restored: {restored}") print(f"Accuracy: {accuracy:.1%}") if accuracy > 0.7: successful += 1 total_accuracy += accuracy # Summary print("\n" + "="*70) print("TEST SUMMARY") print("="*70) print(f"Tests passed: {successful}/{len(test_samples)}") print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}") if successful == len(test_samples): print("\n✅ ALL TESTS PASSED!") return True elif successful >= len(test_samples) * 0.7: print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)") return True else: print("\n❌ TESTS FAILED") return False if __name__ == "__main__": success = test_model() sys.exit(0 if success else 1)