StructureFree-DTA / models.py
Amir Hallaji
version 0.1.0
233e25e
raw
history blame
4.09 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
class Swish(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class ResidualInceptionBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_sizes=[1,3], dropout=0.05):
super(ResidualInceptionBlock, self).__init__()
self.out_channels = out_channels
num_branches = len(kernel_sizes)
branch_out_channels = out_channels // num_branches
self.branches = nn.ModuleList([
nn.Sequential(
nn.Conv1d(in_channels, in_channels, kernel_size=1),
nn.BatchNorm1d(in_channels),
nn.ReLU(),
nn.Conv1d(in_channels, branch_out_channels, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(branch_out_channels),
nn.ReLU(),
nn.Dropout(dropout)
) for k in kernel_sizes
])
self.residual_adjust = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
self.relu = nn.ReLU()
def forward(self, x):
branch_outputs = [branch(x) for branch in self.branches]
concatenated = torch.cat(branch_outputs, dim=1)
residual = self.residual_adjust(x)
output = self.relu(concatenated + residual)
return output
class AffinityPredictor(nn.Module):
def __init__(self,
protein_model_name="facebook/esm2_t6_8M_UR50D",
molecule_model_name="DeepChem/ChemBERTa-77M-MLM",
hidden_sizes=[1024,768,512,256,1],
inception_out_channels=256,
dropout=0.01):
super(AffinityPredictor, self).__init__()
self.protein_model = AutoModel.from_pretrained(protein_model_name)
self.molecule_model = AutoModel.from_pretrained(molecule_model_name)
self.protein_model.config.gradient_checkpointing = True
self.protein_model.gradient_checkpointing_enable()
self.molecule_model.config.gradient_checkpointing = True
self.molecule_model.gradient_checkpointing_enable()
prot_embedding_dim = self.protein_model.config.hidden_size
mol_embedding_dim = self.molecule_model.config.hidden_size
combined_dim = prot_embedding_dim + mol_embedding_dim
self.inc1 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout)
self.inc2 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout)
layers = []
input_dim = combined_dim # After Inception block
for output_dim in hidden_sizes:
layers.append(nn.Linear(input_dim, output_dim))
if output_dim != 1:
layers.append(Mish())
input_dim = output_dim
self.regressor = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
def forward(self, batch):
protein_input = {
"input_ids": batch["protein_input_ids"],
"attention_mask": batch["protein_attention_mask"]
}
molecule_input = {
"input_ids": batch["molecule_input_ids"],
"attention_mask": batch["molecule_attention_mask"]
}
protein_embedding = self.protein_model(**protein_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim)
molecule_embedding = self.molecule_model(**molecule_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim)
combined_features = torch.cat((protein_embedding, molecule_embedding), dim=1).unsqueeze(2) # (batch_size, combined_dim, 1)
combined_features = self.inc1(combined_features) # (batch_size, combined_dim)
combined_features = self.inc2(combined_features)
combined_features = combined_features.squeeze(2)
output = self.regressor(self.dropout(combined_features)) # (batch_size, 1)
return output