Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from models.preprocess_stage.preprocess_lstm import preprocess_lstm | |
| EMBEDDING_DIM = 128 | |
| HIDDEN_SIZE = 16 | |
| MAX_LEN = 125 | |
| # DEVICE='cpu' | |
| embedding_matrix = np.load('models/datasets/embedding_matrix.npy') | |
| embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix)) | |
| class AtenttionTest(nn.Module): | |
| def __init__(self, hidden_size=HIDDEN_SIZE): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.fc2 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.tahn = nn.Tanh() | |
| self.fc3 = nn.Linear(self.hidden_size, 1) | |
| def forward(self, outputs_lmst, h_n): | |
| output_fc1 = self.fc1(outputs_lmst) | |
| output_fc2 = self.fc2(h_n.squeeze(0)) | |
| fc1_fc2_cat = output_fc1 + output_fc2.unsqueeze(1) | |
| output_tahn = self.tahn(fc1_fc2_cat) | |
| attention_weights = torch.softmax(self.fc3(output_tahn).squeeze(2), dim=1) | |
| output_finished = torch.bmm(output_fc1.transpose(1, 2), attention_weights.unsqueeze(2)) | |
| return output_finished, attention_weights | |
| class LSTMnn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.embedding = embedding_layer | |
| self.lstm = nn.LSTM( | |
| input_size=EMBEDDING_DIM, | |
| hidden_size=HIDDEN_SIZE, | |
| num_layers=1, | |
| batch_first=True | |
| ) | |
| self.attention = AtenttionTest(hidden_size=HIDDEN_SIZE) | |
| self.fc_out = nn.Sequential( | |
| nn.Linear(HIDDEN_SIZE, 128), | |
| nn.Dropout(), | |
| nn.Tanh(), | |
| nn.Linear(128, 1) | |
| ) | |
| def forward(self, x): | |
| embedding = self.embedding(x) | |
| output_lstm, (h_n, _) = self.lstm(embedding) | |
| output_attention, attention_weights = self.attention(output_lstm, h_n) | |
| output_finished = self.fc_out(output_attention.squeeze(2)) | |
| return torch.sigmoid(output_finished), attention_weights | |
| model = LSTMnn() | |
| model.load_state_dict(torch.load('models/weights/LSTMBestWeights.pt', map_location=torch.device('cpu'))) | |
| def predict_3(text): | |
| preprocessed_text = preprocess_lstm(text, MAX_LEN=MAX_LEN) | |
| # model.to(DEVICE) | |
| model.eval() | |
| predict, attention = model(torch.tensor(preprocessed_text).unsqueeze(0)) | |
| predict = round(predict.item()) | |
| return predict |