|
|
import argparse |
|
|
import os |
|
|
import pickle |
|
|
import sys |
|
|
from collections import Counter, defaultdict |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from rdkit import Chem |
|
|
from scipy.stats import wasserstein_distance |
|
|
from scipy.spatial.distance import jensenshannon |
|
|
from tqdm import tqdm |
|
|
|
|
|
basedir = Path(__file__).resolve().parent.parent.parent |
|
|
sys.path.append(str(basedir)) |
|
|
|
|
|
from src.data.data_utils import atom_encoder, bond_encoder, encode_atom |
|
|
from src.sbdd_metrics.evaluation import VALIDITY_METRIC_NAME, aggregated_metrics, collection_metrics, get_data_type |
|
|
from src.sbdd_metrics.metrics import FullEvaluator |
|
|
|
|
|
|
|
|
DATA_TYPES = data_types = FullEvaluator().dtypes |
|
|
|
|
|
MEDCHEM_PROPS = [ |
|
|
'medchem.qed', |
|
|
'medchem.sa', |
|
|
'medchem.logp', |
|
|
'medchem.lipinski', |
|
|
'medchem.size', |
|
|
'medchem.n_rotatable_bonds', |
|
|
'energy.energy', |
|
|
] |
|
|
|
|
|
DOCKING_PROPS = [ |
|
|
'gnina.vina_score', |
|
|
'gnina.gnina_score', |
|
|
'gnina.vina_efficiency', |
|
|
'gnina.gnina_efficiency', |
|
|
] |
|
|
|
|
|
RELEVANT_INTERACTIONS = [ |
|
|
'interactions.HBAcceptor', |
|
|
'interactions.HBDonor', |
|
|
'interactions.HB', |
|
|
'interactions.PiStacking', |
|
|
'interactions.Hydrophobic', |
|
|
|
|
|
'interactions.HBAcceptor.normalized', |
|
|
'interactions.HBDonor.normalized', |
|
|
'interactions.HB.normalized', |
|
|
'interactions.PiStacking.normalized', |
|
|
'interactions.Hydrophobic.normalized' |
|
|
] |
|
|
|
|
|
|
|
|
def compute_discrete_distributions(smiles, name): |
|
|
atom_counter = Counter() |
|
|
bond_counter = Counter() |
|
|
|
|
|
for smi in tqdm(smiles, desc=name): |
|
|
mol = Chem.MolFromSmiles(smi) |
|
|
mol = Chem.RemoveAllHs(mol, sanitize=False) |
|
|
for atom in mol.GetAtoms(): |
|
|
try: |
|
|
encoded_atom = encode_atom(atom, atom_encoder=atom_encoder) |
|
|
except KeyError: |
|
|
continue |
|
|
atom_counter[encoded_atom] += 1 |
|
|
for bond in mol.GetBonds(): |
|
|
bond_counter[bond_encoder[str(bond.GetBondType())]] += 1 |
|
|
|
|
|
atom_distribution = np.zeros(len(atom_encoder)) |
|
|
bond_distribution = np.zeros(len(bond_encoder)) |
|
|
|
|
|
for k, v in atom_counter.items(): |
|
|
atom_distribution[k] = v |
|
|
for k, v in bond_counter.items(): |
|
|
bond_distribution[k] = v |
|
|
|
|
|
atom_distribution = atom_distribution / atom_distribution.sum() |
|
|
bond_distribution = bond_distribution / bond_distribution.sum() |
|
|
|
|
|
return atom_distribution, bond_distribution |
|
|
|
|
|
|
|
|
def flatten_distribution(data, name, table): |
|
|
aux = ['sample', 'sdf_file', 'pdb_file'] |
|
|
method_distributions = defaultdict(list) |
|
|
|
|
|
sdf2sample2size = defaultdict(dict) |
|
|
for _, row in table.iterrows(): |
|
|
sdf2sample2size[row['sdf_file']][int(row['sample'])] = row['medchem.size'] |
|
|
|
|
|
for item in tqdm(data, desc=name): |
|
|
if item['medchem.valid'] is not True: |
|
|
continue |
|
|
|
|
|
if 'interactions.HBAcceptor' in item and 'interactions.HBDonor' in item: |
|
|
item['interactions.HB'] = item['interactions.HBAcceptor'] + item['interactions.HBDonor'] |
|
|
|
|
|
new_entries = {} |
|
|
for key, value in item.items(): |
|
|
if key.startswith('interactions'): |
|
|
size = sdf2sample2size.get(item['sdf_file'], dict()).get(int(item['sample'])) |
|
|
if size is not None: |
|
|
new_entries[key + '.normalized'] = value / size |
|
|
item.update(new_entries) |
|
|
|
|
|
for key, value in item.items(): |
|
|
if value is None: |
|
|
continue |
|
|
if key in aux: |
|
|
continue |
|
|
if key == 'energy.energy' and abs(value) > 1000: |
|
|
continue |
|
|
|
|
|
if get_data_type(key, DATA_TYPES, default=type(value)) == list: |
|
|
method_distributions[key] += value |
|
|
else: |
|
|
method_distributions[key].append(value) |
|
|
|
|
|
return method_distributions |
|
|
|
|
|
|
|
|
def prepare_baseline_data(root_path, baseline_name): |
|
|
metrics_detailed = pd.read_csv(f'{root_path}/metrics_detailed.csv') |
|
|
metrics_detailed = metrics_detailed[metrics_detailed['medchem.valid']] |
|
|
distributions = pickle.load(open(f'{root_path}/metrics_data.pkl', 'rb')) |
|
|
distributions = flatten_distribution(distributions, name=baseline_name, table=metrics_detailed) |
|
|
distributions['energy.energy'] = [v for v in distributions['energy.energy'] if -1000 <= v <= 1000] |
|
|
for prop in MEDCHEM_PROPS + DOCKING_PROPS: |
|
|
distributions[prop] = metrics_detailed[prop].dropna().values.tolist() |
|
|
|
|
|
smiles = metrics_detailed['representation.smiles'] |
|
|
atom_distribution, bond_distribution = compute_discrete_distributions(smiles, name=baseline_name) |
|
|
discrete_distributions = { |
|
|
'atom_types': atom_distribution, |
|
|
'bond_types': bond_distribution, |
|
|
} |
|
|
|
|
|
return distributions, discrete_distributions |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
p = argparse.ArgumentParser() |
|
|
p.add_argument('--in_dir', type=Path, required=True, help='Directory with samples') |
|
|
p.add_argument('--out_dir', type=str, required=True, help='Output directory') |
|
|
p.add_argument('--n_samples', type=int, required=False, default=None, help='N samples per target') |
|
|
p.add_argument('--reference_smiles', type=str, default=None, help='Path to the .npy file with reference SMILES (optional)') |
|
|
p.add_argument('--crossdocked_dir', type=str, required=False, default=None, help='Crossdocked data dir for computing distances between distributions') |
|
|
args = p.parse_args() |
|
|
|
|
|
Path(args.out_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print('Combining data') |
|
|
data = [] |
|
|
for file_path in tqdm(Path(args.in_dir).glob('metrics_data_*.pkl')): |
|
|
with open(file_path, 'rb') as f: |
|
|
d = pickle.load(f) |
|
|
if args.n_samples is not None: |
|
|
d = d[:args.n_samples] |
|
|
data += d |
|
|
with open(Path(args.out_dir, 'metrics_data.pkl'), 'wb') as f: |
|
|
pickle.dump(data, f) |
|
|
|
|
|
print('Combining detailed metrics') |
|
|
tables = [] |
|
|
for file_path in tqdm(Path(args.in_dir).glob('metrics_detailed_*.csv')): |
|
|
table = pd.read_csv(file_path) |
|
|
if args.n_samples is not None: |
|
|
table = table.head(args.n_samples) |
|
|
tables.append(table) |
|
|
|
|
|
table_detailed = pd.concat(tables) |
|
|
table_detailed.to_csv(Path(args.out_dir, 'metrics_detailed.csv'), index=False) |
|
|
|
|
|
print('Computing aggregated metrics') |
|
|
evaluator = FullEvaluator(gnina='gnina', reduce='reduce') |
|
|
table_aggregated = aggregated_metrics( |
|
|
table_detailed, |
|
|
data_types=evaluator.dtypes, |
|
|
validity_metric_name=VALIDITY_METRIC_NAME |
|
|
) |
|
|
|
|
|
if args.reference_smiles is not None: |
|
|
reference_smiles = np.load(args.reference_smiles) |
|
|
col_metrics = collection_metrics( |
|
|
table=table_detailed, |
|
|
reference_smiles=reference_smiles, |
|
|
validity_metric_name=VALIDITY_METRIC_NAME, |
|
|
exclude_evaluators=[], |
|
|
) |
|
|
table_aggregated = pd.concat([table_aggregated, col_metrics]) |
|
|
|
|
|
table_aggregated.to_csv(Path(args.out_dir, 'metrics_aggregated.csv'), index=False) |
|
|
|
|
|
|
|
|
if args.crossdocked_dir is not None: |
|
|
|
|
|
|
|
|
crossdocked_distributions = None |
|
|
crossdocked_discrete_distributions = None |
|
|
precomputed_distr_path = f'{args.crossdocked_dir}/crossdocked_distributions.pkl' |
|
|
precomputed_discrete_distr_path = f'{args.crossdocked_dir}/crossdocked_discrete_distributions.pkl' |
|
|
if os.path.exists(precomputed_distr_path) and os.path.exists(precomputed_discrete_distr_path): |
|
|
|
|
|
with open(precomputed_distr_path, 'rb') as f: |
|
|
crossdocked_distributions = pickle.load(f) |
|
|
with open(precomputed_discrete_distr_path, 'rb') as f: |
|
|
crossdocked_discrete_distributions = pickle.load(f) |
|
|
else: |
|
|
assert os.path.exists(f'{args.crossdocked_dir}/metrics_detailed.csv') |
|
|
assert os.path.exists(f'{args.crossdocked_dir}/metrics_data.pkl') |
|
|
crossdocked_distributions, crossdocked_discrete_distributions = prepare_baseline_data( |
|
|
root_path=args.crossdocked_dir, |
|
|
baseline_name='crossdocked' |
|
|
) |
|
|
|
|
|
with open(precomputed_distr_path, 'wb') as f: |
|
|
pickle.dump(crossdocked_distributions, f) |
|
|
with open(precomputed_discrete_distr_path, 'wb') as f: |
|
|
pickle.dump(crossdocked_discrete_distributions, f) |
|
|
|
|
|
|
|
|
bonds = sorted([ |
|
|
(k, len(v)) for k, v in crossdocked_distributions.items() |
|
|
if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 2 |
|
|
], key=lambda t: t[1], reverse=True)[:5] |
|
|
top_5_bonds = [t[0] for t in bonds] |
|
|
|
|
|
angles = sorted([ |
|
|
(k, len(v)) for k, v in crossdocked_distributions.items() |
|
|
if k.startswith('geometry.') and sum(s.isalpha() for s in k.split('.')[1]) == 3 |
|
|
], key=lambda t: t[1], reverse=True)[:5] |
|
|
top_5_angles = [t[0] for t in angles] |
|
|
|
|
|
|
|
|
distributions, discrete_distributions = prepare_baseline_data(args.out_dir, 'samples') |
|
|
|
|
|
|
|
|
distances = {'method': 'method',} |
|
|
relevant_columns = MEDCHEM_PROPS + DOCKING_PROPS + RELEVANT_INTERACTIONS + top_5_bonds + top_5_angles |
|
|
for metric in distributions.keys(): |
|
|
if metric not in relevant_columns: |
|
|
continue |
|
|
|
|
|
ref = crossdocked_distributions.get(metric) |
|
|
|
|
|
cur = [x for x in distributions.get(metric) if not pd.isna(x)] |
|
|
|
|
|
if ref is not None and cur is not None and len(cur) > 0: |
|
|
try: |
|
|
distance = wasserstein_distance(ref, cur) |
|
|
except: |
|
|
from pdb import set_trace; set_trace() |
|
|
num_ref = len(ref) |
|
|
num_cur = len(cur) |
|
|
distances[f'WD.{metric}'] = distance |
|
|
|
|
|
for metric in crossdocked_discrete_distributions.keys(): |
|
|
ref = crossdocked_discrete_distributions.get(metric) |
|
|
cur = discrete_distributions.get(metric) |
|
|
if ref is not None and cur is not None: |
|
|
distance = jensenshannon(p=ref, q=cur) |
|
|
num_ref = len(ref) |
|
|
num_cur = len(cur) |
|
|
distances[f'JS.{metric}'] = distance |
|
|
|
|
|
dist_table = pd.DataFrame([distances]) |
|
|
dist_table.to_csv(Path(args.out_dir, 'metrics_distances.csv'), index=False) |