DrugFlow / src /analysis /visualization_utils.py
mority's picture
Upload 53 files
6e7d4ba verified
import warnings
import torch
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Chem import SanitizeFlags
from src.analysis.metrics import check_mol
from src import utils
from src.data.molecule_builder import build_molecule
from src.data.misc import protein_letters_1to3
# def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
# atom_decoder=None, aa_decoder=None, residue_decoder=None,
# aa_atom_index=None):
#
# rdpockets = []
# for i in torch.unique(pocket['mask']):
#
# node_coord = pocket['x'][pocket['mask'] == i]
# h = pocket['one_hot'][pocket['mask'] == i]
#
# if pocket_representation == 'side_chain_bead':
# coord = node_coord
#
# node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
# atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
#
# elif pocket_representation == 'CA+':
# aa_types = [aa_decoder[b] for b in h.argmax(-1)]
# side_chain_vec = pocket['v'][pocket['mask'] == i]
#
# coord = []
# atom_types = []
# for xyz, aa, vec in zip(node_coord, aa_types, side_chain_vec):
# # C_alpha
# coord.append(xyz)
# atom_types.append('C')
#
# # all other atoms
# for atom_name, idx in aa_atom_index[aa].items():
# coord.append(xyz + vec[idx])
# atom_types.append(atom_name[0])
#
# coord = torch.stack(coord, dim=0)
#
# else:
# raise NotImplementedError(f"{pocket_representation} residue representation not supported")
#
# atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
# rdpockets.append(build_molecule(coord, atom_types, atom_decoder=atom_decoder))
#
# return rdpockets
def pocket_to_rdkit(pocket, pocket_representation, atom_encoder=None,
atom_decoder=None, aa_decoder=None, residue_decoder=None,
aa_atom_index=None):
rdpockets = []
for i in torch.unique(pocket['mask']):
node_coord = pocket['x'][pocket['mask'] == i]
h = pocket['one_hot'][pocket['mask'] == i]
atom_mask = pocket['atom_mask'][pocket['mask'] == i]
pdb_infos = []
if pocket_representation == 'side_chain_bead':
coord = node_coord
node_types = [residue_decoder[b] for b in h[:, -len(residue_decoder):].argmax(-1)]
atom_types = ['C' if r == 'CA' else 'F' for r in node_types]
elif pocket_representation == 'CA+':
aa_types = [aa_decoder[b] for b in h.argmax(-1)]
side_chain_vec = pocket['v'][pocket['mask'] == i]
coord = []
atom_types = []
for resi, (xyz, aa, vec, am) in enumerate(zip(node_coord, aa_types, side_chain_vec, atom_mask)):
# CA not treated differently with updated atom dictionary
for atom_name, idx in aa_atom_index[aa].items():
if ~am[idx]:
warnings.warn(f"Missing atom {atom_name} in {aa}:{resi}")
continue
coord.append(xyz + vec[idx])
atom_types.append(atom_name[0])
info = Chem.AtomPDBResidueInfo()
# info.SetChainId('A')
info.SetResidueName(protein_letters_1to3[aa])
info.SetResidueNumber(resi + 1)
info.SetOccupancy(1.0)
info.SetTempFactor(0.0)
info.SetName(f' {atom_name:<3}')
pdb_infos.append(info)
coord = torch.stack(coord, dim=0)
else:
raise NotImplementedError(f"{pocket_representation} residue representation not supported")
atom_types = torch.tensor([atom_encoder[a] for a in atom_types])
rdmol = build_molecule(coord, atom_types, atom_decoder=atom_decoder)
if len(pdb_infos) == len(rdmol.GetAtoms()):
for a, info in zip(rdmol.GetAtoms(), pdb_infos):
a.SetPDBResidueInfo(info)
rdpockets.append(rdmol)
return rdpockets
def mols_to_pdbfile(rdmols, filename, flavor=0):
pdb_str = ""
for i, mol in enumerate(rdmols):
pdb_str += f"MODEL{i + 1:>9}\n"
block = Chem.MolToPDBBlock(mol, flavor=flavor)
block = "\n".join(block.split("\n")[:-2]) # remove END
pdb_str += block + "\n"
pdb_str += f"ENDMDL\n"
pdb_str += f"END\n"
with open(filename, 'w') as f:
f.write(pdb_str)
return pdb_str
def mol_as_pdb(rdmol, filename=None, bfactor=None):
_rdmol = Chem.Mol(rdmol) # copy
for a in _rdmol.GetAtoms():
a.SetIsAromatic(False)
for b in _rdmol.GetBonds():
b.SetIsAromatic(False)
if bfactor is not None:
for a in _rdmol.GetAtoms():
val = a.GetPropsAsDict()[bfactor]
info = Chem.AtomPDBResidueInfo()
info.SetResidueName('UNL')
info.SetResidueNumber(1)
info.SetName(f' {a.GetSymbol():<3}')
info.SetIsHeteroAtom(True)
info.SetOccupancy(1.0)
info.SetTempFactor(val)
a.SetPDBResidueInfo(info)
pdb_str = Chem.MolToPDBBlock(_rdmol)
if filename is not None:
with open(filename, 'w') as f:
f.write(pdb_str)
return pdb_str
def draw_grid(molecules, mols_per_row=5, fig_size=(200, 200),
label=check_mol,
highlight_atom=lambda atom: False,
highlight_bond=lambda bond: False):
draw_mols = []
marked_atoms = []
marked_bonds = []
for mol in molecules:
draw_mol = Chem.Mol(mol) # copy
Chem.SanitizeMol(draw_mol, sanitizeOps=SanitizeFlags.SANITIZE_NONE)
AllChem.Compute2DCoords(draw_mol)
draw_mol = Draw.rdMolDraw2D.PrepareMolForDrawing(draw_mol,
kekulize=False)
draw_mols.append(draw_mol)
marked_atoms.append([a.GetIdx() for a in draw_mol.GetAtoms() if highlight_atom(a)])
marked_bonds.append([b.GetIdx() for b in draw_mol.GetBonds() if highlight_bond(b)])
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
drawOptions.prepareMolsBeforeDrawing = False
drawOptions.highlightBondWidthMultiplier = 20
return Draw.MolsToGridImage(draw_mols,
molsPerRow=mols_per_row,
subImgSize=fig_size,
drawOptions=drawOptions,
highlightAtomLists=marked_atoms,
highlightBondLists=marked_bonds,
legends=[f'[{i}] {label(mol)}' for
i, mol in enumerate(draw_mols)])