File size: 2,219 Bytes
6bf8f07
83dab7c
a09becc
83dab7c
86f425a
d0c15af
83dab7c
9a22429
 
 
 
 
d0c15af
 
 
 
83dab7c
 
 
a09becc
83dab7c
 
 
 
 
 
 
d0c15af
 
 
 
 
 
 
822b8da
d0c15af
 
9edfd90
 
d0c15af
 
6bf8f07
 
 
fb3cc84
6bf8f07
 
a09becc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
from transformers import  PreTrainedModel, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from .config_tunbert import TunBertConfig

class classifier(nn.Module):
    def __init__(self,config):
        super().__init__()
    
        self.layer0 = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=True)
        self.layer1 = nn.Linear(in_features=config.hidden_size, out_features=config.type_vocab_size, bias=True)
      
    def forward(self,tensor):
        out1 = self.layer0(tensor)
        return self.layer1(out1)


class TunBERT(PreTrainedModel):
    config_class = TunBertConfig
    def __init__(self, config):
        super().__init__(config)
        self.BertModel = BertModel(config)
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.classifier = classifier(config)

    def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None) :
        outputs = self.BertModel(input_ids,token_type_ids,attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)
        # every sentence is surrounded by [cls] in the beginning and [sep] in the end
        # the [cls] token is used in bert to identify the class of the sentence
        # meaning that we need only the first token of each sentence 
        # and the model representation of the rest of the sentence does not concern us
        logits = logits[:,0,:] # [bs, seq, class] 
        loss =None
        if labels is not None :
            loss_func = nn.CrossEntropyLoss()
            loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1)) 
        return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)

    def process(self,**inputs):
        with torch.no_grad():
            out = self.forward(**inputs)
        out = torch.argmax(out.logits,dim=1)
        return ["positive" if index == 0 else "negative" for index in out.tolist()]


TunBertConfig.register_for_auto_class()
TunBERT.register_for_auto_class("AutoModelForSequenceClassification")