ggunio's picture
Upload inference.py with huggingface_hub
318d977 verified
raw
history blame
9.67 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Intelligent Tokenizer v6.0 - Inference Module
임베딩과 복원 기능
"""
import torch
import sys
import io
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# UTF-8 인코딩 설정
if sys.stdout.encoding != 'utf-8':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
sys.path.append(str(Path(__file__).parent))
from core.boundary_aware_model import BoundaryAwareTokenizerModel
from src.core.byte_tokenizer_v6 import ByteTokenizerV6
class IntelligentTokenizer:
"""Intelligent Tokenizer for embedding and restoration"""
def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
"""
Initialize tokenizer
Args:
checkpoint_path: Path to model checkpoint
device: Device to use ('cuda', 'cpu', or None for auto)
"""
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"Initializing Intelligent Tokenizer v6.0...")
print(f"Device: {self.device}")
# Load checkpoint
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Initialize model
self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
# Initialize tokenizer
self.tokenizer = ByteTokenizerV6()
self.max_chunk_size = 250 # Safe margin for 256 byte limit
print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
print(f"Ready for inference!")
def embed(self, text: str) -> torch.Tensor:
"""
Convert text to embeddings
Args:
text: Input text
Returns:
Embedding tensor
"""
# Handle long text by chunking
if len(text.encode('utf-8')) > self.max_chunk_size:
chunks = self._split_text_safely(text)
embeddings = []
for chunk in chunks:
emb = self._embed_single(chunk)
embeddings.append(emb)
# Concatenate embeddings
return torch.cat(embeddings, dim=1)
else:
return self._embed_single(text)
def _embed_single(self, text: str) -> torch.Tensor:
"""Embed single chunk"""
# Encode text
encoded = self.tokenizer.encode(text)
byte_ids = encoded['input_ids']
input_ids = torch.tensor([byte_ids], device=self.device)
attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
with torch.no_grad():
# Get embeddings
encoder_outputs = self.model.encoder(input_ids, attention_mask)
embeddings = encoder_outputs['last_hidden_state']
return embeddings
def restore(self, text: str) -> Tuple[str, float]:
"""
Test restoration capability
Args:
text: Input text
Returns:
Tuple of (restored_text, accuracy)
"""
# Handle long text
if len(text.encode('utf-8')) > self.max_chunk_size:
chunks = self._split_text_safely(text)
restored_chunks = []
accuracies = []
for chunk in chunks:
restored, acc = self._restore_single(chunk)
restored_chunks.append(restored)
accuracies.append(acc)
return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
else:
return self._restore_single(text)
def _restore_single(self, text: str) -> Tuple[str, float]:
"""Restore single chunk"""
# Encode text
encoded = self.tokenizer.encode(text)
byte_ids = encoded['input_ids']
if len(byte_ids) <= 1:
return text, 1.0
input_ids = torch.tensor([byte_ids], device=self.device)
attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
with torch.no_grad():
# Teacher forcing for restoration test
decoder_input = input_ids[:, :-1]
labels = input_ids[:, 1:]
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input,
labels=labels,
use_cross_attention=True
)
# Get predictions
predictions = torch.argmax(outputs['logits'], dim=-1)
accuracy = (predictions == labels).float().mean().item()
# Decode predictions
try:
# Remove special tokens and convert to bytes
pred_list = predictions[0].cpu().tolist()
# Add BOS at beginning for full sequence
full_sequence = [self.tokenizer.BOS] + pred_list
# Filter valid bytes
filtered = [b for b in full_sequence if 0 <= b < 256]
if filtered:
restored_bytes = bytes(filtered)
restored_text = restored_bytes.decode('utf-8', errors='ignore')
else:
restored_text = ""
except Exception as e:
print(f"Restoration error: {e}")
restored_text = ""
return restored_text, accuracy
def compress(self, text: str) -> Dict:
"""
Get compression statistics
Args:
text: Input text
Returns:
Dict with compression info
"""
text_bytes = text.encode('utf-8')
embeddings = self.embed(text)
original_size = len(text_bytes)
compressed_size = embeddings.shape[1]
compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
return {
'original_bytes': original_size,
'compressed_tokens': compressed_size,
'compression_ratio': compression_ratio,
'embedding_shape': list(embeddings.shape)
}
def _split_text_safely(self, text: str) -> List[str]:
"""Split text safely at UTF-8 boundaries"""
chunks = []
text_bytes = text.encode('utf-8')
start = 0
while start < len(text_bytes):
end = min(start + self.max_chunk_size, len(text_bytes))
# Find valid UTF-8 boundary
while end > start and end < len(text_bytes):
try:
chunk = text_bytes[start:end].decode('utf-8')
break
except UnicodeDecodeError:
end -= 1
if end > start:
chunk = text_bytes[start:end].decode('utf-8')
chunks.append(chunk)
start = end
else:
break
return chunks
def test_model():
"""Test model functionality"""
print("="*70)
print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
print("="*70)
# Initialize tokenizer
tokenizer = IntelligentTokenizer()
# Test samples
test_samples = [
("English", "Hello, world!"),
("Korean", "안녕하세요. 반갑습니다."),
("Chinese", "今天天气很好"),
("Japanese", "こんにちは"),
("Arabic", "مرحبا بك"),
("Russian", "Привет, как дела?"),
("Emoji", "Hello 👋 World 🌍!"),
]
print("\n" + "="*70)
print("EMBEDDING & RESTORATION TESTS")
print("="*70)
total_accuracy = 0
successful = 0
for lang, text in test_samples:
print(f"\n[{lang}]")
print(f"Original: {text}")
# Test embedding
embeddings = tokenizer.embed(text)
print(f"Embedding: {embeddings.shape}")
# Test compression
compression = tokenizer.compress(text)
print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens")
print(f"Ratio: {compression['compression_ratio']:.2f}x")
# Test restoration
restored, accuracy = tokenizer.restore(text)
print(f"Restored: {restored}")
print(f"Accuracy: {accuracy:.1%}")
if accuracy > 0.7:
successful += 1
total_accuracy += accuracy
# Summary
print("\n" + "="*70)
print("TEST SUMMARY")
print("="*70)
print(f"Tests passed: {successful}/{len(test_samples)}")
print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")
if successful == len(test_samples):
print("\n✅ ALL TESTS PASSED!")
return True
elif successful >= len(test_samples) * 0.7:
print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)")
return True
else:
print("\n❌ TESTS FAILED")
return False
if __name__ == "__main__":
success = test_model()
sys.exit(0 if success else 1)