|
|
import io |
|
|
from itertools import accumulate, chain |
|
|
from copy import deepcopy |
|
|
import random |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from rdkit import Chem |
|
|
from torch_scatter import scatter_mean |
|
|
from Bio.PDB import StructureBuilder, Chain, Model, Structure |
|
|
from Bio.PDB.PICIO import read_PIC, write_PIC |
|
|
from scipy.ndimage import gaussian_filter |
|
|
from pdb import set_trace |
|
|
|
|
|
from src.constants import FLOAT_TYPE, INT_TYPE |
|
|
from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index |
|
|
from src import utils |
|
|
from src.data.misc import protein_letters_3to1, is_aa |
|
|
from src.data.normal_modes import pdb_to_normal_modes |
|
|
from src.data.nerf import get_nerf_params, ic_to_coords |
|
|
import src.data.so3_utils as so3 |
|
|
|
|
|
|
|
|
class TensorDict(dict): |
|
|
def __init__(self, **kwargs): |
|
|
super(TensorDict, self).__init__(**kwargs) |
|
|
|
|
|
def _apply(self, func: str, *args, **kwargs): |
|
|
""" Apply function to all tensors. """ |
|
|
for k, v in self.items(): |
|
|
if torch.is_tensor(v): |
|
|
self[k] = getattr(v, func)(*args, **kwargs) |
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cuda(self): |
|
|
return self.to('cuda') |
|
|
|
|
|
def cpu(self): |
|
|
return self.to('cpu') |
|
|
|
|
|
def to(self, device): |
|
|
return self._apply("to", device) |
|
|
|
|
|
def detach(self): |
|
|
return self._apply("detach") |
|
|
|
|
|
def __repr__(self): |
|
|
def val_to_str(val): |
|
|
if isinstance(val, torch.Tensor): |
|
|
|
|
|
|
|
|
return "%r" % list(val.size()) |
|
|
if isinstance(val, list): |
|
|
return "[%r,]" % len(val) |
|
|
else: |
|
|
return "?" |
|
|
|
|
|
return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})" |
|
|
|
|
|
|
|
|
def collate_entity(batch): |
|
|
|
|
|
out = {} |
|
|
for prop in batch[0].keys(): |
|
|
|
|
|
if prop == 'name': |
|
|
out[prop] = [x[prop] for x in batch] |
|
|
|
|
|
elif prop == 'size' or prop == 'n_bonds': |
|
|
out[prop] = torch.tensor([x[prop] for x in batch]) |
|
|
|
|
|
elif prop == 'bonds': |
|
|
|
|
|
offset = list(accumulate([x['size'] for x in batch], initial=0)) |
|
|
out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1) |
|
|
|
|
|
elif prop == 'residues': |
|
|
out[prop] = list(chain.from_iterable(x[prop] for x in batch)) |
|
|
|
|
|
elif prop in {'mask', 'bond_mask'}: |
|
|
pass |
|
|
|
|
|
else: |
|
|
out[prop] = torch.cat([x[prop] for x in batch], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
if prop == 'x': |
|
|
out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
|
|
for i, x in enumerate(batch)], dim=0) |
|
|
if prop == 'bond_one_hot': |
|
|
|
|
|
out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) |
|
|
for i, x in enumerate(batch)], dim=0) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def split_entity( |
|
|
batch, |
|
|
*, |
|
|
index_types={'bonds'}, |
|
|
edge_types={'bond_one_hot', 'bond_mask'}, |
|
|
no_split={'name', 'size', 'n_bonds'}, |
|
|
skip={'fragments'}, |
|
|
batch_mask=None, |
|
|
edge_mask=None |
|
|
): |
|
|
""" Splits a batch into items and returns a list. """ |
|
|
|
|
|
batch_mask = batch["mask"] if batch_mask is None else batch_mask |
|
|
edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask |
|
|
sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist() |
|
|
|
|
|
batch_size = len(torch.unique(batch['mask'])) |
|
|
out = {} |
|
|
for prop in batch.keys(): |
|
|
if prop in skip: |
|
|
continue |
|
|
if prop in no_split: |
|
|
out[prop] = batch[prop] |
|
|
|
|
|
elif prop in index_types: |
|
|
offsets = list(accumulate(sizes[:-1], initial=0)) |
|
|
out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets) |
|
|
|
|
|
elif prop in edge_types: |
|
|
out[prop] = utils.batch_to_list(batch[prop], edge_mask) |
|
|
|
|
|
else: |
|
|
out[prop] = utils.batch_to_list(batch[prop], batch_mask) |
|
|
|
|
|
out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)] |
|
|
return out |
|
|
|
|
|
|
|
|
def repeat_items(batch, repeats): |
|
|
batch_list = split_entity(batch) |
|
|
out = collate_entity([x for _ in range(repeats) for x in batch_list]) |
|
|
return type(batch)(**out) |
|
|
|
|
|
|
|
|
def get_side_chain_bead_coord(biopython_residue): |
|
|
""" |
|
|
Places side chain bead at the location of the farthest side chain atom. |
|
|
""" |
|
|
if biopython_residue.get_resname() == 'GLY': |
|
|
return None |
|
|
if biopython_residue.get_resname() == 'ALA': |
|
|
return biopython_residue['CB'].get_coord() |
|
|
|
|
|
ca_coord = biopython_residue['CA'].get_coord() |
|
|
side_chain_atoms = [a for a in biopython_residue.get_atoms() if |
|
|
a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H'] |
|
|
side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms]) |
|
|
|
|
|
atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1)) |
|
|
|
|
|
return side_chain_coords[atom_idx, :] |
|
|
|
|
|
|
|
|
def get_side_chain_vectors(res, index_dict, size=None): |
|
|
if size is None: |
|
|
size = max([x for aa in index_dict.values() for x in aa.values()]) + 1 |
|
|
|
|
|
resname = protein_letters_3to1[res.get_resname()] |
|
|
|
|
|
out = np.zeros((size, 3)) |
|
|
for atom in res.get_atoms(): |
|
|
if atom.get_name() in index_dict[resname]: |
|
|
idx = index_dict[resname][atom.get_name()] |
|
|
out[idx] = atom.get_coord() - res['CA'].get_coord() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
def get_normal_modes(res, normal_mode_dict): |
|
|
nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] |
|
|
return nm |
|
|
|
|
|
|
|
|
def get_torsion_angles(res, device=None): |
|
|
""" |
|
|
Return the five chi angles. Missing angles are filled with zeros. |
|
|
""" |
|
|
ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
|
|
|
|
|
ic_res = res.internal_coord |
|
|
chi_angles = [ic_res.get_angle(chi) for chi in ANGLES] |
|
|
chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles] |
|
|
|
|
|
return torch.tensor(chi_angles, device=device) * np.pi / 180 |
|
|
|
|
|
|
|
|
def apply_torsion_angles(res, chi_angles): |
|
|
""" |
|
|
Set side chain torsion angles of a biopython residue object with |
|
|
internal coordinates. |
|
|
""" |
|
|
ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] |
|
|
|
|
|
chi_angles = chi_angles * 180 / np.pi |
|
|
|
|
|
|
|
|
|
|
|
ic_res = res.internal_coord |
|
|
for chi, angle in zip(ANGLES, chi_angles): |
|
|
if ic_res.pick_angle(chi) is None: |
|
|
continue |
|
|
ic_res.bond_set(chi, angle) |
|
|
|
|
|
res.parent.internal_to_atom_coordinates(verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
def prepare_internal_coord(res): |
|
|
|
|
|
|
|
|
new_struct = Structure.Structure('X') |
|
|
new_struct.header = {} |
|
|
new_model = Model.Model(0) |
|
|
new_struct.add(new_model) |
|
|
new_chain = Chain.Chain('X') |
|
|
new_model.add(new_chain) |
|
|
new_chain.add(res) |
|
|
res.set_parent(new_chain) |
|
|
|
|
|
|
|
|
new_chain.atom_to_internal_coordinates() |
|
|
|
|
|
pic_io = io.StringIO() |
|
|
write_PIC(new_struct, pic_io) |
|
|
return pic_io.getvalue() |
|
|
|
|
|
|
|
|
def residue_from_internal_coord(ic_string): |
|
|
pic_io = io.StringIO(ic_string) |
|
|
struct = read_PIC(pic_io, quick=True) |
|
|
res = struct.child_list[0].child_list[0].child_list[0] |
|
|
res.parent.internal_to_atom_coordinates(verbose=False) |
|
|
return res |
|
|
|
|
|
|
|
|
def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder, |
|
|
residue_bond_encoder, pocket_representation='side_chain_bead', |
|
|
compute_nerf_params=False, compute_bb_frames=False, |
|
|
nma_input=None): |
|
|
|
|
|
assert nma_input is None or pocket_representation == 'CA+', \ |
|
|
"vector features are only supported for CA+ pockets" |
|
|
|
|
|
|
|
|
biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1])) |
|
|
|
|
|
if nma_input is not None: |
|
|
|
|
|
if isinstance(nma_input, dict): |
|
|
nma_dict = nma_input |
|
|
|
|
|
|
|
|
else: |
|
|
nma_dict = pdb_to_normal_modes(str(nma_input)) |
|
|
|
|
|
if pocket_representation == 'side_chain_bead': |
|
|
ca_coords = np.zeros((len(biopython_residues), 3)) |
|
|
ca_types = np.zeros(len(biopython_residues), dtype='int64') |
|
|
side_chain_coords = [] |
|
|
side_chain_aa_types = [] |
|
|
edges = [] |
|
|
edge_types = [] |
|
|
last_res_id = None |
|
|
for i, res in enumerate(biopython_residues): |
|
|
aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
|
|
ca_coords[i, :] = res['CA'].get_coord() |
|
|
ca_types[i] = aa |
|
|
side_chain_coord = get_side_chain_bead_coord(res) |
|
|
if side_chain_coord is not None: |
|
|
side_chain_coords.append(side_chain_coord) |
|
|
side_chain_aa_types.append(aa) |
|
|
edges.append((i, len(ca_coords) + len(side_chain_coords) - 1)) |
|
|
edge_types.append(residue_bond_encoder['CA-SS']) |
|
|
|
|
|
|
|
|
if i > 0 and res.id[1] == last_res_id + 1: |
|
|
edges.append((i - 1, i)) |
|
|
edge_types.append(residue_bond_encoder['CA-CA']) |
|
|
|
|
|
last_res_id = res.id[1] |
|
|
|
|
|
|
|
|
side_chain_coords = np.stack(side_chain_coords) |
|
|
pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0) |
|
|
pocket_coords = torch.from_numpy(pocket_coords) |
|
|
|
|
|
|
|
|
amino_acid_onehot = F.one_hot( |
|
|
torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0), |
|
|
num_classes=len(amino_acid_encoder) |
|
|
) |
|
|
side_chain_onehot = np.concatenate([ |
|
|
np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']), |
|
|
[len(ca_coords), 1]), |
|
|
np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']), |
|
|
[len(side_chain_coords), 1]) |
|
|
], axis=0) |
|
|
side_chain_onehot = torch.from_numpy(side_chain_onehot) |
|
|
pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1) |
|
|
|
|
|
vector_features = None |
|
|
nma_features = None |
|
|
|
|
|
|
|
|
edges = torch.tensor(edges).T |
|
|
edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) |
|
|
|
|
|
elif pocket_representation == 'CA+': |
|
|
ca_coords = np.zeros((len(biopython_residues), 3)) |
|
|
ca_types = np.zeros(len(biopython_residues), dtype='int64') |
|
|
|
|
|
v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 |
|
|
vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32') |
|
|
nf_nma = 5 |
|
|
nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32') |
|
|
|
|
|
edges = [] |
|
|
edge_types = [] |
|
|
last_res_id = None |
|
|
for i, res in enumerate(biopython_residues): |
|
|
aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] |
|
|
ca_coords[i, :] = res['CA'].get_coord() |
|
|
ca_types[i] = aa |
|
|
|
|
|
vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim) |
|
|
if nma_input is not None: |
|
|
nma_feats[i] = get_normal_modes(res, nma_dict) |
|
|
|
|
|
|
|
|
if i > 0 and res.id[1] == last_res_id + 1: |
|
|
edges.append((i - 1, i)) |
|
|
edge_types.append(residue_bond_encoder['CA-CA']) |
|
|
|
|
|
last_res_id = res.id[1] |
|
|
|
|
|
|
|
|
pocket_coords = torch.from_numpy(ca_coords) |
|
|
|
|
|
|
|
|
pocket_onehot = F.one_hot(torch.from_numpy(ca_types), |
|
|
num_classes=len(amino_acid_encoder)) |
|
|
|
|
|
vector_features = torch.from_numpy(vec_feats) |
|
|
nma_features = torch.from_numpy(nma_feats) |
|
|
|
|
|
|
|
|
if len(edges) < 1: |
|
|
edges = torch.empty(2, 0) |
|
|
edge_types = torch.empty(0, len(residue_bond_encoder)) |
|
|
else: |
|
|
edges = torch.tensor(edges).T |
|
|
edge_types = F.one_hot(torch.tensor(edge_types), |
|
|
num_classes=len(residue_bond_encoder)) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Pocket representation '{pocket_representation}' not implemented") |
|
|
|
|
|
|
|
|
|
|
|
pocket = { |
|
|
'x': pocket_coords.to(dtype=FLOAT_TYPE), |
|
|
'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE), |
|
|
|
|
|
'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE), |
|
|
'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE), |
|
|
'bonds': edges.to(INT_TYPE), |
|
|
'bond_one_hot': edge_types.to(FLOAT_TYPE), |
|
|
'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE), |
|
|
'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE), |
|
|
} |
|
|
|
|
|
if vector_features is not None: |
|
|
pocket['v'] = vector_features.to(dtype=FLOAT_TYPE) |
|
|
|
|
|
if nma_input is not None: |
|
|
pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE) |
|
|
|
|
|
if compute_nerf_params: |
|
|
nerf_params = [get_nerf_params(r) for r in biopython_residues] |
|
|
nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0) |
|
|
for k in nerf_params[0].keys()} |
|
|
pocket.update(nerf_params) |
|
|
|
|
|
if compute_bb_frames: |
|
|
n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues])) |
|
|
ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues])) |
|
|
c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues])) |
|
|
pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz) |
|
|
|
|
|
return pocket, biopython_residues |
|
|
|
|
|
|
|
|
def encode_atom(rd_atom, atom_encoder): |
|
|
element = rd_atom.GetSymbol().capitalize() |
|
|
|
|
|
explicitHs = rd_atom.GetNumExplicitHs() |
|
|
if explicitHs == 1 and f'{element}H' in atom_encoder: |
|
|
return atom_encoder[f'{element}H'] |
|
|
|
|
|
charge = rd_atom.GetFormalCharge() |
|
|
if charge == 1 and f'{element}+' in atom_encoder: |
|
|
return atom_encoder[f'{element}+'] |
|
|
if charge == -1 and f'{element}-' in atom_encoder: |
|
|
return atom_encoder[f'{element}-'] |
|
|
|
|
|
return atom_encoder[element] |
|
|
|
|
|
|
|
|
def prepare_ligand(rdmol, atom_encoder, bond_encoder): |
|
|
|
|
|
|
|
|
if 'H' not in atom_encoder: |
|
|
rdmol = Chem.RemoveAllHs(rdmol, sanitize=False) |
|
|
|
|
|
|
|
|
ligand_coord = rdmol.GetConformer().GetPositions() |
|
|
ligand_coord = torch.from_numpy(ligand_coord) |
|
|
|
|
|
|
|
|
ligand_onehot = F.one_hot( |
|
|
torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]), |
|
|
num_classes=len(atom_encoder) |
|
|
) |
|
|
|
|
|
|
|
|
adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND'] |
|
|
for b in rdmol.GetBonds(): |
|
|
i = b.GetBeginAtomIdx() |
|
|
j = b.GetEndAtomIdx() |
|
|
adj[i, j] = bond_encoder[str(b.GetBondType())] |
|
|
adj[j, i] = adj[i, j] |
|
|
|
|
|
|
|
|
bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0) |
|
|
|
|
|
bond_types = adj[bonds[0], bonds[1]].astype('int64') |
|
|
bonds = torch.from_numpy(bonds) |
|
|
bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder)) |
|
|
|
|
|
ligand = { |
|
|
'x': ligand_coord.to(dtype=FLOAT_TYPE), |
|
|
'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE), |
|
|
'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE), |
|
|
'bonds': bonds.to(INT_TYPE), |
|
|
'bond_one_hot': bond_types.to(FLOAT_TYPE), |
|
|
'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE), |
|
|
'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE), |
|
|
'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE), |
|
|
} |
|
|
|
|
|
return ligand |
|
|
|
|
|
|
|
|
def process_raw_molecule_with_empty_pocket(rdmol): |
|
|
ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
|
|
pocket = { |
|
|
'x': torch.tensor([], dtype=FLOAT_TYPE), |
|
|
'one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
|
|
'size': torch.tensor([], dtype=INT_TYPE), |
|
|
'mask': torch.tensor([], dtype=INT_TYPE), |
|
|
'bonds': torch.tensor([], dtype=INT_TYPE), |
|
|
'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE), |
|
|
'bond_mask': torch.tensor([], dtype=INT_TYPE), |
|
|
'n_bonds': torch.tensor([], dtype=INT_TYPE), |
|
|
} |
|
|
return ligand, pocket |
|
|
|
|
|
|
|
|
def process_raw_pair(biopython_model, rdmol, dist_cutoff=None, |
|
|
pocket_representation='side_chain_bead', |
|
|
compute_nerf_params=False, compute_bb_frames=False, |
|
|
nma_input=None, return_pocket_pdb=False): |
|
|
|
|
|
|
|
|
ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) |
|
|
|
|
|
|
|
|
pocket_residues = [] |
|
|
for residue in biopython_model.get_residues(): |
|
|
|
|
|
|
|
|
if not is_aa(residue.get_resname(), standard=True): |
|
|
continue |
|
|
|
|
|
res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()])) |
|
|
if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: |
|
|
pocket_residues.append(residue) |
|
|
|
|
|
pocket, pocket_residues = prepare_pocket( |
|
|
pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder, |
|
|
pocket_representation, compute_nerf_params, compute_bb_frames, nma_input |
|
|
) |
|
|
|
|
|
if return_pocket_pdb: |
|
|
builder = StructureBuilder.StructureBuilder() |
|
|
builder.init_structure("") |
|
|
builder.init_model(0) |
|
|
pocket_struct = builder.get_structure() |
|
|
for residue in pocket_residues: |
|
|
chain = residue.get_parent().get_id() |
|
|
|
|
|
|
|
|
if not pocket_struct[0].has_id(chain): |
|
|
builder.init_chain(chain) |
|
|
|
|
|
|
|
|
pocket_struct[0][chain].add(residue) |
|
|
|
|
|
pocket['pocket_pdb'] = pocket_struct |
|
|
|
|
|
|
|
|
|
|
|
return ligand, pocket |
|
|
|
|
|
|
|
|
class AppendVirtualNodes: |
|
|
def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0): |
|
|
self.max_size = max_ligand_size |
|
|
self.atom_encoder = atom_encoder |
|
|
self.bond_encoder = bond_encoder |
|
|
self.vidx = atom_encoder['NOATOM'] |
|
|
self.bidx = bond_encoder['NOBOND'] |
|
|
self.scale = scale |
|
|
|
|
|
def __call__(self, ligand, max_size=None, eps=1e-6): |
|
|
if max_size is None: |
|
|
max_size = self.max_size |
|
|
|
|
|
n_virt = max_size - ligand['size'] |
|
|
|
|
|
C = torch.cov(ligand['x'].T) |
|
|
L = torch.linalg.cholesky(C + torch.eye(3) * eps) |
|
|
mu = ligand['x'].mean(0, keepdim=True) |
|
|
virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale |
|
|
|
|
|
|
|
|
virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
|
|
virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
|
|
|
|
|
ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
|
|
ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
|
|
ligand['virtual_mask'] = virt_mask |
|
|
ligand['size'] = max_size |
|
|
|
|
|
|
|
|
new_bonds = torch.triu_indices(max_size, max_size, offset=1) |
|
|
|
|
|
bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx |
|
|
row, col = ligand['bonds'] |
|
|
bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
|
|
new_row, new_col = new_bonds |
|
|
bond_types = bond_types[new_row, new_col] |
|
|
|
|
|
ligand['bonds'] = new_bonds |
|
|
ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
|
|
ligand['n_bonds'] = len(ligand['bond_one_hot']) |
|
|
|
|
|
return ligand |
|
|
|
|
|
|
|
|
class AppendVirtualNodesInCoM: |
|
|
def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10): |
|
|
self.atom_encoder = atom_encoder |
|
|
self.bond_encoder = bond_encoder |
|
|
self.vidx = atom_encoder['NOATOM'] |
|
|
self.bidx = bond_encoder['NOBOND'] |
|
|
self.add_min = add_min |
|
|
self.add_max = add_max |
|
|
|
|
|
def __call__(self, ligand): |
|
|
|
|
|
n_virt = random.randint(self.add_min, self.add_max) |
|
|
|
|
|
|
|
|
virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1) |
|
|
|
|
|
|
|
|
virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) |
|
|
virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) |
|
|
|
|
|
ligand['x'] = torch.cat([ligand['x'], virt_coords]) |
|
|
ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) |
|
|
ligand['virtual_mask'] = virt_mask |
|
|
ligand['size'] = len(ligand['x']) |
|
|
|
|
|
|
|
|
new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1) |
|
|
|
|
|
bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx |
|
|
row, col = ligand['bonds'] |
|
|
bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) |
|
|
new_row, new_col = new_bonds |
|
|
bond_types = bond_types[new_row, new_col] |
|
|
|
|
|
ligand['bonds'] = new_bonds |
|
|
ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) |
|
|
ligand['n_bonds'] = len(ligand['bond_one_hot']) |
|
|
|
|
|
return ligand |
|
|
|
|
|
|
|
|
def rdmol_to_smiles(rdmol): |
|
|
mol = Chem.Mol(rdmol) |
|
|
Chem.RemoveStereochemistry(mol) |
|
|
mol = Chem.RemoveHs(mol) |
|
|
return Chem.MolToSmiles(mol) |
|
|
|
|
|
|
|
|
def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None): |
|
|
|
|
|
n_nodes_lig = [len(x) for x in lig_positions] |
|
|
n_nodes_pocket = [len(x) for x in pocket_positions] |
|
|
|
|
|
joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, |
|
|
np.max(n_nodes_pocket) + 1)) |
|
|
|
|
|
for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): |
|
|
joint_histogram[nlig, npocket] += 1 |
|
|
|
|
|
print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' |
|
|
f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') |
|
|
|
|
|
|
|
|
if smooth_sigma is not None: |
|
|
filtered_histogram = gaussian_filter( |
|
|
joint_histogram, sigma=smooth_sigma, order=0, mode='constant', |
|
|
cval=0.0, truncate=4.0) |
|
|
|
|
|
print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' |
|
|
f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') |
|
|
|
|
|
joint_histogram = filtered_histogram |
|
|
|
|
|
return joint_histogram |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_type_histogram(one_hot, type_encoder): |
|
|
|
|
|
one_hot = np.concatenate(one_hot, axis=0) |
|
|
|
|
|
decoder = list(type_encoder.keys()) |
|
|
counts = {k: 0 for k in type_encoder.keys()} |
|
|
for a in [decoder[x] for x in one_hot.argmax(1)]: |
|
|
counts[a] += 1 |
|
|
|
|
|
return counts |
|
|
|
|
|
|
|
|
def get_residue_with_resi(pdb_chain, resi): |
|
|
res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] |
|
|
assert len(res) == 1 |
|
|
return res[0] |
|
|
|
|
|
|
|
|
def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0): |
|
|
|
|
|
if ligand.endswith(".sdf"): |
|
|
|
|
|
rdmol = Chem.SDMolSupplier(str(ligand))[0] |
|
|
ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float() |
|
|
resi = None |
|
|
else: |
|
|
|
|
|
chain, resi = ligand.split(':') |
|
|
ligand = get_residue_with_resi(pdb_model[chain], int(resi)) |
|
|
ligand_coords = torch.from_numpy( |
|
|
np.array([a.get_coord() for a in ligand.get_atoms()])) |
|
|
|
|
|
pocket_residues = [] |
|
|
for residue in pdb_model.get_residues(): |
|
|
if residue.id[1] == resi: |
|
|
continue |
|
|
|
|
|
res_coords = torch.from_numpy( |
|
|
np.array([a.get_coord() for a in residue.get_atoms()])) |
|
|
if is_aa(residue.get_resname(), standard=True) \ |
|
|
and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff: |
|
|
pocket_residues.append(residue) |
|
|
|
|
|
return pocket_residues |
|
|
|
|
|
|
|
|
def encode_residues(biopython_residues, type_encoder, level='atom', |
|
|
remove_H=True): |
|
|
assert level in {'atom', 'residue'} |
|
|
|
|
|
if level == 'atom': |
|
|
entities = [a for res in biopython_residues for a in res.get_atoms() |
|
|
if (a.element != 'H' or not remove_H)] |
|
|
types = [a.element.capitalize() for a in entities] |
|
|
else: |
|
|
entities = [res['CA'] for res in biopython_residues] |
|
|
types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues] |
|
|
|
|
|
coord = torch.tensor(np.stack([e.get_coord() for e in entities])) |
|
|
one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]), |
|
|
num_classes=len(type_encoder)) |
|
|
|
|
|
return coord, one_hot |
|
|
|
|
|
|
|
|
def center_data(ligand, pocket): |
|
|
if pocket['x'].numel() > 0: |
|
|
pocket_com = pocket.center() |
|
|
else: |
|
|
pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) |
|
|
|
|
|
ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] |
|
|
return ligand, pocket |
|
|
|
|
|
|
|
|
def get_bb_transform(n_xyz, ca_xyz, c_xyz): |
|
|
""" |
|
|
Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with |
|
|
Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame |
|
|
|
|
|
Args: |
|
|
n_xyz: (n, 3) |
|
|
ca_xyz: (n, 3) |
|
|
c_xyz: (n, 3) |
|
|
|
|
|
Returns: |
|
|
axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3) |
|
|
translation vector of shape (n, 3) |
|
|
""" |
|
|
|
|
|
def rotation_matrix(angle, axis): |
|
|
axis_mapping = {'x': 0, 'y': 1, 'z': 2} |
|
|
axis = axis_mapping[axis] |
|
|
vector = torch.zeros(len(angle), 3) |
|
|
vector[:, axis] = 1 |
|
|
|
|
|
return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector) |
|
|
|
|
|
translation = ca_xyz |
|
|
n_xyz = n_xyz - translation |
|
|
c_xyz = c_xyz - translation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) |
|
|
Ry = rotation_matrix(theta_y, 'y') |
|
|
Ry = Ry.transpose(2, 1) |
|
|
n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz) |
|
|
|
|
|
|
|
|
theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0]) |
|
|
Rz = rotation_matrix(theta_z, 'z') |
|
|
Rz = Rz.transpose(2, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz) |
|
|
theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1]) |
|
|
Rx = rotation_matrix(theta_x, 'x') |
|
|
Rx = Rx.transpose(2, 1) |
|
|
|
|
|
|
|
|
|
|
|
Ry = Ry.transpose(2, 1) |
|
|
Rz = Rz.transpose(2, 1) |
|
|
Rx = Rx.transpose(2, 1) |
|
|
R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx) |
|
|
|
|
|
|
|
|
|
|
|
return so3.rotation_vector_from_matrix(R), translation |
|
|
|
|
|
|
|
|
class Residues(TensorDict): |
|
|
""" |
|
|
Dictionary-like container for residues that supports some basic transformations. |
|
|
""" |
|
|
|
|
|
|
|
|
KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord', |
|
|
'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', |
|
|
'chi_indices', 'axis_angle', 'mask', 'bond_mask'} |
|
|
|
|
|
|
|
|
COORD_KEYS = {'x', 'fixed_coord'} |
|
|
|
|
|
|
|
|
VECTOR_KEYS = {'v', 'nma_vec'} |
|
|
|
|
|
|
|
|
MUTABLE_PROPS_SS_AND_BB = {'v'} |
|
|
|
|
|
|
|
|
MUTABLE_PROPS_SS = {'chi'} |
|
|
|
|
|
|
|
|
MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'} |
|
|
|
|
|
|
|
|
IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask', |
|
|
'atom_mask', 'nerf_indices', 'length', 'theta', |
|
|
'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'} |
|
|
|
|
|
def copy(self): |
|
|
data = super().copy() |
|
|
return Residues(**data) |
|
|
|
|
|
def deepcopy(self): |
|
|
data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v) |
|
|
for k, v in self.items()} |
|
|
return Residues(**data) |
|
|
|
|
|
def center(self): |
|
|
com = scatter_mean(self['x'], self['mask'], dim=0) |
|
|
self['x'] = self['x'] - com[self['mask']] |
|
|
self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1) |
|
|
return com |
|
|
|
|
|
def set_empty_v(self): |
|
|
self['v'] = torch.tensor([], device=self['x'].device) |
|
|
|
|
|
@torch.no_grad() |
|
|
def set_chi(self, chi_angles): |
|
|
self['chi'][:, :5] = chi_angles |
|
|
nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask', |
|
|
'nerf_indices', 'length', 'theta', |
|
|
'chi', 'ddihedral', 'chi_indices']} |
|
|
self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1) |
|
|
|
|
|
@torch.no_grad() |
|
|
def set_frame(self, new_ca_coord, new_axis_angle): |
|
|
bb_coord = self['fixed_coord'] |
|
|
bb_coord = bb_coord - self['x'].unsqueeze(1) |
|
|
rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle']) |
|
|
rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle) |
|
|
rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2) |
|
|
bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord) |
|
|
bb_coord = bb_coord + new_ca_coord.unsqueeze(1) |
|
|
|
|
|
self['x'] = new_ca_coord |
|
|
self['axis_angle'] = new_axis_angle |
|
|
self['fixed_coord'] = bb_coord |
|
|
self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v']) |
|
|
|
|
|
@staticmethod |
|
|
def empty(device): |
|
|
return Residues( |
|
|
x=torch.zeros(1, 3, device=device).float(), |
|
|
mask=torch.zeros(1, 1, device=device).long(), |
|
|
size=torch.zeros(1, device=device).long(), |
|
|
) |
|
|
|
|
|
|
|
|
def randomize_tensors(tensor_dict, exclude_keys=None): |
|
|
"""Replace tensors with random tensors with the same shape.""" |
|
|
exclude_keys = set() if exclude_keys is None else set(exclude_keys) |
|
|
for k, v in tensor_dict.items(): |
|
|
if isinstance(v, torch.Tensor) and k not in exclude_keys: |
|
|
if torch.is_floating_point(v): |
|
|
tensor_dict[k] = torch.randn_like(v) |
|
|
else: |
|
|
tensor_dict[k] = torch.randint_like(v, low=-42, high=42) |
|
|
return tensor_dict |
|
|
|