|
|
from collections import defaultdict |
|
|
from scipy.stats import spearmanr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
from constants import ASSAY_LIST, ASSAY_HIGHER_IS_BETTER |
|
|
|
|
|
|
|
|
FOLD_COL = "hierarchical_cluster_IgG_isotype_stratified_fold" |
|
|
|
|
|
|
|
|
def recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, frac: float = 0.1) -> float: |
|
|
"""Calculate recall (TP)/(TP+FN) for top fraction of true values. |
|
|
|
|
|
A recall of 1 would mean that the top fraction of true values are also the top fraction of predicted values. |
|
|
There is no penalty for ranking the top k differently. |
|
|
|
|
|
Args: |
|
|
y_true (np.ndarray): true values with shape (num_data,) |
|
|
y_pred (np.ndarray): predicted values with shape (num_data,) |
|
|
frac (float, optional): fraction of data points to consider as the top. Defaults to 0.1. |
|
|
|
|
|
Returns: |
|
|
float: recall at top k of data |
|
|
""" |
|
|
top_k = int(len(y_true) * frac) |
|
|
y_true, y_pred = np.array(y_true).flatten(), np.array(y_pred).flatten() |
|
|
true_top_k = np.argsort(y_true)[-1 * top_k :] |
|
|
predicted_top_k = np.argsort(y_pred)[-1 * top_k :] |
|
|
|
|
|
return ( |
|
|
len( |
|
|
set(list(true_top_k.flatten())).intersection( |
|
|
set(list(predicted_top_k.flatten())) |
|
|
) |
|
|
) |
|
|
/ top_k |
|
|
) |
|
|
|
|
|
|
|
|
def get_metrics( |
|
|
predictions_series: pd.Series, target_series: pd.Series, assay_col: str |
|
|
) -> dict[str, float]: |
|
|
results_dict = { |
|
|
"spearman": spearmanr( |
|
|
predictions_series, target_series, nan_policy="omit" |
|
|
).correlation |
|
|
} |
|
|
|
|
|
y_true = target_series.values |
|
|
y_pred = predictions_series.values |
|
|
if not ASSAY_HIGHER_IS_BETTER[assay_col]: |
|
|
y_true = -1 * y_true |
|
|
y_pred = -1 * y_pred |
|
|
results_dict["top_10_recall"] = recall_at_k(y_true=y_true, y_pred=y_pred, frac=0.1) |
|
|
return results_dict |
|
|
|
|
|
|
|
|
def get_metrics_cross_validation( |
|
|
predictions_series: pd.Series, |
|
|
target_series: pd.Series, |
|
|
folds_series: pd.Series, |
|
|
assay_col: str, |
|
|
) -> dict[str, float]: |
|
|
|
|
|
results_dict = defaultdict(list) |
|
|
if folds_series.nunique() != 5: |
|
|
raise ValueError(f"Expected 5 folds, got {folds_series.nunique()}") |
|
|
for fold in folds_series.unique(): |
|
|
predictions_series_fold = predictions_series[folds_series == fold] |
|
|
target_series_fold = target_series[folds_series == fold] |
|
|
results = get_metrics(predictions_series_fold, target_series_fold, assay_col) |
|
|
|
|
|
for key, value in results.items(): |
|
|
results_dict[key].append(value) |
|
|
|
|
|
for key, values in results_dict.items(): |
|
|
results_dict[key] = np.mean(values) |
|
|
return results_dict |
|
|
|
|
|
|
|
|
def _get_result_for_assay(df_merged, assay_col, dataset_name): |
|
|
""" |
|
|
Return a dictionary with the results for a single assay. |
|
|
""" |
|
|
if dataset_name == "GDPa1_cross_validation": |
|
|
results = get_metrics_cross_validation( |
|
|
df_merged[assay_col + "_pred"], |
|
|
df_merged[assay_col + "_true"], |
|
|
df_merged[FOLD_COL], |
|
|
assay_col, |
|
|
) |
|
|
elif dataset_name == "GDPa1": |
|
|
results = get_metrics( |
|
|
df_merged[assay_col + "_pred"], df_merged[assay_col + "_true"], assay_col |
|
|
) |
|
|
elif dataset_name == "Heldout Test Set": |
|
|
|
|
|
results = {"spearman": np.nan, "top_10_recall": np.nan} |
|
|
results["assay"] = assay_col |
|
|
return results |
|
|
|
|
|
|
|
|
def _get_error_result(assay_col, dataset_name, error): |
|
|
""" |
|
|
Return a dictionary with the error message instead of metrics. |
|
|
Used when _get_result_for_assay fails. |
|
|
""" |
|
|
print(f"Error evaluating {assay_col}: {error}") |
|
|
|
|
|
error_result = { |
|
|
"dataset": dataset_name, |
|
|
"assay": assay_col, |
|
|
} |
|
|
|
|
|
error_result.update({"spearman": error, "top_10_recall": error}) |
|
|
return error_result |
|
|
|
|
|
|
|
|
def evaluate(predictions_df, target_df, dataset_name="GDPa1"): |
|
|
""" |
|
|
Evaluates a single model, where the predictions dataframe has columns named by property. |
|
|
eg. my_model.csv has columns antibody_name, HIC, Tm2 |
|
|
Lood: Copied from Github repo, which I should move over here |
|
|
""" |
|
|
properties_in_preds = [ |
|
|
col for col in predictions_df.columns if col in ASSAY_LIST |
|
|
] |
|
|
df_merged = pd.merge( |
|
|
target_df[["antibody_name", FOLD_COL] + ASSAY_LIST], |
|
|
predictions_df[["antibody_name"] + properties_in_preds], |
|
|
on="antibody_name", |
|
|
how="left", |
|
|
suffixes=("_true", "_pred"), |
|
|
) |
|
|
results_list = [] |
|
|
|
|
|
for assay_col in properties_in_preds: |
|
|
try: |
|
|
results = _get_result_for_assay( |
|
|
df_merged, assay_col, dataset_name |
|
|
) |
|
|
results_list.append(results) |
|
|
|
|
|
except Exception as e: |
|
|
error_result = _get_error_result( |
|
|
assay_col, dataset_name, e |
|
|
) |
|
|
results_list.append(error_result) |
|
|
|
|
|
results_df = pd.DataFrame(results_list) |
|
|
return results_df |
|
|
|