Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| from transformers import AutoTokenizer | |
| from models import AffinityPredictor | |
| # Global variables for model and tokenizers | |
| model = None | |
| molecule_tokenizer = None | |
| protein_tokenizer = None | |
| device = None | |
| def load_model(): | |
| """Load the trained model and tokenizers""" | |
| global model, molecule_tokenizer, protein_tokenizer, device | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Load tokenizers | |
| molecule_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM") | |
| protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") | |
| # Initialize model with same configuration as training | |
| model = AffinityPredictor( | |
| protein_model_name="facebook/esm2_t6_8M_UR50D", | |
| molecule_model_name="DeepChem/ChemBERTa-77M-MLM", | |
| hidden_sizes=[1024, 768, 512, 256, 1], | |
| inception_out_channels=256, | |
| dropout=0.05 | |
| ) | |
| # Load the trained weights | |
| model_path = "Davis-Final.pth" | |
| if os.path.exists(model_path): | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Handle different checkpoint formats | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| elif 'state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| print("Model loaded successfully!") | |
| else: | |
| print(f"Warning: Model file {model_path} not found. Using randomly initialized weights.") | |
| model.to(device) | |
| model.eval() | |
| return True | |
| def predict_affinity(smiles, sequence): | |
| """Predict drug-target affinity using the trained model""" | |
| global model, molecule_tokenizer, protein_tokenizer, device | |
| # Load model if not already loaded | |
| if model is None: | |
| try: | |
| load_model() | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| # Validate inputs | |
| if not smiles or not smiles.strip(): | |
| return "Error: Please enter a valid SMILES string" | |
| if not sequence or not sequence.strip(): | |
| return "Error: Please enter a valid protein sequence" | |
| try: | |
| model.eval() | |
| # Tokenize inputs | |
| molecule_encoding = molecule_tokenizer( | |
| [smiles.strip()], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128, | |
| return_tensors="pt" | |
| ) | |
| protein_encoding = protein_tokenizer( | |
| [sequence.strip()], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=1024, | |
| return_tensors="pt" | |
| ) | |
| # Create batch dictionary | |
| batch = { | |
| "molecule_input_ids": molecule_encoding.input_ids.to(device), | |
| "molecule_attention_mask": molecule_encoding.attention_mask.to(device), | |
| "protein_input_ids": protein_encoding.input_ids.to(device), | |
| "protein_attention_mask": protein_encoding.attention_mask.to(device) | |
| } | |
| # Make prediction | |
| with torch.no_grad(): | |
| prediction = model(batch) | |
| affinity_score = prediction.cpu().item() | |
| return f"Predicted Affinity Score: {affinity_score:.4f}" | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Load model on startup | |
| print("Loading model...") | |
| try: | |
| load_model() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not load model on startup: {e}") | |
| with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo: | |
| gr.Markdown("## Molecule–Protein Affinity Prediction") | |
| gr.Markdown( | |
| "Enter a **Molecule SMILES string** and a **Protein amino acid sequence** " | |
| "then click **Predict** to get the affinity score using the StructureFree-DTA model." | |
| ) | |
| gr.Markdown( | |
| "### Example inputs:\n" | |
| "**SMILES:** `CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC=C4)N`\n" | |
| "\n**Protein:** `MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL`" | |
| ) | |
| with gr.Row(): | |
| smiles_input = gr.Textbox( | |
| label="Molecule SMILES", | |
| placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O", | |
| lines=2 | |
| ) | |
| sequence_input = gr.Textbox( | |
| label="Protein Sequence", | |
| placeholder="e.g. MVLSPADKTNVKAA...", | |
| lines=5 | |
| ) | |
| predict_button = gr.Button("Predict", variant="primary") | |
| output = gr.Textbox(label="Affinity Score", interactive=False) | |
| predict_button.click( | |
| fn=predict_affinity, | |
| inputs=[smiles_input, sequence_input], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |