ggunio commited on
Commit
a6c6452
ยท
verified ยท
1 Parent(s): c65503b

Upload core/unified_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. core/unified_model.py +602 -0
core/unified_model.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Intelligent Tokenizer Model v6.0
3
+ ์ˆœ์ˆ˜ ํ•™์Šต ๊ธฐ๋ฐ˜ - ๋ชจ๋“  ํ•ต์‹ฌ ์ฝ”๋“œ ํ†ตํ•ฉ
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import math
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+
13
+ class PositionalEncoding(nn.Module):
14
+ """
15
+ Sinusoidal Positional Encoding (Transformer ์›๋ณธ ๋ฐฉ์‹)
16
+ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ๊ณ ์ •๋œ sin/cos ํŒจํ„ด ์‚ฌ์šฉ
17
+ """
18
+
19
+ def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
20
+ super().__init__()
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ # Create sinusoidal position encodings
24
+ pe = torch.zeros(max_len, d_model)
25
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
26
+
27
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
28
+ -(math.log(10000.0) / d_model))
29
+
30
+ pe[:, 0::2] = torch.sin(position * div_term) # Even dimensions
31
+ pe[:, 1::2] = torch.cos(position * div_term) # Odd dimensions
32
+
33
+ # Register as buffer (not trainable)
34
+ self.register_buffer('pe', pe.unsqueeze(0))
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Add positional encoding to input
39
+ Args:
40
+ x: (batch_size, seq_len, d_model)
41
+ """
42
+ x = x + self.pe[:, :x.size(1)]
43
+ return self.dropout(x)
44
+
45
+
46
+ class ByteTokenizer:
47
+ """
48
+ Pure byte-level tokenizer - no language rules
49
+ """
50
+
51
+ def __init__(self, max_seq_len: int = 512):
52
+ self.max_seq_len = max_seq_len
53
+ self.PAD = 256
54
+ self.BOS = 257
55
+ self.EOS = 258
56
+ self.MASK = 259
57
+
58
+ def encode(self, text: str, add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
59
+ # Convert to UTF-8 bytes
60
+ byte_seq = list(text.encode('utf-8'))
61
+
62
+ # Truncate if needed
63
+ if len(byte_seq) > self.max_seq_len - 2:
64
+ byte_seq = byte_seq[:self.max_seq_len - 2]
65
+
66
+ # Add special tokens
67
+ if add_special_tokens:
68
+ byte_seq = [self.BOS] + byte_seq + [self.EOS]
69
+
70
+ input_ids = torch.tensor(byte_seq, dtype=torch.long)
71
+ attention_mask = torch.ones_like(input_ids)
72
+
73
+ return {
74
+ 'input_ids': input_ids,
75
+ 'attention_mask': attention_mask,
76
+ 'length': len(input_ids)
77
+ }
78
+
79
+ def encode_batch(self, texts: List[str]) -> Dict[str, torch.Tensor]:
80
+ encoded = [self.encode(text) for text in texts]
81
+ max_len = min(max(e['length'] for e in encoded), self.max_seq_len)
82
+
83
+ batch_size = len(texts)
84
+ input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
85
+ attention_mask = torch.zeros((batch_size, max_len), dtype=torch.float32)
86
+
87
+ for i, enc in enumerate(encoded):
88
+ seq_len = min(enc['length'], max_len)
89
+ input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
90
+ attention_mask[i, :seq_len] = 1.0
91
+
92
+ return {
93
+ 'input_ids': input_ids,
94
+ 'attention_mask': attention_mask
95
+ }
96
+
97
+ def decode(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
98
+ if isinstance(input_ids, torch.Tensor):
99
+ input_ids = input_ids.cpu().numpy().tolist()
100
+
101
+ if skip_special_tokens:
102
+ input_ids = [b for b in input_ids if b < 256]
103
+
104
+ try:
105
+ byte_array = bytes([min(b, 255) for b in input_ids if b != self.PAD])
106
+ return byte_array.decode('utf-8', errors='replace')
107
+ except:
108
+ return "".join([chr(b) if b < 128 else '?' for b in input_ids if b < 256])
109
+
110
+
111
+ class ByteEncoder(nn.Module):
112
+ """
113
+ 5-Layer Encoder with Positional Encoding
114
+ Layer dimensions: [384, 384, 512, 640, 768] - ์ˆ˜์ •๋จ
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_size: int = 260,
120
+ hidden_dims: List[int] = [384, 384, 512, 640, 768], # 512 ์ถ”๊ฐ€
121
+ num_heads: int = 8,
122
+ dropout: float = 0.1,
123
+ max_seq_len: int = 512
124
+ ):
125
+ super().__init__()
126
+
127
+ # Byte embedding
128
+ self.byte_embedding = nn.Embedding(vocab_size, hidden_dims[0])
129
+
130
+ # Positional encoding (Sinusoidal)
131
+ self.pos_encoding = PositionalEncoding(hidden_dims[0], max_seq_len, dropout)
132
+
133
+ # 5 Transformer layers with dimension changes
134
+ self.layers = nn.ModuleList()
135
+ for i in range(len(hidden_dims)):
136
+ input_dim = hidden_dims[i-1] if i > 0 else hidden_dims[0]
137
+ output_dim = hidden_dims[i]
138
+
139
+ # Projection layer if dimension changes
140
+ if input_dim != output_dim:
141
+ proj = nn.Linear(input_dim, output_dim)
142
+ else:
143
+ proj = None
144
+
145
+ # Transformer encoder layer
146
+ layer = nn.TransformerEncoderLayer(
147
+ d_model=output_dim,
148
+ nhead=num_heads,
149
+ dim_feedforward=output_dim * 4,
150
+ dropout=dropout,
151
+ activation='gelu',
152
+ batch_first=True,
153
+ norm_first=True
154
+ )
155
+
156
+ self.layers.append(nn.ModuleDict({
157
+ 'projection': proj,
158
+ 'transformer': layer,
159
+ 'norm': nn.LayerNorm(output_dim)
160
+ }))
161
+
162
+ self.dropout = nn.Dropout(dropout)
163
+
164
+ def forward(
165
+ self,
166
+ input_ids: torch.Tensor,
167
+ attention_mask: Optional[torch.Tensor] = None
168
+ ) -> Dict[str, torch.Tensor]:
169
+ # Embed bytes
170
+ x = self.byte_embedding(input_ids)
171
+
172
+ # Add positional encoding
173
+ x = self.pos_encoding(x)
174
+
175
+ # Prepare attention mask
176
+ if attention_mask is not None:
177
+ # Keep attention mask as is for TransformerEncoderLayer
178
+ # It expects shape (batch_size, seq_len) and handles masking internally
179
+ pass
180
+
181
+ # Process through 5 layers
182
+ all_hidden_states = []
183
+ for layer_dict in self.layers:
184
+ # Project if needed
185
+ if layer_dict['projection'] is not None:
186
+ x = layer_dict['projection'](x)
187
+
188
+ # Transformer layer - properly handle mask
189
+ if attention_mask is not None:
190
+ # TransformerEncoderLayer expects key_padding_mask (batch, seq)
191
+ # where True means "ignore this position"
192
+ key_padding_mask = (attention_mask == 0)
193
+ x = layer_dict['transformer'](x, src_key_padding_mask=key_padding_mask)
194
+ else:
195
+ x = layer_dict['transformer'](x)
196
+ x = layer_dict['norm'](x)
197
+ all_hidden_states.append(x)
198
+
199
+ # Pool for sequence representation
200
+ if attention_mask is not None:
201
+ # Masked mean pooling - attention_mask is (batch, seq)
202
+ mask = attention_mask.unsqueeze(-1) # (batch, seq, 1)
203
+ pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
204
+ else:
205
+ pooled = x.mean(dim=1)
206
+
207
+ return {
208
+ 'last_hidden_state': x,
209
+ 'pooled_output': pooled,
210
+ 'all_hidden_states': all_hidden_states
211
+ }
212
+
213
+
214
+ class CrossAttention(nn.Module):
215
+ """
216
+ Enhanced Cross-attention for relation learning between sequences
217
+ ์ถ”๋ก  ๋ ˆ์ด์–ด ์—ฐ๊ฒฐ์„ ์œ„ํ•œ ๊ฐ•ํ™”๋œ ๊ด€๊ณ„ ํ•™์Šต
218
+ """
219
+
220
+ def __init__(self, hidden_dim: int = 768, num_heads: int = 8, dropout: float = 0.1):
221
+ super().__init__()
222
+
223
+ self.cross_attn = nn.MultiheadAttention(
224
+ hidden_dim, num_heads, dropout, batch_first=True
225
+ )
226
+
227
+ # Enhanced relation classifier (8 types for richer relations)
228
+ # 0: identity, 1: similar, 2: different, 3: continuation
229
+ # 4: translation, 5: summary, 6: expansion, 7: contradiction
230
+ self.relation_head = nn.Sequential(
231
+ nn.Linear(hidden_dim * 2, hidden_dim),
232
+ nn.GELU(),
233
+ nn.Dropout(dropout),
234
+ nn.Linear(hidden_dim, hidden_dim // 2),
235
+ nn.GELU(),
236
+ nn.Dropout(dropout),
237
+ nn.Linear(hidden_dim // 2, 8)
238
+ )
239
+
240
+ # Gating mechanism for adaptive fusion
241
+ self.gate = nn.Sequential(
242
+ nn.Linear(hidden_dim * 2, hidden_dim),
243
+ nn.Sigmoid()
244
+ )
245
+
246
+ self.norm1 = nn.LayerNorm(hidden_dim)
247
+ self.norm2 = nn.LayerNorm(hidden_dim)
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ query_mask: Optional[torch.Tensor] = None,
254
+ key_mask: Optional[torch.Tensor] = None
255
+ ) -> Dict[str, torch.Tensor]:
256
+ # Normalize inputs
257
+ query_norm = self.norm1(query)
258
+ key_norm = self.norm2(key)
259
+
260
+ # Fix key_mask dimension if needed
261
+ if key_mask is not None:
262
+ # Ensure key_mask matches key sequence length
263
+ if key_mask.dim() == 2 and key_mask.size(1) != key.size(1):
264
+ # Create new mask with correct dimensions
265
+ batch_size = key.size(0)
266
+ seq_len = key.size(1)
267
+ key_mask = torch.ones(batch_size, seq_len, dtype=key_mask.dtype, device=key_mask.device)
268
+
269
+ # Cross attention
270
+ attn_output, attn_weights = self.cross_attn(
271
+ query_norm, key_norm, key_norm,
272
+ key_padding_mask=(key_mask == 0) if key_mask is not None else None
273
+ )
274
+
275
+ # Residual connection
276
+ attn_output = attn_output + query
277
+
278
+ # Adaptive gating for fusion
279
+ gate_input = torch.cat([query.mean(dim=1), key.mean(dim=1)], dim=-1)
280
+ gate_weights = self.gate(gate_input).unsqueeze(1)
281
+
282
+ # Gated fusion: ์ ์‘์ ์œผ๋กœ cross-attention ๊ฒฐ๊ณผ ์กฐ์ ˆ
283
+ fused_output = gate_weights * attn_output + (1 - gate_weights) * query
284
+
285
+ # Pool for relation classification
286
+ query_pooled = query.mean(dim=1) if query_mask is None else \
287
+ (query * query_mask.unsqueeze(-1)).sum(1) / query_mask.sum(1, keepdim=True).clamp(min=1e-9)
288
+ key_pooled = key.mean(dim=1) if key_mask is None else \
289
+ (key * key_mask.unsqueeze(-1)).sum(1) / key_mask.sum(1, keepdim=True).clamp(min=1e-9)
290
+
291
+ # Classify relations with enhanced head
292
+ combined = torch.cat([query_pooled, key_pooled], dim=-1)
293
+ relation_logits = self.relation_head(combined)
294
+
295
+ return {
296
+ 'cross_attention': fused_output, # Gated fusion output
297
+ 'attention_weights': attn_weights,
298
+ 'relation_logits': relation_logits,
299
+ 'gate_weights': gate_weights.squeeze(1) # For analysis
300
+ }
301
+
302
+
303
+ class TransformerDecoder(nn.Module):
304
+ """
305
+ Transformer Decoder with Positional Encoding
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ vocab_size: int = 260,
311
+ hidden_dim: int = 768,
312
+ num_heads: int = 8,
313
+ num_layers: int = 6,
314
+ dropout: float = 0.1,
315
+ max_seq_len: int = 512
316
+ ):
317
+ super().__init__()
318
+
319
+ # Token embedding
320
+ self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
321
+
322
+ # Positional encoding
323
+ self.pos_encoding = PositionalEncoding(hidden_dim, max_seq_len, dropout)
324
+
325
+ # Transformer decoder
326
+ decoder_layer = nn.TransformerDecoderLayer(
327
+ d_model=hidden_dim,
328
+ nhead=num_heads,
329
+ dim_feedforward=hidden_dim * 4,
330
+ dropout=dropout,
331
+ activation='gelu',
332
+ batch_first=True,
333
+ norm_first=True
334
+ )
335
+
336
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
337
+
338
+ # Output projection
339
+ self.output_projection = nn.Linear(hidden_dim, vocab_size)
340
+
341
+ self.hidden_dim = hidden_dim
342
+ self.vocab_size = vocab_size
343
+
344
+ def forward(
345
+ self,
346
+ encoder_hidden: torch.Tensor,
347
+ decoder_input_ids: Optional[torch.Tensor] = None,
348
+ encoder_mask: Optional[torch.Tensor] = None,
349
+ decoder_mask: Optional[torch.Tensor] = None
350
+ ) -> Dict[str, torch.Tensor]:
351
+ batch_size = encoder_hidden.size(0)
352
+
353
+ # Start with BOS if no input
354
+ if decoder_input_ids is None:
355
+ decoder_input_ids = torch.full((batch_size, 1), 257, device=encoder_hidden.device)
356
+
357
+ # Embed and add positional encoding
358
+ dec_seq_len = decoder_input_ids.size(1)
359
+ x = self.token_embedding(decoder_input_ids)
360
+ x = self.pos_encoding(x)
361
+
362
+ # Create causal mask
363
+ causal_mask = torch.triu(
364
+ torch.ones(dec_seq_len, dec_seq_len, device=x.device) * float('-inf'),
365
+ diagonal=1
366
+ )
367
+
368
+ # Decoder forward - handle variable-length encoder outputs
369
+ # The encoder may compress the sequence, so memory (encoder_hidden) might be shorter
370
+ # than the decoder sequence. This is expected and correct behavior.
371
+ enc_seq_len = encoder_hidden.size(1)
372
+
373
+ # Adjust encoder mask if needed
374
+ if encoder_mask is not None:
375
+ if encoder_mask.size(1) != enc_seq_len:
376
+ # Encoder compressed the sequence, create new mask for compressed length
377
+ # All compressed positions are valid (not masked)
378
+ memory_key_padding_mask = torch.zeros(
379
+ encoder_hidden.size(0), enc_seq_len,
380
+ dtype=torch.bool, device=encoder_hidden.device
381
+ )
382
+ else:
383
+ memory_key_padding_mask = (encoder_mask == 0)
384
+ else:
385
+ memory_key_padding_mask = None
386
+
387
+ # Decoder attends to compressed encoder states via cross-attention
388
+ # This naturally handles different sequence lengths
389
+ decoder_output = self.transformer(
390
+ tgt=x, # Decoder sequence (original length)
391
+ memory=encoder_hidden, # Encoder sequence (possibly compressed)
392
+ tgt_mask=causal_mask,
393
+ memory_key_padding_mask=memory_key_padding_mask,
394
+ tgt_key_padding_mask=(decoder_mask == 0) if decoder_mask is not None else None
395
+ )
396
+
397
+ # Project to vocabulary
398
+ logits = self.output_projection(decoder_output)
399
+
400
+ return {
401
+ 'logits': logits,
402
+ 'hidden_states': decoder_output
403
+ }
404
+
405
+ @torch.no_grad()
406
+ def generate(
407
+ self,
408
+ encoder_hidden: torch.Tensor,
409
+ encoder_mask: Optional[torch.Tensor] = None,
410
+ max_length: int = 128,
411
+ temperature: float = 1.0,
412
+ top_k: int = 50,
413
+ top_p: float = 0.95
414
+ ) -> torch.Tensor:
415
+ batch_size = encoder_hidden.size(0)
416
+ device = encoder_hidden.device
417
+
418
+ # Start with BOS
419
+ decoder_input_ids = torch.full((batch_size, 1), 257, device=device)
420
+
421
+ for _ in range(max_length - 1):
422
+ # Forward pass
423
+ outputs = self.forward(encoder_hidden, decoder_input_ids, encoder_mask)
424
+ next_token_logits = outputs['logits'][:, -1, :] / temperature
425
+
426
+ # Top-k filtering
427
+ if top_k > 0:
428
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
429
+ next_token_logits[indices_to_remove] = float('-inf')
430
+
431
+ # Top-p filtering
432
+ if top_p < 1.0:
433
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
434
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
435
+
436
+ sorted_indices_to_remove = cumulative_probs > top_p
437
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
438
+ sorted_indices_to_remove[..., 0] = 0
439
+
440
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
441
+ next_token_logits[indices_to_remove] = float('-inf')
442
+
443
+ # Sample
444
+ probs = F.softmax(next_token_logits, dim=-1)
445
+ next_tokens = torch.multinomial(probs, 1)
446
+ decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1)
447
+
448
+ # Stop at EOS
449
+ if (next_tokens == 258).all(): # EOS token
450
+ break
451
+
452
+ return decoder_input_ids
453
+
454
+
455
+ class IntelligentTokenizerModel(nn.Module):
456
+ """
457
+ Complete Intelligent Tokenizer Model v6.0
458
+ ํ†ตํ•ฉ ๋ชจ๋ธ - Encoder + Decoder + Cross-Attention
459
+ """
460
+
461
+ def __init__(
462
+ self,
463
+ vocab_size: int = 260,
464
+ encoder_dims: List[int] = [384, 384, 512, 640, 768], # 512 ์ถ”๊ฐ€
465
+ decoder_hidden: int = 768,
466
+ num_heads: int = 8,
467
+ num_decoder_layers: int = 6,
468
+ dropout: float = 0.1,
469
+ max_seq_len: int = 512
470
+ ):
471
+ super().__init__()
472
+
473
+ # Components
474
+ self.tokenizer = ByteTokenizer(max_seq_len)
475
+ self.encoder = ByteEncoder(vocab_size, encoder_dims, num_heads, dropout, max_seq_len)
476
+ self.decoder = TransformerDecoder(vocab_size, decoder_hidden, num_heads, num_decoder_layers, dropout, max_seq_len)
477
+ self.cross_attention = CrossAttention(encoder_dims[-1], num_heads, dropout)
478
+
479
+ def forward(
480
+ self,
481
+ input_texts: Optional[List[str]] = None,
482
+ input_ids: Optional[torch.Tensor] = None,
483
+ attention_mask: Optional[torch.Tensor] = None,
484
+ decoder_input_ids: Optional[torch.Tensor] = None,
485
+ labels: Optional[torch.Tensor] = None,
486
+ use_cross_attention: bool = True
487
+ ) -> Dict[str, torch.Tensor]:
488
+ # Tokenize if text input
489
+ if input_texts is not None:
490
+ tokenized = self.tokenizer.encode_batch(input_texts)
491
+ input_ids = tokenized['input_ids']
492
+ attention_mask = tokenized['attention_mask']
493
+
494
+ # ์‹œํ€€์Šค ๊ธธ์ด ์ฒดํฌ ๋ฐ ์กฐ์ •
495
+ batch_size, seq_len = input_ids.shape
496
+ device = input_ids.device
497
+
498
+ # Encode
499
+ encoder_outputs = self.encoder(input_ids, attention_mask)
500
+ encoder_hidden = encoder_outputs['last_hidden_state'] # [batch, seq, 768]
501
+
502
+ # ์ฐจ์› ํ™•์ธ
503
+ assert encoder_hidden.size(-1) == 768, f"Encoder dim mismatch: {encoder_hidden.size(-1)}"
504
+
505
+ # Decode
506
+ decoder_outputs = self.decoder(
507
+ encoder_hidden,
508
+ decoder_input_ids,
509
+ attention_mask
510
+ )
511
+ decoder_hidden = decoder_outputs['hidden_states'] # [batch, seq, 768]
512
+
513
+ # Cross-Attention (๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์—์„œ ๊ด€๊ณ„ ํ•™์Šต)
514
+ cross_attn_outputs = None
515
+ relation_logits = None
516
+
517
+ if use_cross_attention and decoder_hidden is not None:
518
+ # ๋””์ฝ”๋” ์ถœ๋ ฅ๊ณผ ์ธ์ฝ”๋” ์ถœ๋ ฅ ๊ฐ„ ํฌ๋กœ์Šค์–ดํ…์…˜
519
+ cross_attn_outputs = self.cross_attention(
520
+ query=decoder_hidden, # ๋””์ฝ”๋”๊ฐ€ query
521
+ key=encoder_hidden, # ์ธ์ฝ”๋”๊ฐ€ key/value
522
+ query_mask=None, # decoder mask๋Š” causal์ด๋ฏ€๋กœ ๋ณ„๋„ ์ฒ˜๋ฆฌ
523
+ key_mask=attention_mask
524
+ )
525
+
526
+ # ๊ด€๊ณ„ ํ•™์Šต ๊ฒฐ๊ณผ
527
+ relation_logits = cross_attn_outputs['relation_logits']
528
+
529
+ # Cross-attention์œผ๋กœ ๊ฐ•ํ™”๋œ ๋””์ฝ”๋” ํ‘œํ˜„
530
+ enhanced_decoder = decoder_hidden + cross_attn_outputs['cross_attention']
531
+
532
+ # ์ตœ์ข… ๋กœ์ง“ ์žฌ๊ณ„์‚ฐ (cross-attention ์ ์šฉ ํ›„)
533
+ if hasattr(self.decoder, 'output_projection'):
534
+ decoder_outputs['logits'] = self.decoder.output_projection(enhanced_decoder)
535
+
536
+ # Calculate loss if labels provided
537
+ loss = None
538
+ if labels is not None:
539
+ # Reconstruction loss
540
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.PAD)
541
+ recon_loss = loss_fct(
542
+ decoder_outputs['logits'].reshape(-1, decoder_outputs['logits'].size(-1)),
543
+ labels.reshape(-1)
544
+ )
545
+
546
+ # Relation loss (if cross-attention used)
547
+ relation_loss = 0
548
+ if relation_logits is not None:
549
+ # ์ž๊ธฐ ๊ด€๊ณ„๋Š” identity (class 0)์—ฌ์•ผ ํ•จ
550
+ batch_identity = torch.zeros(batch_size, dtype=torch.long, device=device)
551
+ relation_loss = F.cross_entropy(relation_logits, batch_identity) * 0.1
552
+
553
+ loss = recon_loss + relation_loss
554
+
555
+ return {
556
+ 'loss': loss,
557
+ 'logits': decoder_outputs['logits'],
558
+ 'encoder_hidden_states': encoder_hidden,
559
+ 'decoder_hidden_states': decoder_hidden,
560
+ 'pooled_output': encoder_outputs['pooled_output'],
561
+ 'cross_attention': cross_attn_outputs['cross_attention'] if cross_attn_outputs else None,
562
+ 'relation_logits': relation_logits,
563
+ 'all_encoder_states': encoder_outputs.get('all_hidden_states', None)
564
+ }
565
+
566
+ def encode_text(self, text: str) -> torch.Tensor:
567
+ """Encode single text to representation"""
568
+ tokenized = self.tokenizer.encode(text)
569
+ # Move to same device as model
570
+ device = next(self.parameters()).device
571
+ input_ids = tokenized['input_ids'].unsqueeze(0).to(device)
572
+ attention_mask = tokenized['attention_mask'].unsqueeze(0).to(device)
573
+
574
+ with torch.no_grad():
575
+ outputs = self.encoder(input_ids, attention_mask)
576
+
577
+ return outputs['pooled_output'].squeeze(0)
578
+
579
+ def decode_representation(self, representation: torch.Tensor, max_length: int = 128) -> str:
580
+ """Decode representation back to text"""
581
+ if representation.dim() == 1:
582
+ representation = representation.unsqueeze(0).unsqueeze(0)
583
+ elif representation.dim() == 2:
584
+ representation = representation.unsqueeze(1)
585
+
586
+ with torch.no_grad():
587
+ output_ids = self.decoder.generate(representation, max_length=max_length)
588
+
589
+ text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
590
+ return text
591
+
592
+ def compute_relation(self, text1: str, text2: str) -> torch.Tensor:
593
+ """Compute relation between two texts"""
594
+ # Encode both texts
595
+ enc1 = self.encode_text(text1).unsqueeze(0).unsqueeze(0)
596
+ enc2 = self.encode_text(text2).unsqueeze(0).unsqueeze(0)
597
+
598
+ # Compute cross-attention and relations
599
+ with torch.no_grad():
600
+ outputs = self.cross_attention(enc1, enc2)
601
+
602
+ return F.softmax(outputs['relation_logits'], dim=-1)