Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoTokenizer | |
| class Swish(torch.nn.Module): | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| class Mish(torch.nn.Module): | |
| def forward(self, x): | |
| return x * torch.tanh(torch.nn.functional.softplus(x)) | |
| class ResidualInceptionBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_sizes=[1,3], dropout=0.05): | |
| super(ResidualInceptionBlock, self).__init__() | |
| self.out_channels = out_channels | |
| num_branches = len(kernel_sizes) | |
| branch_out_channels = out_channels // num_branches | |
| self.branches = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv1d(in_channels, in_channels, kernel_size=1), | |
| nn.BatchNorm1d(in_channels), | |
| nn.ReLU(), | |
| nn.Conv1d(in_channels, branch_out_channels, kernel_size=k, padding=k // 2), | |
| nn.BatchNorm1d(branch_out_channels), | |
| nn.ReLU(), | |
| nn.Dropout(dropout) | |
| ) for k in kernel_sizes | |
| ]) | |
| self.residual_adjust = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| branch_outputs = [branch(x) for branch in self.branches] | |
| concatenated = torch.cat(branch_outputs, dim=1) | |
| residual = self.residual_adjust(x) | |
| output = self.relu(concatenated + residual) | |
| return output | |
| class AffinityPredictor(nn.Module): | |
| def __init__(self, | |
| protein_model_name="facebook/esm2_t6_8M_UR50D", | |
| molecule_model_name="DeepChem/ChemBERTa-77M-MLM", | |
| hidden_sizes=[1024,768,512,256,1], | |
| inception_out_channels=256, | |
| dropout=0.01): | |
| super(AffinityPredictor, self).__init__() | |
| self.protein_model = AutoModel.from_pretrained(protein_model_name) | |
| self.molecule_model = AutoModel.from_pretrained(molecule_model_name) | |
| self.protein_model.config.gradient_checkpointing = True | |
| self.protein_model.gradient_checkpointing_enable() | |
| self.molecule_model.config.gradient_checkpointing = True | |
| self.molecule_model.gradient_checkpointing_enable() | |
| prot_embedding_dim = self.protein_model.config.hidden_size | |
| mol_embedding_dim = self.molecule_model.config.hidden_size | |
| combined_dim = prot_embedding_dim + mol_embedding_dim | |
| self.inc1 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout) | |
| self.inc2 = ResidualInceptionBlock(combined_dim, combined_dim, dropout=dropout) | |
| layers = [] | |
| input_dim = combined_dim # After Inception block | |
| for output_dim in hidden_sizes: | |
| layers.append(nn.Linear(input_dim, output_dim)) | |
| if output_dim != 1: | |
| layers.append(Mish()) | |
| input_dim = output_dim | |
| self.regressor = nn.Sequential(*layers) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, batch): | |
| protein_input = { | |
| "input_ids": batch["protein_input_ids"], | |
| "attention_mask": batch["protein_attention_mask"] | |
| } | |
| molecule_input = { | |
| "input_ids": batch["molecule_input_ids"], | |
| "attention_mask": batch["molecule_attention_mask"] | |
| } | |
| protein_embedding = self.protein_model(**protein_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim) | |
| molecule_embedding = self.molecule_model(**molecule_input).last_hidden_state.mean(dim=1) # (batch_size, hidden_dim) | |
| combined_features = torch.cat((protein_embedding, molecule_embedding), dim=1).unsqueeze(2) # (batch_size, combined_dim, 1) | |
| combined_features = self.inc1(combined_features) # (batch_size, combined_dim) | |
| combined_features = self.inc2(combined_features) | |
| combined_features = combined_features.squeeze(2) | |
| output = self.regressor(self.dropout(combined_features)) # (batch_size, 1) | |
| return output |