DrugFlow / src /utils.py
mority's picture
Upload 53 files
6e7d4ba verified
import warnings
from typing import Union, Iterable
import random
# import argparse
from argparse import Namespace
import numpy as np
import torch
from rdkit import Chem, RDLogger
from rdkit.Chem import KekulizeException, AtomKekulizeException
import networkx as nx
from networkx.algorithms import isomorphism
from torch_scatter import scatter_add, scatter_mean
class Queue():
def __init__(self, max_len=50):
self.items = []
self.max_len = max_len
def __len__(self):
return len(self.items)
def add(self, item):
self.items.insert(0, item)
if len(self) > self.max_len:
self.items.pop()
def mean(self):
return np.mean(self.items)
def std(self):
return np.std(self.items)
def reverse_tensor(x):
return x[torch.arange(x.size(0) - 1, -1, -1)]
#####
def sum_except_batch(x, indices):
if len(x.size()) < 2:
x = x.unsqueeze(-1)
return scatter_add(x.sum(list(range(1, len(x.size())))), indices, dim=0)
def remove_mean_batch(x, batch_mask, dim_size=None):
# Compute center of mass per sample
mean = scatter_mean(x, batch_mask, dim=0, dim_size=dim_size)
x = x - mean[batch_mask]
return x, mean
def assert_mean_zero(x, batch_mask, thresh=1e-2, eps=1e-10):
largest_value = x.abs().max().item()
error = scatter_add(x, batch_mask, dim=0).abs().max().item()
rel_error = error / (largest_value + eps)
assert rel_error < thresh, f'Mean is not zero, relative_error {rel_error}'
def bvm(v, m):
"""
Batched vector-matrix product of the form out = v @ m
:param v: (b, n_in)
:param m: (b, n_in, n_out)
:return: (b, n_out)
"""
# return (v.unsqueeze(1) @ m).squeeze()
return torch.bmm(v.unsqueeze(1), m).squeeze(1)
def get_grad_norm(
parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
norm_type: float = 2.0) -> torch.Tensor:
"""
Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
total_norm = torch.norm(torch.stack(
[torch.norm(p.grad.detach(), norm_type).to(device) for p in
parameters]), norm_type)
return total_norm
def write_xyz_file(coords, atom_types, filename):
out = f"{len(coords)}\n\n"
assert len(coords) == len(atom_types)
for i in range(len(coords)):
out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n"
with open(filename, 'w') as f:
f.write(out)
def write_sdf_file(sdf_path, molecules, catch_errors=True, connected=False):
with Chem.SDWriter(str(sdf_path)) as w:
for mol in molecules:
try:
if mol is None:
raise ValueError("Mol is None.")
w.write(get_largest_connected_component(mol) if connected else mol)
except (RuntimeError, ValueError) as e:
if not catch_errors:
raise e
if isinstance(e, (KekulizeException, AtomKekulizeException)):
w.SetKekulize(False)
w.write(get_largest_connected_component(mol) if connected else mol)
w.SetKekulize(True)
warnings.warn(f"Mol saved without kekulization.")
else:
# write empty mol to preserve the original order
w.write(Chem.Mol())
warnings.warn(f"Erroneous mol replaced with empty dummy.")
def get_largest_connected_component(mol):
try:
frags = Chem.GetMolFrags(mol, asMols=True)
newmol = max(frags, key=lambda m: m.GetNumAtoms())
except:
newmol = mol
return newmol
def write_chain(filename, rdmol_chain):
with open(filename, 'w') as f:
f.write("".join([Chem.MolToXYZBlock(m) for m in rdmol_chain]))
def combine_sdfs(sdf_list, out_file):
all_content = []
for sdf in sdf_list:
with open(sdf, 'r') as f:
all_content.append(f.read())
combined_str = '$$$$\n'.join(all_content)
with open(out_file, 'w') as f:
f.write(combined_str)
def batch_to_list(data, batch_mask, keep_order=True):
if keep_order: # preserve order of elements within each sample
data_list = [data[batch_mask == i]
for i in torch.unique(batch_mask, sorted=True)]
return data_list
# make sure batch_mask is increasing
idx = torch.argsort(batch_mask)
batch_mask = batch_mask[idx]
data = data[idx]
chunk_sizes = torch.unique(batch_mask, return_counts=True)[1].tolist()
return torch.split(data, chunk_sizes)
def batch_to_list_for_indices(indices, batch_mask, offsets=None):
# (2, n) -> (n, 2)
split = batch_to_list(indices.T, batch_mask)
# rebase indices at zero & (n, 2) -> (2, n)
if offsets is None:
warnings.warn("Trying to infer index offset from smallest element in "
"batch. This might be wrong.")
split = [x.T - x.min() for x in split]
else:
# Typically 'offsets' would be accumulate(sizes[:-1], initial=0)
assert len(offsets) == len(split) or indices.numel() == 0
split = [x.T - offset for x, offset in zip(split, offsets)]
return split
def num_nodes_to_batch_mask(n_samples, num_nodes, device):
assert isinstance(num_nodes, int) or len(num_nodes) == n_samples
if isinstance(num_nodes, torch.Tensor):
num_nodes = num_nodes.to(device)
sample_inds = torch.arange(n_samples, device=device)
return torch.repeat_interleave(sample_inds, num_nodes)
def rdmol_to_nxgraph(rdmol):
graph = nx.Graph()
for atom in rdmol.GetAtoms():
# Add the atoms as nodes
graph.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum())
# Add the bonds as edges
for bond in rdmol.GetBonds():
graph.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
return graph
def calc_rmsd(mol_a, mol_b):
""" Calculate RMSD of two molecules with unknown atom correspondence. """
graph_a = rdmol_to_nxgraph(mol_a)
graph_b = rdmol_to_nxgraph(mol_b)
gm = isomorphism.GraphMatcher(
graph_a, graph_b,
node_match=lambda na, nb: na['atom_type'] == nb['atom_type'])
isomorphisms = list(gm.isomorphisms_iter())
if len(isomorphisms) < 1:
return None
all_rmsds = []
for mapping in isomorphisms:
atom_types_a = [atom.GetAtomicNum() for atom in mol_a.GetAtoms()]
atom_types_b = [mol_b.GetAtomWithIdx(mapping[i]).GetAtomicNum()
for i in range(mol_b.GetNumAtoms())]
assert atom_types_a == atom_types_b
conf_a = mol_a.GetConformer()
coords_a = np.array([conf_a.GetAtomPosition(i)
for i in range(mol_a.GetNumAtoms())])
conf_b = mol_b.GetConformer()
coords_b = np.array([conf_b.GetAtomPosition(mapping[i])
for i in range(mol_b.GetNumAtoms())])
diff = coords_a - coords_b
rmsd = np.sqrt(np.mean(np.sum(diff * diff, axis=1)))
all_rmsds.append(rmsd)
if len(isomorphisms) > 1:
print("More than one isomorphism found. Returning minimum RMSD.")
return min(all_rmsds)
def set_deterministic(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def disable_rdkit_logging():
# RDLogger.DisableLog('rdApp.*')
RDLogger.DisableLog('rdApp.info')
RDLogger.DisableLog('rdApp.error')
RDLogger.DisableLog('rdApp.warning')
# class Namespace(argparse.Namespace):
# """Simple definition of a Namespace class that supports Namespace- and
# dictionary-like access."""
# def __getitem__(self, key):
# # return vars(self)[key]
# return self.__dict__[key]
#
# def __setitem__(self, key, value):
# self.__dict__[key] = value
#
# def __getattr__(self, item):
# """Supports other dictionary functionalities, e.g. get(), items(), etc."""
# # return getattr(vars(self), item)
# return getattr(self.__dict__, item)
def dict_to_namespace(input_dict):
""" Recursively convert a nested dictionary into a Namespace object. """
if isinstance(input_dict, dict):
output_namespace = Namespace()
output = output_namespace.__dict__
for key, value in input_dict.items():
output[key] = dict_to_namespace(value)
return output_namespace
elif isinstance(input_dict, Namespace):
# recurse as Namespace might contain dictionaries
return dict_to_namespace(input_dict.__dict__)
else:
return input_dict
def namespace_to_dict(x):
""" Recursively convert a nested Namespace object into a dictionary. """
if not (isinstance(x, Namespace) or isinstance(x, dict)):
return x
if isinstance(x, Namespace):
x = vars(x)
# recurse
output = {}
for key, value in x.items():
output[key] = namespace_to_dict(value)
return output