Spaces:
Running
Running
File size: 5,816 Bytes
c46b695 e2ba292 c46b695 e2ba292 c46b695 e2ba292 c46b695 e2ba292 c46b695 e2ba292 21d1d21 c46b695 e2ba292 c46b695 e2ba292 c46b695 e2ba292 c46b695 e2ba292 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import torch
import os
from transformers import AutoTokenizer
from models import AffinityPredictor
# Global variables for model and tokenizers
model = None
molecule_tokenizer = None
protein_tokenizer = None
device = None
def load_model():
"""Load the trained model and tokenizers"""
global model, molecule_tokenizer, protein_tokenizer, device
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load tokenizers
molecule_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# Initialize model with same configuration as training
model = AffinityPredictor(
protein_model_name="facebook/esm2_t6_8M_UR50D",
molecule_model_name="DeepChem/ChemBERTa-77M-MLM",
hidden_sizes=[1024, 768, 512, 256, 1],
inception_out_channels=256,
dropout=0.05
)
# Load the trained weights
model_path = "Davis-Final.pth"
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("Model loaded successfully!")
else:
print(f"Warning: Model file {model_path} not found. Using randomly initialized weights.")
model.to(device)
model.eval()
return True
def predict_affinity(smiles, sequence):
"""Predict drug-target affinity using the trained model"""
global model, molecule_tokenizer, protein_tokenizer, device
# Load model if not already loaded
if model is None:
try:
load_model()
except Exception as e:
return f"Error loading model: {str(e)}"
# Validate inputs
if not smiles or not smiles.strip():
return "Error: Please enter a valid SMILES string"
if not sequence or not sequence.strip():
return "Error: Please enter a valid protein sequence"
try:
model.eval()
# Tokenize inputs
molecule_encoding = molecule_tokenizer(
[smiles.strip()],
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
protein_encoding = protein_tokenizer(
[sequence.strip()],
padding="max_length",
truncation=True,
max_length=1024,
return_tensors="pt"
)
# Create batch dictionary
batch = {
"molecule_input_ids": molecule_encoding.input_ids.to(device),
"molecule_attention_mask": molecule_encoding.attention_mask.to(device),
"protein_input_ids": protein_encoding.input_ids.to(device),
"protein_attention_mask": protein_encoding.attention_mask.to(device)
}
# Make prediction
with torch.no_grad():
prediction = model(batch)
affinity_score = prediction.cpu().item()
return f"Predicted Affinity Score: {affinity_score:.4f}"
except Exception as e:
return f"Error during prediction: {str(e)}"
# Load model on startup
print("Loading model...")
try:
load_model()
print("Model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load model on startup: {e}")
with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo:
gr.Markdown("## Molecule–Protein Affinity Prediction")
gr.Markdown(
"Enter a **Molecule SMILES string** and a **Protein amino acid sequence** "
"then click **Predict** to get the affinity score using the StructureFree-DTA model."
)
gr.Markdown(
"### Example inputs:\n"
"**SMILES:** `CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC=C4)N`\n"
"\n**Protein:** `MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL`"
)
with gr.Row():
smiles_input = gr.Textbox(
label="Molecule SMILES",
placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O",
lines=2
)
sequence_input = gr.Textbox(
label="Protein Sequence",
placeholder="e.g. MVLSPADKTNVKAA...",
lines=5
)
predict_button = gr.Button("Predict", variant="primary")
output = gr.Textbox(label="Affinity Score", interactive=False)
predict_button.click(
fn=predict_affinity,
inputs=[smiles_input, sequence_input],
outputs=output
)
if __name__ == "__main__":
demo.launch() |