File size: 24,083 Bytes
ff85374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
"""

Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder

With GPT-5 suggested improvements

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import math


class RoPEPositionalEncoding(nn.Module):
    """

    Rotary Position Embedding (RoPE) - GPT-5 suggestion

    Better for handling chunk boundaries and variable sequence lengths

    """

    def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute sinusoidal frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute positional encodings
        t = torch.arange(max_seq_len).type_as(self.inv_freq)
        freqs = torch.outer(t, self.inv_freq)
        self.register_buffer('cos_cached', freqs.cos())
        self.register_buffer('sin_cached', freqs.sin())

    def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
        """

        Apply RoPE to input tensor

        Handles chunk boundary corrections as suggested by GPT-5

        """
        if seq_len is None:
            seq_len = x.shape[1]

        # Get cached cos/sin values
        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]

        # Apply rotary embedding
        x_rot = self._apply_rotary_emb(x, cos, sin)

        return x_rot

    def _apply_rotary_emb(self, x, cos, sin):
        """Apply rotary embedding to input"""
        x1, x2 = x[..., ::2], x[..., 1::2]
        x_rot = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1).flatten(-2)
        return x_rot


class GatedCrossAttention(nn.Module):
    """

    Gated Cross-Attention with MQA - GPT-5 suggestion

    Monitor gate values for quality assessment

    16Q β†’ 2K/V for 8x memory reduction

    """

    def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.kv_heads = kv_heads  # Reduced KV heads (GPT suggestion)
        self.head_dim = hidden_dim // num_heads  # 80

        # Multi-Query Attention projections
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)  # 16 heads
        self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim)  # 2 heads
        self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim)  # 2 heads
        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

        # Gating mechanism (GPT-5 suggestion)
        self.gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        # Gate monitoring (for analysis)
        self.register_buffer('gate_values', torch.zeros(1))

        # Warmup factor (GPT suggestion)
        self.register_buffer('warmup_alpha', torch.tensor(1.0))

    def forward(self,

                query: torch.Tensor,

                key: torch.Tensor,

                value: torch.Tensor,

                mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Forward pass with gate monitoring

        Returns: (output, gate_values)

        """
        batch_size, seq_len = query.shape[:2]

        # Multi-head attention projections
        Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim)
        V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim)

        # Transpose for attention computation
        Q = Q.transpose(1, 2)  # [batch, heads, seq, dim]
        K = K.transpose(1, 2)  # [batch, kv_heads, seq, dim]
        V = V.transpose(1, 2)

        # Repeat KV heads to match Q heads if necessary
        if self.kv_heads < self.num_heads:
            repeat_factor = self.num_heads // self.kv_heads
            K = K.repeat_interleave(repeat_factor, dim=1)
            V = V.repeat_interleave(repeat_factor, dim=1)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
        attn_output = self.o_proj(attn_output)

        # Gating mechanism
        gate_input = torch.cat([query, attn_output], dim=-1)
        gate_values = self.gate(gate_input)

        # Store gate values for monitoring (keep tensor shape consistent)
        self.gate_values[0] = gate_values.mean().detach()

        # Apply gate with warmup factor (GPT suggestion)
        gate_values = gate_values * self.warmup_alpha
        output = gate_values * attn_output + (1 - gate_values) * query

        return output, gate_values



class ProgressiveSplittingLayer(nn.Module):
    """

    Core innovation: 48 bytes β†’ 1 token β†’ N tokens β†’ M tokens

    """

    def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.config = config or {}

        # Dynamic splitting: 1~4 tokens for efficiency
        # 48 bytes / 4 tokens = 12:1 compression (still beats BPE's 4:1)
        self.min_tokens = 1  # 48:1 compression
        self.max_tokens = 4  # 12:1 compression (still 3x better than BPE)

        # Initial compression: 48 bytes β†’ 1 super token
        self.byte_embed = nn.Embedding(260, 64)  # Small embedding
        self.initial_compressor = nn.Sequential(
            nn.Linear(48 * 64, 2048),
            nn.LayerNorm(2048),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(2048, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )

        # Language-aware splitting: 1 β†’ N tokens (config-based)
        self.language_splitter = nn.ModuleDict({
            'analyzer': nn.Sequential(
                nn.Linear(hidden_dim, 512),
                nn.ReLU(),
                nn.Linear(512, 256)  # Language features
            ),
            'split_predictor': nn.Linear(256, self.max_tokens),  # Predict 1~4 tokens
            # Single unified expander that can produce any number of tokens
            'dynamic_expander': nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 2),
                nn.LayerNorm(hidden_dim * 2),
                nn.GELU(),  # Better than ReLU for transformers
                nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens)  # Can produce up to 4 tokens
            ),
            # Token-wise importance predictor
            'importance_predictor': nn.Sequential(
                nn.Linear(hidden_dim, 256),
                nn.ReLU(),
                nn.Linear(256, self.max_tokens),  # Importance for each potential token
                nn.Softmax(dim=-1)
            )
        })

        # Boundary refinement: N β†’ M tokens with linguistic awareness
        self.boundary_refiner = nn.ModuleDict({
            'scorer': nn.Sequential(
                nn.Linear(hidden_dim, 512),
                nn.ReLU(),
                nn.Linear(512, 1)
            ),
            'morpheme_detector': nn.Conv1d(256, 64, 3),  # ν˜•νƒœμ†Œ
            'word_detector': nn.Conv1d(256, 64, 5),       # 단어
            'phrase_detector': nn.Conv1d(256, 64, 7),     # ꡬ
            'adjuster': nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=16,
                dim_feedforward=4 * hidden_dim,
                dropout=0.1,
                batch_first=True
            )
        })

        # Initialize split_predictor bias to prefer 1 token initially
        # This ensures untrained model starts with maximum compression
        with torch.no_grad():
            self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0])
            # High bias for 1 token, negative for others

    def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]:
        """

        Progressive splitting forward pass



        Args:

            input_ids: Input byte sequence [batch, seq_len]

            temperature: Gumbel-Softmax temperature for annealing

        """
        batch_size = input_ids.size(0)

        # Step 1: 48 bytes β†’ 1 super token
        byte_embeddings = self.byte_embed(input_ids)  # [batch, 48, 64]
        flattened = byte_embeddings.view(batch_size, -1)  # [batch, 3072]
        super_token = self.initial_compressor(flattened)  # [batch, 1280]
        super_token = super_token.unsqueeze(1)  # [batch, 1, 1280]

        # Step 2: Language analysis and splitting (1 β†’ N)
        lang_features = self.language_splitter['analyzer'](super_token)
        split_logits = self.language_splitter['split_predictor'](lang_features)
        split_weights = F.softmax(split_logits, dim=-1)  # [batch, 1, 8]

        # Direct transformation from super token to initial representation
        # No hardcoded splits - let the model learn everything
        lang_tokens = super_token  # Start with compressed representation

        # TRUE Adaptive expansion - Model learns optimal split (1~4 tokens)
        # Analyze content to decide how many tokens needed
        expansion_features = self.language_splitter['analyzer'](lang_tokens)  # [batch, 1, 256]

        # Dynamic expansion: generate up to 4 tokens from super token
        expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1))  # [batch, hidden_dim*4]
        expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim)  # [batch, 4, hidden_dim]

        # Predict how many tokens we actually need (1~4)
        split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1))  # [batch, 4]
        # Clamp logits to prevent extreme values that cause NaN
        split_logits = torch.clamp(split_logits, min=-10, max=10)
        # Ensure minimum temperature to prevent instability
        safe_temperature = max(temperature, 0.5)
        split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1)  # [batch, 4]

        # Predict importance for each potential token position
        importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1))  # [batch, 4]

        # Dynamic token selection with importance-weighted allocation
        # Create cumulative mask for progressive token usage
        # If split_weights = [0.1, 0.2, 0.6, 0.1], we mainly use 3 tokens

        # Create progressive masks for 1, 2, 3, 4 tokens
        masks = []
        for n in range(1, self.max_tokens + 1):
            mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device)
            mask[:, :n, :] = 1.0
            masks.append(mask)

        # Apply importance-weighted masking
        # Important parts get more tokens, less important parts get fewer
        weighted_outputs = []
        for i, mask in enumerate(masks):
            num_tokens = i + 1
            # Weight by both split decision and importance
            token_weight = split_weights[:, i:i+1].unsqueeze(-1)  # [batch, 1, 1]

            # Apply importance modulation for asymmetric splits
            if num_tokens > 1:
                # Redistribute tokens based on importance
                importance_adjusted = importance[:, :num_tokens].unsqueeze(-1)  # [batch, n, 1]
                masked = expanded[:, :num_tokens] * importance_adjusted
            else:
                masked = expanded[:, :num_tokens]

            # Pad to max length
            if num_tokens < self.max_tokens:
                padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim,
                                     device=expanded.device)
                masked = torch.cat([masked, padding], dim=1)

            weighted_outputs.append(masked * token_weight)

        # Sum all weighted possibilities (differentiable selection)
        lang_tokens = sum(weighted_outputs)

        # Determine effective number of tokens (for monitoring)
        # Weighted average of token counts
        token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32)
        avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item()

        k = lang_tokens.size(1)

        # Step 3: Boundary refinement (N β†’ M)
        # Calculate boundary scores for each token position
        boundary_scores = self.boundary_refiner['scorer'](lang_tokens)  # [batch, N, 1]

        # Detect linguistic boundaries (morpheme, word, phrase)
        # Extract features for boundary detection
        if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3:
            batch_size, num_tokens, hidden_dim = lang_tokens.shape

            # For boundary detection, we need to consider the original byte sequence
            # But we're working with compressed tokens here
            # So we detect boundaries based on learned representations

            # Apply boundary adjustment with TransformerEncoderLayer
            # This learns to adjust token boundaries based on context
            refined_tokens = self.boundary_refiner['adjuster'](lang_tokens)

            # The adjuster should learn to:
            # 1. Respect UTF-8 boundaries (learned during training)
            # 2. Align with word/phrase boundaries (learned from language patterns)
            # 3. Maintain semantic coherence within each token
        else:
            refined_tokens = lang_tokens

        # Determine actual number of tokens based on highest probability
        # During inference, use argmax. During training, use weighted average.
        if self.training:
            # During training, use weighted average for differentiability
            actual_num_tokens = avg_tokens
        else:
            # During inference, select the split with highest probability
            split_decision = torch.argmax(split_weights, dim=-1)  # [batch]
            actual_num_tokens = (split_decision.float().mean() + 1).item()  # +1 because indices are 0-3

        # Calculate compression ratio based on actual tokens used
        compression_ratio = 48.0 / max(1, actual_num_tokens)

        return {
            'tokens': refined_tokens,
            'num_tokens': actual_num_tokens,
            'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device),
            'gate_values': None,  # Will be filled by cross-attention
            'language_features': lang_features,
            'split_weights': split_weights,
            'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1),
            'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None
        }


class EncoderV62(nn.Module):
    """

    4-Layer Progressive Splitting Encoder with Cross-Attention

    All layers: 1280 dimensions

    """

    def __init__(self, config: Optional[Dict] = None):
        super().__init__()

        # Store config for later use
        self.config = config or {}

        # Configuration
        self.hidden_dim = 1280
        self.num_heads = 16
        self.num_layers = 4
        self.max_seq_len = 48
        self.dropout = 0.1

        # RoPE positional encoding (GPT-5 suggestion)
        self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len)

        # Layer 0: Progressive Splitting (48→1→N→M) - Pass config
        self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config)

        # Layers 1-3: Transformer encoders with cross-attention
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=self.hidden_dim,
                nhead=self.num_heads,
                dim_feedforward=4 * self.hidden_dim,  # 5120
                dropout=self.dropout,
                batch_first=True
            ) for _ in range(3)
        ])

        # Cross-attention between layers with MQA (GPT-5 suggestion)
        self.cross_attentions = nn.ModuleList([
            GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2)  # 8x memory reduction
            for _ in range(3)
        ])

        # Output heads for different tasks
        self.boundary_head = nn.Linear(self.hidden_dim, 4)
        self.language_head = nn.Linear(self.hidden_dim, 128)  # Reduced from 512 (GPT suggestion)
        self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Monitoring metrics (GPT-5 suggestion)
        self.register_buffer('compression_ratios', torch.zeros(1))
        self.register_buffer('gate_averages', torch.zeros(3))

    def forward(self,

                input_ids: torch.Tensor,

                attention_mask: Optional[torch.Tensor] = None,

                temperature: float = 1.0) -> Dict[str, torch.Tensor]:
        """

        Forward pass through the encoder



        Args:

            input_ids: Input byte sequence

            attention_mask: Optional attention mask

            temperature: Gumbel-Softmax temperature for annealing

        """
        # Layer 0: Progressive splitting with temperature
        split_output = self.progressive_splitter(input_ids, temperature)
        x = split_output['tokens']  # [batch, M, 1280]

        # Apply RoPE
        x = self.rope(x, x.size(1))

        # Store all hidden states for decoder
        all_hidden_states = [x]
        gate_values_list = []

        # Layers 1-3 with cross-attention
        for i, (encoder_layer, cross_attn) in enumerate(
            zip(self.encoder_layers, self.cross_attentions)
        ):
            # Self-attention through transformer layer
            # GPT final check: Don't pass mask after progressive splitting changes sequence length
            x = encoder_layer(x)  # No mask needed (no padding after compression)

            # Cross-attention with previous layer
            if i > 0:
                # Cross-attention with previous layer
                x, gate_values = cross_attn(
                    query=x,
                    key=all_hidden_states[-1],
                    value=all_hidden_states[-1],
                    mask=None  # Mask not applicable after compression
                )
                gate_values_list.append(gate_values)
                # Keep tensor shape consistent - store in existing buffer element
                self.gate_averages[i-1] = gate_values.mean().detach().item()  # Fix indexing

            all_hidden_states.append(x)

        # Output projections
        boundaries = self.boundary_head(x)
        language_clusters = self.language_head(x)
        compressed = self.compression_head(x)

        # Update monitoring metrics
        # Ensure tensor is 1-dimensional for buffer assignment
        compression_ratio = split_output['compression_ratio']
        if compression_ratio.dim() == 0:  # Scalar tensor
            self.compression_ratios[0] = compression_ratio
        else:
            self.compression_ratios = compression_ratio

        return {
            'last_hidden_state': x,
            'all_hidden_states': all_hidden_states,
            'boundaries': boundaries,
            'language_clusters': language_clusters,
            'compressed': compressed,
            'compression_ratio': split_output['compression_ratio'],
            'num_tokens': split_output['num_tokens'],
            'splitting_probs': split_output.get('split_weights', None),  # Add for diagnostics
            'gate_values': gate_values_list,
            'gate_averages': self.gate_averages,
            'split_info': {
                'language_features': split_output['language_features'],
                'split_weights': split_output['split_weights']
            }
        }

    def get_monitoring_stats(self) -> Dict[str, float]:
        """

        Get monitoring statistics (GPT-5 suggestion)

        """
        return {
            'avg_compression_ratio': self.compression_ratios.item(),
            'gate_layer1': self.gate_averages[0].item(),
            'gate_layer2': self.gate_averages[1].item(),
            'gate_layer3': self.gate_averages[2].item(),
        }

    def set_warmup_step(self, step: int, total_warmup: int = 1000):
        """

        Set warmup alpha for all gates (GPT suggestion)

        Gradually increase gate influence from 0 to 1

        """
        alpha = min(1.0, step / total_warmup)
        for cross_attn in self.cross_attentions:
            cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device)

    def adaptive_compression_control(self, reconstruction_loss: float):
        """

        Adaptive compression based on reconstruction quality

        No fixed phases - model learns optimal compression

        """
        # If reconstruction is poor, model will learn to use more tokens
        # This happens automatically through gradient descent
        # No manual phase control needed
        pass  # Let gradients handle it


class DualSlidingWindowEncoder(EncoderV62):
    """

    Extension with dual sliding window system

    Handles both chunk-level and token-level boundaries

    """

    def __init__(self, config: Optional[Dict] = None):
        super().__init__(config)

        # Chunk-level sliding window
        self.chunk_window = nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=8,  # 8-byte overlap
            stride=40,      # 48-8=40 stride
            padding=4
        )

        # Token-level sliding window
        self.token_window = nn.MultiheadAttention(
            embed_dim=self.hidden_dim,
            num_heads=self.num_heads,
            batch_first=True
        )

    def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor:
        """

        Handle sequences longer than 48 bytes with sliding windows

        """
        batch_size, seq_len = input_ids.shape

        if seq_len <= 48:
            return super().forward(input_ids)

        # Process in chunks with overlap
        chunks = []
        for i in range(0, seq_len - 48 + 1, 40):  # 8-byte overlap
            chunk = input_ids[:, i:i+48]
            chunk_output = super().forward(chunk)
            chunks.append(chunk_output['last_hidden_state'])

        # Combine chunks with attention
        combined = torch.cat(chunks, dim=1)
        attended, _ = self.token_window(combined, combined, combined)

        return {
            'last_hidden_state': attended,
            'num_chunks': len(chunks),
            'total_compression': seq_len / attended.size(1)
        }


if __name__ == "__main__":
    # Test the encoder
    encoder = EncoderV62()

    # Test input
    batch_size = 2
    input_ids = torch.randint(0, 256, (batch_size, 48))

    # Forward pass
    output = encoder(input_ids)

    print(f"Input shape: {input_ids.shape}")
    print(f"Output tokens: {output['num_tokens']}")
    print(f"Compression ratio: {output['compression_ratio']:.2f}:1")
    print(f"Gate averages: {output['gate_averages']}")
    print(f"Monitoring stats: {encoder.get_monitoring_stats()}")