|
|
|
|
|
|
|
|
"""
|
|
|
POC ๋ฐ๋ชจ ์คํฌ๋ฆฝํธ - ๊ธด ํ
์คํธ ์๋ ๋ถํ ์ฒ๋ฆฌ
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import sys
|
|
|
import io
|
|
|
from pathlib import Path
|
|
|
import time
|
|
|
|
|
|
|
|
|
if sys.stdout.encoding != 'utf-8':
|
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent))
|
|
|
|
|
|
from core.boundary_aware_model import BoundaryAwareTokenizerModel
|
|
|
from src.core.byte_tokenizer_v6 import ByteTokenizerV6
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
class IntelligentTokenizerPOC:
|
|
|
"""POC ๋ฐ๋ชจ์ฉ ํด๋์ค"""
|
|
|
|
|
|
def __init__(self, checkpoint_path="checkpoints/unified/latest_checkpoint.pt"):
|
|
|
print("="*70)
|
|
|
print("INTELLIGENT TOKENIZER v6.0 - POC Demo")
|
|
|
print("="*70)
|
|
|
print(f"Device: {device}")
|
|
|
print(f"Loading checkpoint...")
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
self.model = self.model.to(device)
|
|
|
self.model.eval()
|
|
|
|
|
|
self.tokenizer = ByteTokenizerV6()
|
|
|
self.max_chunk_size = 250
|
|
|
|
|
|
print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
|
|
|
print(f"Current limitation: 256 bytes per chunk")
|
|
|
print(f"(Due to POC development constraints and limited GPU resources)")
|
|
|
print("="*70)
|
|
|
print()
|
|
|
|
|
|
def process_text(self, text: str, show_details=True):
|
|
|
"""ํ
์คํธ ์ฒ๋ฆฌ (์๋ ๋ถํ )"""
|
|
|
|
|
|
|
|
|
text_bytes = text.encode('utf-8')
|
|
|
total_bytes = len(text_bytes)
|
|
|
|
|
|
if show_details:
|
|
|
print(f"Input text: {text[:100]}..." if len(text) > 100 else f"Input text: {text}")
|
|
|
print(f"Total bytes: {total_bytes}")
|
|
|
|
|
|
|
|
|
if total_bytes > self.max_chunk_size:
|
|
|
chunks = self._split_text_safely(text)
|
|
|
if show_details:
|
|
|
print(f"Auto-splitting into {len(chunks)} chunks (256 byte limit for POC)")
|
|
|
print("Note: Production version will handle up to 4096+ bytes")
|
|
|
print("-"*50)
|
|
|
|
|
|
results = []
|
|
|
total_compressed = 0
|
|
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
|
if show_details:
|
|
|
print(f"\nChunk {i+1}/{len(chunks)}:")
|
|
|
result = self._process_single_chunk(chunk, show_details)
|
|
|
results.append(result)
|
|
|
total_compressed += result['compressed_tokens']
|
|
|
|
|
|
|
|
|
if show_details:
|
|
|
print("\n" + "="*50)
|
|
|
print("OVERALL RESULTS:")
|
|
|
print(f"Total input: {total_bytes} bytes")
|
|
|
print(f"Total compressed: {total_compressed} tokens")
|
|
|
print(f"Compression ratio: {total_bytes/total_compressed:.2f}x")
|
|
|
print(f"Average accuracy: {sum(r['accuracy'] for r in results)/len(results):.1%}")
|
|
|
|
|
|
return results
|
|
|
|
|
|
else:
|
|
|
|
|
|
return self._process_single_chunk(text, show_details)
|
|
|
|
|
|
def _split_text_safely(self, text: str):
|
|
|
"""UTF-8 ๊ฒฝ๊ณ๋ฅผ ๊ณ ๋ คํ ์์ ํ ํ
์คํธ ๋ถํ """
|
|
|
chunks = []
|
|
|
text_bytes = text.encode('utf-8')
|
|
|
|
|
|
start = 0
|
|
|
while start < len(text_bytes):
|
|
|
|
|
|
end = min(start + self.max_chunk_size, len(text_bytes))
|
|
|
|
|
|
|
|
|
while end > start and end < len(text_bytes):
|
|
|
try:
|
|
|
|
|
|
chunk = text_bytes[start:end].decode('utf-8')
|
|
|
break
|
|
|
except UnicodeDecodeError:
|
|
|
|
|
|
end -= 1
|
|
|
|
|
|
if end > start:
|
|
|
chunk = text_bytes[start:end].decode('utf-8')
|
|
|
chunks.append(chunk)
|
|
|
start = end
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
def _process_single_chunk(self, text: str, show_details=True):
|
|
|
"""๋จ์ผ ์ฒญํฌ ์ฒ๋ฆฌ"""
|
|
|
|
|
|
|
|
|
encoded = self.tokenizer.encode(text)
|
|
|
byte_ids = encoded['input_ids']
|
|
|
input_ids = torch.tensor([byte_ids], device=device)
|
|
|
attention_mask = torch.tensor([encoded['attention_mask']], device=device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
start_time = time.time()
|
|
|
encoder_outputs = self.model.encoder(input_ids, attention_mask)
|
|
|
encoder_hidden = encoder_outputs['last_hidden_state']
|
|
|
compression_time = time.time() - start_time
|
|
|
|
|
|
compressed_tokens = encoder_hidden.shape[1]
|
|
|
compression_ratio = len(byte_ids) / compressed_tokens
|
|
|
|
|
|
|
|
|
if len(byte_ids) > 1:
|
|
|
decoder_input = input_ids[:, :-1]
|
|
|
labels = input_ids[:, 1:]
|
|
|
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
decoder_input_ids=decoder_input,
|
|
|
labels=labels,
|
|
|
use_cross_attention=True
|
|
|
)
|
|
|
|
|
|
predictions = torch.argmax(outputs['logits'], dim=-1)
|
|
|
accuracy = (predictions == labels).float().mean().item()
|
|
|
else:
|
|
|
accuracy = 1.0
|
|
|
|
|
|
if show_details:
|
|
|
print(f" Input: {len(byte_ids)} bytes")
|
|
|
print(f" Compressed: {compressed_tokens} tokens ({compression_ratio:.2f}x)")
|
|
|
print(f" Accuracy: {accuracy:.1%}")
|
|
|
print(f" Processing time: {compression_time*1000:.1f}ms")
|
|
|
|
|
|
return {
|
|
|
'text': text,
|
|
|
'input_bytes': len(byte_ids),
|
|
|
'compressed_tokens': compressed_tokens,
|
|
|
'compression_ratio': compression_ratio,
|
|
|
'accuracy': accuracy,
|
|
|
'time_ms': compression_time * 1000
|
|
|
}
|
|
|
|
|
|
def benchmark_languages(self):
|
|
|
"""๋ค๊ตญ์ด ๋ฒค์น๋งํฌ"""
|
|
|
print("\n" + "="*70)
|
|
|
print("MULTILINGUAL BENCHMARK")
|
|
|
print("="*70)
|
|
|
|
|
|
test_samples = {
|
|
|
'English': "The quick brown fox jumps over the lazy dog",
|
|
|
'Korean': "์๋
ํ์ธ์. ์ค๋ ๋ ์จ๊ฐ ์ ๋ง ์ข๋ค์",
|
|
|
'Chinese': "ไปๅคฉๅคฉๆฐๅพๅฅฝ",
|
|
|
'Japanese': "ใใใซใกใฏ",
|
|
|
'Spanish': "Hola, ยฟcรณmo estรกs?",
|
|
|
'Arabic': "ู
ุฑุญุจุง ุจู",
|
|
|
'Russian': "ะัะธะฒะตั, ะบะฐะบ ะดะตะปะฐ?",
|
|
|
}
|
|
|
|
|
|
for lang, text in test_samples.items():
|
|
|
print(f"\n{lang}:")
|
|
|
self._process_single_chunk(text, show_details=True)
|
|
|
|
|
|
def explain_advantages(self):
|
|
|
"""์ฅ์ ์ค๋ช
"""
|
|
|
print("\n" + "="*70)
|
|
|
print("KEY ADVANTAGES")
|
|
|
print("="*70)
|
|
|
print("""
|
|
|
1. PURE LEARNING-BASED
|
|
|
- No vocabulary files (260 fixed bytes vs 50K+ tokens)
|
|
|
- No language-specific rules
|
|
|
- Learns compression patterns from data
|
|
|
|
|
|
2. MULTILINGUAL EQUALITY
|
|
|
- All 204 languages treated equally
|
|
|
- No vocabulary bias towards English
|
|
|
- Better for low-resource languages
|
|
|
|
|
|
3. COMPRESSION CAPABILITY
|
|
|
- Current: 2-3x compression (POC stage)
|
|
|
- Target: 5-10x compression (with more training)
|
|
|
- API cost reduction: 50-80%
|
|
|
|
|
|
4. CURRENT LIMITATIONS (POC)
|
|
|
- 256 byte chunks (due to limited GPU resources)
|
|
|
- Will expand to 4096+ bytes post-POC
|
|
|
- Training on personal RTX 3060 (4 months development)
|
|
|
|
|
|
5. FUTURE ROADMAP
|
|
|
- Multimodal support (text + image + audio)
|
|
|
- Dynamic compression levels
|
|
|
- Real-time streaming mode
|
|
|
""")
|
|
|
print("="*70)
|
|
|
|
|
|
def main():
|
|
|
"""๋ฉ์ธ ๋ฐ๋ชจ"""
|
|
|
poc = IntelligentTokenizerPOC()
|
|
|
|
|
|
|
|
|
print("\n### SHORT TEXT DEMO ###")
|
|
|
poc.process_text("Hello, world!")
|
|
|
poc.process_text("์๋
ํ์ธ์. ๋ฐ๊ฐ์ต๋๋ค.")
|
|
|
|
|
|
|
|
|
print("\n### LONG TEXT AUTO-SPLIT DEMO ###")
|
|
|
long_text = """
|
|
|
์ธ๊ณต์ง๋ฅ ๊ธฐ์ ์ด ๋น ๋ฅด๊ฒ ๋ฐ์ ํ๊ณ ์์ต๋๋ค. ํนํ ์์ฐ์ด ์ฒ๋ฆฌ ๋ถ์ผ์์
|
|
|
๋๋ผ์ด ์ฑ๊ณผ๋ฅผ ๋ณด์ด๊ณ ์์ผ๋ฉฐ, ์ด๋ ์ฐ๋ฆฌ์ ์ผ์์ํ์๋ ํฐ ์ํฅ์
|
|
|
๋ฏธ์น๊ณ ์์ต๋๋ค. ์์ผ๋ก ๋ ๋ง์ ํ์ ์ด ๊ธฐ๋๋ฉ๋๋ค.
|
|
|
|
|
|
The development of artificial intelligence is accelerating rapidly.
|
|
|
Natural language processing, in particular, has shown remarkable progress,
|
|
|
significantly impacting our daily lives. We can expect even more innovations
|
|
|
in the near future.
|
|
|
"""
|
|
|
poc.process_text(long_text)
|
|
|
|
|
|
|
|
|
poc.benchmark_languages()
|
|
|
|
|
|
|
|
|
poc.explain_advantages()
|
|
|
|
|
|
print("\n" + "="*70)
|
|
|
print("POC DEMO COMPLETE")
|
|
|
print("Developed in 4 months by a solo developer with no prior AI experience")
|
|
|
print("Contact: [your contact info]")
|
|
|
print("="*70)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |