Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchinfo import summary | |
| # Not in use yet | |
| class Conv1d_layer(nn.Module): | |
| def __init__(self, in_channel, out_channel, kernel_size) -> None: | |
| super().__init__() | |
| self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size) | |
| self.batch_norm = torch.nn.BatchNorm1d(out_channel) | |
| self.dropout = nn.Dropout1d(p=0.5) | |
| def forward(self, x): | |
| x= self.conv(x) | |
| x = self.batch_norm(x) | |
| x = self.dropout(x) | |
| return x | |
| class CNN(nn.Module): | |
| def __init__(self, ecg_channels=12): | |
| super(CNN, self).__init__() | |
| self.name = "CNN" | |
| self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
| self.pool1 = nn.MaxPool1d(2, 2) | |
| self.conv2 = nn.Conv1d(16, 32, 5) | |
| self.pool2 = nn.MaxPool1d(2, 2) | |
| self.conv3 = nn.Conv1d(32, 48, 3) | |
| self.pool3 = nn.MaxPool1d(2, 2) | |
| self.fc0 = nn.Linear(5856, 512) | |
| self.fc1 = nn.Linear(512, 128) | |
| self.fc2 = nn.Linear(128, 5) | |
| self.activation = nn.ReLU() | |
| def forward(self, x, notes=None): | |
| x = self.pool1(self.activation(self.conv1(x))) | |
| x = self.pool2(self.activation(self.conv2(x))) | |
| x = self.pool3(self.activation(self.conv3(x))) | |
| x = x.view(x.size(0),-1) | |
| x = self.activation(self.fc0(x)) | |
| x = self.activation(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) | |
| return x | |
| class MMCNN_SUM(nn.Module): | |
| def __init__(self, ecg_channels=12): | |
| super(MMCNN_SUM, self).__init__() | |
| # ECG processing Layers | |
| self.name = "MMCNN_SUM" | |
| self.conv1 = Conv1d_layer(ecg_channels, 16, 7) | |
| self.pool1 = nn.MaxPool1d(2, 2) | |
| self.conv2 = Conv1d_layer(16, 32, 5) | |
| self.pool2 = nn.MaxPool1d(2, 2) | |
| self.conv3 = Conv1d_layer(32, 48, 3) | |
| self.pool3 = nn.MaxPool1d(2, 2) | |
| self.fc0 = nn.Linear(5856, 512) | |
| self.fc1 = nn.Linear(512, 128) | |
| self.fc2 = nn.Linear(128, 5) | |
| # Clinical Notes Processing Layers | |
| self.fc_emb = nn.Linear(768, 128) | |
| self.norm = nn.LayerNorm(128) | |
| self.activation = nn.ReLU() | |
| def forward(self, x, notes): | |
| # ECG Processing | |
| x = self.pool1(self.activation(self.conv1(x))) | |
| x = self.pool2(self.activation(self.conv2(x))) | |
| x = self.pool3(self.activation(self.conv3(x))) | |
| x = x.view(x.size(0),-1) | |
| x = self.activation(self.fc0(x)) | |
| x = self.activation(self.fc1(x)) | |
| # Notes Processing | |
| notes = notes.view(notes.size(0),-1) | |
| notes = self.activation(self.fc_emb(notes)) | |
| x = self.fc2(self.norm(x + notes)) | |
| x = x.squeeze(1) | |
| return x | |
| class MMCNN_CAT(nn.Module): | |
| def __init__(self, ecg_channels=12): | |
| super(MMCNN_CAT, self).__init__() | |
| # ECG processing Layers | |
| self.name = "MMCNN_CAT" | |
| self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
| self.pool1 = nn.MaxPool1d(2, 2) | |
| self.conv2 = nn.Conv1d(16, 32, 5) | |
| self.pool2 = nn.MaxPool1d(2, 2) | |
| self.conv3 = nn.Conv1d(32, 48, 3) | |
| self.pool3 = nn.MaxPool1d(2, 2) | |
| self.fc0 = nn.Linear(5856, 512) | |
| self.fc1 = nn.Linear(512, 128) | |
| self.fc2 = nn.Linear(256, 5) | |
| # Clinical Notes Processing Layers | |
| self.fc_emb = nn.Linear(768, 128) | |
| self.norm = nn.LayerNorm(128) | |
| self.activation = nn.ReLU() | |
| def forward(self, x, notes): | |
| # ECG Processing | |
| x = self.pool1(self.activation(self.conv1(x))) | |
| x = self.pool2(self.activation(self.conv2(x))) | |
| x = self.pool3(self.activation(self.conv3(x))) | |
| x = x.view(x.size(0),-1) | |
| x = self.activation(self.fc0(x)) | |
| x = self.activation(self.fc1(x)) | |
| # Notes Processing | |
| notes = notes.view(notes.size(0),-1) | |
| notes = self.activation(self.fc_emb(notes)) | |
| x = self.fc2(torch.cat((x,notes),dim=1)) | |
| x = x.squeeze(1) | |
| return x | |
| class MMCNN_ATT(nn.Module): | |
| def __init__(self, ecg_channels=12): | |
| super(MMCNN_ATT, self).__init__() | |
| # ECG processing Layers | |
| self.name = "MMCNN_ATT" | |
| self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
| self.pool1 = nn.MaxPool1d(2, 2) | |
| self.conv2 = nn.Conv1d(16, 32, 5) | |
| self.pool2 = nn.MaxPool1d(2, 2) | |
| self.conv3 = nn.Conv1d(32, 48, 3) | |
| self.pool3 = nn.MaxPool1d(2, 2) | |
| self.fc0 = nn.Linear(5856, 512) | |
| self.fc1 = nn.Linear(512, 128) | |
| self.fc2 = nn.Linear(128, 5) | |
| # Clinical Notes Processing Layers | |
| self.fc_emb = nn.Linear(768, 128) | |
| self.norm1 = nn.LayerNorm(128) | |
| self.norm2 = nn.LayerNorm(128) | |
| self.attention = nn.MultiheadAttention(128, 8, batch_first=True) | |
| self.activation = nn.ReLU() | |
| def forward(self, x, notes): | |
| # ECG Processing | |
| x = self.pool1(self.activation(self.conv1(x))) | |
| x = self.pool2(self.activation(self.conv2(x))) | |
| x = self.pool3(self.activation(self.conv3(x))) | |
| x = x.view(x.size(0),-1) | |
| x = self.activation(self.fc0(x)) | |
| x = self.activation(self.fc1(x)) | |
| x = self.norm1(x) | |
| # Notes Processing | |
| notes = notes.view(notes.size(0),-1) | |
| notes = self.activation(self.fc_emb(notes)) | |
| notes = self.norm2(notes) | |
| notes=notes.unsqueeze(1) | |
| x=x.unsqueeze(1) | |
| x,_= self.attention(notes, x, x) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) | |
| return x | |
| class MMCNN_SUM_ATT(nn.Module): | |
| def __init__(self, ecg_channels=12): | |
| super(MMCNN_SUM_ATT, self).__init__() | |
| # ECG processing Layers | |
| self.name = "MMCNN_SUM_ATT" | |
| self.conv1 = nn.Conv1d(ecg_channels, 16, 7) | |
| self.pool1 = nn.MaxPool1d(2, 2) | |
| self.conv2 = nn.Conv1d(16, 32, 5) | |
| self.pool2 = nn.MaxPool1d(2, 2) | |
| self.conv3 = nn.Conv1d(32, 48, 3) | |
| self.pool3 = nn.MaxPool1d(2, 2) | |
| self.fc0 = nn.Linear(5856, 512) | |
| self.fc1 = nn.Linear(512, 128) | |
| self.fc2 = nn.Linear(128, 5) | |
| # Clinical Notes Processing Layers | |
| self.fc_emb = nn.Linear(768, 128) | |
| self.norm = nn.LayerNorm(128) | |
| self.attention = nn.MultiheadAttention(128, 8, batch_first=True) | |
| self.activation = nn.ReLU() | |
| def forward(self, x, notes): | |
| # ECG Processing | |
| x = self.pool1(self.activation(self.conv1(x))) | |
| x = self.pool2(self.activation(self.conv2(x))) | |
| x = self.pool3(self.activation(self.conv3(x))) | |
| x = x.view(x.size(0),-1) | |
| x = self.activation(self.fc0(x)) | |
| x = self.activation(self.fc1(x)) | |
| # Notes Processing | |
| notes = notes.view(notes.size(0),-1) | |
| notes = self.activation(self.fc_emb(notes)) | |
| x = self.norm(x + notes) | |
| x=x.unsqueeze(1) | |
| # print(x.shape) | |
| x,_= self.attention(x, x, x) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) | |
| return x | |
| if __name__ == "__main__": | |
| model = CNN() | |
| # model = Conv1d_layer(12, 16, 7) | |
| summary(model, input_size = (1, 12, 1000)) | |