ggunio commited on
Commit
318d977
·
verified ·
1 Parent(s): 51bc354

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +297 -0
inference.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Intelligent Tokenizer v6.0 - Inference Module
5
+ 임베딩과 복원 기능
6
+ """
7
+
8
+ import torch
9
+ import sys
10
+ import io
11
+ from pathlib import Path
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+ # UTF-8 인코딩 설정
15
+ if sys.stdout.encoding != 'utf-8':
16
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
17
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
18
+
19
+ sys.path.append(str(Path(__file__).parent))
20
+
21
+ from core.boundary_aware_model import BoundaryAwareTokenizerModel
22
+ from src.core.byte_tokenizer_v6 import ByteTokenizerV6
23
+
24
+
25
+ class IntelligentTokenizer:
26
+ """Intelligent Tokenizer for embedding and restoration"""
27
+
28
+ def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
29
+ """
30
+ Initialize tokenizer
31
+
32
+ Args:
33
+ checkpoint_path: Path to model checkpoint
34
+ device: Device to use ('cuda', 'cpu', or None for auto)
35
+ """
36
+ if device is None:
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+ else:
39
+ self.device = torch.device(device)
40
+
41
+ print(f"Initializing Intelligent Tokenizer v6.0...")
42
+ print(f"Device: {self.device}")
43
+
44
+ # Load checkpoint
45
+ checkpoint_path = Path(checkpoint_path)
46
+ if not checkpoint_path.exists():
47
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
48
+
49
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
50
+
51
+ # Initialize model
52
+ self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
53
+ self.model.load_state_dict(checkpoint['model_state_dict'])
54
+ self.model = self.model.to(self.device)
55
+ self.model.eval()
56
+
57
+ # Initialize tokenizer
58
+ self.tokenizer = ByteTokenizerV6()
59
+ self.max_chunk_size = 250 # Safe margin for 256 byte limit
60
+
61
+ print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
62
+ print(f"Ready for inference!")
63
+
64
+ def embed(self, text: str) -> torch.Tensor:
65
+ """
66
+ Convert text to embeddings
67
+
68
+ Args:
69
+ text: Input text
70
+
71
+ Returns:
72
+ Embedding tensor
73
+ """
74
+ # Handle long text by chunking
75
+ if len(text.encode('utf-8')) > self.max_chunk_size:
76
+ chunks = self._split_text_safely(text)
77
+ embeddings = []
78
+
79
+ for chunk in chunks:
80
+ emb = self._embed_single(chunk)
81
+ embeddings.append(emb)
82
+
83
+ # Concatenate embeddings
84
+ return torch.cat(embeddings, dim=1)
85
+ else:
86
+ return self._embed_single(text)
87
+
88
+ def _embed_single(self, text: str) -> torch.Tensor:
89
+ """Embed single chunk"""
90
+ # Encode text
91
+ encoded = self.tokenizer.encode(text)
92
+ byte_ids = encoded['input_ids']
93
+ input_ids = torch.tensor([byte_ids], device=self.device)
94
+ attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
95
+
96
+ with torch.no_grad():
97
+ # Get embeddings
98
+ encoder_outputs = self.model.encoder(input_ids, attention_mask)
99
+ embeddings = encoder_outputs['last_hidden_state']
100
+
101
+ return embeddings
102
+
103
+ def restore(self, text: str) -> Tuple[str, float]:
104
+ """
105
+ Test restoration capability
106
+
107
+ Args:
108
+ text: Input text
109
+
110
+ Returns:
111
+ Tuple of (restored_text, accuracy)
112
+ """
113
+ # Handle long text
114
+ if len(text.encode('utf-8')) > self.max_chunk_size:
115
+ chunks = self._split_text_safely(text)
116
+ restored_chunks = []
117
+ accuracies = []
118
+
119
+ for chunk in chunks:
120
+ restored, acc = self._restore_single(chunk)
121
+ restored_chunks.append(restored)
122
+ accuracies.append(acc)
123
+
124
+ return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
125
+ else:
126
+ return self._restore_single(text)
127
+
128
+ def _restore_single(self, text: str) -> Tuple[str, float]:
129
+ """Restore single chunk"""
130
+ # Encode text
131
+ encoded = self.tokenizer.encode(text)
132
+ byte_ids = encoded['input_ids']
133
+
134
+ if len(byte_ids) <= 1:
135
+ return text, 1.0
136
+
137
+ input_ids = torch.tensor([byte_ids], device=self.device)
138
+ attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
139
+
140
+ with torch.no_grad():
141
+ # Teacher forcing for restoration test
142
+ decoder_input = input_ids[:, :-1]
143
+ labels = input_ids[:, 1:]
144
+
145
+ outputs = self.model(
146
+ input_ids=input_ids,
147
+ attention_mask=attention_mask,
148
+ decoder_input_ids=decoder_input,
149
+ labels=labels,
150
+ use_cross_attention=True
151
+ )
152
+
153
+ # Get predictions
154
+ predictions = torch.argmax(outputs['logits'], dim=-1)
155
+ accuracy = (predictions == labels).float().mean().item()
156
+
157
+ # Decode predictions
158
+ try:
159
+ # Remove special tokens and convert to bytes
160
+ pred_list = predictions[0].cpu().tolist()
161
+ # Add BOS at beginning for full sequence
162
+ full_sequence = [self.tokenizer.BOS] + pred_list
163
+
164
+ # Filter valid bytes
165
+ filtered = [b for b in full_sequence if 0 <= b < 256]
166
+ if filtered:
167
+ restored_bytes = bytes(filtered)
168
+ restored_text = restored_bytes.decode('utf-8', errors='ignore')
169
+ else:
170
+ restored_text = ""
171
+ except Exception as e:
172
+ print(f"Restoration error: {e}")
173
+ restored_text = ""
174
+
175
+ return restored_text, accuracy
176
+
177
+ def compress(self, text: str) -> Dict:
178
+ """
179
+ Get compression statistics
180
+
181
+ Args:
182
+ text: Input text
183
+
184
+ Returns:
185
+ Dict with compression info
186
+ """
187
+ text_bytes = text.encode('utf-8')
188
+ embeddings = self.embed(text)
189
+
190
+ original_size = len(text_bytes)
191
+ compressed_size = embeddings.shape[1]
192
+ compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
193
+
194
+ return {
195
+ 'original_bytes': original_size,
196
+ 'compressed_tokens': compressed_size,
197
+ 'compression_ratio': compression_ratio,
198
+ 'embedding_shape': list(embeddings.shape)
199
+ }
200
+
201
+ def _split_text_safely(self, text: str) -> List[str]:
202
+ """Split text safely at UTF-8 boundaries"""
203
+ chunks = []
204
+ text_bytes = text.encode('utf-8')
205
+
206
+ start = 0
207
+ while start < len(text_bytes):
208
+ end = min(start + self.max_chunk_size, len(text_bytes))
209
+
210
+ # Find valid UTF-8 boundary
211
+ while end > start and end < len(text_bytes):
212
+ try:
213
+ chunk = text_bytes[start:end].decode('utf-8')
214
+ break
215
+ except UnicodeDecodeError:
216
+ end -= 1
217
+
218
+ if end > start:
219
+ chunk = text_bytes[start:end].decode('utf-8')
220
+ chunks.append(chunk)
221
+ start = end
222
+ else:
223
+ break
224
+
225
+ return chunks
226
+
227
+
228
+ def test_model():
229
+ """Test model functionality"""
230
+ print("="*70)
231
+ print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
232
+ print("="*70)
233
+
234
+ # Initialize tokenizer
235
+ tokenizer = IntelligentTokenizer()
236
+
237
+ # Test samples
238
+ test_samples = [
239
+ ("English", "Hello, world!"),
240
+ ("Korean", "안녕하세요. 반갑습니다."),
241
+ ("Chinese", "今天天气很好"),
242
+ ("Japanese", "こんにちは"),
243
+ ("Arabic", "مرحبا بك"),
244
+ ("Russian", "Привет, как дела?"),
245
+ ("Emoji", "Hello 👋 World 🌍!"),
246
+ ]
247
+
248
+ print("\n" + "="*70)
249
+ print("EMBEDDING & RESTORATION TESTS")
250
+ print("="*70)
251
+
252
+ total_accuracy = 0
253
+ successful = 0
254
+
255
+ for lang, text in test_samples:
256
+ print(f"\n[{lang}]")
257
+ print(f"Original: {text}")
258
+
259
+ # Test embedding
260
+ embeddings = tokenizer.embed(text)
261
+ print(f"Embedding: {embeddings.shape}")
262
+
263
+ # Test compression
264
+ compression = tokenizer.compress(text)
265
+ print(f"Compression: {compression['original_bytes']} bytes → {compression['compressed_tokens']} tokens")
266
+ print(f"Ratio: {compression['compression_ratio']:.2f}x")
267
+
268
+ # Test restoration
269
+ restored, accuracy = tokenizer.restore(text)
270
+ print(f"Restored: {restored}")
271
+ print(f"Accuracy: {accuracy:.1%}")
272
+
273
+ if accuracy > 0.7:
274
+ successful += 1
275
+ total_accuracy += accuracy
276
+
277
+ # Summary
278
+ print("\n" + "="*70)
279
+ print("TEST SUMMARY")
280
+ print("="*70)
281
+ print(f"Tests passed: {successful}/{len(test_samples)}")
282
+ print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")
283
+
284
+ if successful == len(test_samples):
285
+ print("\n✅ ALL TESTS PASSED!")
286
+ return True
287
+ elif successful >= len(test_samples) * 0.7:
288
+ print("\n⚠️ PARTIAL SUCCESS (70%+ tests passed)")
289
+ return True
290
+ else:
291
+ print("\n❌ TESTS FAILED")
292
+ return False
293
+
294
+
295
+ if __name__ == "__main__":
296
+ success = test_model()
297
+ sys.exit(0 if success else 1)