import io from itertools import accumulate, chain from copy import deepcopy import random import torch import torch.nn.functional as F import numpy as np from rdkit import Chem from torch_scatter import scatter_mean from Bio.PDB import StructureBuilder, Chain, Model, Structure from Bio.PDB.PICIO import read_PIC, write_PIC from scipy.ndimage import gaussian_filter from pdb import set_trace from src.constants import FLOAT_TYPE, INT_TYPE from src.constants import atom_encoder, bond_encoder, aa_encoder, residue_encoder, residue_bond_encoder, aa_atom_index from src import utils from src.data.misc import protein_letters_3to1, is_aa from src.data.normal_modes import pdb_to_normal_modes from src.data.nerf import get_nerf_params, ic_to_coords import src.data.so3_utils as so3 class TensorDict(dict): def __init__(self, **kwargs): super(TensorDict, self).__init__(**kwargs) def _apply(self, func: str, *args, **kwargs): """ Apply function to all tensors. """ for k, v in self.items(): if torch.is_tensor(v): self[k] = getattr(v, func)(*args, **kwargs) return self # def to(self, device): # for k, v in self.items(): # if torch.is_tensor(v): # self[k] = v.to(device) # return self def cuda(self): return self.to('cuda') def cpu(self): return self.to('cpu') def to(self, device): return self._apply("to", device) def detach(self): return self._apply("detach") def __repr__(self): def val_to_str(val): if isinstance(val, torch.Tensor): # if val.isnan().any(): # return "(!nan)" return "%r" % list(val.size()) if isinstance(val, list): return "[%r,]" % len(val) else: return "?" return f"{type(self).__name__}({', '.join(f'{k}={val_to_str(v)}' for k, v in self.items())})" def collate_entity(batch): out = {} for prop in batch[0].keys(): if prop == 'name': out[prop] = [x[prop] for x in batch] elif prop == 'size' or prop == 'n_bonds': out[prop] = torch.tensor([x[prop] for x in batch]) elif prop == 'bonds': # index offset offset = list(accumulate([x['size'] for x in batch], initial=0)) out[prop] = torch.cat([x[prop] + offset[i] for i, x in enumerate(batch)], dim=1) elif prop == 'residues': out[prop] = list(chain.from_iterable(x[prop] for x in batch)) elif prop in {'mask', 'bond_mask'}: pass # batch masks will be written later else: out[prop] = torch.cat([x[prop] for x in batch], dim=0) # Create batch masks # make sure indices in batch start at zero (needed for torch_scatter) if prop == 'x': out['mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) for i, x in enumerate(batch)], dim=0) if prop == 'bond_one_hot': # TODO: this is not necessary as it can be computed on-the-fly as bond_mask = mask[bonds[0]] or bond_mask = mask[bonds[1]] out['bond_mask'] = torch.cat([i * torch.ones(len(x[prop]), dtype=torch.int64, device=x[prop].device) for i, x in enumerate(batch)], dim=0) return out def split_entity( batch, *, index_types={'bonds'}, edge_types={'bond_one_hot', 'bond_mask'}, no_split={'name', 'size', 'n_bonds'}, skip={'fragments'}, batch_mask=None, edge_mask=None ): """ Splits a batch into items and returns a list. """ batch_mask = batch["mask"] if batch_mask is None else batch_mask edge_mask = batch["bond_mask"] if edge_mask is None else edge_mask sizes = batch['size'] if 'size' in batch else torch.unique(batch_mask, return_counts=True)[1].tolist() batch_size = len(torch.unique(batch['mask'])) out = {} for prop in batch.keys(): if prop in skip: continue if prop in no_split: out[prop] = batch[prop] # already a list elif prop in index_types: offsets = list(accumulate(sizes[:-1], initial=0)) out[prop] = utils.batch_to_list_for_indices(batch[prop], edge_mask, offsets) elif prop in edge_types: out[prop] = utils.batch_to_list(batch[prop], edge_mask) else: out[prop] = utils.batch_to_list(batch[prop], batch_mask) out = [{k: v[i] for k, v in out.items()} for i in range(batch_size)] return out def repeat_items(batch, repeats): batch_list = split_entity(batch) out = collate_entity([x for _ in range(repeats) for x in batch_list]) return type(batch)(**out) def get_side_chain_bead_coord(biopython_residue): """ Places side chain bead at the location of the farthest side chain atom. """ if biopython_residue.get_resname() == 'GLY': return None if biopython_residue.get_resname() == 'ALA': return biopython_residue['CB'].get_coord() ca_coord = biopython_residue['CA'].get_coord() side_chain_atoms = [a for a in biopython_residue.get_atoms() if a.id not in {'N', 'CA', 'C', 'O'} and a.element != 'H'] side_chain_coords = np.stack([a.get_coord() for a in side_chain_atoms]) atom_idx = np.argmax(np.sum((side_chain_coords - ca_coord[None, :]) ** 2, axis=-1)) return side_chain_coords[atom_idx, :] def get_side_chain_vectors(res, index_dict, size=None): if size is None: size = max([x for aa in index_dict.values() for x in aa.values()]) + 1 resname = protein_letters_3to1[res.get_resname()] out = np.zeros((size, 3)) for atom in res.get_atoms(): if atom.get_name() in index_dict[resname]: idx = index_dict[resname][atom.get_name()] out[idx] = atom.get_coord() - res['CA'].get_coord() # else: # if atom.get_name() != 'CA' and not atom.get_name().startswith('H'): # print(resname, atom.get_name()) return out def get_normal_modes(res, normal_mode_dict): nm = normal_mode_dict[(res.get_parent().id, res.id[1], 'CA')] # (n_modes, 3) return nm def get_torsion_angles(res, device=None): """ Return the five chi angles. Missing angles are filled with zeros. """ ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] ic_res = res.internal_coord chi_angles = [ic_res.get_angle(chi) for chi in ANGLES] chi_angles = [chi if chi is not None else float('nan') for chi in chi_angles] return torch.tensor(chi_angles, device=device) * np.pi / 180 def apply_torsion_angles(res, chi_angles): """ Set side chain torsion angles of a biopython residue object with internal coordinates. """ ANGLES = ['chi1', 'chi2', 'chi3', 'chi4', 'chi5'] chi_angles = chi_angles * 180 / np.pi # res.parent.internal_coord.build_atomArray() # rebuild atom pointers ic_res = res.internal_coord for chi, angle in zip(ANGLES, chi_angles): if ic_res.pick_angle(chi) is None: continue ic_res.bond_set(chi, angle) res.parent.internal_to_atom_coordinates(verbose=False) # res.parent.internal_coord.init_atom_coords() # res.internal_coord.assemble() return res def prepare_internal_coord(res): # Make new structure with a single residue new_struct = Structure.Structure('X') new_struct.header = {} new_model = Model.Model(0) new_struct.add(new_model) new_chain = Chain.Chain('X') new_model.add(new_chain) new_chain.add(res) res.set_parent(new_chain) # update pointer # Compute internal coordinates new_chain.atom_to_internal_coordinates() pic_io = io.StringIO() write_PIC(new_struct, pic_io) return pic_io.getvalue() def residue_from_internal_coord(ic_string): pic_io = io.StringIO(ic_string) struct = read_PIC(pic_io, quick=True) res = struct.child_list[0].child_list[0].child_list[0] res.parent.internal_to_atom_coordinates(verbose=False) return res def prepare_pocket(biopython_residues, amino_acid_encoder, residue_encoder, residue_bond_encoder, pocket_representation='side_chain_bead', compute_nerf_params=False, compute_bb_frames=False, nma_input=None): assert nma_input is None or pocket_representation == 'CA+', \ "vector features are only supported for CA+ pockets" # sort residues biopython_residues = sorted(biopython_residues, key=lambda x: (x.parent.id, x.id[1])) if nma_input is not None: # preprocessed normal mode eigenvectors if isinstance(nma_input, dict): nma_dict = nma_input # PDB file else: nma_dict = pdb_to_normal_modes(str(nma_input)) if pocket_representation == 'side_chain_bead': ca_coords = np.zeros((len(biopython_residues), 3)) ca_types = np.zeros(len(biopython_residues), dtype='int64') side_chain_coords = [] side_chain_aa_types = [] edges = [] # CA-CA and CA-side_chain edge_types = [] last_res_id = None for i, res in enumerate(biopython_residues): aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] ca_coords[i, :] = res['CA'].get_coord() ca_types[i] = aa side_chain_coord = get_side_chain_bead_coord(res) if side_chain_coord is not None: side_chain_coords.append(side_chain_coord) side_chain_aa_types.append(aa) edges.append((i, len(ca_coords) + len(side_chain_coords) - 1)) edge_types.append(residue_bond_encoder['CA-SS']) # add edges between contiguous CA atoms if i > 0 and res.id[1] == last_res_id + 1: edges.append((i - 1, i)) edge_types.append(residue_bond_encoder['CA-CA']) last_res_id = res.id[1] # Coordinates side_chain_coords = np.stack(side_chain_coords) pocket_coords = np.concatenate([ca_coords, side_chain_coords], axis=0) pocket_coords = torch.from_numpy(pocket_coords) # Features amino_acid_onehot = F.one_hot( torch.cat([torch.from_numpy(ca_types), torch.tensor(side_chain_aa_types, dtype=torch.int64)], dim=0), num_classes=len(amino_acid_encoder) ) side_chain_onehot = np.concatenate([ np.tile(np.eye(1, len(residue_encoder), residue_encoder['CA']), [len(ca_coords), 1]), np.tile(np.eye(1, len(residue_encoder), residue_encoder['SS']), [len(side_chain_coords), 1]) ], axis=0) side_chain_onehot = torch.from_numpy(side_chain_onehot) pocket_onehot = torch.cat([amino_acid_onehot, side_chain_onehot], dim=1) vector_features = None nma_features = None # Bonds edges = torch.tensor(edges).T edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) elif pocket_representation == 'CA+': ca_coords = np.zeros((len(biopython_residues), 3)) ca_types = np.zeros(len(biopython_residues), dtype='int64') v_dim = max([x for aa in aa_atom_index.values() for x in aa.values()]) + 1 vec_feats = np.zeros((len(biopython_residues), v_dim, 3), dtype='float32') nf_nma = 5 nma_feats = np.zeros((len(biopython_residues), nf_nma, 3), dtype='float32') edges = [] # CA-CA and CA-side_chain edge_types = [] last_res_id = None for i, res in enumerate(biopython_residues): aa = amino_acid_encoder[protein_letters_3to1[res.get_resname()]] ca_coords[i, :] = res['CA'].get_coord() ca_types[i] = aa vec_feats[i] = get_side_chain_vectors(res, aa_atom_index, v_dim) if nma_input is not None: nma_feats[i] = get_normal_modes(res, nma_dict) # add edges between contiguous CA atoms if i > 0 and res.id[1] == last_res_id + 1: edges.append((i - 1, i)) edge_types.append(residue_bond_encoder['CA-CA']) last_res_id = res.id[1] # Coordinates pocket_coords = torch.from_numpy(ca_coords) # Features pocket_onehot = F.one_hot(torch.from_numpy(ca_types), num_classes=len(amino_acid_encoder)) vector_features = torch.from_numpy(vec_feats) nma_features = torch.from_numpy(nma_feats) # Bonds if len(edges) < 1: edges = torch.empty(2, 0) edge_types = torch.empty(0, len(residue_bond_encoder)) else: edges = torch.tensor(edges).T edge_types = F.one_hot(torch.tensor(edge_types), num_classes=len(residue_bond_encoder)) else: raise NotImplementedError( f"Pocket representation '{pocket_representation}' not implemented") # pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in biopython_residues] pocket = { 'x': pocket_coords.to(dtype=FLOAT_TYPE), 'one_hot': pocket_onehot.to(dtype=FLOAT_TYPE), # 'ids': pocket_ids, 'size': torch.tensor([len(pocket_coords)], dtype=INT_TYPE), 'mask': torch.zeros(len(pocket_coords), dtype=INT_TYPE), 'bonds': edges.to(INT_TYPE), 'bond_one_hot': edge_types.to(FLOAT_TYPE), 'bond_mask': torch.zeros(edges.size(1), dtype=INT_TYPE), 'n_bonds': torch.tensor([len(edge_types)], dtype=INT_TYPE), } if vector_features is not None: pocket['v'] = vector_features.to(dtype=FLOAT_TYPE) if nma_input is not None: pocket['nma_vec'] = nma_features.to(dtype=FLOAT_TYPE) if compute_nerf_params: nerf_params = [get_nerf_params(r) for r in biopython_residues] nerf_params = {k: torch.stack([x[k] for x in nerf_params], dim=0) for k in nerf_params[0].keys()} pocket.update(nerf_params) if compute_bb_frames: n_xyz = torch.from_numpy(np.stack([r['N'].get_coord() for r in biopython_residues])) ca_xyz = torch.from_numpy(np.stack([r['CA'].get_coord() for r in biopython_residues])) c_xyz = torch.from_numpy(np.stack([r['C'].get_coord() for r in biopython_residues])) pocket['axis_angle'], _ = get_bb_transform(n_xyz, ca_xyz, c_xyz) return pocket, biopython_residues def encode_atom(rd_atom, atom_encoder): element = rd_atom.GetSymbol().capitalize() explicitHs = rd_atom.GetNumExplicitHs() if explicitHs == 1 and f'{element}H' in atom_encoder: return atom_encoder[f'{element}H'] charge = rd_atom.GetFormalCharge() if charge == 1 and f'{element}+' in atom_encoder: return atom_encoder[f'{element}+'] if charge == -1 and f'{element}-' in atom_encoder: return atom_encoder[f'{element}-'] return atom_encoder[element] def prepare_ligand(rdmol, atom_encoder, bond_encoder): # remove H atoms if not in atom_encoder if 'H' not in atom_encoder: rdmol = Chem.RemoveAllHs(rdmol, sanitize=False) # Coordinates ligand_coord = rdmol.GetConformer().GetPositions() ligand_coord = torch.from_numpy(ligand_coord) # Features ligand_onehot = F.one_hot( torch.tensor([encode_atom(a, atom_encoder) for a in rdmol.GetAtoms()]), num_classes=len(atom_encoder) ) # Bonds adj = np.ones((rdmol.GetNumAtoms(), rdmol.GetNumAtoms())) * bond_encoder['NOBOND'] for b in rdmol.GetBonds(): i = b.GetBeginAtomIdx() j = b.GetEndAtomIdx() adj[i, j] = bond_encoder[str(b.GetBondType())] adj[j, i] = adj[i, j] # undirected graph # molecular graph is undirected -> don't save redundant information bonds = np.stack(np.triu_indices(len(ligand_coord), k=1), axis=0) # bonds = np.stack(np.ones_like(adj).nonzero(), axis=0) bond_types = adj[bonds[0], bonds[1]].astype('int64') bonds = torch.from_numpy(bonds) bond_types = F.one_hot(torch.from_numpy(bond_types), num_classes=len(bond_encoder)) ligand = { 'x': ligand_coord.to(dtype=FLOAT_TYPE), 'one_hot': ligand_onehot.to(dtype=FLOAT_TYPE), 'mask': torch.zeros(len(ligand_coord), dtype=INT_TYPE), 'bonds': bonds.to(INT_TYPE), 'bond_one_hot': bond_types.to(FLOAT_TYPE), 'bond_mask': torch.zeros(bonds.size(1), dtype=INT_TYPE), 'size': torch.tensor([len(ligand_coord)], dtype=INT_TYPE), 'n_bonds': torch.tensor([len(bond_types)], dtype=INT_TYPE), } return ligand def process_raw_molecule_with_empty_pocket(rdmol): ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) pocket = { 'x': torch.tensor([], dtype=FLOAT_TYPE), 'one_hot': torch.tensor([], dtype=FLOAT_TYPE), 'size': torch.tensor([], dtype=INT_TYPE), 'mask': torch.tensor([], dtype=INT_TYPE), 'bonds': torch.tensor([], dtype=INT_TYPE), 'bond_one_hot': torch.tensor([], dtype=FLOAT_TYPE), 'bond_mask': torch.tensor([], dtype=INT_TYPE), 'n_bonds': torch.tensor([], dtype=INT_TYPE), } return ligand, pocket def process_raw_pair(biopython_model, rdmol, dist_cutoff=None, pocket_representation='side_chain_bead', compute_nerf_params=False, compute_bb_frames=False, nma_input=None, return_pocket_pdb=False): # Process ligand ligand = prepare_ligand(rdmol, atom_encoder, bond_encoder) # Find interacting pocket residues based on distance cutoff pocket_residues = [] for residue in biopython_model.get_residues(): # Remove non-standard amino acids and HETATMs if not is_aa(residue.get_resname(), standard=True): continue res_coords = torch.from_numpy(np.array([a.get_coord() for a in residue.get_atoms()])) if dist_cutoff is None or (((res_coords[:, None, :] - ligand['x'][None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: pocket_residues.append(residue) pocket, pocket_residues = prepare_pocket( pocket_residues, aa_encoder, residue_encoder, residue_bond_encoder, pocket_representation, compute_nerf_params, compute_bb_frames, nma_input ) if return_pocket_pdb: builder = StructureBuilder.StructureBuilder() builder.init_structure("") builder.init_model(0) pocket_struct = builder.get_structure() for residue in pocket_residues: chain = residue.get_parent().get_id() # init chain if necessary if not pocket_struct[0].has_id(chain): builder.init_chain(chain) # add residue pocket_struct[0][chain].add(residue) pocket['pocket_pdb'] = pocket_struct # if return_pocket_pdb: # pocket['residues'] = [prepare_internal_coord(res) for res in pocket_residues] return ligand, pocket class AppendVirtualNodes: def __init__(self, atom_encoder, bond_encoder, max_ligand_size, scale=1.0): self.max_size = max_ligand_size self.atom_encoder = atom_encoder self.bond_encoder = bond_encoder self.vidx = atom_encoder['NOATOM'] self.bidx = bond_encoder['NOBOND'] self.scale = scale def __call__(self, ligand, max_size=None, eps=1e-6): if max_size is None: max_size = self.max_size n_virt = max_size - ligand['size'] C = torch.cov(ligand['x'].T) L = torch.linalg.cholesky(C + torch.eye(3) * eps) mu = ligand['x'].mean(0, keepdim=True) virt_coords = mu + torch.randn(n_virt, 3) @ L.T * self.scale # insert virtual atom column virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) ligand['x'] = torch.cat([ligand['x'], virt_coords]) ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) ligand['virtual_mask'] = virt_mask ligand['size'] = max_size # Bonds new_bonds = torch.triu_indices(max_size, max_size, offset=1) bond_types = torch.ones(max_size, max_size, dtype=INT_TYPE) * self.bidx row, col = ligand['bonds'] bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) new_row, new_col = new_bonds bond_types = bond_types[new_row, new_col] ligand['bonds'] = new_bonds ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) ligand['n_bonds'] = len(ligand['bond_one_hot']) return ligand class AppendVirtualNodesInCoM: def __init__(self, atom_encoder, bond_encoder, add_min=0, add_max=10): self.atom_encoder = atom_encoder self.bond_encoder = bond_encoder self.vidx = atom_encoder['NOATOM'] self.bidx = bond_encoder['NOBOND'] self.add_min = add_min self.add_max = add_max def __call__(self, ligand): n_virt = random.randint(self.add_min, self.add_max) # all virtual coordinates in the CoM virt_coords = ligand['x'].mean(0, keepdim=True).repeat(n_virt, 1) # insert virtual atom column virt_one_hot = F.one_hot(torch.ones(n_virt, dtype=torch.int64) * self.vidx, num_classes=len(self.atom_encoder)) virt_mask = torch.cat([torch.zeros(ligand['size'], dtype=bool), torch.ones(n_virt, dtype=bool)]) ligand['x'] = torch.cat([ligand['x'], virt_coords]) ligand['one_hot'] = torch.cat(([ligand['one_hot'], virt_one_hot])) ligand['virtual_mask'] = virt_mask ligand['size'] = len(ligand['x']) # Bonds new_bonds = torch.triu_indices(ligand['size'], ligand['size'], offset=1) bond_types = torch.ones(ligand['size'], ligand['size'], dtype=INT_TYPE) * self.bidx row, col = ligand['bonds'] bond_types[row, col] = ligand['bond_one_hot'].argmax(dim=1) new_row, new_col = new_bonds bond_types = bond_types[new_row, new_col] ligand['bonds'] = new_bonds ligand['bond_one_hot'] = F.one_hot(bond_types, num_classes=len(self.bond_encoder)).to(ligand['bond_one_hot'].dtype) ligand['n_bonds'] = len(ligand['bond_one_hot']) return ligand def rdmol_to_smiles(rdmol): mol = Chem.Mol(rdmol) Chem.RemoveStereochemistry(mol) mol = Chem.RemoveHs(mol) return Chem.MolToSmiles(mol) def get_n_nodes(lig_positions, pocket_positions, smooth_sigma=None): # Joint distribution of ligand's and pocket's number of nodes n_nodes_lig = [len(x) for x in lig_positions] n_nodes_pocket = [len(x) for x in pocket_positions] joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, np.max(n_nodes_pocket) + 1)) for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): joint_histogram[nlig, npocket] += 1 print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') # Smooth the histogram if smooth_sigma is not None: filtered_histogram = gaussian_filter( joint_histogram, sigma=smooth_sigma, order=0, mode='constant', cval=0.0, truncate=4.0) print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') joint_histogram = filtered_histogram return joint_histogram # def get_type_histograms(lig_one_hot, pocket_one_hot, lig_encoder, pocket_encoder): # # lig_one_hot = np.concatenate(lig_one_hot, axis=0) # pocket_one_hot = np.concatenate(pocket_one_hot, axis=0) # # atom_decoder = list(lig_encoder.keys()) # lig_counts = {k: 0 for k in lig_encoder.keys()} # for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]: # lig_counts[a] += 1 # # aa_decoder = list(pocket_encoder.keys()) # pocket_counts = {k: 0 for k in pocket_encoder.keys()} # for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]: # pocket_counts[r] += 1 # # return lig_counts, pocket_counts def get_type_histogram(one_hot, type_encoder): one_hot = np.concatenate(one_hot, axis=0) decoder = list(type_encoder.keys()) counts = {k: 0 for k in type_encoder.keys()} for a in [decoder[x] for x in one_hot.argmax(1)]: counts[a] += 1 return counts def get_residue_with_resi(pdb_chain, resi): res = [x for x in pdb_chain.get_residues() if x.id[1] == resi] assert len(res) == 1 return res[0] def get_pocket_from_ligand(pdb_model, ligand, dist_cutoff=8.0): if ligand.endswith(".sdf"): # ligand as sdf file rdmol = Chem.SDMolSupplier(str(ligand))[0] ligand_coords = torch.from_numpy(rdmol.GetConformer().GetPositions()).float() resi = None else: # ligand contained in PDB; given in : format chain, resi = ligand.split(':') ligand = get_residue_with_resi(pdb_model[chain], int(resi)) ligand_coords = torch.from_numpy( np.array([a.get_coord() for a in ligand.get_atoms()])) pocket_residues = [] for residue in pdb_model.get_residues(): if residue.id[1] == resi: continue # skip ligand itself res_coords = torch.from_numpy( np.array([a.get_coord() for a in residue.get_atoms()])) if is_aa(residue.get_resname(), standard=True) \ and torch.cdist(res_coords, ligand_coords).min() < dist_cutoff: pocket_residues.append(residue) return pocket_residues def encode_residues(biopython_residues, type_encoder, level='atom', remove_H=True): assert level in {'atom', 'residue'} if level == 'atom': entities = [a for res in biopython_residues for a in res.get_atoms() if (a.element != 'H' or not remove_H)] types = [a.element.capitalize() for a in entities] else: entities = [res['CA'] for res in biopython_residues] types = [protein_letters_3to1[res.get_resname()] for res in biopython_residues] coord = torch.tensor(np.stack([e.get_coord() for e in entities])) one_hot = F.one_hot(torch.tensor([type_encoder[t] for t in types]), num_classes=len(type_encoder)) return coord, one_hot def center_data(ligand, pocket): if pocket['x'].numel() > 0: pocket_com = pocket.center() else: pocket_com = scatter_mean(ligand['x'], ligand['mask'], dim=0) ligand['x'] = ligand['x'] - pocket_com[ligand['mask']] return ligand, pocket def get_bb_transform(n_xyz, ca_xyz, c_xyz): """ Compute translation and rotation of the canoncical backbone frame (triangle N-Ca-C) from a position with Ca at the origin, N on the x-axis and C in the xy-plane to the global position of the backbone frame Args: n_xyz: (n, 3) ca_xyz: (n, 3) c_xyz: (n, 3) Returns: axis-angle representation of the rotation, shape (n, 3) # rotation matrix of shape (n, 3, 3) translation vector of shape (n, 3) """ def rotation_matrix(angle, axis): axis_mapping = {'x': 0, 'y': 1, 'z': 2} axis = axis_mapping[axis] vector = torch.zeros(len(angle), 3) vector[:, axis] = 1 # return axis_angle_to_matrix(angle * vector) return so3.matrix_from_rotation_vector(angle.view(-1, 1) * vector) translation = ca_xyz n_xyz = n_xyz - translation c_xyz = c_xyz - translation # Find rotation matrix that aligns the coordinate systems # rotate around y-axis to move N into the xy-plane theta_y = torch.arctan2(n_xyz[:, 2], -n_xyz[:, 0]) Ry = rotation_matrix(theta_y, 'y') Ry = Ry.transpose(2, 1) n_xyz = torch.einsum('noi,ni->no', Ry, n_xyz) # rotate around z-axis to move N onto the x-axis theta_z = torch.arctan2(n_xyz[:, 1], n_xyz[:, 0]) Rz = rotation_matrix(theta_z, 'z') Rz = Rz.transpose(2, 1) # print(torch.einsum('noi,ni->no', Rz, n_xyz)) # n_xyz = torch.einsum('noi,ni->no', Rz.transpose(0, 2, 1), n_xyz) # rotate around x-axis to move C into the xy-plane c_xyz = torch.einsum('noj,nji,ni->no', Rz, Ry, c_xyz) theta_x = torch.arctan2(c_xyz[:, 2], c_xyz[:, 1]) Rx = rotation_matrix(theta_x, 'x') Rx = Rx.transpose(2, 1) # print(torch.einsum('noi,ni->no', Rx, c_xyz)) # Final rotation matrix Ry = Ry.transpose(2, 1) Rz = Rz.transpose(2, 1) Rx = Rx.transpose(2, 1) R = torch.einsum('nok,nkj,nji->noi', Ry, Rz, Rx) # return R, translation # return matrix_to_axis_angle(R), translation return so3.rotation_vector_from_matrix(R), translation class Residues(TensorDict): """ Dictionary-like container for residues that supports some basic transformations. """ # all keys KEYS = {'x', 'one_hot', 'bonds', 'bond_one_hot', 'v', 'nma_vec', 'fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices', 'axis_angle', 'mask', 'bond_mask'} # coordinate-type values, shape (..., 3) COORD_KEYS = {'x', 'fixed_coord'} # vector-type values, shape (n_residues, n_feat, 3) VECTOR_KEYS = {'v', 'nma_vec'} # properties that change if the side chains and/or backbones are updated MUTABLE_PROPS_SS_AND_BB = {'v'} # properties that only change if the side chains are updated MUTABLE_PROPS_SS = {'chi'} # properties that only change if the backbones are updated MUTABLE_PROPS_BB = {'x', 'fixed_coord', 'axis_angle', 'nma_vec'} # properties that remain fixed in all cases IMMUTABLE_PROPS = {'mask', 'one_hot', 'bonds', 'bond_one_hot', 'bond_mask', 'atom_mask', 'nerf_indices', 'length', 'theta', 'ddihedral', 'chi_indices', 'name', 'size', 'n_bonds'} def copy(self): data = super().copy() return Residues(**data) def deepcopy(self): data = {k: v.clone() if torch.is_tensor(v) else deepcopy(v) for k, v in self.items()} return Residues(**data) def center(self): com = scatter_mean(self['x'], self['mask'], dim=0) self['x'] = self['x'] - com[self['mask']] self['fixed_coord'] = self['fixed_coord'] - com[self['mask']].unsqueeze(1) return com def set_empty_v(self): self['v'] = torch.tensor([], device=self['x'].device) @torch.no_grad() def set_chi(self, chi_angles): self['chi'][:, :5] = chi_angles nerf_params = {k: self[k] for k in ['fixed_coord', 'atom_mask', 'nerf_indices', 'length', 'theta', 'chi', 'ddihedral', 'chi_indices']} self['v'] = ic_to_coords(**nerf_params) - self['x'].unsqueeze(1) @torch.no_grad() def set_frame(self, new_ca_coord, new_axis_angle): bb_coord = self['fixed_coord'] bb_coord = bb_coord - self['x'].unsqueeze(1) rotmat_before = so3.matrix_from_rotation_vector(self['axis_angle']) rotmat_after = so3.matrix_from_rotation_vector(new_axis_angle) rotmat_diff = rotmat_after @ rotmat_before.transpose(-1, -2) bb_coord = torch.einsum('boi,bai->bao', rotmat_diff, bb_coord) bb_coord = bb_coord + new_ca_coord.unsqueeze(1) self['x'] = new_ca_coord self['axis_angle'] = new_axis_angle self['fixed_coord'] = bb_coord self['v'] = torch.einsum('boi,bai->bao', rotmat_diff, self['v']) @staticmethod def empty(device): return Residues( x=torch.zeros(1, 3, device=device).float(), mask=torch.zeros(1, 1, device=device).long(), size=torch.zeros(1, device=device).long(), ) def randomize_tensors(tensor_dict, exclude_keys=None): """Replace tensors with random tensors with the same shape.""" exclude_keys = set() if exclude_keys is None else set(exclude_keys) for k, v in tensor_dict.items(): if isinstance(v, torch.Tensor) and k not in exclude_keys: if torch.is_floating_point(v): tensor_dict[k] = torch.randn_like(v) else: tensor_dict[k] = torch.randint_like(v, low=-42, high=42) return tensor_dict