Amir Hallaji commited on
Commit
e2ba292
·
1 Parent(s): c46b695

version 0.1.0 using davis

Browse files
Files changed (3) hide show
  1. Davis-Final.pth +3 -0
  2. app.py +125 -11
  3. requirements.txt +5 -0
Davis-Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c2c4f602839b78e253e22a312765691f3387dcbf4a478553d958c717d866c1b
3
+ size 67932195
app.py CHANGED
@@ -1,33 +1,146 @@
1
  import gradio as gr
2
- import random
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Placeholder prediction function
6
  def predict_affinity(smiles, sequence):
7
- # In practice, you'd call your ML model here
8
- # For now, we just return a random score between 0 and 1
9
- score = round(random.uniform(0, 1), 4)
10
- return f"Predicted Affinity Score: {score}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
12
 
13
  with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo:
14
  gr.Markdown("## Molecule–Protein Affinity Prediction")
15
  gr.Markdown(
16
  "Enter a **Molecule SMILES string** and a **Protein amino acid sequence** "
17
- "then click **Predict** to get the affinity score."
 
 
 
 
 
 
18
  )
19
 
20
  with gr.Row():
21
  smiles_input = gr.Textbox(
22
  label="Molecule SMILES",
23
- placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O"
 
24
  )
25
  sequence_input = gr.Textbox(
26
  label="Protein Sequence",
27
- placeholder="e.g. MVLSPADKTNVKAA..."
 
28
  )
29
 
30
- predict_button = gr.Button("Predict")
31
  output = gr.Textbox(label="Affinity Score", interactive=False)
32
 
33
  predict_button.click(
@@ -36,4 +149,5 @@ with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo:
36
  outputs=output
37
  )
38
 
39
- demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer
5
+ from models import AffinityPredictor
6
 
7
+ # Global variables for model and tokenizers
8
+ model = None
9
+ molecule_tokenizer = None
10
+ protein_tokenizer = None
11
+ device = None
12
+
13
+ def load_model():
14
+ """Load the trained model and tokenizers"""
15
+ global model, molecule_tokenizer, protein_tokenizer, device
16
+
17
+ # Set device
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ print(f"Using device: {device}")
20
+
21
+ # Load tokenizers
22
+ molecule_tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
23
+ protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
24
+
25
+ # Initialize model with same configuration as training
26
+ model = AffinityPredictor(
27
+ protein_model_name="facebook/esm2_t6_8M_UR50D",
28
+ molecule_model_name="DeepChem/ChemBERTa-77M-MLM",
29
+ hidden_sizes=[1024, 768, 512, 256, 1],
30
+ inception_out_channels=256,
31
+ dropout=0.05
32
+ )
33
+
34
+ # Load the trained weights
35
+ model_path = "Davis-Final.pth"
36
+ if os.path.exists(model_path):
37
+ checkpoint = torch.load(model_path, map_location=device)
38
+ # Handle different checkpoint formats
39
+ if 'model_state_dict' in checkpoint:
40
+ model.load_state_dict(checkpoint['model_state_dict'])
41
+ elif 'state_dict' in checkpoint:
42
+ model.load_state_dict(checkpoint['state_dict'])
43
+ else:
44
+ model.load_state_dict(checkpoint)
45
+ print("Model loaded successfully!")
46
+ else:
47
+ print(f"Warning: Model file {model_path} not found. Using randomly initialized weights.")
48
+
49
+ model.to(device)
50
+ model.eval()
51
+
52
+ return True
53
 
 
54
  def predict_affinity(smiles, sequence):
55
+ """Predict drug-target affinity using the trained model"""
56
+ global model, molecule_tokenizer, protein_tokenizer, device
57
+
58
+ # Load model if not already loaded
59
+ if model is None:
60
+ try:
61
+ load_model()
62
+ except Exception as e:
63
+ return f"Error loading model: {str(e)}"
64
+
65
+ # Validate inputs
66
+ if not smiles or not smiles.strip():
67
+ return "Error: Please enter a valid SMILES string"
68
+
69
+ if not sequence or not sequence.strip():
70
+ return "Error: Please enter a valid protein sequence"
71
+
72
+ try:
73
+ model.eval()
74
+
75
+ # Tokenize inputs
76
+ molecule_encoding = molecule_tokenizer(
77
+ [smiles.strip()],
78
+ padding="max_length",
79
+ truncation=True,
80
+ max_length=128,
81
+ return_tensors="pt"
82
+ )
83
+
84
+ protein_encoding = protein_tokenizer(
85
+ [sequence.strip()],
86
+ padding="max_length",
87
+ truncation=True,
88
+ max_length=1024,
89
+ return_tensors="pt"
90
+ )
91
+
92
+ # Create batch dictionary
93
+ batch = {
94
+ "molecule_input_ids": molecule_encoding.input_ids.to(device),
95
+ "molecule_attention_mask": molecule_encoding.attention_mask.to(device),
96
+ "protein_input_ids": protein_encoding.input_ids.to(device),
97
+ "protein_attention_mask": protein_encoding.attention_mask.to(device)
98
+ }
99
+
100
+ # Make prediction
101
+ with torch.no_grad():
102
+ prediction = model(batch)
103
+ affinity_score = prediction.cpu().item()
104
+
105
+ return f"Predicted Affinity Score: {affinity_score:.4f}"
106
+
107
+ except Exception as e:
108
+ return f"Error during prediction: {str(e)}"
109
 
110
+ # Load model on startup
111
+ print("Loading model...")
112
+ try:
113
+ load_model()
114
+ print("Model loaded successfully!")
115
+ except Exception as e:
116
+ print(f"Warning: Could not load model on startup: {e}")
117
 
118
  with gr.Blocks(title="Molecule-Protein Affinity Predictor") as demo:
119
  gr.Markdown("## Molecule–Protein Affinity Prediction")
120
  gr.Markdown(
121
  "Enter a **Molecule SMILES string** and a **Protein amino acid sequence** "
122
+ "then click **Predict** to get the affinity score using the StructureFree-DTA model."
123
+ )
124
+
125
+ gr.Markdown(
126
+ "### Example inputs:\n"
127
+ "**SMILES:** `CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC=C4)N`\n"
128
+ "**Protein:** `MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTAEKRKGGQTVDSGLPLLSVSDPFIPLQVPDAPEKLIEGLKSPDTSLLLPDLLPMTDPFGSTSDAVIEKADVAVESLIPGLEPPVPQRLPSQTESVTSNRTDSLTGEDSLLDCSLLSNPTTDLLEEFAPTAISAPVHKAAEDSNLISGFDVPEGSDKVAEDEFDPIPVLITKNPQGGHSRNSSGSSESSLPNLARSLLLVDQLIDL`"
129
  )
130
 
131
  with gr.Row():
132
  smiles_input = gr.Textbox(
133
  label="Molecule SMILES",
134
+ placeholder="e.g. CC(=O)OC1=CC=CC=C1C(=O)O",
135
+ lines=2
136
  )
137
  sequence_input = gr.Textbox(
138
  label="Protein Sequence",
139
+ placeholder="e.g. MVLSPADKTNVKAA...",
140
+ lines=5
141
  )
142
 
143
+ predict_button = gr.Button("Predict", variant="primary")
144
  output = gr.Textbox(label="Affinity Score", interactive=False)
145
 
146
  predict_button.click(
 
149
  outputs=output
150
  )
151
 
152
+ if __name__ == "__main__":
153
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,6 @@
1
  gradio
 
 
 
 
 
 
1
  gradio
2
+ torch
3
+ transformers
4
+ scikit-learn
5
+ pandas
6
+ numpy