not-lain commited on
Commit
6bf8f07
·
verified ·
1 Parent(s): 822b8da

add process method

Browse files
Files changed (1) hide show
  1. modeling_tunbert.py +7 -0
modeling_tunbert.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch.nn as nn
2
  from transformers import PreTrainedModel, BertModel
3
  from transformers.modeling_outputs import SequenceClassifierOutput
@@ -38,6 +39,12 @@ class TunBERT(PreTrainedModel):
38
  loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
39
  return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)
40
 
 
 
 
 
 
 
41
 
42
  TunBertConfig.register_for_auto_class()
43
  TunBERT.register_for_auto_class("AutoModelForSequenceClassification")
 
1
+ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, BertModel
4
  from transformers.modeling_outputs import SequenceClassifierOutput
 
39
  loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
40
  return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)
41
 
42
+ def process(self,**inputs):
43
+ with torch.no_grad():
44
+ out = self.forward(**inputs)
45
+ out = torch.argmax(output.logits,dim=1)
46
+ return ["positive" if index == 0 else "negative" for index in out.tolist()]
47
+
48
 
49
  TunBertConfig.register_for_auto_class()
50
  TunBERT.register_for_auto_class("AutoModelForSequenceClassification")