File size: 16,795 Bytes
ff85374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
"""

Intelligent Tokenizer v6.2.0 - Byte Tokenizer with 46+2 Configuration

Handles chunking, sliding windows, and boundary adjustments

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple, Union
import numpy as np


def _trim_utf8_boundary(byte_seq: List[int], limit: int) -> int:
    """

    Trim byte sequence to valid UTF-8 boundary (GPT suggestion)

    """
    end = min(limit, len(byte_seq))
    while end > 0:
        try:
            bytes(byte_seq[:end]).decode('utf-8')
            return end
        except UnicodeDecodeError:
            end -= 1
    return limit


class ByteTokenizerV62:
    """

    Pure byte-level tokenizer

    46 content bytes + 2 special tokens (BOS/EOS) = 48 total

    """

    def __init__(self, config: Optional[Dict] = None):
        # Configuration
        self.content_size = 46  # Actual content bytes
        self.max_seq_len = 48   # Total with BOS/EOS
        self.chunk_overlap = 8  # Overlap for sliding window

        # Special tokens
        self.PAD = 256
        self.BOS = 257
        self.EOS = 258
        self.MASK = 259
        self.vocab_size = 260  # 256 bytes + 4 special

    def encode(self,

               text: str,

               add_special_tokens: bool = True,

               return_chunks: bool = False) -> Dict[str, torch.Tensor]:
        """

        Encode text to byte sequences



        Args:

            text: Input text

            add_special_tokens: Whether to add BOS/EOS

            return_chunks: Return multiple chunks for long sequences

        """
        # Convert to UTF-8 bytes
        byte_sequence = list(text.encode('utf-8'))

        if return_chunks and len(byte_sequence) > self.content_size:
            # Handle long sequences with sliding window
            return self._encode_with_chunks(byte_sequence, add_special_tokens)

        # Single chunk processing with UTF-8 boundary (GPT suggestion)
        if len(byte_sequence) > self.content_size:
            cut_point = _trim_utf8_boundary(byte_sequence, self.content_size)
            byte_sequence = byte_sequence[:cut_point]

        # Add special tokens (GPT suggestion: cleaner padding order)
        if add_special_tokens:
            byte_sequence = [self.BOS] + byte_sequence + [self.EOS]

        # Pad to max_seq_len (after special tokens for cleaner structure)
        if len(byte_sequence) < self.max_seq_len:
            padding_length = self.max_seq_len - len(byte_sequence)
            byte_sequence = byte_sequence + [self.PAD] * padding_length

        input_ids = torch.tensor(byte_sequence, dtype=torch.long)
        attention_mask = (input_ids != self.PAD)  # bool type (GPT suggestion)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'length': len(byte_sequence),
            'original_length': len(text.encode('utf-8'))
        }

    def _encode_with_chunks(self,

                           byte_sequence: List[int],

                           add_special_tokens: bool) -> Dict[str, torch.Tensor]:
        """

        Encode long sequences with sliding window chunks

        """
        chunks = []
        positions = []

        # Calculate stride (content_size - overlap)
        stride = self.content_size - self.chunk_overlap

        for i in range(0, len(byte_sequence), stride):
            # Extract chunk
            chunk = byte_sequence[i:i + self.content_size]

            # Skip if chunk is too small (last chunk)
            if len(chunk) < self.content_size // 2:
                if chunks:  # Merge with previous chunk if exists
                    last_chunk = chunks[-1]['input_ids'].tolist()
                    # Remove padding and special tokens from last chunk (GPT final check)
                    last_chunk = [b for b in last_chunk if b not in [self.PAD, self.BOS, self.EOS]]
                    # Add current chunk
                    merged = last_chunk + chunk + [self.EOS]
                    # Repad
                    if len(merged) < self.max_seq_len:
                        merged += [self.PAD] * (self.max_seq_len - len(merged))
                    merged_ids = torch.tensor(merged[:self.max_seq_len], dtype=torch.long)
                    merged_mask = (merged_ids != self.PAD)  # Recalculate mask (GPT suggestion)
                    chunks[-1]['input_ids'] = merged_ids
                    chunks[-1]['attention_mask'] = merged_mask
                break

            # Pad chunk if necessary
            if len(chunk) < self.content_size:
                chunk += [self.PAD] * (self.content_size - len(chunk))

            # Add special tokens
            if add_special_tokens:
                chunk_with_special = [self.BOS] + chunk + [self.EOS]
            else:
                chunk_with_special = chunk

            # Create tensors
            input_ids = torch.tensor(chunk_with_special, dtype=torch.long)
            attention_mask = (input_ids != self.PAD)  # bool type (GPT suggestion)

            chunks.append({
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'position': (i, min(i + self.content_size, len(byte_sequence)))
            })
            positions.append((i, min(i + self.content_size, len(byte_sequence))))

        # Stack all chunks
        all_input_ids = torch.stack([c['input_ids'] for c in chunks])
        all_attention_masks = torch.stack([c['attention_mask'] for c in chunks])

        return {
            'input_ids': all_input_ids,  # [num_chunks, seq_len]
            'attention_mask': all_attention_masks,
            'num_chunks': len(chunks),
            'chunk_positions': positions,
            'original_length': len(byte_sequence)
        }

    def reconstruct(self,

                    input_ids: torch.Tensor,

                    positions: List[Tuple[int, int]] = None,

                    skip_special_tokens: bool = True,

                    overlap: int = 8) -> str:
        """

        Reconstruct text from multiple chunks (GPT suggestion)



        Args:

            input_ids: [num_chunks, seq_len] for multi-chunk

            positions: List of (start, end) positions for each chunk

            skip_special_tokens: Whether to skip special tokens

            overlap: Overlap size between chunks

        """
        if input_ids.dim() == 1:
            # Single sequence, use regular decode
            return self.decode(input_ids, skip_special_tokens)

        # Multi-chunk reconstruction
        pieces = []
        for i, chunk_ids in enumerate(input_ids):
            chunk_ids = chunk_ids.cpu().numpy().tolist()

            # Remove special tokens and padding
            if skip_special_tokens:
                chunk_ids = [
                    b for b in chunk_ids
                    if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
                ]

            pieces.append(chunk_ids)

        # Merge chunks with overlap handling
        output = []
        for i, chunk in enumerate(pieces):
            if i == 0:
                output.extend(chunk)
            else:
                # Skip overlap bytes from current chunk
                output.extend(chunk[overlap:] if len(chunk) > overlap else chunk)

        # Convert to string
        try:
            text = bytes(output).decode('utf-8', errors='replace')
        except:
            text = ""

        return text

    def decode(self,

               input_ids: torch.Tensor,

               skip_special_tokens: bool = True) -> str:
        """

        Decode byte sequences back to text

        """
        if isinstance(input_ids, torch.Tensor):
            input_ids = input_ids.cpu().numpy().tolist()

        # Handle batch dimension
        if isinstance(input_ids[0], list):
            input_ids = input_ids[0]

        # Remove special tokens and padding
        if skip_special_tokens:
            input_ids = [
                b for b in input_ids
                if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
            ]

        # Convert bytes to string
        try:
            text = bytes(input_ids).decode('utf-8', errors='replace')
        except:
            text = ""

        return text

    def batch_encode(self,

                    texts: List[str],

                    add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
        """

        Encode multiple texts as a batch

        """
        encoded = [self.encode(text, add_special_tokens) for text in texts]

        # Find max length
        max_len = max(e['length'] for e in encoded)
        max_len = min(max_len, self.max_seq_len)

        # Create batch tensors
        batch_size = len(texts)
        input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
        attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool)  # bool type (GPT suggestion)

        for i, enc in enumerate(encoded):
            seq_len = min(enc['length'], max_len)
            if enc['input_ids'].dim() == 0:  # Handle scalar
                enc['input_ids'] = enc['input_ids'].unsqueeze(0)
            input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
            attention_mask[i, :seq_len] = True

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'lengths': [e['length'] for e in encoded]
        }


class ChunkBoundaryAdjuster(nn.Module):
    """

    Neural network for adjusting chunk boundaries

    Learns optimal splitting points

    """

    def __init__(self, hidden_dim: int = 256):
        super().__init__()

        # Boundary scoring network
        self.boundary_scorer = nn.Sequential(
            nn.Linear(256, hidden_dim),  # Input: byte embeddings
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),  # Output: boundary score
            nn.Sigmoid()
        )

        # UTF-8 boundary detector
        self.utf8_detector = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=4, padding=2),  # Detect multi-byte patterns
            nn.ReLU(),
            nn.Conv1d(16, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, byte_sequence: torch.Tensor) -> torch.Tensor:
        """

        Find optimal chunk boundaries



        Args:

            byte_sequence: [batch, seq_len, embedding_dim]



        Returns:

            boundary_scores: [batch, seq_len] - probability of boundary at each position

        """
        batch_size, seq_len = byte_sequence.shape[:2]

        # Score each position as potential boundary
        boundary_scores = self.boundary_scorer(byte_sequence).squeeze(-1)

        # Detect UTF-8 boundaries (avoid splitting multi-byte characters)
        byte_values = byte_sequence[..., 0].unsqueeze(1)  # [batch, 1, seq_len]
        utf8_scores = self.utf8_detector(byte_values).squeeze(1)  # [batch, seq_len]

        # Combine scores (prefer boundaries at valid UTF-8 positions)
        combined_scores = boundary_scores * utf8_scores

        # Apply constraints: boundaries should be ~46 bytes apart
        for i in range(0, seq_len, 46):
            if i < seq_len:
                # Boost score at expected positions
                combined_scores[:, i] = combined_scores[:, i] * 1.5

        return combined_scores


class SlidingWindowProcessor(nn.Module):
    """

    Process sequences with sliding windows at multiple scales

    """

    def __init__(self, window_sizes: List[int] = [8, 16, 32, 46]):
        super().__init__()
        self.window_sizes = window_sizes

        # Multi-scale convolutions for different window sizes
        self.convs = nn.ModuleList([
            nn.Conv1d(256, 128, kernel_size=ws, stride=ws//2, padding=ws//4)
            for ws in window_sizes
        ])

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(128 * len(window_sizes), 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 256)
        )

    def forward(self, byte_embeddings: torch.Tensor) -> torch.Tensor:
        """

        Apply multi-scale sliding windows



        Args:

            byte_embeddings: [batch, seq_len, embedding_dim]



        Returns:

            processed: [batch, seq_len, embedding_dim]

        """
        # Transpose for conv1d
        x = byte_embeddings.transpose(1, 2)  # [batch, embed, seq]

        # Apply multi-scale convolutions
        multi_scale_features = []
        for conv in self.convs:
            features = conv(x)  # Different seq lengths
            # Global average pooling to fixed size
            pooled = F.adaptive_avg_pool1d(features, byte_embeddings.size(1))
            multi_scale_features.append(pooled)

        # Concatenate and transpose back
        concat = torch.cat(multi_scale_features, dim=1)  # [batch, 128*scales, seq]
        concat = concat.transpose(1, 2)  # [batch, seq, 128*scales]

        # Fuse multi-scale features
        fused = self.fusion(concat)  # [batch, seq, 256]

        # Residual connection
        output = fused + byte_embeddings

        return output


class AdaptiveChunker:
    """

    Adaptive chunking based on content complexity

    Simple heuristic-based chunker for inference

    """

    def __init__(self):
        self.min_chunk = 32
        self.max_chunk = 46
        self.target_chunk = 46

    def determine_chunk_size(self, text: str) -> int:
        """

        Determine optimal chunk size based on text characteristics

        """
        byte_seq = text.encode('utf-8')

        # Check character types
        has_cjk = any(b >= 0x80 for b in byte_seq[:100])  # Non-ASCII
        has_arabic = any(0x0600 <= ord(c) <= 0x06FF for c in text[:100])

        # Adjust chunk size based on content
        if has_cjk:
            # CJK characters need smaller chunks (multi-byte)
            return self.min_chunk
        elif has_arabic:
            # Arabic also benefits from smaller chunks
            return 40
        else:
            # ASCII/Latin can use larger chunks
            return self.target_chunk

    def chunk_text(self, text: str) -> List[str]:
        """

        Split text into adaptive chunks

        """
        chunk_size = self.determine_chunk_size(text)
        byte_seq = text.encode('utf-8')
        chunks = []

        i = 0
        while i < len(byte_seq):
            # Find chunk boundary (don't split UTF-8 sequences)
            end = min(i + chunk_size, len(byte_seq))

            # Backtrack to valid UTF-8 boundary if needed
            while end > i and end < len(byte_seq):
                try:
                    _ = byte_seq[i:end].decode('utf-8')
                    break
                except:
                    end -= 1

            chunk_bytes = byte_seq[i:end]
            chunks.append(chunk_bytes.decode('utf-8', errors='replace'))
            i = end

        return chunks


if __name__ == "__main__":
    # Test the tokenizer
    tokenizer = ByteTokenizerV62()

    # Test texts
    test_texts = [
        "Hello, world!",
        "안녕하세요, 세계!",
        "今天天气很好。",
        "مرحبا بالعالم",
        "A" * 100  # Long text
    ]

    for text in test_texts:
        print(f"\nText: {text[:50]}...")

        # Single chunk encoding
        encoded = tokenizer.encode(text)
        print(f"  Encoded shape: {encoded['input_ids'].shape}")
        print(f"  Original length: {encoded['original_length']} bytes")

        # Decode back
        decoded = tokenizer.decode(encoded['input_ids'])
        print(f"  Decoded: {decoded[:50]}...")

        # Check multi-chunk for long text
        if encoded['original_length'] > 46:
            multi_encoded = tokenizer.encode(text, return_chunks=True)
            print(f"  Chunks: {multi_encoded['num_chunks']}")

    # Test batch encoding
    batch = tokenizer.batch_encode(test_texts[:3])
    print(f"\nBatch shape: {batch['input_ids'].shape}")

    # Test adaptive chunker
    chunker = AdaptiveChunker()
    for text in test_texts[:3]:
        chunk_size = chunker.determine_chunk_size(text)
        print(f"\n{text[:30]}... → Chunk size: {chunk_size}")