|
|
import warnings |
|
|
|
|
|
import torch |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import Draw, AllChem |
|
|
from rdkit.Chem import SanitizeFlags |
|
|
from src.analysis.metrics import check_mol |
|
|
from src import utils |
|
|
from src.data.molecule_builder import build_molecule |
|
|
from src.data.misc import protein_letters_1to3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None, |
|
|
atom_decoder=None, aa_decoder=None, residue_decoder=None, |
|
|
aa_atom_index=None): |
|
|
|
|
|
rdpockets = [] |
|
|
for i in torch.unique(pocket['mask']): |
|
|
|
|
|
node_coord = pocket['x'][pocket['mask'] == i] |
|
|
h = pocket['one_hot'][pocket['mask'] == i] |
|
|
atom_mask = pocket['atom_mask'][pocket['mask'] == i] |
|
|
|
|
|
pdb_infos = [] |
|
|
|
|
|
if pocket_representation == 'side_chain_bead': |
|
|
coord = node_coord |
|
|
|
|
|
node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)] |
|
|
atom_types = ['C' if r == 'CA' else 'F' for r in node_types] |
|
|
|
|
|
elif pocket_representation == 'CA+': |
|
|
aa_types = [aa_decoder[b] for b in h.argmax(-1)] |
|
|
side_chain_vec = pocket['v'][pocket['mask'] == i] |
|
|
|
|
|
coord = [] |
|
|
atom_types = [] |
|
|
for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)): |
|
|
|
|
|
|
|
|
for atom_name, idx in aa_atom_index[aa].items(): |
|
|
|
|
|
if ~am[idx]: |
|
|
warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}") |
|
|
continue |
|
|
|
|
|
coord.append(xyz + vec[idx]) |
|
|
atom_types.append(atom_name[0]) |
|
|
|
|
|
info = Chem.AtomPDBResidueInfo() |
|
|
|
|
|
info.SetResidueName(protein_letters_1to3[aa]) |
|
|
info.SetResidueNumber(resi + 1) |
|
|
info.SetOccupancy(1.0) |
|
|
info.SetTempFactor(0.0) |
|
|
info.SetName(f' {atom_name:<3}') |
|
|
pdb_infos.append(info) |
|
|
|
|
|
coord = torch.stack(coord, dim=0) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError(f"{pocket_representation} residue representation not supported") |
|
|
|
|
|
atom_types = torch.tensor([atom_encoder[a] for a in atom_types]) |
|
|
rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder) |
|
|
|
|
|
if len(pdb_infos) == len(rdmol.GetAtoms()): |
|
|
for a, info in zip(rdmol.GetAtoms(), pdb_infos): |
|
|
a.SetPDBResidueInfo(info) |
|
|
|
|
|
rdpockets.append(rdmol) |
|
|
|
|
|
return rdpockets |
|
|
|
|
|
|
|
|
def mols_to_pdbfile(rdmols, filename, flavor=0): |
|
|
pdb_str = "" |
|
|
for i, mol in enumerate(rdmols): |
|
|
pdb_str += f"MODEL{i + 1:>9}\n" |
|
|
block = Chem.MolToPDBBlock(mol, flavor=flavor) |
|
|
block = "\n".join(block.split("\n")[:-2]) |
|
|
pdb_str += block + "\n" |
|
|
pdb_str += f"ENDMDL\n" |
|
|
pdb_str += f"END\n" |
|
|
|
|
|
with open(filename, 'w') as f: |
|
|
f.write(pdb_str) |
|
|
|
|
|
return pdb_str |
|
|
|
|
|
|
|
|
def mol_as_pdb(rdmol, filename=None, bfactor=None): |
|
|
|
|
|
_rdmol = Chem.Mol(rdmol) |
|
|
for a in _rdmol.GetAtoms(): |
|
|
a.SetIsAromatic(False) |
|
|
for b in _rdmol.GetBonds(): |
|
|
b.SetIsAromatic(False) |
|
|
|
|
|
if bfactor is not None: |
|
|
for a in _rdmol.GetAtoms(): |
|
|
val = a.GetPropsAsDict()[bfactor] |
|
|
|
|
|
info = Chem.AtomPDBResidueInfo() |
|
|
info.SetResidueName('UNL') |
|
|
info.SetResidueNumber(1) |
|
|
info.SetName(f' {a.GetSymbol():<3}') |
|
|
info.SetIsHeteroAtom(True) |
|
|
info.SetOccupancy(1.0) |
|
|
info.SetTempFactor(val) |
|
|
a.SetPDBResidueInfo(info) |
|
|
|
|
|
pdb_str = Chem.MolToPDBBlock(_rdmol) |
|
|
|
|
|
if filename is not None: |
|
|
with open(filename, 'w') as f: |
|
|
f.write(pdb_str) |
|
|
|
|
|
return pdb_str |
|
|
|
|
|
|
|
|
def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200), |
|
|
label=check_mol, |
|
|
highlight_atom=lambda atom: False, |
|
|
highlight_bond=lambda bond: False): |
|
|
|
|
|
draw_mols = [] |
|
|
marked_atoms = [] |
|
|
marked_bonds = [] |
|
|
for mol in molecules: |
|
|
draw_mol = Chem.Mol(mol) |
|
|
Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE) |
|
|
AllChem.Compute2DCoords(draw_mol) |
|
|
draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol, |
|
|
kekulize=False) |
|
|
draw_mols.append(draw_mol) |
|
|
marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)]) |
|
|
marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)]) |
|
|
|
|
|
drawOptions = Draw.rdMolDraw2D.MolDrawOptions() |
|
|
drawOptions.prepareMolsBeforeDrawing = False |
|
|
drawOptions.highlightBondWidthMultiplier = 20 |
|
|
|
|
|
return Draw.MolsToGridImage(draw_mols, |
|
|
molsPerRow=mols_per_row, |
|
|
subImgSize=fig_size, |
|
|
drawOptions=drawOptions, |
|
|
highlightAtomLists=marked_atoms, |
|
|
highlightBondLists=marked_bonds, |
|
|
legends=[f'[{i}] {label(mol)}' for |
|
|
i, mol in enumerate(draw_mols)]) |
|
|
|