ggunio's picture
Upload inference.py with huggingface_hub
905c972 verified
"""
B2NL-IntelligentTokenizer v6.2.1 - ์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ์ถ”๋ก  ์ฝ”๋“œ
์ด ํŒŒ์ผ์ด ๋ฉ”์ธ ์‚ฌ์šฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
"""
import torch
import sys
from pathlib import Path
# ๊ฒฝ๋กœ ์ถ”๊ฐ€
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "intelligent-tokenizer_v6.2.1"))
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "intelligent-tokenizer_v6.2.1/core"))
from core.unified_model import IntelligentTokenizerV62
from core.tokenizer import ByteTokenizerV62
class B2NLTokenizer:
"""์‹ค์ œ๋กœ ์ž‘๋™ํ•˜๋Š” B2NL ํ† ํฌ๋‚˜์ด์ €"""
def __init__(self, checkpoint_path: str = None):
"""
Args:
checkpoint_path: ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ (์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ)
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ๊ธฐ๋ณธ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
if checkpoint_path is None:
checkpoint_path = "D:/intelligent-tokenizer/intelligent-tokenizer_v6.2.1/checkpoints/v62/16.0/epoch_100.pt"
# ๋ชจ๋ธ ๋กœ๋“œ
self.model = IntelligentTokenizerV62()
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
print(f"Model loaded successfully on {self.device}")
def compress(self, text: str) -> dict:
"""ํ…์ŠคํŠธ๋ฅผ ์••์ถ•"""
return self.model.compress(text)
def reconstruct(self, text: str, temperature: float = 0.1) -> str:
"""
ํ…์ŠคํŠธ๋ฅผ ์••์ถ• ํ›„ ๋ณต์› (์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ๋ฒ„์ „)
Args:
text: ์ž…๋ ฅ ํ…์ŠคํŠธ
temperature: ์ƒ์„ฑ ์˜จ๋„ (๋‚ฎ์„์ˆ˜๋ก ๊ฒฐ์ •์ )
Returns:
๋ณต์›๋œ ํ…์ŠคํŠธ
"""
# 1. ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
tokenizer = self.model.tokenizer
encoded = tokenizer.encode(text)
if isinstance(encoded, dict):
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
else:
input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
attention_mask = torch.ones_like(input_ids)
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
# 2. ์ธ์ฝ”๋”๋กœ ์••์ถ•
with torch.no_grad():
encoder_outputs = self.model.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# ๋ชจ๋“  ํžˆ๋“  ์Šคํ…Œ์ดํŠธ ์ค€๋น„
if 'all_hidden_states' in encoder_outputs:
encoder_all_hidden = encoder_outputs['all_hidden_states']
else:
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
encoder_all_hidden = [compressed] * 4
# 3. ์ž๋™ํšŒ๊ท€ ๋””์ฝ”๋”ฉ (์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ๋ฐฉ์‹)
batch_size = input_ids.size(0)
max_length = 48
# BOS ํ† ํฐ์œผ๋กœ ์‹œ์ž‘
generated = torch.full((batch_size, 1), tokenizer.BOS, device=self.device)
for step in range(max_length - 1):
with torch.no_grad():
# ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋กœ ๋””์ฝ”๋”ฉ
decoder_outputs = self.model.decoder(
encoder_all_hidden=encoder_all_hidden,
decoder_input_ids=generated,
attention_mask=torch.ones_like(generated),
use_cache=False
)
# ๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก
logits = decoder_outputs['logits'][:, -1, :] / temperature
# Top-k ์ƒ˜ํ”Œ๋ง
top_k = 10
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# ํ™•๋ฅ  ๊ณ„์‚ฐ ๋ฐ ์ƒ˜ํ”Œ๋ง
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# ์ƒ์„ฑ๋œ ์‹œํ€€์Šค์— ์ถ”๊ฐ€
generated = torch.cat([generated, next_token], dim=1)
# EOS ํ† ํฐ ์ฒดํฌ
if (next_token == tokenizer.EOS).all():
break
# 4. ํ…์ŠคํŠธ๋กœ ๋””์ฝ”๋”ฉ
if generated.dim() > 1:
text = tokenizer.decode(generated[0])
else:
text = tokenizer.decode(generated)
return text
def test_tokenizer():
"""ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ"""
print("="*60)
print("B2NL-IntelligentTokenizer v6.2.1 ํ…Œ์ŠคํŠธ")
print("="*60)
# ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
tokenizer = B2NLTokenizer()
# ํ…Œ์ŠคํŠธ ํ…์ŠคํŠธ
test_texts = [
"Hello, world!",
"์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค.",
"The quick brown fox jumps over the lazy dog.",
"ไบบๅทฅๆ™บ่ƒฝๆŠ€ๆœฏๆญฃๅœจๆ”นๅ˜ไธ–็•Œใ€‚",
]
for text in test_texts:
print(f"\n์›๋ณธ: {text}")
# ์••์ถ•
compressed = tokenizer.compress(text)
print(f"์••์ถ•๋ฅ : {compressed['compression_ratio']:.1f}:1 ({compressed['num_tokens']} ํ† ํฐ)")
# ๋ณต์›
reconstructed = tokenizer.reconstruct(text, temperature=0.1)
print(f"๋ณต์›: {reconstructed}")
# ์ •ํ™•๋„ ๊ณ„์‚ฐ
min_len = min(len(text), len(reconstructed))
accuracy = sum(1 for i in range(min_len) if text[i] == reconstructed[i]) / len(text) * 100
print(f"์ •ํ™•๋„: {accuracy:.1f}%")
print("\n" + "="*60)
print("Test completed!")
print("="*60)
# ์‚ฌ์šฉ ์˜ˆ์ œ
def example_usage():
"""๊ฐ„๋‹จํ•œ ์‚ฌ์šฉ ์˜ˆ์ œ"""
# 1. ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
tokenizer = B2NLTokenizer()
# 2. ํ…์ŠคํŠธ ์••์ถ•
text = "์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค!"
compressed = tokenizer.compress(text)
print(f"์••์ถ• ๊ฒฐ๊ณผ: {compressed['compression_ratio']:.1f}:1")
# 3. ํ…์ŠคํŠธ ๋ณต์›
reconstructed = tokenizer.reconstruct(text)
print(f"๋ณต์› ๊ฒฐ๊ณผ: {reconstructed}")
return tokenizer
if __name__ == "__main__":
test_tokenizer()