ggunio commited on
Commit
905c972
ยท
verified ยท
1 Parent(s): 4e3eeae

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +186 -297
inference.py CHANGED
@@ -1,297 +1,186 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- """
4
- Intelligent Tokenizer v6.0 - Inference Module
5
- ์ž„๋ฒ ๋”ฉ๊ณผ ๋ณต์› ๊ธฐ๋Šฅ
6
- """
7
-
8
- import torch
9
- import sys
10
- import io
11
- from pathlib import Path
12
- from typing import Dict, List, Optional, Tuple
13
-
14
- # UTF-8 ์ธ์ฝ”๋”ฉ ์„ค์ •
15
- if sys.stdout.encoding != 'utf-8':
16
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
17
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
18
-
19
- sys.path.append(str(Path(__file__).parent))
20
-
21
- from core.boundary_aware_model import BoundaryAwareTokenizerModel
22
- from src.core.byte_tokenizer_v6 import ByteTokenizerV6
23
-
24
-
25
- class IntelligentTokenizer:
26
- """Intelligent Tokenizer for embedding and restoration"""
27
-
28
- def __init__(self, checkpoint_path: str = "checkpoints/latest_checkpoint.pt", device: str = None):
29
- """
30
- Initialize tokenizer
31
-
32
- Args:
33
- checkpoint_path: Path to model checkpoint
34
- device: Device to use ('cuda', 'cpu', or None for auto)
35
- """
36
- if device is None:
37
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
- else:
39
- self.device = torch.device(device)
40
-
41
- print(f"Initializing Intelligent Tokenizer v6.0...")
42
- print(f"Device: {self.device}")
43
-
44
- # Load checkpoint
45
- checkpoint_path = Path(checkpoint_path)
46
- if not checkpoint_path.exists():
47
- raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
48
-
49
- checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
50
-
51
- # Initialize model
52
- self.model = BoundaryAwareTokenizerModel(**checkpoint['model_config'])
53
- self.model.load_state_dict(checkpoint['model_state_dict'])
54
- self.model = self.model.to(self.device)
55
- self.model.eval()
56
-
57
- # Initialize tokenizer
58
- self.tokenizer = ByteTokenizerV6()
59
- self.max_chunk_size = 250 # Safe margin for 256 byte limit
60
-
61
- print(f"Model loaded: Epoch {checkpoint['epoch']}, Loss {checkpoint['loss']:.4f}")
62
- print(f"Ready for inference!")
63
-
64
- def embed(self, text: str) -> torch.Tensor:
65
- """
66
- Convert text to embeddings
67
-
68
- Args:
69
- text: Input text
70
-
71
- Returns:
72
- Embedding tensor
73
- """
74
- # Handle long text by chunking
75
- if len(text.encode('utf-8')) > self.max_chunk_size:
76
- chunks = self._split_text_safely(text)
77
- embeddings = []
78
-
79
- for chunk in chunks:
80
- emb = self._embed_single(chunk)
81
- embeddings.append(emb)
82
-
83
- # Concatenate embeddings
84
- return torch.cat(embeddings, dim=1)
85
- else:
86
- return self._embed_single(text)
87
-
88
- def _embed_single(self, text: str) -> torch.Tensor:
89
- """Embed single chunk"""
90
- # Encode text
91
- encoded = self.tokenizer.encode(text)
92
- byte_ids = encoded['input_ids']
93
- input_ids = torch.tensor([byte_ids], device=self.device)
94
- attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
95
-
96
- with torch.no_grad():
97
- # Get embeddings
98
- encoder_outputs = self.model.encoder(input_ids, attention_mask)
99
- embeddings = encoder_outputs['last_hidden_state']
100
-
101
- return embeddings
102
-
103
- def restore(self, text: str) -> Tuple[str, float]:
104
- """
105
- Test restoration capability
106
-
107
- Args:
108
- text: Input text
109
-
110
- Returns:
111
- Tuple of (restored_text, accuracy)
112
- """
113
- # Handle long text
114
- if len(text.encode('utf-8')) > self.max_chunk_size:
115
- chunks = self._split_text_safely(text)
116
- restored_chunks = []
117
- accuracies = []
118
-
119
- for chunk in chunks:
120
- restored, acc = self._restore_single(chunk)
121
- restored_chunks.append(restored)
122
- accuracies.append(acc)
123
-
124
- return ''.join(restored_chunks), sum(accuracies) / len(accuracies)
125
- else:
126
- return self._restore_single(text)
127
-
128
- def _restore_single(self, text: str) -> Tuple[str, float]:
129
- """Restore single chunk"""
130
- # Encode text
131
- encoded = self.tokenizer.encode(text)
132
- byte_ids = encoded['input_ids']
133
-
134
- if len(byte_ids) <= 1:
135
- return text, 1.0
136
-
137
- input_ids = torch.tensor([byte_ids], device=self.device)
138
- attention_mask = torch.tensor([encoded['attention_mask']], device=self.device)
139
-
140
- with torch.no_grad():
141
- # Teacher forcing for restoration test
142
- decoder_input = input_ids[:, :-1]
143
- labels = input_ids[:, 1:]
144
-
145
- outputs = self.model(
146
- input_ids=input_ids,
147
- attention_mask=attention_mask,
148
- decoder_input_ids=decoder_input,
149
- labels=labels,
150
- use_cross_attention=True
151
- )
152
-
153
- # Get predictions
154
- predictions = torch.argmax(outputs['logits'], dim=-1)
155
- accuracy = (predictions == labels).float().mean().item()
156
-
157
- # Decode predictions
158
- try:
159
- # Remove special tokens and convert to bytes
160
- pred_list = predictions[0].cpu().tolist()
161
- # Add BOS at beginning for full sequence
162
- full_sequence = [self.tokenizer.BOS] + pred_list
163
-
164
- # Filter valid bytes
165
- filtered = [b for b in full_sequence if 0 <= b < 256]
166
- if filtered:
167
- restored_bytes = bytes(filtered)
168
- restored_text = restored_bytes.decode('utf-8', errors='ignore')
169
- else:
170
- restored_text = ""
171
- except Exception as e:
172
- print(f"Restoration error: {e}")
173
- restored_text = ""
174
-
175
- return restored_text, accuracy
176
-
177
- def compress(self, text: str) -> Dict:
178
- """
179
- Get compression statistics
180
-
181
- Args:
182
- text: Input text
183
-
184
- Returns:
185
- Dict with compression info
186
- """
187
- text_bytes = text.encode('utf-8')
188
- embeddings = self.embed(text)
189
-
190
- original_size = len(text_bytes)
191
- compressed_size = embeddings.shape[1]
192
- compression_ratio = original_size / compressed_size if compressed_size > 0 else 0
193
-
194
- return {
195
- 'original_bytes': original_size,
196
- 'compressed_tokens': compressed_size,
197
- 'compression_ratio': compression_ratio,
198
- 'embedding_shape': list(embeddings.shape)
199
- }
200
-
201
- def _split_text_safely(self, text: str) -> List[str]:
202
- """Split text safely at UTF-8 boundaries"""
203
- chunks = []
204
- text_bytes = text.encode('utf-8')
205
-
206
- start = 0
207
- while start < len(text_bytes):
208
- end = min(start + self.max_chunk_size, len(text_bytes))
209
-
210
- # Find valid UTF-8 boundary
211
- while end > start and end < len(text_bytes):
212
- try:
213
- chunk = text_bytes[start:end].decode('utf-8')
214
- break
215
- except UnicodeDecodeError:
216
- end -= 1
217
-
218
- if end > start:
219
- chunk = text_bytes[start:end].decode('utf-8')
220
- chunks.append(chunk)
221
- start = end
222
- else:
223
- break
224
-
225
- return chunks
226
-
227
-
228
- def test_model():
229
- """Test model functionality"""
230
- print("="*70)
231
- print("INTELLIGENT TOKENIZER v6.0 - FUNCTIONALITY TEST")
232
- print("="*70)
233
-
234
- # Initialize tokenizer
235
- tokenizer = IntelligentTokenizer()
236
-
237
- # Test samples
238
- test_samples = [
239
- ("English", "Hello, world!"),
240
- ("Korean", "์•ˆ๋…•ํ•˜์„ธ์š”. ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค."),
241
- ("Chinese", "ไปŠๅคฉๅคฉๆฐ”ๅพˆๅฅฝ"),
242
- ("Japanese", "ใ“ใ‚“ใซใกใฏ"),
243
- ("Arabic", "ู…ุฑุญุจุง ุจูƒ"),
244
- ("Russian", "ะŸั€ะธะฒะตั‚, ะบะฐะบ ะดะตะปะฐ?"),
245
- ("Emoji", "Hello ๐Ÿ‘‹ World ๐ŸŒ!"),
246
- ]
247
-
248
- print("\n" + "="*70)
249
- print("EMBEDDING & RESTORATION TESTS")
250
- print("="*70)
251
-
252
- total_accuracy = 0
253
- successful = 0
254
-
255
- for lang, text in test_samples:
256
- print(f"\n[{lang}]")
257
- print(f"Original: {text}")
258
-
259
- # Test embedding
260
- embeddings = tokenizer.embed(text)
261
- print(f"Embedding: {embeddings.shape}")
262
-
263
- # Test compression
264
- compression = tokenizer.compress(text)
265
- print(f"Compression: {compression['original_bytes']} bytes โ†’ {compression['compressed_tokens']} tokens")
266
- print(f"Ratio: {compression['compression_ratio']:.2f}x")
267
-
268
- # Test restoration
269
- restored, accuracy = tokenizer.restore(text)
270
- print(f"Restored: {restored}")
271
- print(f"Accuracy: {accuracy:.1%}")
272
-
273
- if accuracy > 0.7:
274
- successful += 1
275
- total_accuracy += accuracy
276
-
277
- # Summary
278
- print("\n" + "="*70)
279
- print("TEST SUMMARY")
280
- print("="*70)
281
- print(f"Tests passed: {successful}/{len(test_samples)}")
282
- print(f"Average accuracy: {total_accuracy/len(test_samples):.1%}")
283
-
284
- if successful == len(test_samples):
285
- print("\nโœ… ALL TESTS PASSED!")
286
- return True
287
- elif successful >= len(test_samples) * 0.7:
288
- print("\nโš ๏ธ PARTIAL SUCCESS (70%+ tests passed)")
289
- return True
290
- else:
291
- print("\nโŒ TESTS FAILED")
292
- return False
293
-
294
-
295
- if __name__ == "__main__":
296
- success = test_model()
297
- sys.exit(0 if success else 1)
 
1
+ """
2
+ B2NL-IntelligentTokenizer v6.2.1 - ์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ์ถ”๋ก  ์ฝ”๋“œ
3
+ ์ด ํŒŒ์ผ์ด ๋ฉ”์ธ ์‚ฌ์šฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
4
+ """
5
+
6
+ import torch
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # ๊ฒฝ๋กœ ์ถ”๊ฐ€
11
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent / "intelligent-tokenizer_v6.2.1"))
12
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent / "intelligent-tokenizer_v6.2.1/core"))
13
+
14
+ from core.unified_model import IntelligentTokenizerV62
15
+ from core.tokenizer import ByteTokenizerV62
16
+
17
+
18
+ class B2NLTokenizer:
19
+ """์‹ค์ œ๋กœ ์ž‘๋™ํ•˜๋Š” B2NL ํ† ํฌ๋‚˜์ด์ €"""
20
+
21
+ def __init__(self, checkpoint_path: str = None):
22
+ """
23
+ Args:
24
+ checkpoint_path: ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ (์—†์œผ๋ฉด ๊ธฐ๋ณธ๊ฐ’ ์‚ฌ์šฉ)
25
+ """
26
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ # ๊ธฐ๋ณธ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
29
+ if checkpoint_path is None:
30
+ checkpoint_path = "D:/intelligent-tokenizer/intelligent-tokenizer_v6.2.1/checkpoints/v62/16.0/epoch_100.pt"
31
+
32
+ # ๋ชจ๋ธ ๋กœ๋“œ
33
+ self.model = IntelligentTokenizerV62()
34
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
35
+ self.model.load_state_dict(checkpoint['model_state_dict'])
36
+ self.model = self.model.to(self.device)
37
+ self.model.eval()
38
+
39
+ print(f"Model loaded successfully on {self.device}")
40
+
41
+ def compress(self, text: str) -> dict:
42
+ """ํ…์ŠคํŠธ๋ฅผ ์••์ถ•"""
43
+ return self.model.compress(text)
44
+
45
+ def reconstruct(self, text: str, temperature: float = 0.1) -> str:
46
+ """
47
+ ํ…์ŠคํŠธ๋ฅผ ์••์ถ• ํ›„ ๋ณต์› (์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ๋ฒ„์ „)
48
+
49
+ Args:
50
+ text: ์ž…๋ ฅ ํ…์ŠคํŠธ
51
+ temperature: ์ƒ์„ฑ ์˜จ๋„ (๋‚ฎ์„์ˆ˜๋ก ๊ฒฐ์ •์ )
52
+
53
+ Returns:
54
+ ๋ณต์›๋œ ํ…์ŠคํŠธ
55
+ """
56
+ # 1. ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ
57
+ tokenizer = self.model.tokenizer
58
+ encoded = tokenizer.encode(text)
59
+
60
+ if isinstance(encoded, dict):
61
+ input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
62
+ attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
63
+ else:
64
+ input_ids = encoded.unsqueeze(0) if encoded.dim() == 1 else encoded
65
+ attention_mask = torch.ones_like(input_ids)
66
+
67
+ input_ids = input_ids.to(self.device)
68
+ attention_mask = attention_mask.to(self.device)
69
+
70
+ # 2. ์ธ์ฝ”๋”๋กœ ์••์ถ•
71
+ with torch.no_grad():
72
+ encoder_outputs = self.model.encoder(
73
+ input_ids=input_ids,
74
+ attention_mask=attention_mask
75
+ )
76
+
77
+ # ๋ชจ๋“  ํžˆ๋“  ์Šคํ…Œ์ดํŠธ ์ค€๋น„
78
+ if 'all_hidden_states' in encoder_outputs:
79
+ encoder_all_hidden = encoder_outputs['all_hidden_states']
80
+ else:
81
+ compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
82
+ encoder_all_hidden = [compressed] * 4
83
+
84
+ # 3. ์ž๋™ํšŒ๊ท€ ๋””์ฝ”๋”ฉ (์‹ค์ œ ์ž‘๋™ํ•˜๋Š” ๋ฐฉ์‹)
85
+ batch_size = input_ids.size(0)
86
+ max_length = 48
87
+
88
+ # BOS ํ† ํฐ์œผ๋กœ ์‹œ์ž‘
89
+ generated = torch.full((batch_size, 1), tokenizer.BOS, device=self.device)
90
+
91
+ for step in range(max_length - 1):
92
+ with torch.no_grad():
93
+ # ํ˜„์žฌ๊นŒ์ง€ ์ƒ์„ฑ๋œ ์‹œํ€€์Šค๋กœ ๋””์ฝ”๋”ฉ
94
+ decoder_outputs = self.model.decoder(
95
+ encoder_all_hidden=encoder_all_hidden,
96
+ decoder_input_ids=generated,
97
+ attention_mask=torch.ones_like(generated),
98
+ use_cache=False
99
+ )
100
+
101
+ # ๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก
102
+ logits = decoder_outputs['logits'][:, -1, :] / temperature
103
+
104
+ # Top-k ์ƒ˜ํ”Œ๋ง
105
+ top_k = 10
106
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
107
+ logits[indices_to_remove] = float('-inf')
108
+
109
+ # ํ™•๋ฅ  ๊ณ„์‚ฐ ๋ฐ ์ƒ˜ํ”Œ๋ง
110
+ probs = torch.nn.functional.softmax(logits, dim=-1)
111
+ next_token = torch.multinomial(probs, num_samples=1)
112
+
113
+ # ์ƒ์„ฑ๋œ ์‹œํ€€์Šค์— ์ถ”๊ฐ€
114
+ generated = torch.cat([generated, next_token], dim=1)
115
+
116
+ # EOS ํ† ํฐ ์ฒดํฌ
117
+ if (next_token == tokenizer.EOS).all():
118
+ break
119
+
120
+ # 4. ํ…์ŠคํŠธ๋กœ ๋””์ฝ”๋”ฉ
121
+ if generated.dim() > 1:
122
+ text = tokenizer.decode(generated[0])
123
+ else:
124
+ text = tokenizer.decode(generated)
125
+
126
+ return text
127
+
128
+
129
+ def test_tokenizer():
130
+ """ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ"""
131
+ print("="*60)
132
+ print("B2NL-IntelligentTokenizer v6.2.1 ํ…Œ์ŠคํŠธ")
133
+ print("="*60)
134
+
135
+ # ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
136
+ tokenizer = B2NLTokenizer()
137
+
138
+ # ํ…Œ์ŠคํŠธ ํ…์ŠคํŠธ
139
+ test_texts = [
140
+ "Hello, world!",
141
+ "์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค.",
142
+ "The quick brown fox jumps over the lazy dog.",
143
+ "ไบบๅทฅๆ™บ่ƒฝๆŠ€ๆœฏๆญฃๅœจๆ”นๅ˜ไธ–็•Œใ€‚",
144
+ ]
145
+
146
+ for text in test_texts:
147
+ print(f"\n์›๋ณธ: {text}")
148
+
149
+ # ์••์ถ•
150
+ compressed = tokenizer.compress(text)
151
+ print(f"์••์ถ•๋ฅ : {compressed['compression_ratio']:.1f}:1 ({compressed['num_tokens']} ํ† ํฐ)")
152
+
153
+ # ๋ณต์›
154
+ reconstructed = tokenizer.reconstruct(text, temperature=0.1)
155
+ print(f"๋ณต์›: {reconstructed}")
156
+
157
+ # ์ •ํ™•๋„ ๊ณ„์‚ฐ
158
+ min_len = min(len(text), len(reconstructed))
159
+ accuracy = sum(1 for i in range(min_len) if text[i] == reconstructed[i]) / len(text) * 100
160
+ print(f"์ •ํ™•๋„: {accuracy:.1f}%")
161
+
162
+ print("\n" + "="*60)
163
+ print("Test completed!")
164
+ print("="*60)
165
+
166
+
167
+ # ์‚ฌ์šฉ ์˜ˆ์ œ
168
+ def example_usage():
169
+ """๊ฐ„๋‹จํ•œ ์‚ฌ์šฉ ์˜ˆ์ œ"""
170
+ # 1. ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™”
171
+ tokenizer = B2NLTokenizer()
172
+
173
+ # 2. ํ…์ŠคํŠธ ์••์ถ•
174
+ text = "์•ˆ๋…•ํ•˜์„ธ์š”, ๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค!"
175
+ compressed = tokenizer.compress(text)
176
+ print(f"์••์ถ• ๊ฒฐ๊ณผ: {compressed['compression_ratio']:.1f}:1")
177
+
178
+ # 3. ํ…์ŠคํŠธ ๋ณต์›
179
+ reconstructed = tokenizer.reconstruct(text)
180
+ print(f"๋ณต์› ๊ฒฐ๊ณผ: {reconstructed}")
181
+
182
+ return tokenizer
183
+
184
+
185
+ if __name__ == "__main__":
186
+ test_tokenizer()