joko333 commited on
Commit
41047a5
·
1 Parent(s): dd9aa69

Refactor BiLSTMAttentionBERT to use BiLSTMConfig for improved configuration management

Browse files
Files changed (1) hide show
  1. utils/model.py +19 -15
utils/model.py CHANGED
@@ -2,24 +2,28 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, AutoModel, PretrainedConfig
4
 
 
 
 
 
 
 
 
 
5
  class BiLSTMAttentionBERT(PreTrainedModel):
6
- def __init__(self, hidden_dim, num_classes, num_layers, dropout):
7
- super().__init__(PretrainedConfig())
 
8
  self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
9
- self.lstm = nn.LSTM(768, hidden_dim, num_layers, batch_first=True, bidirectional=True)
10
- self.dropout = nn.Dropout(dropout)
11
- self.fc = nn.Linear(hidden_dim * 2, num_classes)
12
-
13
- @classmethod
14
- def from_pretrained(cls, model_path, hidden_dim, num_classes, num_layers, dropout):
15
- model = cls(hidden_dim, num_classes, num_layers, dropout)
16
- state_dict = torch.load(model_path, map_location='cpu')
17
- model.load_state_dict(state_dict)
18
- return model
19
 
20
  def forward(self, input_ids, attention_mask):
21
- bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]
 
22
  lstm_output, _ = self.lstm(bert_output)
23
  dropped = self.dropout(lstm_output[:, -1, :])
24
- output = self.fc(dropped)
25
- return output
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, AutoModel, PretrainedConfig
4
 
5
+ class BiLSTMConfig(PretrainedConfig):
6
+ def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.hidden_dim = hidden_dim
9
+ self.num_classes = num_classes
10
+ self.num_layers = num_layers
11
+ self.dropout = dropout
12
+
13
  class BiLSTMAttentionBERT(PreTrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.config = config
17
  self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
18
+ self.lstm = nn.LSTM(768, config.hidden_dim, config.num_layers,
19
+ batch_first=True, bidirectional=True)
20
+ self.dropout = nn.Dropout(config.dropout)
21
+ self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes)
 
 
 
 
 
 
22
 
23
  def forward(self, input_ids, attention_mask):
24
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
25
+ bert_output = outputs[0]
26
  lstm_output, _ = self.lstm(bert_output)
27
  dropped = self.dropout(lstm_output[:, -1, :])
28
+ logits = self.fc(dropped)
29
+ return logits