Upload proto_model/modeling_proto.py with huggingface_hub
Browse files
proto_model/modeling_proto.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel
|
| 2 |
+
import torch
|
| 3 |
+
from .proto import ProtoModule
|
| 4 |
+
from .configuration_proto import ProtoConfig
|
| 5 |
+
|
| 6 |
+
class ProtoForMultiLabelClassification(PreTrainedModel):
|
| 7 |
+
config_class = ProtoConfig
|
| 8 |
+
|
| 9 |
+
def __init__(self, config: ProtoConfig):
|
| 10 |
+
super().__init__(config)
|
| 11 |
+
self.proto_module = ProtoModule(
|
| 12 |
+
pretrained_model=config.pretrained_model_name_or_path,
|
| 13 |
+
num_classes=config.num_classes,
|
| 14 |
+
label_order_path=config.label_order_path,
|
| 15 |
+
use_sigmoid=config.use_sigmoid,
|
| 16 |
+
use_cuda=config.use_cuda,
|
| 17 |
+
lr_prototypes=config.lr_prototypes,
|
| 18 |
+
lr_features=config.lr_features,
|
| 19 |
+
lr_others=config.lr_others,
|
| 20 |
+
num_training_steps=config.num_training_steps,
|
| 21 |
+
num_warmup_steps=config.num_warmup_steps,
|
| 22 |
+
loss=config.loss,
|
| 23 |
+
save_dir=config.save_dir,
|
| 24 |
+
use_attention=config.use_attention,
|
| 25 |
+
dot_product=config.dot_product,
|
| 26 |
+
normalize=config.normalize,
|
| 27 |
+
final_layer=config.final_layer,
|
| 28 |
+
reduce_hidden_size=config.reduce_hidden_size,
|
| 29 |
+
use_prototype_loss=config.use_prototype_loss,
|
| 30 |
+
prototype_vector_path=config.prototype_vector_path,
|
| 31 |
+
attention_vector_path=config.attention_vector_path,
|
| 32 |
+
eval_buckets=config.eval_buckets,
|
| 33 |
+
seed=config.seed
|
| 34 |
+
)
|
| 35 |
+
self.init_weights()
|
| 36 |
+
|
| 37 |
+
def forward(self, input_ids, attention_mask, token_type_ids, **kwargs):
|
| 38 |
+
batch = {
|
| 39 |
+
"input_ids": input_ids,
|
| 40 |
+
"attention_masks": attention_mask,
|
| 41 |
+
"token_type_ids": token_type_ids,
|
| 42 |
+
}
|
| 43 |
+
logits, metadata = self.proto_module(batch)
|
| 44 |
+
return {"logits": logits, "metadata": metadata}
|