Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import copy | |
| import os | |
| import torch | |
| import time | |
| from argparse import ArgumentParser, Namespace, FileType | |
| from rdkit.Chem import RemoveHs | |
| from functools import partial | |
| import numpy as np | |
| import pandas as pd | |
| from rdkit import RDLogger | |
| from rdkit.Chem import MolFromSmiles, AddHs | |
| from torch_geometric.loader import DataLoader | |
| import yaml | |
| print(os.getcwd()) | |
| print(os.listdir("datasets")) | |
| from datasets.process_mols import ( | |
| read_molecule, | |
| generate_conformer, | |
| write_mol_with_coords, | |
| ) | |
| from datasets.pdbbind import PDBBind | |
| from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule | |
| from utils.sampling import randomize_position, sampling | |
| from utils.utils import get_model | |
| from utils.visualise import PDBFile | |
| from tqdm import tqdm | |
| from datasets.esm_embedding_preparation import esm_embedding_prep | |
| import subprocess | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| with open(f"workdir/paper_score_model/model_parameters.yml") as f: | |
| score_model_args = Namespace(**yaml.full_load(f)) | |
| with open(f"workdir/paper_confidence_model/model_parameters.yml") as f: | |
| confidence_args = Namespace(**yaml.full_load(f)) | |
| t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) | |
| model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True) | |
| state_dict = torch.load( | |
| f"workdir/paper_score_model/best_ema_inference_epoch_model.pt", | |
| map_location=torch.device("cpu"), | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device) | |
| model.eval() | |
| confidence_model = get_model( | |
| confidence_args, | |
| device, | |
| t_to_sigma=t_to_sigma, | |
| no_parallel=True, | |
| confidence_mode=True, | |
| ) | |
| state_dict = torch.load( | |
| f"workdir/paper_confidence_model/best_model_epoch75.pt", | |
| map_location=torch.device("cpu"), | |
| ) | |
| confidence_model.load_state_dict(state_dict, strict=True) | |
| confidence_model = confidence_model.to(device) | |
| confidence_model.eval() | |
| tr_schedule = get_t_schedule(inference_steps=10) | |
| rot_schedule = tr_schedule | |
| tor_schedule = tr_schedule | |
| print("common t schedule", tr_schedule) | |
| failures, skipped, confidences_list, names_list, run_times, min_self_distances_list = ( | |
| 0, | |
| 0, | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| N = 10 | |
| def get_pdb(pdb_code="", filepath=""): | |
| if pdb_code is None or pdb_code == "": | |
| try: | |
| return filepath.name | |
| except AttributeError as e: | |
| return None | |
| else: | |
| os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") | |
| return f"{pdb_code}.pdb" | |
| def get_ligand(smiles="", filepath=""): | |
| if smiles is None or smiles == "": | |
| try: | |
| return filepath.name | |
| except AttributeError as e: | |
| return None | |
| else: | |
| return smiles | |
| def read_mol(molpath): | |
| with open(molpath, "r") as fp: | |
| lines = fp.readlines() | |
| mol = "" | |
| for l in lines: | |
| mol += l | |
| return mol | |
| def molecule(input_pdb, ligand_pdb): | |
| structure = read_mol(input_pdb) | |
| mol = read_mol(ligand_pdb) | |
| x = ( | |
| """<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> | |
| <style> | |
| body{ | |
| font-family:sans-serif | |
| } | |
| .mol-container { | |
| width: 600px; | |
| height: 600px; | |
| position: relative; | |
| mx-auto:0 | |
| } | |
| .mol-container select{ | |
| background-image:None; | |
| } | |
| </style> | |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> | |
| </head> | |
| <body> | |
| <button id="startanimation">Replay diffusion process</button> | |
| <div id="container" class="mol-container"></div> | |
| <script> | |
| let ligand = `""" | |
| + mol | |
| + """` | |
| let structure = `""" | |
| + structure | |
| + """` | |
| let viewer = null; | |
| $(document).ready(function () { | |
| let element = $("#container"); | |
| let config = { backgroundColor: "white" }; | |
| viewer = $3Dmol.createViewer(element, config); | |
| viewer.addModel( structure, "pdb" ); | |
| viewer.setStyle({}, {cartoon: {color: "gray"}}); | |
| viewer.zoomTo(); | |
| viewer.zoom(0.7); | |
| viewer.addModelsAsFrames(ligand, "pdb"); | |
| viewer.animate({loop: "forward",reps: 1}); | |
| viewer.getModel(1).setStyle({stick:{colorscheme:"magentaCarbon"}}); | |
| viewer.render(); | |
| }) | |
| $("#startanimation").click(function() { | |
| viewer.animate({loop: "forward",reps: 1}); | |
| }); | |
| </script> | |
| </body></html>""" | |
| ) | |
| return f"""<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera; | |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
| allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
| def esm(protein_path, out_file): | |
| esm_embedding_prep(out_file, protein_path) | |
| # create args object with defaults | |
| os.environ["HOME"] = "esm/model_weights" | |
| subprocess.call( | |
| f"python esm/scripts/extract.py esm2_t33_650M_UR50D {out_file} data/esm2_output --repr_layers 33 --include per_tok", | |
| shell=True, | |
| ) | |
| def update(inp, file, ligand_inp, ligand_file): | |
| pdb_path = get_pdb(inp, file) | |
| ligand_path = get_ligand(ligand_inp, ligand_file) | |
| esm( | |
| pdb_path, | |
| f"data/{os.path.basename(pdb_path)}_prepared_for_esm.fasta", | |
| ) | |
| protein_path_list = [pdb_path] | |
| ligand_descriptions = [ligand_path] | |
| no_random = False | |
| ode = False | |
| no_final_step_noise = False | |
| out_dir = "results/test" | |
| test_dataset = PDBBind( | |
| transform=None, | |
| root="", | |
| protein_path_list=protein_path_list, | |
| ligand_descriptions=ligand_descriptions, | |
| receptor_radius=score_model_args.receptor_radius, | |
| cache_path="data/cache", | |
| remove_hs=score_model_args.remove_hs, | |
| max_lig_size=None, | |
| c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, | |
| matching=False, | |
| keep_original=False, | |
| popsize=score_model_args.matching_popsize, | |
| maxiter=score_model_args.matching_maxiter, | |
| all_atoms=score_model_args.all_atoms, | |
| atom_radius=score_model_args.atom_radius, | |
| atom_max_neighbors=score_model_args.atom_max_neighbors, | |
| esm_embeddings_path="data/esm2_output", | |
| require_ligand=True, | |
| num_workers=1, | |
| keep_local_structures=False, | |
| ) | |
| test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) | |
| confidence_test_dataset = PDBBind( | |
| transform=None, | |
| root="", | |
| protein_path_list=protein_path_list, | |
| ligand_descriptions=ligand_descriptions, | |
| receptor_radius=confidence_args.receptor_radius, | |
| cache_path="data/cache", | |
| remove_hs=confidence_args.remove_hs, | |
| max_lig_size=None, | |
| c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, | |
| matching=False, | |
| keep_original=False, | |
| popsize=confidence_args.matching_popsize, | |
| maxiter=confidence_args.matching_maxiter, | |
| all_atoms=confidence_args.all_atoms, | |
| atom_radius=confidence_args.atom_radius, | |
| atom_max_neighbors=confidence_args.atom_max_neighbors, | |
| esm_embeddings_path="data/esm2_output", | |
| require_ligand=True, | |
| num_workers=1, | |
| ) | |
| confidence_complex_dict = {d.name: d for d in confidence_test_dataset} | |
| for idx, orig_complex_graph in tqdm(enumerate(test_loader)): | |
| if ( | |
| confidence_model is not None | |
| and not ( | |
| confidence_args.use_original_model_cache | |
| or confidence_args.transfer_weights | |
| ) | |
| and orig_complex_graph.name[0] not in confidence_complex_dict.keys() | |
| ): | |
| skipped += 1 | |
| print( | |
| f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name[0]}. We are skipping this complex." | |
| ) | |
| continue | |
| try: | |
| data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)] | |
| randomize_position( | |
| data_list, | |
| score_model_args.no_torsion, | |
| no_random, | |
| score_model_args.tr_sigma_max, | |
| ) | |
| pdb = None | |
| lig = orig_complex_graph.mol[0] | |
| visualization_list = [] | |
| for graph in data_list: | |
| pdb = PDBFile(lig) | |
| pdb.add(lig, 0, 0) | |
| pdb.add( | |
| ( | |
| orig_complex_graph["ligand"].pos | |
| + orig_complex_graph.original_center | |
| ) | |
| .detach() | |
| .cpu(), | |
| 1, | |
| 0, | |
| ) | |
| pdb.add( | |
| (graph["ligand"].pos + graph.original_center).detach().cpu(), | |
| part=1, | |
| order=1, | |
| ) | |
| visualization_list.append(pdb) | |
| start_time = time.time() | |
| if confidence_model is not None and not ( | |
| confidence_args.use_original_model_cache | |
| or confidence_args.transfer_weights | |
| ): | |
| confidence_data_list = [ | |
| copy.deepcopy(confidence_complex_dict[orig_complex_graph.name[0]]) | |
| for _ in range(N) | |
| ] | |
| else: | |
| confidence_data_list = None | |
| data_list, confidence = sampling( | |
| data_list=data_list, | |
| model=model, | |
| inference_steps=10, | |
| tr_schedule=tr_schedule, | |
| rot_schedule=rot_schedule, | |
| tor_schedule=tor_schedule, | |
| device=device, | |
| t_to_sigma=t_to_sigma, | |
| model_args=score_model_args, | |
| no_random=no_random, | |
| ode=ode, | |
| visualization_list=visualization_list, | |
| confidence_model=confidence_model, | |
| confidence_data_list=confidence_data_list, | |
| confidence_model_args=confidence_args, | |
| batch_size=1, | |
| no_final_step_noise=no_final_step_noise, | |
| ) | |
| ligand_pos = np.asarray( | |
| [ | |
| complex_graph["ligand"].pos.cpu().numpy() | |
| + orig_complex_graph.original_center.cpu().numpy() | |
| for complex_graph in data_list | |
| ] | |
| ) | |
| run_times.append(time.time() - start_time) | |
| if confidence is not None and isinstance( | |
| confidence_args.rmsd_classification_cutoff, list | |
| ): | |
| confidence = confidence[:, 0] | |
| if confidence is not None: | |
| confidence = confidence.cpu().numpy() | |
| re_order = np.argsort(confidence)[::-1] | |
| confidence = confidence[re_order] | |
| confidences_list.append(confidence) | |
| ligand_pos = ligand_pos[re_order] | |
| write_dir = ( | |
| f'{out_dir}/index{idx}_{data_list[0]["name"][0].replace("/","-")}' | |
| ) | |
| os.makedirs(write_dir, exist_ok=True) | |
| for rank, pos in enumerate(ligand_pos): | |
| mol_pred = copy.deepcopy(lig) | |
| if score_model_args.remove_hs: | |
| mol_pred = RemoveHs(mol_pred) | |
| if rank == 0: | |
| write_mol_with_coords( | |
| mol_pred, pos, os.path.join(write_dir, f"rank{rank+1}.sdf") | |
| ) | |
| write_mol_with_coords( | |
| mol_pred, | |
| pos, | |
| os.path.join( | |
| write_dir, f"rank{rank+1}_confidence{confidence[rank]:.2f}.sdf" | |
| ), | |
| ) | |
| self_distances = np.linalg.norm( | |
| ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1 | |
| ) | |
| self_distances = np.where( | |
| np.eye(self_distances.shape[2]), np.inf, self_distances | |
| ) | |
| min_self_distances_list.append(np.min(self_distances, axis=(1, 2))) | |
| filenames = [] | |
| if confidence is not None: | |
| for rank, batch_idx in enumerate(re_order): | |
| visualization_list[batch_idx].write( | |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") | |
| ) | |
| filenames.append( | |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") | |
| ) | |
| else: | |
| for rank, batch_idx in enumerate(ligand_pos): | |
| visualization_list[batch_idx].write( | |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") | |
| ) | |
| filenames.append( | |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") | |
| ) | |
| names_list.append(orig_complex_graph.name[0]) | |
| except Exception as e: | |
| print("Failed on", orig_complex_graph["name"], e) | |
| failures += 1 | |
| return None | |
| labels = [f"rank {i+1}" for i in range(len(filenames))] | |
| return ( | |
| molecule(pdb_path, filenames[0]), | |
| gr.Dropdown.update(choices=labels, value="rank 1"), | |
| filenames, | |
| pdb_path, | |
| ) | |
| def updateView(out, filenames, pdb): | |
| i = int(out.replace("rank", "")) | |
| return molecule(pdb, filenames[i]) | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# DiffDock") | |
| gr.Markdown( | |
| ">**DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking**, Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi, arXiv:2210.01776 [GitHub](https://github.com/gcorso/diffdock)" | |
| ) | |
| gr.Markdown("Runs the diffusion model `10` times with `10` inference steps") | |
| with gr.Box(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Protein") | |
| inp = gr.Textbox( | |
| placeholder="PDB Code or upload file below", label="Input structure" | |
| ) | |
| file = gr.File(file_count="single", label="Input PDB") | |
| with gr.Column(): | |
| gr.Markdown("## Ligand") | |
| ligand_inp = gr.Textbox( | |
| placeholder="Provide SMILES input or upload mol2/sdf file below", | |
| label="SMILES string", | |
| ) | |
| ligand_file = gr.File(file_count="single", label="Input Ligand") | |
| btn = gr.Button("Run predictions") | |
| gr.Markdown("## Output") | |
| pdb = gr.Variable() | |
| filenames = gr.Variable() | |
| out = gr.Dropdown(interactive=True, label="Ranked samples") | |
| mol = gr.HTML() | |
| gr.Examples( | |
| [ | |
| [ | |
| None, | |
| "examples/1a46_protein_processed.pdb", | |
| None, | |
| "examples/1a46_ligand.sdf", | |
| ] | |
| ], | |
| [inp, file, ligand_inp, ligand_file], | |
| [mol, out], | |
| # cache_examples=True, | |
| ) | |
| btn.click( | |
| fn=update, | |
| inputs=[inp, file, ligand_inp, ligand_file], | |
| outputs=[mol, out, filenames, pdb], | |
| ) | |
| out.change(fn=updateView, inputs=[out, filenames, pdb], outputs=mol) | |
| demo.launch() | |