import gradio as gr import torch import os from transformers import AutoTokenizer from models import AffinityPredictor # Global variables for model and tokenizers model = None molecule_tokenizer = None protein_tokenizer = None device = None def load_model(): """Load the trained model and tokenizers""" global model, molecule_tokenizer, protein_tokenizer, device # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load tokenizers molecule_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM") protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") # Initialize model with same configuration as training model = AffinityPredictor( protein_model_name="facebook/esm2_t6_8M_UR50D", molecule_model_name="DeepChem/ChemBERTa-77M-MLM", hidden_sizes=[1024, 768, 512, 256, 1], inception_out_channels=256, dropout=0.05 ) # Load the trained weights model_path = "Davis-Final.pth" if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) print("Model loaded successfully!") else: print(f"Warning: Model file {model_path} not found. Using randomly initialized weights.") model.to(device) model.eval() return True def predict_affinity(smiles, sequence): """Predict drug-target affinity using the trained model""" global model, molecule_tokenizer, protein_tokenizer, device # Load model if not already loaded if model is None: try: load_model() except Exception as e: return f"Error loading model: {str(e)}" # Validate inputs if not smiles or not smiles.strip(): return "Error: Please enter a valid SMILES string" if not sequence or not sequence.strip(): return "Error: Please enter a valid protein sequence" try: model.eval() # Tokenize inputs molecule_encoding = molecule_tokenizer( [smiles.strip()], padding="max_length", truncation=True, max_length=128, return_tensors="pt" ) protein_encoding = protein_tokenizer( [sequence.strip()], padding="max_length", truncation=True, max_length=1024, return_tensors="pt" ) # Create batch dictionary batch = { "molecule_input_ids": molecule_encoding.input_ids.to(device), "molecule_attention_mask": molecule_encoding.attention_mask.to(device), "protein_input_ids": protein_encoding.input_ids.to(device), "protein_attention_mask": protein_encoding.attention_mask.to(device) } # Make prediction with torch.no_grad(): prediction = model(batch) affinity_score = prediction.cpu().item() return f"Predicted Affinity Score: {affinity_score:.4f}" except Exception as e: return f"Error during prediction: {str(e)}" # Load model on startup print("Loading model...") try: load_model() print("Model loaded successfully!") except Exception as e: print(f"Warning: Could not load model on startup: {e}") with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo: gr.Markdown("## Molecule–Protein Affinity Prediction") gr.Markdown( "Enter a **Molecule SMILES string** and a **Protein amino acid sequence** " "then click **Predict** to get the affinity score using the StructureFree-DTA model." ) gr.Markdown( "### Example inputs:\n" "**SMILES:** `CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC=C4)N`\n" "\n**Protein:** `MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL`" ) with gr.Row(): smiles_input = gr.Textbox( label="Molecule SMILES", placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O", lines=2 ) sequence_input = gr.Textbox( label="Protein Sequence", placeholder="e.g. MVLSPADKTNVKAA...", lines=5 ) predict_button = gr.Button("Predict", variant="primary") output = gr.Textbox(label="Affinity Score", interactive=False) predict_button.click( fn=predict_affinity, inputs=[smiles_input, sequence_input], outputs=output ) if __name__ == "__main__": demo.launch()