Amir Hallaji
add new line
21d1d21
import gradio as gr
import torch
import os
from transformers import AutoTokenizer
from models import AffinityPredictor
# Global variables for model and tokenizers
model = None
molecule_tokenizer = None
protein_tokenizer = None
device = None
def load_model():
"""Load the trained model and tokenizers"""
global model, molecule_tokenizer, protein_tokenizer, device
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load tokenizers
molecule_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# Initialize model with same configuration as training
model = AffinityPredictor(
protein_model_name="facebook/esm2_t6_8M_UR50D",
molecule_model_name="DeepChem/ChemBERTa-77M-MLM",
hidden_sizes=[1024, 768, 512, 256, 1],
inception_out_channels=256,
dropout=0.05
)
# Load the trained weights
model_path = "Davis-Final.pth"
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("Model loaded successfully!")
else:
print(f"Warning: Model file {model_path} not found. Using randomly initialized weights.")
model.to(device)
model.eval()
return True
def predict_affinity(smiles, sequence):
"""Predict drug-target affinity using the trained model"""
global model, molecule_tokenizer, protein_tokenizer, device
# Load model if not already loaded
if model is None:
try:
load_model()
except Exception as e:
return f"Error loading model: {str(e)}"
# Validate inputs
if not smiles or not smiles.strip():
return "Error: Please enter a valid SMILES string"
if not sequence or not sequence.strip():
return "Error: Please enter a valid protein sequence"
try:
model.eval()
# Tokenize inputs
molecule_encoding = molecule_tokenizer(
[smiles.strip()],
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
protein_encoding = protein_tokenizer(
[sequence.strip()],
padding="max_length",
truncation=True,
max_length=1024,
return_tensors="pt"
)
# Create batch dictionary
batch = {
"molecule_input_ids": molecule_encoding.input_ids.to(device),
"molecule_attention_mask": molecule_encoding.attention_mask.to(device),
"protein_input_ids": protein_encoding.input_ids.to(device),
"protein_attention_mask": protein_encoding.attention_mask.to(device)
}
# Make prediction
with torch.no_grad():
prediction = model(batch)
affinity_score = prediction.cpu().item()
return f"Predicted Affinity Score: {affinity_score:.4f}"
except Exception as e:
return f"Error during prediction: {str(e)}"
# Load model on startup
print("Loading model...")
try:
load_model()
print("Model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load model on startup: {e}")
with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo:
gr.Markdown("## Molecule–Protein Affinity Prediction")
gr.Markdown(
"Enter a **Molecule SMILES string** and a **Protein amino acid sequence** "
"then click **Predict** to get the affinity score using the StructureFree-DTA model."
)
gr.Markdown(
"### Example inputs:\n"
"**SMILES:** `CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC=C4)N`\n"
"\n**Protein:** `MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL`"
)
with gr.Row():
smiles_input = gr.Textbox(
label="Molecule SMILES",
placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O",
lines=2
)
sequence_input = gr.Textbox(
label="Protein Sequence",
placeholder="e.g. MVLSPADKTNVKAA...",
lines=5
)
predict_button = gr.Button("Predict", variant="primary")
output = gr.Textbox(label="Affinity Score", interactive=False)
predict_button.click(
fn=predict_affinity,
inputs=[smiles_input, sequence_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()