Upload ModularStarEncoder
Browse files- modularStarEncoder.py +13 -7
modularStarEncoder.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
-
from
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
@@ -347,12 +347,18 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 347 |
|
| 348 |
pooled_and_normalized.append(normalized_source_embedding)
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
|
|
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
+
from config import ModularStarEncoderConfig
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
|
|
| 347 |
|
| 348 |
pooled_and_normalized.append(normalized_source_embedding)
|
| 349 |
|
| 350 |
+
if not self.till_layer:
|
| 351 |
+
return ModularStarEncoderOutput(
|
| 352 |
+
projected_pooled_normalized = pooled_and_normalized,
|
| 353 |
+
raw_hidden_states=source_embedding.hidden_states,
|
| 354 |
+
attentions=source_embedding.attentions,
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
return ModularStarEncoderOutput(
|
| 358 |
+
projected_pooled_normalized = pooled_and_normalized[0],
|
| 359 |
+
raw_hidden_states=source_embedding.hidden_states,
|
| 360 |
+
attentions=source_embedding.attentions,
|
| 361 |
+
)
|
| 362 |
|
| 363 |
|
| 364 |
|