Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel, PretrainedConfig | |
| class BiLSTMConfig(PretrainedConfig): | |
| model_type = "bilstm_attention" | |
| def __init__(self, hidden_dim=128, num_classes=22, num_layers=2, dropout=0.5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.hidden_dim = hidden_dim | |
| self.num_classes = num_classes | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| class BiLSTMAttentionBERT(PreTrainedModel): | |
| config_class = BiLSTMConfig | |
| base_model_prefix = "bilstm_attention" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.bert = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2') | |
| self.lstm = nn.LSTM( | |
| 768, | |
| config.hidden_dim, | |
| config.num_layers, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.fc = nn.Linear(config.hidden_dim * 2, config.num_classes) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids, attention_mask=attention_mask) | |
| bert_output = outputs[0] | |
| lstm_output, _ = self.lstm(bert_output) | |
| dropped = self.dropout(lstm_output[:, -1, :]) | |
| logits = self.fc(dropped) | |
| return logits |