File size: 9,669 Bytes
318d977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""

Intelligent Tokenizer v6.0 - Inference Module

임베딩과 복원 기능

"""

import torch
import sys
import io
from pathlib import Path
from typing import Dict, List, Optional, Tuple

# UTF-8 인코딩 설정
if sys.stdout.encoding != 'utf-8':
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
    sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')

sys.path.append(str(Path(__file__).parent))

from core.boundary_aware_model import BoundaryAwareTokenizerModel
from src.core.byte_tokenizer_v6 import ByteTokenizerV6


class IntelligentTokenizer:
    """Intelligent Tokenizer for embedding and restoration"""

    def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
        """

        Initialize tokenizer



        Args:

            checkpoint_path: Path to model checkpoint

            device: Device to use ('cuda', 'cpu', or None for auto)

        """
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        print(f"Initializing Intelligent Tokenizer v6.0...")
        print(f"Device: {self.device}")

        # Load checkpoint
        checkpoint_path = Path(checkpoint_path)
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)

        # Initialize model
        self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()

        # Initialize tokenizer
        self.tokenizer = ByteTokenizerV6()
        self.max_chunk_size = 250  # Safe margin for 256 byte limit

        print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
        print(f"Ready for inference!")

    def embed(self, text: str) -> torch.Tensor:
        """

        Convert text to embeddings



        Args:

            text: Input text



        Returns:

            Embedding tensor

        """
        # Handle long text by chunking
        if len(text.encode('utf-8')) > self.max_chunk_size:
            chunks = self._split_text_safely(text)
            embeddings = []

            for chunk in chunks:
                emb = self._embed_single(chunk)
                embeddings.append(emb)

            # Concatenate embeddings
            return torch.cat(embeddings, dim=1)
        else:
            return self._embed_single(text)

    def _embed_single(self, text: str) -> torch.Tensor:
        """Embed single chunk"""
        # Encode text
        encoded = self.tokenizer.encode(text)
        byte_ids = encoded['input_ids']
        input_ids = torch.tensor([byte_ids], device=self.device)
        attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)

        with torch.no_grad():
            # Get embeddings
            encoder_outputs = self.model.encoder(input_ids, attention_mask)
            embeddings = encoder_outputs['last_hidden_state']

        return embeddings

    def restore(self, text: str) -> Tuple[str, float]:
        """

        Test restoration capability



        Args:

            text: Input text



        Returns:

            Tuple of (restored_text, accuracy)

        """
        # Handle long text
        if len(text.encode('utf-8')) > self.max_chunk_size:
            chunks = self._split_text_safely(text)
            restored_chunks = []
            accuracies = []

            for chunk in chunks:
                restored, acc = self._restore_single(chunk)
                restored_chunks.append(restored)
                accuracies.append(acc)

            return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
        else:
            return self._restore_single(text)

    def _restore_single(self, text: str) -> Tuple[str, float]:
        """Restore single chunk"""
        # Encode text
        encoded = self.tokenizer.encode(text)
        byte_ids = encoded['input_ids']

        if len(byte_ids) <= 1:
            return text, 1.0

        input_ids = torch.tensor([byte_ids], device=self.device)
        attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)

        with torch.no_grad():
            # Teacher forcing for restoration test
            decoder_input = input_ids[:, :-1]
            labels = input_ids[:, 1:]

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input,
                labels=labels,
                use_cross_attention=True
            )

            # Get predictions
            predictions = torch.argmax(outputs['logits'], dim=-1)
            accuracy = (predictions == labels).float().mean().item()

            # Decode predictions
            try:
                # Remove special tokens and convert to bytes
                pred_list = predictions[0].cpu().tolist()
                # Add BOS at beginning for full sequence
                full_sequence = [self.tokenizer.BOS] + pred_list

                # Filter valid bytes
                filtered = [b for b in full_sequence if 0 <= b < 256]
                if filtered:
                    restored_bytes = bytes(filtered)
                    restored_text = restored_bytes.decode('utf-8', errors='ignore')
                else:
                    restored_text = ""
            except Exception as e:
                print(f"Restoration error: {e}")
                restored_text = ""

        return restored_text, accuracy

    def compress(self, text: str) -> Dict:
        """

        Get compression statistics



        Args:

            text: Input text



        Returns:

            Dict with compression info

        """
        text_bytes = text.encode('utf-8')
        embeddings = self.embed(text)

        original_size = len(text_bytes)
        compressed_size = embeddings.shape[1]
        compression_ratio = original_size / compressed_size if compressed_size > 0 else 0

        return {
            'original_bytes': original_size,
            'compressed_tokens': compressed_size,
            'compression_ratio': compression_ratio,
            'embedding_shape': list(embeddings.shape)
        }

    def _split_text_safely(self, text: str) -> List[str]:
        """Split text safely at UTF-8 boundaries"""
        chunks = []
        text_bytes = text.encode('utf-8')

        start = 0
        while start < len(text_bytes):
            end = min(start + self.max_chunk_size, len(text_bytes))

            # Find valid UTF-8 boundary
            while end > start and end < len(text_bytes):
                try:
                    chunk = text_bytes[start:end].decode('utf-8')
                    break
                except UnicodeDecodeError:
                    end -= 1

            if end > start:
                chunk = text_bytes[start:end].decode('utf-8')
                chunks.append(chunk)
                start = end
            else:
                break

        return chunks


def test_model():
    """Test model functionality"""
    print("="*70)
    print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
    print("="*70)

    # Initialize tokenizer
    tokenizer = IntelligentTokenizer()

    # Test samples
    test_samples = [
        ("English", "Hello, world!"),
        ("Korean", "안녕하세요. 반갑습니다."),
        ("Chinese", "今天天气很好"),
        ("Japanese", "こんにちは"),
        ("Arabic", "مرحبا بك"),
        ("Russian", "Привет, как дела?"),
        ("Emoji", "Hello 👋 World 🌍!"),
    ]

    print("\n" + "="*70)
    print("EMBEDDING & RESTORATION TESTS")
    print("="*70)

    total_accuracy = 0
    successful = 0

    for lang, text in test_samples:
        print(f"\n[{lang}]")
        print(f"Original: {text}")

        # Test embedding
        embeddings = tokenizer.embed(text)
        print(f"Embedding: {embeddings.shape}")

        # Test compression
        compression = tokenizer.compress(text)
        print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens")
        print(f"Ratio: {compression['compression_ratio']:.2f}x")

        # Test restoration
        restored, accuracy = tokenizer.restore(text)
        print(f"Restored: {restored}")
        print(f"Accuracy: {accuracy:.1%}")

        if accuracy > 0.7:
            successful += 1
        total_accuracy += accuracy

    # Summary
    print("\n" + "="*70)
    print("TEST SUMMARY")
    print("="*70)
    print(f"Tests passed: {successful}/{len(test_samples)}")
    print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")

    if successful == len(test_samples):
        print("\n✅ ALL TESTS PASSED!")
        return True
    elif successful >= len(test_samples) * 0.7:
        print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)")
        return True
    else:
        print("\n❌ TESTS FAILED")
        return False


if __name__ == "__main__":
    success = test_model()
    sys.exit(0 if success else 1)