row56 commited on
Commit
832f402
·
verified ·
1 Parent(s): b367dc5

Upload proto_model/modeling_proto.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. proto_model/modeling_proto.py +44 -0
proto_model/modeling_proto.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ import torch
3
+ from .proto import ProtoModule
4
+ from .configuration_proto import ProtoConfig
5
+
6
+ class ProtoForMultiLabelClassification(PreTrainedModel):
7
+ config_class = ProtoConfig
8
+
9
+ def __init__(self, config: ProtoConfig):
10
+ super().__init__(config)
11
+ self.proto_module = ProtoModule(
12
+ pretrained_model=config.pretrained_model_name_or_path,
13
+ num_classes=config.num_classes,
14
+ label_order_path=config.label_order_path,
15
+ use_sigmoid=config.use_sigmoid,
16
+ use_cuda=config.use_cuda,
17
+ lr_prototypes=config.lr_prototypes,
18
+ lr_features=config.lr_features,
19
+ lr_others=config.lr_others,
20
+ num_training_steps=config.num_training_steps,
21
+ num_warmup_steps=config.num_warmup_steps,
22
+ loss=config.loss,
23
+ save_dir=config.save_dir,
24
+ use_attention=config.use_attention,
25
+ dot_product=config.dot_product,
26
+ normalize=config.normalize,
27
+ final_layer=config.final_layer,
28
+ reduce_hidden_size=config.reduce_hidden_size,
29
+ use_prototype_loss=config.use_prototype_loss,
30
+ prototype_vector_path=config.prototype_vector_path,
31
+ attention_vector_path=config.attention_vector_path,
32
+ eval_buckets=config.eval_buckets,
33
+ seed=config.seed
34
+ )
35
+ self.init_weights()
36
+
37
+ def forward(self, input_ids, attention_mask, token_type_ids, **kwargs):
38
+ batch = {
39
+ "input_ids": input_ids,
40
+ "attention_masks": attention_mask,
41
+ "token_type_ids": token_type_ids,
42
+ }
43
+ logits, metadata = self.proto_module(batch)
44
+ return {"logits": logits, "metadata": metadata}