Spaces:
Runtime error
Runtime error
Commit
·
72dbfd7
1
Parent(s):
022cccc
both first demos now work
Browse files- new_app.py +14 -16
new_app.py
CHANGED
|
@@ -291,21 +291,23 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
|
|
| 291 |
|
| 292 |
def run_model(self, sample_dict, model: Mammal):
|
| 293 |
# Generate Prediction
|
| 294 |
-
batch_dict = model.
|
| 295 |
-
[sample_dict],
|
| 296 |
-
output_scores=True,
|
| 297 |
-
return_dict_in_generate=True,
|
| 298 |
-
max_new_tokens=5,
|
| 299 |
-
)
|
| 300 |
return batch_dict
|
| 301 |
|
| 302 |
def decode_output(self,batch_dict, model_holder):
|
| 303 |
|
| 304 |
# Get output
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
|
| 311 |
def create_and_run_prompt(self,model_name,target_seq, drug_seq):
|
|
@@ -353,15 +355,11 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
|
|
| 353 |
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
|
| 354 |
|
| 355 |
with gr.Row():
|
| 356 |
-
decoded = gr.Textbox(label="Mammal output")
|
| 357 |
run_mammal.click(
|
| 358 |
fn=self.create_and_run_prompt,
|
| 359 |
inputs=[model_name_widget, target_textbox, drug_textbox],
|
| 360 |
-
outputs=[prompt_box, decoded, gr.Number(label="
|
| 361 |
-
)
|
| 362 |
-
with gr.Row():
|
| 363 |
-
gr.Markdown(
|
| 364 |
-
"```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
|
| 365 |
)
|
| 366 |
demo.visible = False
|
| 367 |
return demo
|
|
|
|
| 291 |
|
| 292 |
def run_model(self, sample_dict, model: Mammal):
|
| 293 |
# Generate Prediction
|
| 294 |
+
batch_dict = model.forward_encoder_only([sample_dict])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
return batch_dict
|
| 296 |
|
| 297 |
def decode_output(self,batch_dict, model_holder):
|
| 298 |
|
| 299 |
# Get output
|
| 300 |
+
batch_dict = DtiBindingdbKdTask.process_model_output(
|
| 301 |
+
batch_dict,
|
| 302 |
+
scalars_preds_processed_key="model.out.dti_bindingdb_kd",
|
| 303 |
+
norm_y_mean=5.79384684128215,
|
| 304 |
+
norm_y_std=1.33808027428196,
|
| 305 |
+
)
|
| 306 |
+
ans = (
|
| 307 |
+
"model.out.dti_bindingdb_kd",
|
| 308 |
+
float(batch_dict["model.out.dti_bindingdb_kd"][0]),
|
| 309 |
+
)
|
| 310 |
+
return ans
|
| 311 |
|
| 312 |
|
| 313 |
def create_and_run_prompt(self,model_name,target_seq, drug_seq):
|
|
|
|
| 355 |
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
|
| 356 |
|
| 357 |
with gr.Row():
|
| 358 |
+
decoded = gr.Textbox(label="Mammal output key")
|
| 359 |
run_mammal.click(
|
| 360 |
fn=self.create_and_run_prompt,
|
| 361 |
inputs=[model_name_widget, target_textbox, drug_textbox],
|
| 362 |
+
outputs=[prompt_box, decoded, gr.Number(label="binding affinity")],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
)
|
| 364 |
demo.visible = False
|
| 365 |
return demo
|