fairchem_leaderboard / evaluator.py
mshuaibi's picture
space for time
5433f8c
raw
history blame
5.91 kB
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
import json
from fairchem.data.omol.modules.evaluator import (
ligand_pocket,
ligand_strain,
geom_conformers,
protonation_energies,
unoptimized_ie_ea,
distance_scaling,
unoptimized_spin_gap,
)
OMOL_EVAL_FUNCTIONS = {
"Ligand pocket": ligand_pocket,
"Ligand strain": ligand_strain,
"Conformers": geom_conformers,
"Protonation": protonation_energies,
"IE_EA": unoptimized_ie_ea,
"Distance scaling": distance_scaling,
"Spin gap": unoptimized_spin_gap,
}
OMOL_DATA_ID_MAPPING = {
"metal_complexes": ["metal_complexes"],
"electrolytes": ["elytes"],
"biomolecules": ["biomolecules"],
"neutral_organics": ["ani2x", "orbnet_denali", "geom_orca6", "trans1x", "rgd"],
}
def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray:
"""
Get the ordering so that `to_reorder[ordering]` == ref.
eg:
ref = [c, a, b]
to_reorder = [b, a, c]
order = reorder(ref, to_reorder) # [2, 1, 0]
assert ref == to_reorder[order]
Parameters
----------
ref : np.ndarray
Reference array. Must not contains duplicates.
to_reorder : np.ndarray
Array to re-order. Must not contains duplicates.
Items must be the same as in `ref`.
Returns
-------
np.ndarray
the ordering to apply on `to_reorder`
"""
assert len(ref) == len(set(ref))
assert len(to_reorder) == len(set(to_reorder))
assert set(ref) == set(to_reorder)
item_to_idx = {item: idx for idx, item in enumerate(to_reorder)}
return np.array([item_to_idx[item] for item in ref])
def get_order(path_submission: Path, path_annotations: Path):
with np.load(path_submission) as data:
submission_ids = data["ids"]
with np.load(path_annotations, allow_pickle=True) as data:
annotations_ids = data["ids"]
# Use sets for faster comparison
submission_set = set(submission_ids)
annotations_set = set(annotations_ids)
if submission_set != annotations_set:
missing_ids = annotations_set - submission_set
unexpected_ids = submission_set - annotations_set
details = (
f"{len(missing_ids)} missing IDs: ({list(missing_ids)[:3]}, ...)\n"
f"{len(unexpected_ids)} unexpected IDs: ({list(unexpected_ids)[:3]}, ...)"
)
raise Exception(f"IDs don't match.\n{details}")
return reorder(annotations_ids, submission_ids)
def s2ef_metrics(
annotations_path: Path,
submission_filename: Path,
subsets: list = ["all"],
) -> Dict[str, float]:
order = get_order(submission_filename, annotations_path)
with np.load(submission_filename) as data:
forces = data["forces"]
energy = data["energy"][order]
forces = np.array(np.split(forces, np.cumsum(data["natoms"])[:-1]), dtype=object)[order]
if len(set(np.where(np.isinf(energy))[0])) != 0:
inf_energy_ids = list(set(np.where(np.isinf(energy))[0]))
raise Exception(
f"Inf values found in `energy` for IDs: ({inf_energy_ids[:3]}, ...)"
)
with np.load(annotations_path, allow_pickle=True) as data:
target_forces = data["forces"]
target_energy = data["energy"]
target_data_ids = data["data_ids"]
metrics = {}
for subset in subsets:
if subset == "all":
subset_mask = np.ones(len(target_data_ids), dtype=bool)
else:
allowed_ids = set(OMOL_DATA_ID_MAPPING.get(subset, []))
subset_mask = np.array(
[data_id in allowed_ids for data_id in target_data_ids]
)
sub_energy = energy[subset_mask]
sub_target_energy = target_energy[subset_mask]
energy_mae = np.mean(np.abs(sub_target_energy - sub_energy))
metrics[f"{subset}_energy_mae"] = energy_mae
forces_mae = 0
natoms = 0
for sub_forces, sub_target_forces in zip(forces[subset_mask], target_forces[subset_mask]):
forces_mae += np.sum(np.abs(sub_target_forces - sub_forces))
natoms += sub_forces.shape[0]
forces_mae /= (3*natoms)
metrics[f"{subset}_forces_mae"] = forces_mae
return metrics
def omol_evaluations(
annotations_path: Path,
submission_filename: Path,
eval_type: str,
) -> Dict[str, float]:
with open(submission_filename) as f:
submission_data = json.load(f)
with open(annotations_path) as f:
annotations_data = json.load(f)
submission_entries = set(submission_data.keys())
annotation_entries = set(annotations_data.keys())
if submission_entries != annotation_entries:
missing = annotation_entries - submission_entries
unexpected = submission_entries - annotation_entries
raise ValueError(
f"Submission and annotations entries do not match.\n"
f"Missing entries in submission: {missing}\n"
f"Unexpected entries in submission: {unexpected}"
)
eval_fn = OMOL_EVAL_FUNCTIONS.get(eval_type)
metrics = eval_fn(annotations_data, submission_data)
return metrics
def evaluate(
annotations_path: Path,
submission_filename: Path,
eval_type: str,
):
if eval_type in ["Validation", "Test"]:
metrics = s2ef_metrics(
annotations_path,
submission_filename,
subsets=[
"all",
"metal_complexes",
"electrolytes",
"biomolecules",
"neutral_organics",
],
)
elif eval_type in OMOL_EVAL_FUNCTIONS:
metrics = omol_evaluations(
annotations_path,
submission_filename,
eval_type,
)
else:
raise ValueError(f"Unknown eval_type: {eval_type}")
return metrics