diff --git a/polyreact/__init__.py b/polyreact/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..969432e99d32979ba38ff17da7e3638b88b042c7 --- /dev/null +++ b/polyreact/__init__.py @@ -0,0 +1,10 @@ +"""Polyreactivity prediction package.""" + +from importlib import metadata + +__all__ = ["__version__"] + +try: + __version__ = metadata.version("polyreact") +except metadata.PackageNotFoundError: # pragma: no cover + __version__ = "0.0.0" diff --git a/polyreact/__pycache__/__init__.cpython-311.pyc b/polyreact/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4152a71b04dfea95de14053790fd69092132ca20 Binary files /dev/null and b/polyreact/__pycache__/__init__.cpython-311.pyc differ diff --git a/polyreact/__pycache__/api.cpython-311.pyc b/polyreact/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5586b5aa74fe469a3314120719f7f2d07b75c14 Binary files /dev/null and b/polyreact/__pycache__/api.cpython-311.pyc differ diff --git a/polyreact/__pycache__/config.cpython-311.pyc b/polyreact/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6f98089771acaed393e89bdd9f0f73e0416d38 Binary files /dev/null and b/polyreact/__pycache__/config.cpython-311.pyc differ diff --git a/polyreact/__pycache__/predict.cpython-311.pyc b/polyreact/__pycache__/predict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac3148cbe39969d5d105da08d8bea0959b2f77b Binary files /dev/null and b/polyreact/__pycache__/predict.cpython-311.pyc differ diff --git a/polyreact/__pycache__/train.cpython-311.pyc b/polyreact/__pycache__/train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15f9464f06c60e01d852d0da6074b374553c6cbd Binary files /dev/null and b/polyreact/__pycache__/train.cpython-311.pyc differ diff --git a/polyreact/api.py b/polyreact/api.py new file mode 100644 index 0000000000000000000000000000000000000000..768d172b4eceb6aa24b9dcf784cfc51cb3aaf3d8 --- /dev/null +++ b/polyreact/api.py @@ -0,0 +1,121 @@ +"""Public Python API for polyreactivity prediction.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +import copy + +import joblib +import pandas as pd +from sklearn.preprocessing import StandardScaler + +from .config import Config, load_config +from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline + + +def predict_batch( # noqa: ANN003 + records: Iterable[dict], + *, + config: Config | str | Path | None = None, + backend: str | None = None, + plm_model: str | None = None, + weights: str | Path | None = None, + heavy_only: bool = True, + batch_size: int = 8, + device: str | None = None, + cache_dir: str | None = None, +) -> pd.DataFrame: + """Predict polyreactivity scores for a batch of sequences.""" + + records_list = list(records) + if not records_list: + return pd.DataFrame(columns=["id", "score", "pred"]) + + artifact = _load_artifact(weights) + + if config is None: + artifact_config = artifact.get("config") + if isinstance(artifact_config, Config): + config = copy.deepcopy(artifact_config) + else: + config = load_config("configs/default.yaml") + elif isinstance(config, (str, Path)): + config = load_config(config) + else: + config = copy.deepcopy(config) + + if backend: + config.feature_backend.type = backend + if plm_model: + config.feature_backend.plm_model_name = plm_model + if device: + config.device = device + if cache_dir: + config.feature_backend.cache_dir = cache_dir + + pipeline = _restore_pipeline(config, artifact) + trained_model = artifact["model"] + + frame = pd.DataFrame(records_list) + if frame.empty: + raise ValueError("Prediction requires at least one record.") + if "id" not in frame.columns: + frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str) + if "heavy_seq" in frame.columns: + frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str) + else: + heavy_series = frame.get("heavy") + if heavy_series is None: + heavy_series = pd.Series([""] * len(frame)) + frame["heavy_seq"] = heavy_series.fillna("").astype(str) + + if "light_seq" in frame.columns: + frame["light_seq"] = frame["light_seq"].fillna("").astype(str) + else: + light_series = frame.get("light") + if light_series is None: + light_series = pd.Series([""] * len(frame)) + frame["light_seq"] = light_series.fillna("").astype(str) + + if heavy_only: + frame["light_seq"] = "" + if frame["heavy_seq"].str.len().eq(0).all(): + raise ValueError("No heavy chain sequences provided for prediction.") + + features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size) + scores = trained_model.predict_proba(features) + preds = (scores >= 0.5).astype(int) + + return pd.DataFrame( + { + "id": frame["id"].astype(str), + "score": scores, + "pred": preds, + } + ) + + +def _load_artifact(weights: str | Path | None) -> dict: + if weights is None: + msg = "Prediction requires a path to model weights" + raise ValueError(msg) + artifact = joblib.load(weights) + if not isinstance(artifact, dict): + msg = "Model artifact must be a dictionary" + raise ValueError(msg) + return artifact + + +def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline: + pipeline = build_feature_pipeline(config) + state = artifact.get("feature_state") + if isinstance(state, FeaturePipelineState): + pipeline.load_state(state) + if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None: + pipeline._plm_scaler = StandardScaler() + return pipeline + + msg = "Model artifact is missing feature pipeline state" + raise ValueError(msg) diff --git a/polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc b/polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8d887c5539a280ffb7907c5222f2ef6661c145f Binary files /dev/null and b/polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc differ diff --git a/polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc b/polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb4adc6d1414cc97b87a24967d6d11c869c4600 Binary files /dev/null and b/polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc differ diff --git a/polyreact/benchmarks/reproduce_paper.ipynb b/polyreact/benchmarks/reproduce_paper.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c8bc17d4524fe6ec9992fef52092185a7dfbc86a --- /dev/null +++ b/polyreact/benchmarks/reproduce_paper.ipynb @@ -0,0 +1,25 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Polyreactivity Benchmark Notebook\n", + "\n", + "This notebook will reproduce paper results once the pipeline is implemented." ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/polyreact/benchmarks/reproduce_paper.py b/polyreact/benchmarks/reproduce_paper.py new file mode 100644 index 0000000000000000000000000000000000000000..43ce01651a1d5e9106fc27a227293a52ae33c406 --- /dev/null +++ b/polyreact/benchmarks/reproduce_paper.py @@ -0,0 +1,1020 @@ +"""Reproduce key metrics and visualisations for the polyreactivity model.""" + +from __future__ import annotations + +import argparse +import copy +import json +import subprocess +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +import joblib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import yaml +from scipy.stats import pearsonr, spearmanr +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import roc_curve +from sklearn.model_selection import KFold +from sklearn.preprocessing import StandardScaler + +from polyreact import train as train_module +from polyreact.config import load_config +from polyreact.features.anarsi import AnarciNumberer +from polyreact.features.pipeline import FeaturePipeline +from polyreact.models.ordinal import ( + fit_negative_binomial_model, + fit_poisson_model, + pearson_dispersion, + regression_metrics, +) + + +@dataclass(slots=True) +class DatasetSpec: + name: str + path: Path + display: str + + +DISPLAY_LABELS = { + "jain": "Jain (2017)", + "shehata": "Shehata PSR (398)", + "shehata_curated": "Shehata curated (88)", + "harvey": "Harvey (2022)", +} + +RAW_LABELS = { + "jain": "jain2017", + "shehata": "shehata2019", + "shehata_curated": "shehata2019_curated", + "harvey": "harvey2022", +} + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Reproduce paper-style metrics and plots") + parser.add_argument( + "--train-data", + default="data/processed/boughter_counts_rebuilt.csv", + help="Reconstructed Boughter dataset path.", + ) + parser.add_argument( + "--full-data", + default="data/processed/boughter_counts_rebuilt.csv", + help="Dataset (including mild flags) for correlation analysis.", + ) + parser.add_argument("--jain", default="data/processed/jain.csv") + parser.add_argument( + "--shehata", + default="data/processed/shehata_full.csv", + help="Full Shehata PSR panel (398 sequences) in processed CSV form.", + ) + parser.add_argument( + "--shehata-curated", + default="data/processed/shehata_curated.csv", + help="Optional curated subset of Shehata et al. (88 sequences).", + ) + parser.add_argument("--harvey", default="data/processed/harvey.csv") + parser.add_argument("--output-dir", default="artifacts/paper") + parser.add_argument("--config", default="configs/default.yaml") + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--rebuild", action="store_true") + parser.add_argument( + "--bootstrap-samples", + type=int, + default=1000, + help="Bootstrap resamples for metrics confidence intervals.", + ) + parser.add_argument( + "--bootstrap-alpha", + type=float, + default=0.05, + help="Alpha for bootstrap confidence intervals (default 0.05 → 95%).", + ) + parser.add_argument( + "--human-only", + action="store_true", + help=( + "Restrict the main cross-validation run to human HIV and influenza families" + " (legacy behaviour). By default all Boughter families, including mouse IgA," + " participate in CV as in Sakhnini et al." + ), + ) + parser.add_argument( + "--skip-flag-regression", + action="store_true", + help="Skip ELISA flag regression diagnostics (Poisson/NB).", + ) + parser.add_argument( + "--skip-lofo", + action="store_true", + help="Skip leave-one-family-out experiments.", + ) + parser.add_argument( + "--skip-descriptor-variants", + action="store_true", + help="Skip descriptor-only benchmark variants.", + ) + parser.add_argument( + "--skip-fragment-variants", + action="store_true", + help="Skip CDR fragment ablation benchmarks.", + ) + return parser + + +def _config_to_dict(config) -> Dict[str, Any]: + data = asdict(config) + data.pop("raw", None) + return data + + +def _deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: + result = copy.deepcopy(base) + for key, value in overrides.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = _deep_merge(result.get(key, {}), value) + else: + result[key] = value + return result + + +def _write_variant_config( + base_config: Dict[str, Any], + overrides: Dict[str, Any], + target_path: Path, +) -> Path: + merged = _deep_merge(base_config, overrides) + target_path.parent.mkdir(parents=True, exist_ok=True) + with target_path.open("w", encoding="utf-8") as handle: + yaml.safe_dump(merged, handle, sort_keys=False) + return target_path + + +def _collect_metric_records(variant: str, metrics: pd.DataFrame) -> list[dict[str, Any]]: + tracked = { + "roc_auc", + "pr_auc", + "accuracy", + "f1", + "f1_positive", + "f1_negative", + "precision", + "sensitivity", + "specificity", + "brier", + "ece", + "mce", + } + records: list[dict[str, Any]] = [] + for _, row in metrics.iterrows(): + metric_name = row["metric"] + if metric_name not in tracked: + continue + record = {"variant": variant, "metric": metric_name} + for column in metrics.columns: + if column == "metric": + continue + record[column] = float(row[column]) if pd.notna(row[column]) else np.nan + records.append(record) + return records + + +def _dump_coefficients(model_path: Path, output_path: Path) -> None: + artifact = joblib.load(model_path) + trained = artifact["model"] + estimator = getattr(trained, "estimator", None) + if estimator is None or not hasattr(estimator, "coef_"): + return + coefs = estimator.coef_[0] + feature_state = artifact.get("feature_state") + feature_names: list[str] + if feature_state is not None and getattr(feature_state, "feature_names", None): + feature_names = list(feature_state.feature_names) + else: + feature_names = [f"f{i}" for i in range(len(coefs))] + coeff_df = pd.DataFrame( + { + "feature": feature_names, + "coef": coefs, + "abs_coef": np.abs(coefs), + } + ).sort_values("abs_coef", ascending=False) + coeff_df.to_csv(output_path, index=False) + + +def _summarise_predictions(preds: pd.DataFrame) -> pd.DataFrame: + records: list[dict[str, Any]] = [] + for split, group in preds.groupby("split"): + stats = { + "split": split, + "n_samples": int(len(group)), + "positives": int(group["y_true"].sum()), + "positive_rate": float(group["y_true"].mean()) if len(group) else np.nan, + "score_mean": float(group["y_score"].mean()) if len(group) else np.nan, + "score_std": float(group["y_score"].std(ddof=1)) if len(group) > 1 else np.nan, + } + records.append(stats) + return pd.DataFrame(records) + + +def _summarise_raw_dataset(path: Path, name: str) -> dict[str, Any]: + df = pd.read_csv(path) + summary: dict[str, Any] = { + "dataset": name, + "path": str(path), + "rows": int(len(df)), + } + if "label" in df.columns: + positives = int(df["label"].sum()) + summary["positives"] = positives + summary["positive_rate"] = float(df["label"].mean()) if len(df) else np.nan + if "reactivity_count" in df.columns: + summary["reactivity_count_mean"] = float(df["reactivity_count"].mean()) + summary["reactivity_count_median"] = float(df["reactivity_count"].median()) + summary["reactivity_count_max"] = int(df["reactivity_count"].max()) + if "smp" in df.columns: + summary["smp_mean"] = float(df["smp"].mean()) + summary["smp_median"] = float(df["smp"].median()) + summary["smp_max"] = float(df["smp"].max()) + summary["smp_min"] = float(df["smp"].min()) + summary["unique_heavy"] = int(df["heavy_seq"].nunique()) if "heavy_seq" in df.columns else np.nan + return summary + + +def _extract_region_sequence(sequence: str, regions: List[str], numberer: AnarciNumberer) -> str: + if not sequence: + return "" + upper_regions = [region.upper() for region in regions] + if upper_regions == ["VH"]: + return sequence + try: + numbered = numberer.number_sequence(sequence) + except Exception: + return "" + fragments: list[str] = [] + for region in upper_regions: + if region == "VH": + return sequence + fragment = numbered.regions.get(region) + if not fragment: + return "" + fragments.append(fragment) + return "".join(fragments) + + +def _make_region_dataset( + frame: pd.DataFrame, regions: List[str], numberer: AnarciNumberer +) -> tuple[pd.DataFrame, dict[str, Any]]: + records: list[dict[str, Any]] = [] + dropped = 0 + for record in frame.to_dict(orient="records"): + new_seq = _extract_region_sequence(record.get("heavy_seq", ""), regions, numberer) + if not new_seq: + dropped += 1 + continue + updated = record.copy() + updated["heavy_seq"] = new_seq + updated["light_seq"] = "" + records.append(updated) + result = pd.DataFrame(records, columns=frame.columns) + summary = { + "regions": "+".join(regions), + "input_rows": int(len(frame)), + "retained_rows": int(len(result)), + "dropped_rows": int(dropped), + } + return result, summary + + +def run_train( + *, + train_path: Path, + eval_specs: Sequence[DatasetSpec], + output_dir: Path, + model_path: Path, + config: str, + batch_size: int, + include_species: list[str] | None = None, + include_families: list[str] | None = None, + exclude_families: list[str] | None = None, + keep_duplicates: bool = False, + group_column: str | None = "lineage", + train_loader: str | None = None, + bootstrap_samples: int = 200, + bootstrap_alpha: float = 0.05, +) -> None: + args: list[str] = [ + "--config", + str(config), + "--train", + str(train_path), + "--report-to", + str(output_dir), + "--save-to", + str(model_path), + "--batch-size", + str(batch_size), + ] + + if eval_specs: + args.append("--eval") + args.extend(str(spec.path) for spec in eval_specs) + + if train_loader: + args.extend(["--train-loader", train_loader]) + if eval_specs: + args.append("--eval-loaders") + args.extend(spec.name for spec in eval_specs) + if include_species: + args.append("--include-species") + args.extend(include_species) + if include_families: + args.append("--include-families") + args.extend(include_families) + if exclude_families: + args.append("--exclude-families") + args.extend(exclude_families) + if keep_duplicates: + args.append("--keep-train-duplicates") + if group_column: + args.extend(["--cv-group-column", group_column]) + else: + args.append("--no-group-cv") + args.extend(["--bootstrap-samples", str(bootstrap_samples)]) + args.extend(["--bootstrap-alpha", str(bootstrap_alpha)]) + + exit_code = train_module.main(args) + if exit_code != 0: + raise RuntimeError(f"Training command failed with exit code {exit_code}") + + +def compute_spearman(model_path: Path, dataset_path: Path, batch_size: int) -> tuple[float, float, pd.DataFrame]: + artifact = joblib.load(model_path) + config = artifact["config"] + pipeline_state = artifact["feature_state"] + trained_model = artifact["model"] + + pipeline = FeaturePipeline(backend=config.feature_backend, descriptors=config.descriptors, device=config.device) + pipeline.load_state(pipeline_state) + + dataset = pd.read_csv(dataset_path) + features = pipeline.transform(dataset, heavy_only=True, batch_size=batch_size) + scores = trained_model.predict_proba(features) + dataset = dataset.copy() + dataset["score"] = scores + + stat, pvalue = spearmanr(dataset["reactivity_count"], dataset["score"]) + return float(stat), float(pvalue), dataset + + +def plot_accuracy( + metrics: pd.DataFrame, + output_path: Path, + eval_specs: Sequence[DatasetSpec], +) -> None: + row = metrics.loc[metrics["metric"] == "accuracy"].iloc[0] + labels = ["Train CV"] + [spec.display for spec in eval_specs] + values = [row.get("train_cv_mean", np.nan)] + [row.get(spec.name, np.nan) for spec in eval_specs] + + fig, ax = plt.subplots(figsize=(6, 4)) + xs = np.arange(len(labels)) + ax.bar(xs, values, color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]) + ax.set_xticks(xs, labels) + ax.set_ylim(0.0, 1.05) + ax.set_ylabel("Accuracy") + ax.set_title("Polyreactivity accuracy overview") + for x, val in zip(xs, values, strict=False): + if np.isnan(val): + continue + ax.text(x, val + 0.02, f"{val:.3f}", ha="center", va="bottom") + fig.tight_layout() + fig.savefig(output_path, dpi=300) + plt.close(fig) + + +def plot_rocs( + preds: pd.DataFrame, + output_path: Path, + eval_specs: Sequence[DatasetSpec], +) -> None: + mapping = {"train_cv_oof": "Train CV"} + for spec in eval_specs: + mapping[spec.name] = spec.display + fig, ax = plt.subplots(figsize=(6, 6)) + for split, label in mapping.items(): + subset = preds[preds["split"] == split] + if subset.empty: + continue + fpr, tpr, _ = roc_curve(subset["y_true"], subset["y_score"]) + ax.plot(fpr, tpr, label=label) + ax.plot([0, 1], [0, 1], linestyle="--", color="gray") + ax.set_xlabel("False positive rate") + ax.set_ylabel("True positive rate") + ax.set_title("ROC curves") + ax.legend() + fig.tight_layout() + fig.savefig(output_path, dpi=300) + plt.close(fig) + + +def plot_flags_scatter(data: pd.DataFrame, spearman_stat: float, output_path: Path) -> None: + rng = np.random.default_rng(42) + jitter = rng.uniform(-0.1, 0.1, size=len(data)) + x = data["reactivity_count"].to_numpy(dtype=float) + jitter + y = data["score"].to_numpy(dtype=float) + + fig, ax = plt.subplots(figsize=(6, 4)) + ax.scatter(x, y, alpha=0.5, s=10) + ax.set_xlabel("ELISA flag count") + ax.set_ylabel("Predicted probability") + ax.set_title(f"Prediction vs flag count (Spearman={spearman_stat:.2f})") + fig.tight_layout() + fig.savefig(output_path, dpi=300) + plt.close(fig) + + +def run_lofo( + full_df: pd.DataFrame, + *, + families: list[str], + config: str, + batch_size: int, + output_dir: Path, + bootstrap_samples: int, + bootstrap_alpha: float, +) -> pd.DataFrame: + results: list[dict[str, float]] = [] + for family in families: + family_lower = family.lower() + holdout = full_df[full_df["family"].str.lower() == family_lower].copy() + train = full_df[full_df["family"].str.lower() != family_lower].copy() + if holdout.empty or train.empty: + continue + + train_path = output_dir / f"train_lofo_{family_lower}.csv" + holdout_path = output_dir / f"eval_lofo_{family_lower}.csv" + train.to_csv(train_path, index=False) + holdout.to_csv(holdout_path, index=False) + + run_dir = output_dir / f"lofo_{family_lower}" + run_dir.mkdir(parents=True, exist_ok=True) + model_path = run_dir / "model.joblib" + + run_train( + train_path=train_path, + eval_specs=[ + DatasetSpec( + name="boughter", + path=holdout_path, + display=f"{family.title()} holdout", + ) + ], + output_dir=run_dir, + model_path=model_path, + config=config, + batch_size=batch_size, + keep_duplicates=True, + include_species=None, + include_families=None, + exclude_families=None, + group_column="lineage", + train_loader="boughter", + bootstrap_samples=bootstrap_samples, + bootstrap_alpha=bootstrap_alpha, + ) + + metrics = pd.read_csv(run_dir / "metrics.csv") + evaluation_cols = [ + col + for col in metrics.columns + if col not in {"metric", "train_cv_mean", "train_cv_std"} + ] + if not evaluation_cols: + continue + eval_col = evaluation_cols[0] + def _metric_value(name: str) -> float: + series = metrics.loc[metrics["metric"] == name, eval_col] + return float(series.values[0]) if not series.empty else float("nan") + + results.append( + { + "family": family, + "accuracy": _metric_value("accuracy"), + "roc_auc": _metric_value("roc_auc"), + "pr_auc": _metric_value("pr_auc"), + "sensitivity": _metric_value("sensitivity"), + "specificity": _metric_value("specificity"), + } + ) + + return pd.DataFrame(results) + + + +def run_flag_regression( + train_path: Path, + *, + output_dir: Path, + config_path: str, + batch_size: int, + n_splits: int = 5, +) -> None: + df = pd.read_csv(train_path) + if "reactivity_count" not in df.columns: + return + + config = load_config(config_path) + kfold = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed) + + metrics_rows: list[dict[str, float]] = [] + preds_rows: list[dict[str, float]] = [] + + for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(df), start=1): + train_split = df.iloc[train_idx].reset_index(drop=True) + val_split = df.iloc[val_idx].reset_index(drop=True) + + pipeline = FeaturePipeline( + backend=config.feature_backend, + descriptors=config.descriptors, + device=config.device, + ) + X_train = pipeline.fit_transform(train_split, heavy_only=True, batch_size=batch_size) + scaler = StandardScaler() + X_train_scaled = scaler.fit_transform(X_train) + y_train = train_split["reactivity_count"].to_numpy(dtype=float) + # Train a logistic head to obtain probabilities as a 1-D feature + clf = LogisticRegression( + C=config.model.C, + class_weight=config.model.class_weight, + max_iter=2000, + solver="lbfgs", + ) + clf.fit(X_train_scaled, train_split["label"].to_numpy(dtype=int)) + prob_train = clf.predict_proba(X_train_scaled)[:, 1] + + X_val = pipeline.transform(val_split, heavy_only=True, batch_size=batch_size) + X_val_scaled = scaler.transform(X_val) + y_val = val_split["reactivity_count"].to_numpy(dtype=float) + prob_val = clf.predict_proba(X_val_scaled)[:, 1] + + poisson_X_train = prob_train.reshape(-1, 1) + poisson_X_val = prob_val.reshape(-1, 1) + model = fit_poisson_model(poisson_X_train, y_train) + poisson_preds = model.predict(poisson_X_val) + + n_params = poisson_X_train.shape[1] + 1 # include intercept + dof = max(len(y_val) - n_params, 1) + variance_to_mean = float(np.var(y_val, ddof=1) / np.mean(y_val)) if np.mean(y_val) else float("nan") + + spearman_val = float(spearmanr(y_val, poisson_preds).statistic) + try: + pearson_val = float(pearsonr(y_val, poisson_preds)[0]) + except Exception: # pragma: no cover - fallback if correlation fails + pearson_val = float("nan") + + poisson_metrics = regression_metrics(y_val, poisson_preds) + poisson_metrics.update( + { + "spearman": spearman_val, + "pearson": pearson_val, + "pearson_dispersion": pearson_dispersion(y_val, poisson_preds, dof=dof), + "variance_to_mean": variance_to_mean, + "fold": fold_idx, + "model": "poisson", + "status": "ok", + } + ) + metrics_rows.append(poisson_metrics) + + nb_preds: np.ndarray | None = None + nb_model = None + try: + nb_model = fit_negative_binomial_model(poisson_X_train, y_train) + nb_preds = nb_model.predict(poisson_X_val) + if not np.all(np.isfinite(nb_preds)): + raise ValueError("negative binomial produced non-finite predictions") + except Exception: + nb_metrics = { + "spearman": float("nan"), + "pearson": float("nan"), + "pearson_dispersion": float("nan"), + "variance_to_mean": variance_to_mean, + "alpha": float("nan"), + "fold": fold_idx, + "model": "negative_binomial", + "status": "failed", + } + metrics_rows.append(nb_metrics) + else: + spearman_nb = float(spearmanr(y_val, nb_preds).statistic) + try: + pearson_nb = float(pearsonr(y_val, nb_preds)[0]) + except Exception: # pragma: no cover + pearson_nb = float("nan") + + nb_metrics = regression_metrics(y_val, nb_preds) + nb_metrics.update( + { + "spearman": spearman_nb, + "pearson": pearson_nb, + "pearson_dispersion": pearson_dispersion(y_val, nb_preds, dof=dof), + "variance_to_mean": variance_to_mean, + "alpha": nb_model.alpha, + "fold": fold_idx, + "model": "negative_binomial", + "status": "ok", + } + ) + metrics_rows.append(nb_metrics) + + records = list(val_split.itertuples(index=False)) + for idx, row in enumerate(records): + row_id = getattr(row, "id", idx) + y_true_val = float(getattr(row, "reactivity_count")) + preds_rows.append( + { + "fold": fold_idx, + "model": "poisson", + "id": row_id, + "y_true": y_true_val, + "y_pred": float(poisson_preds[idx]), + } + ) + if nb_preds is not None: + preds_rows.append( + { + "fold": fold_idx, + "model": "negative_binomial", + "id": row_id, + "y_true": y_true_val, + "y_pred": float(nb_preds[idx]), + } + ) + + metrics_df = pd.DataFrame(metrics_rows) + metrics_df.to_csv(output_dir / "flag_regression_folds.csv", index=False) + + summary_records: list[dict[str, float]] = [] + for model_name, group in metrics_df.groupby("model"): + for column in group.columns: + if column in {"fold", "model", "status"}: + continue + values = group[column].dropna() + if values.empty: + continue + summary_records.append( + { + "model": model_name, + "metric": column, + "mean": float(values.mean()), + "std": float(values.std(ddof=1)) if len(values) > 1 else float("nan"), + } + ) + if summary_records: + pd.DataFrame(summary_records).to_csv( + output_dir / "flag_regression_metrics.csv", index=False + ) + + if preds_rows: + pd.DataFrame(preds_rows).to_csv(output_dir / "flag_regression_preds.csv", index=False) + +def run_descriptor_variants( + base_config: Dict[str, Any], + *, + train_path: Path, + eval_specs: Sequence[DatasetSpec], + output_dir: Path, + batch_size: int, + include_species: List[str] | None, + include_families: List[str] | None, + bootstrap_samples: int, + bootstrap_alpha: float, +) -> None: + variants = [ + ( + "descriptors_full_vh", + { + "feature_backend": {"type": "descriptors"}, + "descriptors": { + "use_anarci": True, + "regions": ["CDRH1", "CDRH2", "CDRH3"], + "features": [ + "length", + "charge", + "hydropathy", + "aromaticity", + "pI", + "net_charge", + ], + }, + }, + ), + ( + "descriptors_cdrh3_pi", + { + "feature_backend": {"type": "descriptors"}, + "descriptors": { + "use_anarci": True, + "regions": ["CDRH3"], + "features": ["pI"], + }, + }, + ), + ( + "descriptors_cdrh3_top5", + { + "feature_backend": {"type": "descriptors"}, + "descriptors": { + "use_anarci": True, + "regions": ["CDRH3"], + "features": [ + "pI", + "net_charge", + "charge", + "hydropathy", + "length", + ], + }, + }, + ), + ] + + configs_dir = output_dir / "configs" + configs_dir.mkdir(parents=True, exist_ok=True) + summary_records: list[dict[str, Any]] = [] + + for name, overrides in variants: + variant_config_path = _write_variant_config( + base_config, + overrides, + configs_dir / f"{name}.yaml", + ) + variant_output = output_dir / name + variant_output.mkdir(parents=True, exist_ok=True) + model_path = variant_output / "model.joblib" + + run_train( + train_path=train_path, + eval_specs=eval_specs, + output_dir=variant_output, + model_path=model_path, + config=str(variant_config_path), + batch_size=batch_size, + include_species=include_species, + include_families=include_families, + keep_duplicates=True, + group_column="lineage", + train_loader="boughter", + bootstrap_samples=bootstrap_samples, + bootstrap_alpha=bootstrap_alpha, + ) + + metrics_path = variant_output / "metrics.csv" + if metrics_path.exists(): + metrics_df = pd.read_csv(metrics_path) + summary_records.extend(_collect_metric_records(name, metrics_df)) + + _dump_coefficients(model_path, variant_output / "coefficients.csv") + + if summary_records: + pd.DataFrame(summary_records).to_csv(output_dir / "summary.csv", index=False) + + +def run_fragment_variants( + config_path: str, + *, + train_path: Path, + eval_specs: Sequence[DatasetSpec], + output_dir: Path, + batch_size: int, + include_species: List[str] | None, + include_families: List[str] | None, + bootstrap_samples: int, + bootstrap_alpha: float, +) -> None: + numberer = AnarciNumberer() + specs = [ + ("vh_full", ["VH"]), + ("cdrh1", ["CDRH1"]), + ("cdrh2", ["CDRH2"]), + ("cdrh3", ["CDRH3"]), + ("cdrh123", ["CDRH1", "CDRH2", "CDRH3"]), + ] + + summary_rows: list[dict[str, Any]] = [] + metric_summary_rows: list[dict[str, Any]] = [] + + for name, regions in specs: + variant_dir = output_dir / name + variant_dir.mkdir(parents=True, exist_ok=True) + dataset_dir = variant_dir / "datasets" + dataset_dir.mkdir(parents=True, exist_ok=True) + + train_df = pd.read_csv(train_path) + train_variant, train_summary = _make_region_dataset(train_df, regions, numberer) + train_variant_path = dataset_dir / "train.csv" + train_variant.to_csv(train_variant_path, index=False) + + eval_variant_specs: list[DatasetSpec] = [] + for spec in eval_specs: + eval_df = pd.read_csv(spec.path) + transformed, eval_summary = _make_region_dataset(eval_df, regions, numberer) + eval_path = dataset_dir / f"{spec.name}.csv" + transformed.to_csv(eval_path, index=False) + eval_variant_specs.append( + DatasetSpec(name=spec.name, path=eval_path, display=spec.display) + ) + eval_summary.update({"variant": name, "dataset": spec.name}) + summary_rows.append(eval_summary) + + train_summary.update({"variant": name, "dataset": "train"}) + summary_rows.append(train_summary) + + run_train( + train_path=train_variant_path, + eval_specs=eval_variant_specs, + output_dir=variant_dir, + model_path=variant_dir / "model.joblib", + config=config_path, + batch_size=batch_size, + include_species=include_species, + include_families=include_families, + keep_duplicates=True, + group_column="lineage", + train_loader="boughter", + bootstrap_samples=bootstrap_samples, + bootstrap_alpha=bootstrap_alpha, + ) + + metrics_path = variant_dir / "metrics.csv" + if metrics_path.exists(): + metrics_df = pd.read_csv(metrics_path) + metric_records = _collect_metric_records(name, metrics_df) + for record in metric_records: + record["variant_type"] = "fragment" + metric_summary_rows.extend(metric_records) + + if summary_rows: + pd.DataFrame(summary_rows).to_csv(output_dir / "fragment_dataset_summary.csv", index=False) + if metric_summary_rows: + pd.DataFrame(metric_summary_rows).to_csv(output_dir / "fragment_metrics_summary.csv", index=False) + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if args.rebuild: + rebuild_cmd = [ + "python", + "scripts/rebuild_boughter_from_counts.py", + "--output", + str(args.train_data), + ] + if subprocess.run(rebuild_cmd, check=False).returncode != 0: + raise RuntimeError("Dataset rebuild failed") + + train_path = Path(args.train_data) + + def _make_spec(name: str, path_str: str) -> DatasetSpec | None: + path = Path(path_str) + if not path.exists(): + return None + display = DISPLAY_LABELS.get(name, name.replace("_", " ").title()) + return DatasetSpec(name=name, path=path, display=display) + + eval_specs: list[DatasetSpec] = [] + seen_paths: set[Path] = set() + for name, path_str in [ + ("jain", args.jain), + ("shehata", args.shehata), + ("shehata_curated", args.shehata_curated), + ("harvey", args.harvey), + ]: + spec = _make_spec(name, path_str) + if spec is not None: + resolved = spec.path.resolve() + if resolved in seen_paths: + continue + seen_paths.add(resolved) + eval_specs.append(spec) + + base_config = load_config(args.config) + base_config_dict = _config_to_dict(base_config) + + main_output = output_dir / "main" + main_output.mkdir(parents=True, exist_ok=True) + model_path = main_output / "model.joblib" + + main_include_species = ["human"] if args.human_only else None + main_include_families = ["hiv", "influenza"] if args.human_only else None + + run_train( + train_path=train_path, + eval_specs=eval_specs, + output_dir=main_output, + model_path=model_path, + config=args.config, + batch_size=args.batch_size, + include_species=main_include_species, + include_families=main_include_families, + keep_duplicates=True, + group_column="lineage", + train_loader="boughter", + bootstrap_samples=args.bootstrap_samples, + bootstrap_alpha=args.bootstrap_alpha, + ) + + metrics = pd.read_csv(main_output / "metrics.csv") + preds = pd.read_csv(main_output / "preds.csv") + + plot_accuracy(metrics, main_output / "accuracy_overview.png", eval_specs) + plot_rocs(preds, main_output / "roc_overview.png", eval_specs) + + if not args.skip_flag_regression: + run_flag_regression( + train_path=train_path, + output_dir=main_output, + config_path=args.config, + batch_size=args.batch_size, + ) + + split_summary = _summarise_predictions(preds) + split_summary.to_csv(main_output / "dataset_split_summary.csv", index=False) + + spearman_stat, spearman_p, corr_df = compute_spearman( + model_path=model_path, + dataset_path=Path(args.full_data), + batch_size=args.batch_size, + ) + plot_flags_scatter(corr_df, spearman_stat, main_output / "prob_vs_flags.png") + (main_output / "spearman_flags.json").write_text( + json.dumps({"spearman": spearman_stat, "p_value": spearman_p}, indent=2) + ) + corr_df.to_csv(main_output / "prob_vs_flags.csv", index=False) + + if not args.skip_lofo: + full_df = pd.read_csv(args.train_data) + lofo_dir = output_dir / "lofo_runs" + lofo_dir.mkdir(parents=True, exist_ok=True) + lofo_df = run_lofo( + full_df, + families=["influenza", "hiv", "mouse_iga"], + config=args.config, + batch_size=args.batch_size, + output_dir=lofo_dir, + bootstrap_samples=args.bootstrap_samples, + bootstrap_alpha=args.bootstrap_alpha, + ) + lofo_df.to_csv(output_dir / "lofo_metrics.csv", index=False) + + if not args.skip_descriptor_variants: + descriptor_dir = output_dir / "descriptor_variants" + descriptor_dir.mkdir(parents=True, exist_ok=True) + run_descriptor_variants( + base_config_dict, + train_path=train_path, + eval_specs=eval_specs, + output_dir=descriptor_dir, + batch_size=args.batch_size, + include_species=main_include_species, + include_families=main_include_families, + bootstrap_samples=args.bootstrap_samples, + bootstrap_alpha=args.bootstrap_alpha, + ) + + if not args.skip_fragment_variants: + fragment_dir = output_dir / "fragment_variants" + fragment_dir.mkdir(parents=True, exist_ok=True) + run_fragment_variants( + args.config, + train_path=train_path, + eval_specs=eval_specs, + output_dir=fragment_dir, + batch_size=args.batch_size, + include_species=main_include_species, + include_families=main_include_families, + bootstrap_samples=args.bootstrap_samples, + bootstrap_alpha=args.bootstrap_alpha, + ) + + raw_summaries = [] + raw_summaries.append(_summarise_raw_dataset(train_path, "boughter_rebuilt")) + for spec in eval_specs: + summary_name = RAW_LABELS.get(spec.name, spec.name) + raw_summaries.append(_summarise_raw_dataset(spec.path, summary_name)) + pd.DataFrame(raw_summaries).to_csv(output_dir / "raw_dataset_summary.csv", index=False) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/polyreact/benchmarks/run_benchmarks.py b/polyreact/benchmarks/run_benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0614cbc50327900852e44c4d7944781b13fcad --- /dev/null +++ b/polyreact/benchmarks/run_benchmarks.py @@ -0,0 +1,114 @@ +"""Run end-to-end benchmarks for the polyreactivity model.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import List + +from .. import train as train_cli + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv" +DEFAULT_EVAL = [ + PROJECT_ROOT / "tests" / "fixtures" / "jain.csv", + PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv", + PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv", +] + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks") + parser.add_argument( + "--config", + default="configs/default.yaml", + help="Path to configuration YAML file.", + ) + parser.add_argument( + "--train", + default=str(DEFAULT_TRAIN), + help="Training dataset CSV (defaults to bundled fixture).", + ) + parser.add_argument( + "--eval", + nargs="+", + default=[str(path) for path in DEFAULT_EVAL], + help="Evaluation dataset CSV paths (>=1).", + ) + parser.add_argument( + "--report-dir", + default="artifacts", + help="Directory to write metrics, predictions, and plots.", + ) + parser.add_argument( + "--model-path", + default="artifacts/model.joblib", + help="Destination for the trained model artifact.", + ) + parser.add_argument( + "--backend", + choices=["descriptors", "plm", "concat"], + help="Override feature backend during training.", + ) + parser.add_argument("--plm-model", help="Optional PLM model override.") + parser.add_argument("--cache-dir", help="Embedding cache directory override.") + parser.add_argument( + "--device", + choices=["auto", "cpu", "cuda"], + help="Device override for embeddings.", + ) + parser.add_argument( + "--paired", + action="store_true", + help="Use paired heavy/light chains when available.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=8, + help="Batch size for PLM embedding batches.", + ) + return parser + + +def main(argv: List[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + if len(args.eval) < 1: + parser.error("Provide at least one evaluation dataset via --eval.") + + report_dir = Path(args.report_dir) + report_dir.mkdir(parents=True, exist_ok=True) + + train_args: list[str] = [ + "--config", + args.config, + "--train", + args.train, + "--save-to", + str(Path(args.model_path)), + "--report-to", + str(report_dir), + "--batch-size", + str(args.batch_size), + ] + + train_args.extend(["--eval", *args.eval]) + + if args.backend: + train_args.extend(["--backend", args.backend]) + if args.plm_model: + train_args.extend(["--plm-model", args.plm_model]) + if args.cache_dir: + train_args.extend(["--cache-dir", args.cache_dir]) + if args.device: + train_args.extend(["--device", args.device]) + if args.paired: + train_args.append("--paired") + + return train_cli.main(train_args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/polyreact/config.py b/polyreact/config.py new file mode 100644 index 0000000000000000000000000000000000000000..16922bed92635ef9ff9939f1f6267b62f0982905 --- /dev/null +++ b/polyreact/config.py @@ -0,0 +1,160 @@ +"""Configuration helpers for the polyreactivity project.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +import importlib.resources as pkg_resources +from importlib.resources.abc import Traversable +from pathlib import Path +from typing import Any, Sequence + +import yaml + + +@dataclass(slots=True) +class FeatureBackendSettings: + type: str = "plm" + plm_model_name: str = "facebook/esm2_t12_35M_UR50D" + layer_pool: str = "mean" + cache_dir: str = ".cache/embeddings" + standardize: bool = True + + +@dataclass(slots=True) +class DescriptorSettings: + use_anarci: bool = True + regions: Sequence[str] = field(default_factory=lambda: ["CDRH1", "CDRH2", "CDRH3"]) + features: Sequence[str] = field( + default_factory=lambda: [ + "length", + "charge", + "hydropathy", + "aromaticity", + "pI", + "net_charge", + ] + ) + ph: float = 7.4 + + +@dataclass(slots=True) +class ModelSettings: + head: str = "logreg" + C: float = 1.0 + class_weight: Any = "balanced" + + +@dataclass(slots=True) +class CalibrationSettings: + method: str | None = "isotonic" + + +@dataclass(slots=True) +class TrainingSettings: + cv_folds: int = 10 + scoring: str = "roc_auc" + n_jobs: int = -1 + + +@dataclass(slots=True) +class IOSettings: + outputs_dir: str = "artifacts" + preds_filename: str = "preds.csv" + metrics_filename: str = "metrics.csv" + + +@dataclass(slots=True) +class Config: + seed: int = 42 + device: str = "auto" + feature_backend: FeatureBackendSettings = field(default_factory=FeatureBackendSettings) + descriptors: DescriptorSettings = field(default_factory=DescriptorSettings) + model: ModelSettings = field(default_factory=ModelSettings) + calibration: CalibrationSettings = field(default_factory=CalibrationSettings) + training: TrainingSettings = field(default_factory=TrainingSettings) + io: IOSettings = field(default_factory=IOSettings) + + raw: dict[str, Any] = field(default_factory=dict) + + +def _merge_section(default: Any, data: dict[str, Any] | None) -> Any: + if data is None: + return default + merged = asdict(default) | data + return type(default)(**merged) + + +def load_config(path: str | Path | None = None) -> Config: + """Load a YAML configuration file into a strongly-typed ``Config`` object.""" + + data = _read_config_data(path) + + feature_backend = _merge_section(FeatureBackendSettings(), data.get("feature_backend")) + descriptors = _merge_section(DescriptorSettings(), data.get("descriptors")) + model = _merge_section(ModelSettings(), data.get("model")) + calibration = _merge_section(CalibrationSettings(), data.get("calibration")) + training = _merge_section(TrainingSettings(), data.get("training")) + io_settings = _merge_section(IOSettings(), data.get("io")) + + config = Config( + seed=int(data.get("seed", 42)), + device=str(data.get("device", "auto")), + feature_backend=feature_backend, + descriptors=descriptors, + model=model, + calibration=calibration, + training=training, + io=io_settings, + raw=data, + ) + return config + + +def _read_config_data(path: str | Path | None) -> dict[str, Any]: + """Return mapping data from YAML or the bundled default.""" + + if path is None: + resource = pkg_resources.files("polyreact.configs") / "default.yaml" + return _load_yaml_resource(resource) + + resolved = _resolve_config_path(Path(path)) + if resolved is not None: + return _load_yaml_file(resolved) + + resource_root = pkg_resources.files("polyreact") + resource = resource_root / Path(path).as_posix() + if resource.is_file(): + return _load_yaml_resource(resource) + + msg = f"Configuration file not found: {path}" + raise FileNotFoundError(msg) + + +def _resolve_config_path(path: Path) -> Path | None: + if path.exists(): + return path + + if not path.is_absolute(): + candidate = Path(__file__).resolve().parent / path + if candidate.exists(): + return candidate + + return None + + +def _load_yaml_file(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + return _parse_yaml(handle.read()) + + +def _load_yaml_resource(resource: Traversable) -> dict[str, Any]: + with resource.open("r", encoding="utf-8") as handle: + return _parse_yaml(handle.read()) + + +def _parse_yaml(text: str) -> dict[str, Any]: + parsed = yaml.safe_load(text) or {} + if not isinstance(parsed, dict): # pragma: no cover - safeguard + msg = "Configuration must be a mapping at the top level" + raise ValueError(msg) + return parsed diff --git a/polyreact/configs/__init__.py b/polyreact/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74b6fab289cf713af651f06dcc0926c42442a9d5 --- /dev/null +++ b/polyreact/configs/__init__.py @@ -0,0 +1 @@ +"""Configuration package data.""" diff --git a/polyreact/configs/default.yaml b/polyreact/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b4e9d151f0151627d21e6934b710a866d36c77a --- /dev/null +++ b/polyreact/configs/default.yaml @@ -0,0 +1,34 @@ +seed: 42 +device: "auto" +feature_backend: + type: "plm" + plm_model_name: "facebook/esm1v_t33_650M_UR90S_1" + layer_pool: "mean" + cache_dir: ".cache/embeddings" +descriptors: + use_anarci: true + regions: + - "CDRH1" + - "CDRH2" + - "CDRH3" + features: + - "length" + - "charge" + - "hydropathy" + - "aromaticity" + - "pI" + - "net_charge" +model: + head: "logreg" + C: 0.1 + class_weight: "balanced" +calibration: + method: "isotonic" +training: + cv_folds: 10 + scoring: "roc_auc" + n_jobs: -1 +io: + outputs_dir: "artifacts" + preds_filename: "preds.csv" + metrics_filename: "metrics.csv" diff --git a/polyreact/data_loaders/__init__.py b/polyreact/data_loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4532649754d4516db0051c8dda80da5c090861ec --- /dev/null +++ b/polyreact/data_loaders/__init__.py @@ -0,0 +1,3 @@ +"""Dataset loaders for polyreactivity benchmarks.""" + +__all__ = ["boughter", "jain", "shehata", "harvey", "utils"] diff --git a/polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc b/polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2145f834aed3bc27f5e4b19ad4b24bd17ffded2 Binary files /dev/null and b/polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc differ diff --git a/polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc b/polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73f57ad433e7adbaf0430a579acfe435274b2dbd Binary files /dev/null and b/polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc differ diff --git a/polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc b/polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45b2465ae02b7d025aea5f3419d13affea79b90d Binary files /dev/null and b/polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc differ diff --git a/polyreact/data_loaders/__pycache__/jain.cpython-311.pyc b/polyreact/data_loaders/__pycache__/jain.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5734e6e50010aec0a080e2a58bbe843969de3277 Binary files /dev/null and b/polyreact/data_loaders/__pycache__/jain.cpython-311.pyc differ diff --git a/polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc b/polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2163d6c132e48da6d76b3288e1afb97b143b541 Binary files /dev/null and b/polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc differ diff --git a/polyreact/data_loaders/__pycache__/utils.cpython-311.pyc b/polyreact/data_loaders/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47355803442beab304c9404407a61826fe7fa799 Binary files /dev/null and b/polyreact/data_loaders/__pycache__/utils.cpython-311.pyc differ diff --git a/polyreact/data_loaders/boughter.py b/polyreact/data_loaders/boughter.py new file mode 100644 index 0000000000000000000000000000000000000000..3008aadc31432f57f7f7375f945e1c742552ad4f --- /dev/null +++ b/polyreact/data_loaders/boughter.py @@ -0,0 +1,76 @@ +"""Loader for the Boughter et al. 2020 dataset.""" + +from __future__ import annotations + +from typing import Iterable + +import numpy as np +import pandas as pd + +from .utils import LOGGER, standardize_frame + +_COLUMN_ALIASES = { + "id": ("sequence_id",), + "heavy_seq": ("heavy", "heavy_chain"), + "light_seq": ("light", "light_chain"), + "label": ("polyreactive",), +} + + +def _find_flag_columns(columns: Iterable[str]) -> list[str]: + flag_cols: list[str] = [] + for column in columns: + normalized = column.lower().replace(" ", "") + if "flag" in normalized: + flag_cols.append(column) + return flag_cols + + +def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame: + if not flag_columns: + return frame + + flag_values = ( + frame[flag_columns] + .apply(pd.to_numeric, errors="coerce") + .fillna(0.0) + ) + flag_binary = (flag_values > 0).astype(int) + flags_total = flag_binary.sum(axis=1) + + specific_mask = flags_total == 0 + nonspecific_mask = flags_total >= 4 + keep_mask = specific_mask | nonspecific_mask + + dropped = int((~keep_mask).sum()) + if dropped: + LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped) + + filtered = frame.loc[keep_mask].copy() + filtered["flags_total"] = flags_total.loc[keep_mask].astype(int) + filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0) + filtered["polyreactive"] = filtered["label"] + return filtered + + +def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame: + """Load the Boughter dataset into the canonical format.""" + + frame = pd.read_csv(path_or_url) + flag_columns = _find_flag_columns(frame.columns) + frame = _apply_flag_policy(frame, flag_columns) + + label_series = frame.get("label") + if label_series is not None: + frame = frame[label_series.isin({0, 1})].copy() + + standardized = standardize_frame( + frame, + source="boughter2020", + heavy_only=heavy_only, + column_aliases=_COLUMN_ALIASES, + is_test=False, + ) + if "flags_total" in frame.columns and "flags_total" not in standardized.columns: + standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int) + return standardized diff --git a/polyreact/data_loaders/harvey.py b/polyreact/data_loaders/harvey.py new file mode 100644 index 0000000000000000000000000000000000000000..d216ea76e52b6093349083a84c8b2e34f7017990 --- /dev/null +++ b/polyreact/data_loaders/harvey.py @@ -0,0 +1,39 @@ +"""Loader for the Harvey et al. 2022 dataset.""" + +from __future__ import annotations + +import pandas as pd + +from .utils import standardize_frame + +_COLUMN_ALIASES = { + "id": ("id", "clone_id"), + "heavy_seq": ("heavy", "heavy_chain", "sequence"), + "light_seq": ("light", "light_chain"), + "label": ("polyreactive", "is_polyreactive"), +} + +_LABEL_MAP = { + "polyreactive": 1, + "non-polyreactive": 0, + "positive": 1, + "negative": 0, + 1: 1, + 0: 0, + "1": 1, + "0": 0, +} + + +def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame: + """Load the Harvey dataset into the canonical format.""" + + frame = pd.read_csv(path_or_url) + return standardize_frame( + frame, + source="harvey2022", + heavy_only=heavy_only, + column_aliases=_COLUMN_ALIASES, + label_map=_LABEL_MAP, + is_test=True, + ) diff --git a/polyreact/data_loaders/jain.py b/polyreact/data_loaders/jain.py new file mode 100644 index 0000000000000000000000000000000000000000..5abe587f6e8f5e3d69ebdb7440d04ebc71a3f883 --- /dev/null +++ b/polyreact/data_loaders/jain.py @@ -0,0 +1,41 @@ +"""Loader for the Jain et al. 2017 dataset.""" + +from __future__ import annotations + +import pandas as pd + +from .utils import standardize_frame + +_COLUMN_ALIASES = { + "id": ("id", "antibody_id"), + "heavy_seq": ("heavy", "heavy_sequence", "H_chain"), + "light_seq": ("light", "light_sequence", "L_chain"), + "label": ("class", "polyreactive"), +} + +_LABEL_MAP = { + "polyreactive": 1, + "non-polyreactive": 0, + "reactive": 1, + "non-reactive": 0, + 1: 1, + 0: 0, + 1.0: 1, + 0.0: 0, + "1": 1, + "0": 0, +} + + +def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame: + """Load the Jain dataset into the canonical format.""" + + frame = pd.read_csv(path_or_url) + return standardize_frame( + frame, + source="jain2017", + heavy_only=heavy_only, + column_aliases=_COLUMN_ALIASES, + label_map=_LABEL_MAP, + is_test=True, + ) diff --git a/polyreact/data_loaders/shehata.py b/polyreact/data_loaders/shehata.py new file mode 100644 index 0000000000000000000000000000000000000000..14946b26ae6f961fe42a22e062364f4cbe430995 --- /dev/null +++ b/polyreact/data_loaders/shehata.py @@ -0,0 +1,202 @@ +"""Loader for the Shehata et al. (2019) PSR dataset.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +import pandas as pd + +from .utils import standardize_frame + +SHEHATA_SOURCE = "shehata2019" + +_COLUMN_ALIASES = { + "id": ( + "antibody_id", + "antibody", + "antibody name", + "antibody_name", + "sequence_name", + "Antibody Name", + ), + "heavy_seq": ( + "heavy", + "heavy_chain", + "heavy aa", + "heavy_sequence", + "vh", + "vh_sequence", + "heavy chain aa", + "Heavy Chain AA", + ), + "light_seq": ( + "light", + "light_chain", + "light aa", + "light_sequence", + "vl", + "vl_sequence", + "light chain aa", + "Light Chain AA", + ), + "label": ( + "polyreactive", + "binding_class", + "binding class", + "psr_class", + "psr binding", + "psr classification", + "Binding class", + "Binding Class", + ), +} + +_LABEL_MAP = { + "polyreactive": 1, + "non-polyreactive": 0, + "positive": 1, + "negative": 0, + "high": 1, + "low": 0, + "pos": 1, + "neg": 0, + 1: 1, + 0: 0, + 1.0: 1, + 0.0: 0, + "1": 1, + "0": 0, +} + +_PSR_SCORE_ALIASES: tuple[str, ...] = ( + "psr score", + "psr_score", + "psr overall score", + "overall score", + "psr z", + "psr_z", +) + + +def _clean_sequence(sequence: object) -> str: + if isinstance(sequence, str): + return "".join(sequence.split()).upper() + return "" + + +def _maybe_extract_psr_scores(frame: pd.DataFrame) -> pd.DataFrame: + scores: dict[str, pd.Series] = {} + for column in frame.columns: + lowered = column.strip().lower() + if any(alias in lowered for alias in _PSR_SCORE_ALIASES): + key = lowered.replace(" ", "_") + scores[key] = frame[column] + if not scores: + return pd.DataFrame(index=frame.index) + renamed = {} + for name, series in scores.items(): + cleaned_name = name + for prefix in ("psr_", "overall_"): + if cleaned_name.startswith(prefix): + cleaned_name = cleaned_name[len(prefix) :] + break + cleaned_name = cleaned_name.replace("__", "_") + cleaned_name = cleaned_name.replace("(", "").replace(")", "") + cleaned_name = cleaned_name.replace("-", "_") + renamed[f"psr_{cleaned_name}"] = pd.to_numeric(series, errors="coerce") + return pd.DataFrame(renamed) + + +def _pick_source_label(path: Path | None) -> str: + if path is None: + return SHEHATA_SOURCE + stem = path.stem.lower() + if "curated" in stem or "subset" in stem: + return f"{SHEHATA_SOURCE}_curated" + return SHEHATA_SOURCE + + +def _standardize( + frame: pd.DataFrame, + *, + heavy_only: bool, + source: str, +) -> pd.DataFrame: + standardized = standardize_frame( + frame, + source=source, + heavy_only=heavy_only, + column_aliases=_COLUMN_ALIASES, + label_map=_LABEL_MAP, + is_test=True, + ) + + psr_scores = _maybe_extract_psr_scores(frame) + + mask = standardized["heavy_seq"].map(_clean_sequence) != "" + standardized = standardized.loc[mask].copy() + standardized.reset_index(drop=True, inplace=True) + standardized["heavy_seq"] = standardized["heavy_seq"].map(_clean_sequence) + standardized["light_seq"] = standardized["light_seq"].map(_clean_sequence) + + if not psr_scores.empty: + psr_scores = psr_scores.loc[mask] + psr_scores = psr_scores.reset_index(drop=True) + for column in psr_scores.columns: + standardized[column] = psr_scores[column].reset_index(drop=True) + + return standardized + + +def _read_excel(path: Path, *, heavy_only: bool) -> pd.DataFrame: + excel = pd.ExcelFile(path, engine="openpyxl") + sheet_candidates: Iterable[str] = excel.sheet_names + + def _score(name: str) -> tuple[int, str]: + lowered = name.lower() + priority = 0 + if "psr" in lowered or "polyreactivity" in lowered: + priority = 2 + elif "sheet" not in lowered: + priority = 1 + return (-priority, name) + + sheet_name = sorted(sheet_candidates, key=_score)[0] + raw = excel.parse(sheet_name) + raw = raw.dropna(how="all") + return _standardize(raw, heavy_only=heavy_only, source=_pick_source_label(path)) + + +def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame: + """Load the Shehata dataset into the canonical format. + + Supports both pre-processed CSV exports and the original Excel supplement + (*.xls/*.xlsx). Additional PSR score columns are preserved when available. + """ + + lower = path_or_url.lower() + source_override: str | None = None + if lower.startswith("http://") or lower.startswith("https://"): + if lower.endswith((".xls", ".xlsx")): + raw = pd.read_excel(path_or_url, engine="openpyxl") + return _standardize(raw, heavy_only=heavy_only, source=SHEHATA_SOURCE) + frame = pd.read_csv(path_or_url) + return _standardize(frame, heavy_only=heavy_only, source=SHEHATA_SOURCE) + + path = Path(path_or_url) + source_override = _pick_source_label(path) + if path.suffix.lower() in {".xls", ".xlsx"}: + engine = "openpyxl" if path.suffix.lower() == ".xlsx" else None + if engine: + frame = _read_excel(path, heavy_only=heavy_only) + else: + frame = pd.read_excel(path, engine=None) + frame = _standardize(frame, heavy_only=heavy_only, source=source_override) + frame["source"] = source_override + return frame + + frame = pd.read_csv(path) + standardized = _standardize(frame, heavy_only=heavy_only, source=source_override) + standardized["source"] = source_override + return standardized diff --git a/polyreact/data_loaders/utils.py b/polyreact/data_loaders/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b896b17dc652c07393c6e07dcb7d8741b4d411fd --- /dev/null +++ b/polyreact/data_loaders/utils.py @@ -0,0 +1,186 @@ +"""Utility helpers for dataset loading.""" + +from __future__ import annotations + +import logging +from typing import Iterable, Sequence + +import pandas as pd + +EXPECTED_COLUMNS = ("id", "heavy_seq", "light_seq", "label") +OPTIONAL_COLUMNS = ("source", "is_test") + +LOGGER = logging.getLogger("polyreact.data") + +_DEFAULT_ALIASES: dict[str, Sequence[str]] = { + "id": ("id", "sequence_id", "antibody_id", "uid"), + "heavy_seq": ("heavy_seq", "heavy", "heavy_chain", "H", "H_chain"), + "light_seq": ("light_seq", "light", "light_chain", "L", "L_chain"), + "label": ("label", "polyreactive", "is_polyreactive", "class", "target"), +} + +DEFAULT_LABEL_MAP: dict[str | int | float | bool, int] = { + 1: 1, + 0: 0, + "1": 1, + "0": 0, + True: 1, + False: 0, + "true": 1, + "false": 0, + "polyreactive": 1, + "non-polyreactive": 0, + "poly": 1, + "non": 0, + "positive": 1, + "negative": 0, +} + + +def _normalize_label_key(value: object) -> object: + if isinstance(value, str): + trimmed = value.strip().lower() + if trimmed in { + "polyreactive", + "non-polyreactive", + "poly", + "non", + "positive", + "negative", + "high", + "low", + "pos", + "neg", + "1", + "0", + "true", + "false", + }: + return trimmed + if trimmed.isdigit(): + return trimmed + return value + + +def ensure_columns(frame: pd.DataFrame, *, heavy_only: bool = True) -> pd.DataFrame: + """Validate and coerce dataframe columns to the canonical format.""" + + frame = frame.copy() + for column in ("id", "heavy_seq", "label"): + if column not in frame.columns: + msg = f"Required column '{column}' missing from dataframe" + raise KeyError(msg) + + if "light_seq" not in frame.columns: + frame["light_seq"] = "" + + if heavy_only: + frame["light_seq"] = "" + + frame["id"] = frame["id"].astype(str) + frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str) + frame["light_seq"] = frame["light_seq"].fillna("").astype(str) + frame["label"] = frame["label"].astype(int) + + ordered = list(EXPECTED_COLUMNS) + [ + col for col in frame.columns if col not in EXPECTED_COLUMNS + ] + return frame[ordered] + + +def standardize_frame( + frame: pd.DataFrame, + *, + source: str, + heavy_only: bool = True, + column_aliases: dict[str, Sequence[str]] | None = None, + label_map: dict[str | int | float | bool, int] | None = None, + is_test: bool | None = None, +) -> pd.DataFrame: + """Rename columns using aliases and coerce labels to integers.""" + + aliases = {**_DEFAULT_ALIASES} + if column_aliases: + for key, values in column_aliases.items(): + aliases[key] = tuple(values) + tuple(aliases.get(key, ())) + + rename_map: dict[str, str] = {} + for target, candidates in aliases.items(): + if target in frame.columns: + continue + for candidate in candidates: + if candidate in frame.columns and candidate not in rename_map: + rename_map[candidate] = target + break + + normalized = frame.rename(columns=rename_map).copy() + + if "light_seq" not in normalized.columns: + normalized["light_seq"] = "" + + label_lookup = label_map or DEFAULT_LABEL_MAP + normalized["label"] = normalized["label"].map(lambda x: label_lookup.get(_normalize_label_key(x))) + + if normalized["label"].isnull().any(): + msg = "Label column contains unmapped or missing values" + raise ValueError(msg) + + normalized["source"] = source + if is_test is not None: + normalized["is_test"] = bool(is_test) + + normalized = ensure_columns(normalized, heavy_only=heavy_only) + return normalized + + +def deduplicate_sequences( + frames: Iterable[pd.DataFrame], + *, + heavy_only: bool = True, + key_columns: Sequence[str] | None = None, + keep_intra_frames: set[int] | None = None, +) -> list[pd.DataFrame]: + """Remove duplicate entries across multiple dataframes with configurable keys.""" + + if key_columns is None: + key_columns = ["heavy_seq"] if heavy_only else ["heavy_seq", "light_seq"] + keep_intra_frames = keep_intra_frames or set() + + seen: set[tuple[str, ...]] = set() + cleaned: list[pd.DataFrame] = [] + + for frame_idx, frame in enumerate(frames): + valid_columns = [col for col in key_columns if col in frame.columns] + if not valid_columns: + valid_columns = ["heavy_seq"] + + mask: list[bool] = [] + frame_seen: set[tuple[str, ...]] = set() + allow_intra = frame_idx in keep_intra_frames + + for values in frame[valid_columns].itertuples(index=False, name=None): + key = tuple(_normalise_key_value(value) for value in values) + if key in seen: + mask.append(False) + continue + if not allow_intra and key in frame_seen: + mask.append(False) + continue + mask.append(True) + frame_seen.add(key) + seen.update(frame_seen) + filtered = frame.loc[mask].reset_index(drop=True) + removed = len(frame) - len(filtered) + if removed: + dataset = "" + if "source" in frame.columns and not frame["source"].empty: + dataset = str(frame["source"].iloc[0]) + LOGGER.info("Removed %s duplicate sequences from %s", removed, dataset) + cleaned.append(filtered) + return cleaned + + +def _normalise_key_value(value: object) -> str: + if value is None or (isinstance(value, float) and pd.isna(value)): + return "" + return str(value).strip() diff --git a/polyreact/features/__init__.py b/polyreact/features/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..46d551c4b851b577582915d7e073ea0918413c3d --- /dev/null +++ b/polyreact/features/__init__.py @@ -0,0 +1,13 @@ +"""Feature backends for polyreactivity prediction.""" + +from . import anarsi, descriptors, plm +from .pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline + +__all__ = [ + "anarsi", + "descriptors", + "plm", + "FeaturePipeline", + "FeaturePipelineState", + "build_feature_pipeline", +] diff --git a/polyreact/features/__pycache__/__init__.cpython-311.pyc b/polyreact/features/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab95e6cacc5617937fed5d3e3e1b52e003e7507c Binary files /dev/null and b/polyreact/features/__pycache__/__init__.cpython-311.pyc differ diff --git a/polyreact/features/__pycache__/anarsi.cpython-311.pyc b/polyreact/features/__pycache__/anarsi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..446f2f8042ce5cea9f66396e2c83985789347dcd Binary files /dev/null and b/polyreact/features/__pycache__/anarsi.cpython-311.pyc differ diff --git a/polyreact/features/__pycache__/descriptors.cpython-311.pyc b/polyreact/features/__pycache__/descriptors.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e51f2ad2fef2a0fbb2ed6989ac10365de57a18c5 Binary files /dev/null and b/polyreact/features/__pycache__/descriptors.cpython-311.pyc differ diff --git a/polyreact/features/__pycache__/pipeline.cpython-311.pyc b/polyreact/features/__pycache__/pipeline.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe9fd5bb192ebc5c98dc14f3f07a2010cbf0e57 Binary files /dev/null and b/polyreact/features/__pycache__/pipeline.cpython-311.pyc differ diff --git a/polyreact/features/__pycache__/plm.cpython-311.pyc b/polyreact/features/__pycache__/plm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1bb10a71c93e8fa59fe55300df2314a406be08f Binary files /dev/null and b/polyreact/features/__pycache__/plm.cpython-311.pyc differ diff --git a/polyreact/features/anarsi.py b/polyreact/features/anarsi.py new file mode 100644 index 0000000000000000000000000000000000000000..ff477b2f6e27dd4b47bb658988ecd11f066d1e2a --- /dev/null +++ b/polyreact/features/anarsi.py @@ -0,0 +1,222 @@ +"""ANARCI/ANARCII numbering helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache +from typing import Dict, List, Sequence, Tuple + +try: + from anarcii.pipeline import Anarcii # type: ignore +except ImportError: # pragma: no cover - optional dependency + Anarcii = None + + +@dataclass(slots=True) +class NumberedResidue: + """Single residue with IMGT numbering metadata.""" + + position: int + insertion: str + amino_acid: str + + +@dataclass(slots=True) +class NumberedSequence: + """Container for numbering results and derived regions.""" + + sequence: str + scheme: str + chain_type: str + residues: list[NumberedResidue] + regions: dict[str, str] + + +_IMGT_HEAVY_REGIONS: Sequence[Tuple[str, int, int]] = ( + ("FR1", 1, 26), + ("CDRH1", 27, 38), + ("FR2", 39, 55), + ("CDRH2", 56, 65), + ("FR3", 66, 104), + ("CDRH3", 105, 117), + ("FR4", 118, 128), +) + +_IMGT_LIGHT_REGIONS: Sequence[Tuple[str, int, int]] = ( + ("FR1", 1, 26), + ("CDRL1", 27, 38), + ("FR2", 39, 55), + ("CDRL2", 56, 65), + ("FR3", 66, 104), + ("CDRL3", 105, 117), + ("FR4", 118, 128), +) + +_REGION_MAP: dict[Tuple[str, str], Sequence[Tuple[str, int, int]]] = { + ("imgt", "H"): _IMGT_HEAVY_REGIONS, + ("imgt", "L"): _IMGT_LIGHT_REGIONS, +} + +_VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") +_DEFAULT_SCHEME = "imgt" +_DEFAULT_CHAIN = "H" + + +_DEFAULT_NUMBERER: AnarciNumberer | None = None + + +def _sanitize_sequence(sequence: str) -> str: + return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS) + + +def get_default_numberer() -> AnarciNumberer: + global _DEFAULT_NUMBERER + if _DEFAULT_NUMBERER is None: + _DEFAULT_NUMBERER = AnarciNumberer(chain_type=_DEFAULT_CHAIN, cpu=True, ncpu=1, verbose=False) + return _DEFAULT_NUMBERER + + +def trim_variable_domain( + sequence: str, + *, + numberer: AnarciNumberer | None = None, + scheme: str = _DEFAULT_SCHEME, + chain_type: str = _DEFAULT_CHAIN, + fallback_length: int = 130, +) -> str: + """Return the FR1–FR4 variable domain for a heavy/light chain sequence.""" + + cleaned = _sanitize_sequence(sequence) + if not cleaned: + return "" + + active_numberer = numberer or get_default_numberer() + try: + numbered = active_numberer.number_sequence(cleaned) + except Exception: # pragma: no cover - best effort safeguard + return cleaned[:fallback_length] + + region_sets = _region_boundaries(scheme, chain_type) + pieces: list[str] = [] + for region_name, _start, _end in region_sets: + segment = numbered.regions.get(region_name, "") + if segment: + pieces.append(segment) + trimmed = "".join(pieces) + if not trimmed: + trimmed = numbered.regions.get("full", "") + if not trimmed: + trimmed = cleaned[:fallback_length] + return trimmed + + +def _normalise_chain_type(chain_type: str) -> str: + upper = chain_type.upper() + if upper in {"H", "HV"}: + return "H" + if upper in {"L", "K", "LV", "KV"}: + return "L" + return upper + + +class AnarciNumberer: + """Thin wrapper around the ANARCII pipeline to obtain IMGT regions.""" + + def __init__( + self, + *, + scheme: str = "imgt", + chain_type: str = "H", + cpu: bool = True, + ncpu: int = 1, + verbose: bool = False, + ) -> None: + if Anarcii is None: # pragma: no cover - optional dependency guard + msg = ( + "anarcii is required for numbering but is not installed." + " Install 'anarcii' to enable ANARCI-based features." + ) + raise ImportError(msg) + self.scheme = scheme + self.expected_chain_type = _normalise_chain_type(chain_type) + self.cpu = cpu + self.ncpu = ncpu + self.verbose = verbose + self._runner = None + + def _ensure_runner(self) -> Anarcii: + if self._runner is None: + self._runner = Anarcii( + seq_type="antibody", + mode="accuracy", + batch_size=1, + cpu=self.cpu, + ncpu=self.ncpu, + verbose=self.verbose, + ) + return self._runner + + def number_sequence(self, sequence: str) -> NumberedSequence: + """Return numbering metadata for a single amino-acid sequence.""" + + runner = self._ensure_runner() + output = runner.number([sequence]) + record = next(iter(output.values())) + if record.get("error"): + raise RuntimeError(f"ANARCI failed: {record['error']}") + + scheme = record.get("scheme", self.scheme) + detected_chain = record.get("chain_type", self.expected_chain_type) + normalised_chain = _normalise_chain_type(detected_chain) + if self.expected_chain_type and normalised_chain != self.expected_chain_type: + msg = ( + f"Expected chain type {self.expected_chain_type!r} but got" + f" {normalised_chain!r}" + ) + raise ValueError(msg) + + residues = [ + NumberedResidue(position=pos, insertion=ins, amino_acid=aa) + for (pos, ins), aa in record["numbering"] + ] + regions = _extract_regions( + residues=residues, + scheme=scheme, + chain_type=normalised_chain, + ) + return NumberedSequence( + sequence=sequence, + scheme=scheme, + chain_type=normalised_chain, + residues=residues, + regions=regions, + ) + + +@lru_cache(maxsize=32) +def _region_boundaries(scheme: str, chain_type: str) -> Sequence[Tuple[str, int, int]]: + key = (scheme.lower(), chain_type.upper()) + return _REGION_MAP.get(key, ()) + + +def _extract_regions( + *, + residues: Sequence[NumberedResidue], + scheme: str, + chain_type: str, +) -> dict[str, str]: + boundaries = _region_boundaries(scheme, chain_type) + slots: Dict[str, List[str]] = {name: [] for name, _, _ in boundaries} + slots["full"] = [] + + for residue in residues: + aa = residue.amino_acid + if aa == "-": + continue + slots["full"].append(aa) + for name, start, end in boundaries: + if start <= residue.position <= end: + slots.setdefault(name, []).append(aa) + break + + return {key: "".join(value) for key, value in slots.items()} diff --git a/polyreact/features/descriptors.py b/polyreact/features/descriptors.py new file mode 100644 index 0000000000000000000000000000000000000000..4476f9282a9ec1d6d4f74ffec96d11a720a58c3f --- /dev/null +++ b/polyreact/features/descriptors.py @@ -0,0 +1,146 @@ +"""Sequence descriptor features for polyreactivity prediction.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Sequence + +import numpy as np +import pandas as pd +from Bio.SeqUtils.ProtParam import ProteinAnalysis +from sklearn.preprocessing import StandardScaler + +from .anarsi import AnarciNumberer, NumberedSequence + +_VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") + + +@dataclass(slots=True) +class DescriptorConfig: + """Configuration for descriptor-based features.""" + + use_anarci: bool = True + regions: Sequence[str] = ("CDRH1", "CDRH2", "CDRH3") + features: Sequence[str] = ( + "length", + "charge", + "hydropathy", + "aromaticity", + "pI", + "net_charge", + ) + ph: float = 7.4 + + +class DescriptorFeaturizer: + """Compute descriptor features with optional ANARCI-based regions.""" + + def __init__( + self, + *, + config: DescriptorConfig, + numberer: AnarciNumberer | None = None, + standardize: bool = True, + ) -> None: + self.config = config + self.numberer = numberer if not config.use_anarci else numberer or AnarciNumberer() + self.standardize = standardize + self.scaler = StandardScaler() if standardize else None + self.feature_names_: list[str] | None = None + + def fit(self, sequences: Iterable[str]) -> "DescriptorFeaturizer": + table = self.compute_feature_table(sequences) + values = table.to_numpy(dtype=float) + if self.standardize and self.scaler is not None: + self.scaler.fit(values) + self.feature_names_ = list(table.columns) + return self + + def transform(self, sequences: Iterable[str]) -> np.ndarray: + if self.feature_names_ is None: + msg = "DescriptorFeaturizer must be fitted before calling transform." + raise RuntimeError(msg) + table = self.compute_feature_table(sequences) + values = table.to_numpy(dtype=float) + if self.standardize and self.scaler is not None: + values = self.scaler.transform(values) + return values + + def fit_transform(self, sequences: Iterable[str]) -> np.ndarray: + table = self.compute_feature_table(sequences) + values = table.to_numpy(dtype=float) + if self.standardize and self.scaler is not None: + self.scaler.fit(values) + values = self.scaler.transform(values) + self.feature_names_ = list(table.columns) + return values + + def compute_feature_table(self, sequences: Iterable[str]) -> pd.DataFrame: + rows: list[dict[str, float]] = [] + for sequence in sequences: + regions = self._prepare_regions(sequence) + if not self.config.use_anarci: + region_names = ["FULL"] + else: + region_names = [region.upper() for region in self.config.regions] + row: dict[str, float] = {} + for region_name in region_names: + normalized_name = region_name.upper() + region_sequence = regions.get(normalized_name, "") + for feature_name in self.config.features: + column = f"{normalized_name}_{feature_name}" + row[column] = _compute_feature( + region_sequence, + feature_name, + ph=self.config.ph, + ) + rows.append(row) + + if not self.config.use_anarci: + region_names = ["FULL"] + else: + region_names = [region.upper() for region in self.config.regions] + columns = [ + f"{region}_{feature}" + for region in region_names + for feature in self.config.features + ] + frame = pd.DataFrame(rows, columns=columns) + return frame.fillna(0.0) + + def _prepare_regions(self, sequence: str) -> dict[str, str]: + if not self.config.use_anarci: + return {"FULL": sequence} + + try: + numbered: NumberedSequence = self.numberer.number_sequence(sequence) + except (RuntimeError, ValueError): + return {} + return {key.upper(): value for key, value in numbered.regions.items()} + + +def _sanitize_sequence(sequence: str) -> str: + return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS) + + +def _compute_feature(sequence: str, feature_name: str, *, ph: float) -> float: + sanitized = _sanitize_sequence(sequence) + if not sanitized: + return 0.0 + + analysis = ProteinAnalysis(sanitized) + if feature_name == "length": + return float(len(sanitized)) + if feature_name == "hydropathy": + return float(analysis.gravy()) + if feature_name == "aromaticity": + return float(analysis.aromaticity()) + if feature_name == "pI": + return float(analysis.isoelectric_point()) + if feature_name == "net_charge": + return float(analysis.charge_at_pH(ph)) + if feature_name == "charge": + net = analysis.charge_at_pH(ph) + return float(net / len(sanitized)) + msg = f"Unsupported feature: {feature_name}" + raise ValueError(msg) diff --git a/polyreact/features/pipeline.py b/polyreact/features/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..99e214fbb14384f18e7b1bd9d7ac312c50ee17c8 --- /dev/null +++ b/polyreact/features/pipeline.py @@ -0,0 +1,343 @@ +"""Feature pipeline construction utilities.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Iterable, Sequence + +import numpy as np +from sklearn.preprocessing import StandardScaler + +from ..config import Config, DescriptorSettings, FeatureBackendSettings +from .descriptors import DescriptorConfig, DescriptorFeaturizer +from .plm import PLMEmbedder + + +@dataclass(slots=True) +class FeaturePipelineState: + backend_type: str + descriptor_featurizer: DescriptorFeaturizer | None + plm_scaler: StandardScaler | None + descriptor_config: DescriptorConfig | None + plm_model_name: str | None + plm_layer_pool: str | None + cache_dir: str | None + device: str + feature_names: list[str] = field(default_factory=list) + + +class FeaturePipeline: + """Fit/transform feature matrices according to configuration.""" + + def __init__( + self, + *, + backend: FeatureBackendSettings, + descriptors: DescriptorSettings, + device: str, + cache_dir_override: str | None = None, + plm_model_override: str | None = None, + layer_pool_override: str | None = None, + ) -> None: + self.backend = backend + self.descriptor_settings = descriptors + self.device = device + self.cache_dir_override = cache_dir_override + self.plm_model_override = plm_model_override + self.layer_pool_override = layer_pool_override + + self._descriptor: DescriptorFeaturizer | None = None + self._plm: PLMEmbedder | None = None + self._plm_scaler: StandardScaler | None = None + self._feature_names: list[str] = [] + + def fit_transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001 + backend_type = self.backend.type if self.backend.type else "descriptors" + self._validate_heavy_support(backend_type, heavy_only) + sequences = _extract_sequences(df, heavy_only=heavy_only) + + if backend_type == "descriptors": + self._descriptor = _build_descriptor_featurizer(self.descriptor_settings) + features = self._descriptor.fit_transform(sequences) + self._feature_names = list(self._descriptor.feature_names_ or []) + self._plm = None + self._plm_scaler = None + return features.astype(np.float32) + + if backend_type == "plm": + self._descriptor = None + self._plm = _build_plm_embedder( + self.backend, + device=self.device, + cache_dir_override=self.cache_dir_override, + plm_model_override=self.plm_model_override, + layer_pool_override=self.layer_pool_override, + ) + embeddings = self._plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize: + self._plm_scaler = StandardScaler() + embeddings = self._plm_scaler.fit_transform(embeddings) + else: + self._plm_scaler = None + self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])] + return embeddings.astype(np.float32) + + if backend_type == "concat": + descriptor = _build_descriptor_featurizer(self.descriptor_settings) + desc_features = descriptor.fit_transform(sequences) + plm = _build_plm_embedder( + self.backend, + device=self.device, + cache_dir_override=self.cache_dir_override, + plm_model_override=self.plm_model_override, + layer_pool_override=self.layer_pool_override, + ) + embeddings = plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize: + plm_scaler = StandardScaler() + embeddings = plm_scaler.fit_transform(embeddings) + else: + plm_scaler = None + self._descriptor = descriptor + self._plm = plm + self._plm_scaler = plm_scaler + self._feature_names = list(descriptor.feature_names_ or []) + [ + f"plm_{i}" for i in range(embeddings.shape[1]) + ] + return np.concatenate([desc_features, embeddings], axis=1).astype(np.float32) + + msg = f"Unsupported feature backend: {backend_type}" + raise ValueError(msg) + + def fit(self, df, *, heavy_only: bool, batch_size: int = 8) -> "FeaturePipeline": # noqa: ANN001 + backend_type = self.backend.type if self.backend.type else "descriptors" + self._validate_heavy_support(backend_type, heavy_only) + sequences = _extract_sequences(df, heavy_only=heavy_only) + + if backend_type == "descriptors": + self._descriptor = _build_descriptor_featurizer(self.descriptor_settings) + self._descriptor.fit(sequences) + self._feature_names = list(self._descriptor.feature_names_ or []) + self._plm = None + self._plm_scaler = None + elif backend_type == "plm": + self._descriptor = None + self._plm = _build_plm_embedder( + self.backend, + device=self.device, + cache_dir_override=self.cache_dir_override, + plm_model_override=self.plm_model_override, + layer_pool_override=self.layer_pool_override, + ) + embeddings = self._plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize: + self._plm_scaler = StandardScaler() + embeddings = self._plm_scaler.fit_transform(embeddings) + else: + self._plm_scaler = None + self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])] + elif backend_type == "concat": + descriptor = _build_descriptor_featurizer(self.descriptor_settings) + desc_features = descriptor.fit_transform(sequences) + plm = _build_plm_embedder( + self.backend, + device=self.device, + cache_dir_override=self.cache_dir_override, + plm_model_override=self.plm_model_override, + layer_pool_override=self.layer_pool_override, + ) + embeddings = plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize: + plm_scaler = StandardScaler() + embeddings = plm_scaler.fit_transform(embeddings) + else: + plm_scaler = None + self._descriptor = descriptor + self._plm = plm + self._plm_scaler = plm_scaler + self._feature_names = list(descriptor.feature_names_ or []) + [ + f"plm_{i}" for i in range(embeddings.shape[1]) + ] + else: # pragma: no cover - defensive branch + msg = f"Unsupported feature backend: {backend_type}" + raise ValueError(msg) + return self + + def transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001 + backend_type = self.backend.type if self.backend.type else "descriptors" + self._validate_heavy_support(backend_type, heavy_only) + sequences = _extract_sequences(df, heavy_only=heavy_only) + + if backend_type == "descriptors": + if self._descriptor is None: + msg = "Descriptor featurizer is not fitted" + raise RuntimeError(msg) + features = self._descriptor.transform(sequences) + elif backend_type == "plm": + if self._plm is None: + msg = "PLM embedder is not initialised" + raise RuntimeError(msg) + embeddings = self._plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize and self._plm_scaler is not None: + embeddings = self._plm_scaler.transform(embeddings) + features = embeddings + elif backend_type == "concat": + if self._descriptor is None or self._plm is None: + msg = "Feature pipeline not fitted" + raise RuntimeError(msg) + desc_features = self._descriptor.transform(sequences) + embeddings = self._plm.embed(sequences, batch_size=batch_size) + if self.backend.standardize and self._plm_scaler is not None: + embeddings = self._plm_scaler.transform(embeddings) + features = np.concatenate([desc_features, embeddings], axis=1) + else: # pragma: no cover - defensive branch + msg = f"Unsupported feature backend: {backend_type}" + raise ValueError(msg) + + return features.astype(np.float32) + + @property + def feature_names(self) -> list[str]: + return self._feature_names + + def get_state(self) -> FeaturePipelineState: + descriptor = self._descriptor + if descriptor is not None and descriptor.numberer is not None: + if hasattr(descriptor.numberer, "_runner"): + descriptor.numberer._runner = None # type: ignore[attr-defined] + return FeaturePipelineState( + backend_type=self.backend.type, + descriptor_featurizer=descriptor, + plm_scaler=self._plm_scaler, + descriptor_config=_build_descriptor_config(self.descriptor_settings), + plm_model_name=self._effective_plm_model_name, + plm_layer_pool=self._effective_layer_pool, + cache_dir=self._effective_cache_dir, + device=self.device, + feature_names=self._feature_names, + ) + + def load_state(self, state: FeaturePipelineState) -> None: + self.backend.type = state.backend_type + if state.plm_model_name: + self.backend.plm_model_name = state.plm_model_name + self.plm_model_override = state.plm_model_name + if state.plm_layer_pool: + self.backend.layer_pool = state.plm_layer_pool + self.layer_pool_override = state.plm_layer_pool + if state.cache_dir: + self.backend.cache_dir = state.cache_dir + self.cache_dir_override = state.cache_dir + if state.descriptor_config: + self.descriptor_settings = DescriptorSettings( + use_anarci=state.descriptor_config.use_anarci, + regions=tuple(state.descriptor_config.regions), + features=tuple(state.descriptor_config.features), + ph=state.descriptor_config.ph, + ) + self._descriptor = state.descriptor_featurizer + self._plm_scaler = state.plm_scaler + self._feature_names = state.feature_names + if self.backend.type in {"plm", "concat"}: + self._plm = _build_plm_embedder( + self.backend, + device=self.device, + cache_dir_override=self.backend.cache_dir, + plm_model_override=self.backend.plm_model_name, + layer_pool_override=self.backend.layer_pool, + ) + else: + self._plm = None + + @property + def _effective_plm_model_name(self) -> str | None: + if self.backend.type not in {"plm", "concat"}: + return None + return self.plm_model_override or self.backend.plm_model_name + + @property + def _effective_layer_pool(self) -> str | None: + if self.backend.type not in {"plm", "concat"}: + return None + return self.layer_pool_override or self.backend.layer_pool + + @property + def _effective_cache_dir(self) -> str | None: + if self.backend.type not in {"plm", "concat"}: + return None + if self.cache_dir_override is not None: + return self.cache_dir_override + return self.backend.cache_dir + + def _validate_heavy_support(self, backend_type: str, heavy_only: bool) -> None: + if heavy_only: + return + if backend_type == "descriptors" and self.descriptor_settings.use_anarci: + msg = "Descriptor backend with ANARCI currently supports heavy-chain only inference." + raise ValueError(msg) + if backend_type == "concat" and self.descriptor_settings.use_anarci: + msg = "Concat backend with descriptors requires heavy-chain only data." + raise ValueError(msg) + + +def build_feature_pipeline( + config: Config, + *, + backend_override: str | None = None, + plm_model_override: str | None = None, + cache_dir_override: str | None = None, + layer_pool_override: str | None = None, +) -> FeaturePipeline: + backend = FeatureBackendSettings(**asdict(config.feature_backend)) + if backend_override: + backend.type = backend_override + pipeline = FeaturePipeline( + backend=backend, + descriptors=config.descriptors, + device=config.device, + cache_dir_override=cache_dir_override, + plm_model_override=plm_model_override, + layer_pool_override=layer_pool_override, + ) + return pipeline + + +def _build_descriptor_featurizer(settings: DescriptorSettings) -> DescriptorFeaturizer: + descriptor_config = _build_descriptor_config(settings) + return DescriptorFeaturizer(config=descriptor_config, standardize=True) + + +def _build_descriptor_config(settings: DescriptorSettings) -> DescriptorConfig: + return DescriptorConfig( + use_anarci=settings.use_anarci, + regions=tuple(settings.regions), + features=tuple(settings.features), + ph=settings.ph, + ) + + +def _build_plm_embedder( + backend: FeatureBackendSettings, + *, + device: str, + cache_dir_override: str | None, + plm_model_override: str | None, + layer_pool_override: str | None, +) -> PLMEmbedder: + model_name = plm_model_override or backend.plm_model_name + cache_dir = cache_dir_override or backend.cache_dir + layer_pool = layer_pool_override or backend.layer_pool + return PLMEmbedder( + model_name=model_name, + layer_pool=layer_pool, + device=device, + cache_dir=cache_dir, + ) + + +def _extract_sequences(df, heavy_only: bool) -> Sequence[str]: # noqa: ANN001 + if heavy_only or "light_seq" not in df.columns: + return df["heavy_seq"].fillna("").astype(str).tolist() + heavy = df["heavy_seq"].fillna("").astype(str) + light = df["light_seq"].fillna("").astype(str) + return (heavy + "|" + light).tolist() diff --git a/polyreact/features/plm.py b/polyreact/features/plm.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8a2d395c1b927cf6b0b7288744158e0f0a6817 --- /dev/null +++ b/polyreact/features/plm.py @@ -0,0 +1,378 @@ +"""Protein language model embeddings backend with caching support.""" + +from __future__ import annotations + +import hashlib +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Callable, Iterable, List, Sequence, Tuple + +import numpy as np +import torch +from torch import nn + +try: # pragma: no cover - optional dependency + from transformers import AutoModel, AutoTokenizer +except ImportError: # pragma: no cover - optional dependency + AutoModel = None + AutoTokenizer = None + +try: # pragma: no cover - optional dependency + import esm +except ImportError: # pragma: no cover - optional dependency + esm = None + +from .anarsi import AnarciNumberer + +ModelLoader = Callable[[str, str], Tuple[object, nn.Module]] + +if esm is not None: # pragma: no cover - optional dependency + _ESM1V_LOADERS = { + "esm1v_t33_650m_ur90s_1": esm.pretrained.esm1v_t33_650M_UR90S_1, + "esm1v_t33_650m_ur90s_2": esm.pretrained.esm1v_t33_650M_UR90S_2, + "esm1v_t33_650m_ur90s_3": esm.pretrained.esm1v_t33_650M_UR90S_3, + "esm1v_t33_650m_ur90s_4": esm.pretrained.esm1v_t33_650M_UR90S_4, + "esm1v_t33_650m_ur90s_5": esm.pretrained.esm1v_t33_650M_UR90S_5, + } +else: # pragma: no cover - optional dependency + _ESM1V_LOADERS: dict[str, Callable[[], tuple[nn.Module, object]]] = {} + + +class _ESMTokenizer: + """Callable wrapper that mimics Hugging Face tokenizers for ESM models.""" + + def __init__(self, alphabet) -> None: # noqa: ANN001 + self.alphabet = alphabet + self._batch_converter = alphabet.get_batch_converter() + + def __call__( + self, + sequences: Sequence[str], + *, + return_tensors: str = "pt", + padding: bool = True, # noqa: FBT002 + truncation: bool = True, # noqa: FBT002 + add_special_tokens: bool = True, # noqa: FBT002 + return_special_tokens_mask: bool = True, # noqa: FBT002 + ) -> dict[str, torch.Tensor]: + if return_tensors != "pt": # pragma: no cover - defensive branch + msg = "ESM tokenizer only supports return_tensors='pt'" + raise ValueError(msg) + data = [(str(idx), (seq or "").upper()) for idx, seq in enumerate(sequences)] + _labels, _strings, tokens = self._batch_converter(data) + attention_mask = (tokens != self.alphabet.padding_idx).long() + special_tokens = torch.zeros_like(tokens) + specials = { + self.alphabet.padding_idx, + self.alphabet.cls_idx, + self.alphabet.eos_idx, + } + for special in specials: + special_tokens |= tokens == special + output: dict[str, torch.Tensor] = { + "input_ids": tokens, + "attention_mask": attention_mask, + } + if return_special_tokens_mask: + output["special_tokens_mask"] = special_tokens.long() + return output + + +class _ESMModelWrapper(nn.Module): + """Adapter providing a Hugging Face style interface for ESM models.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + self.layer_index = getattr(model, "num_layers", None) + if self.layer_index is None: + msg = "Unable to determine final layer for ESM model" + raise AttributeError(msg) + + def eval(self) -> "_ESMModelWrapper": # pragma: no cover - trivial + self.model.eval() + return self + + def to(self, device: str) -> "_ESMModelWrapper": # pragma: no cover - trivial + self.model.to(device) + return self + + def forward(self, input_ids: torch.Tensor, **_): # noqa: ANN003 + with torch.no_grad(): + outputs = self.model( + input_ids, + repr_layers=[self.layer_index], + return_contacts=False, + ) + hidden = outputs["representations"][self.layer_index] + return SimpleNamespace(last_hidden_state=hidden) + + __call__ = forward + + +@dataclass(slots=True) +class PLMConfig: + model_name: str = "facebook/esm1v_t33_650M_UR90S_1" + layer_pool: str = "mean" + cache_dir: Path = Path(".cache/embeddings") + device: str = "auto" + + +class PLMEmbedder: + """Embed amino-acid sequences using a transformer model with caching.""" + + def __init__( + self, + model_name: str = "facebook/esm1v_t33_650M_UR90S_1", + *, + layer_pool: str = "mean", + device: str = "auto", + cache_dir: str | Path | None = None, + numberer: AnarciNumberer | None = None, + model_loader: ModelLoader | None = None, + ) -> None: + self.model_name = model_name + self.layer_pool = layer_pool + self.device = self._resolve_device(device) + self.cache_dir = Path(cache_dir or ".cache/embeddings") + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.numberer = numberer + self.model_loader = model_loader + self._tokenizer: object | None = None + self._model: nn.Module | None = None + + @staticmethod + def _resolve_device(device: str) -> str: + if device == "auto": + return "cuda" if torch.cuda.is_available() else "cpu" + return device + + @property + def tokenizer(self): # noqa: D401 + if self._tokenizer is None: + tokenizer, model = self._load_model_components() + self._tokenizer = tokenizer + self._model = model + return self._tokenizer + + @property + def model(self) -> nn.Module: + if self._model is None: + tokenizer, model = self._load_model_components() + self._tokenizer = tokenizer + self._model = model + return self._model + + def _load_model_components(self) -> Tuple[object, nn.Module]: + if self.model_loader is not None: + tokenizer, model = self.model_loader(self.model_name, self.device) + return tokenizer, model + + if self._is_esm1v_model(self.model_name): + return self._load_esm_model() + + if AutoModel is None or AutoTokenizer is None: # pragma: no cover - optional dependency + msg = "transformers must be installed to use PLMEmbedder" + raise ImportError(msg) + tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True) + model.eval() + model.to(self.device) + return tokenizer, model + + def _load_esm_model(self) -> Tuple[object, nn.Module]: + if esm is None: # pragma: no cover - optional dependency + msg = ( + "The 'esm' package is required to use ESM-1v models." + ) + raise ImportError(msg) + + normalized = self._canonical_esm_name(self.model_name) + loader = _ESM1V_LOADERS.get(normalized) + if loader is None: # pragma: no cover - guard branch + msg = f"Unsupported ESM-1v model: {self.model_name}" + raise ValueError(msg) + + model, alphabet = loader() + model.eval() + model.to(self.device) + tokenizer = _ESMTokenizer(alphabet) + wrapper = _ESMModelWrapper(model) + return tokenizer, wrapper + + @staticmethod + def _canonical_esm_name(model_name: str) -> str: + name = model_name.lower() + if "/" in name: + name = name.split("/")[-1] + return name + + @classmethod + def _is_esm1v_model(cls, model_name: str) -> bool: + return cls._canonical_esm_name(model_name).startswith("esm1v") + + def embed(self, sequences: Iterable[str], *, batch_size: int = 8) -> np.ndarray: + batch_sequences = list(sequences) + if not batch_sequences: + return np.empty((0, 0), dtype=np.float32) + + outputs: List[np.ndarray | None] = [None] * len(batch_sequences) + unique_to_compute: dict[str, List[Tuple[int, Path]]] = {} + model_dir = self.cache_dir / self._normalized_model_name() + model_dir.mkdir(parents=True, exist_ok=True) + + cache_hits: list[tuple[int, Path]] = [] + for idx, sequence in enumerate(batch_sequences): + cache_path = self._sequence_cache_path(model_dir, sequence) + if cache_path.exists(): + cache_hits.append((idx, cache_path)) + else: + unique_to_compute.setdefault(sequence, []).append((idx, cache_path)) + + if cache_hits: + loaders = [path for _, path in cache_hits] + max_workers = min(len(loaders), 32) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + for (idx, _), embedding in zip(cache_hits, executor.map(np.load, loaders), strict=True): + outputs[idx] = embedding + + if unique_to_compute: + embeddings = self._compute_embeddings(list(unique_to_compute.keys()), batch_size=batch_size) + for sequence, embedding in zip(unique_to_compute.keys(), embeddings, strict=True): + targets = unique_to_compute[sequence] + for idx, cache_path in targets: + outputs[idx] = embedding + np.save(cache_path, embedding) + if any(item is None for item in outputs): # pragma: no cover - safety + msg = "Failed to compute embeddings for all sequences" + raise RuntimeError(msg) + array_outputs = [np.asarray(item, dtype=np.float32) for item in outputs] # type: ignore[arg-type] + return np.stack(array_outputs, axis=0) + + def _compute_embeddings(self, sequences: Sequence[str], *, batch_size: int) -> List[np.ndarray]: + tokenizer = self.tokenizer + model = self.model + model.eval() + embeddings: List[np.ndarray] = [] + for start in range(0, len(sequences), batch_size): + chunk = list(sequences[start : start + batch_size]) + tokenized = self._tokenize(tokenizer, chunk) + model_inputs: dict[str, torch.Tensor] = {} + aux_inputs: dict[str, torch.Tensor] = {} + for key, value in tokenized.items(): + if isinstance(value, torch.Tensor): + tensor_value = value.to(self.device) + else: + tensor_value = value + if key == "special_tokens_mask": + aux_inputs[key] = tensor_value + else: + model_inputs[key] = tensor_value + with torch.no_grad(): + outputs = model(**model_inputs) + hidden_states = outputs.last_hidden_state.detach().cpu() + attention_mask = model_inputs.get("attention_mask") + special_tokens_mask = aux_inputs.get("special_tokens_mask") + if isinstance(attention_mask, torch.Tensor): + attention_mask = attention_mask.detach().cpu() + if isinstance(special_tokens_mask, torch.Tensor): + special_tokens_mask = special_tokens_mask.detach().cpu() + + for idx, sequence in enumerate(chunk): + hidden = hidden_states[idx] + mask = attention_mask[idx] if isinstance(attention_mask, torch.Tensor) else None + special_mask = ( + special_tokens_mask[idx] + if isinstance(special_tokens_mask, torch.Tensor) + else None + ) + embedding = self._pool_hidden(hidden, mask, special_mask, sequence) + embeddings.append(embedding) + return embeddings + + def _tokenize(self, tokenizer, sequences: Sequence[str]): + if hasattr(tokenizer, "__call__"): + return tokenizer( + list(sequences), + return_tensors="pt", + padding=True, + truncation=True, + add_special_tokens=True, + return_special_tokens_mask=True, + ) + msg = "Tokenizer does not implement __call__" + raise TypeError(msg) + + def _pool_hidden( + self, + hidden: torch.Tensor, + attention_mask: torch.Tensor | None, + special_mask: torch.Tensor | None, + sequence: str, + ) -> np.ndarray: + if attention_mask is None: + attention = torch.ones(hidden.size(0), dtype=torch.float32) + else: + attention = attention_mask.to(dtype=torch.float32) + if special_mask is not None: + attention = attention * (1.0 - special_mask.to(dtype=torch.float32)) + if attention.sum() == 0: + attention = torch.ones_like(attention) + + if self.layer_pool == "mean": + return self._masked_mean(hidden, attention) + if self.layer_pool == "cls": + return hidden[0].detach().cpu().numpy() + if self.layer_pool == "per_token_mean_cdrh3": + return self._pool_cdrh3(hidden, attention, sequence) + msg = f"Unsupported layer pool: {self.layer_pool}" + raise ValueError(msg) + + @staticmethod + def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> np.ndarray: + weights = mask.unsqueeze(-1) + weighted = hidden * weights + denom = weights.sum() + if denom == 0: + pooled = hidden.mean(dim=0) + else: + pooled = weighted.sum(dim=0) / denom + return pooled.detach().cpu().numpy() + + def _pool_cdrh3(self, hidden: torch.Tensor, mask: torch.Tensor, sequence: str) -> np.ndarray: + numberer = self.numberer + if numberer is None: + numberer = AnarciNumberer() + self.numberer = numberer + numbered = numberer.number_sequence(sequence) + cdr = numbered.regions.get("CDRH3", "") + if not cdr: + return self._masked_mean(hidden, mask) + sequence_upper = sequence.upper() + start = sequence_upper.find(cdr.upper()) + if start == -1: + return self._masked_mean(hidden, mask) + residues_idx = mask.nonzero(as_tuple=False).squeeze(-1).tolist() + if not residues_idx: + return self._masked_mean(hidden, mask) + end = start + len(cdr) + if end > len(residues_idx): + return self._masked_mean(hidden, mask) + cdr_token_positions = residues_idx[start:end] + if not cdr_token_positions: + return self._masked_mean(hidden, mask) + cdr_mask = torch.zeros_like(mask) + for pos in cdr_token_positions: + cdr_mask[pos] = 1.0 + return self._masked_mean(hidden, cdr_mask) + + def _sequence_cache_path(self, model_dir: Path, sequence: str) -> Path: + digest = hashlib.sha1(sequence.encode("utf-8")).hexdigest() + return model_dir / f"{digest}.npy" + + def _normalized_model_name(self) -> str: + if self._is_esm1v_model(self.model_name): + return self._canonical_esm_name(self.model_name) + return self.model_name.replace("/", "_") diff --git a/polyreact/models/__init__.py b/polyreact/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26d80b461dddb6dd14cb29d3ca0e3cbe3c628c86 --- /dev/null +++ b/polyreact/models/__init__.py @@ -0,0 +1,3 @@ +"""Models for polyreactivity classification.""" + +__all__ = ["linear", "calibrate", "ordinal"] diff --git a/polyreact/models/__pycache__/__init__.cpython-311.pyc b/polyreact/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d66f00a23b5e357f03771ee8367b0f95a77b01 Binary files /dev/null and b/polyreact/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/polyreact/models/__pycache__/calibrate.cpython-311.pyc b/polyreact/models/__pycache__/calibrate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8df2f6eb983f97bb243a70c0c1f73e4d2c8d5c Binary files /dev/null and b/polyreact/models/__pycache__/calibrate.cpython-311.pyc differ diff --git a/polyreact/models/__pycache__/linear.cpython-311.pyc b/polyreact/models/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50cdcb937c12ab716f939e703ac0da1ea576cd7d Binary files /dev/null and b/polyreact/models/__pycache__/linear.cpython-311.pyc differ diff --git a/polyreact/models/__pycache__/ordinal.cpython-311.pyc b/polyreact/models/__pycache__/ordinal.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b96eb2a3510a23365fac3d63530c21cd44841988 Binary files /dev/null and b/polyreact/models/__pycache__/ordinal.cpython-311.pyc differ diff --git a/polyreact/models/calibrate.py b/polyreact/models/calibrate.py new file mode 100644 index 0000000000000000000000000000000000000000..1c08de617f8e936f6224f2f8bac4a7fa7ab50a67 --- /dev/null +++ b/polyreact/models/calibrate.py @@ -0,0 +1,24 @@ +"""Probability calibration helpers.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +from sklearn.calibration import CalibratedClassifierCV + + +def fit_calibrator( + estimator: Any, + X: np.ndarray, + y: np.ndarray, + *, + method: str = "isotonic", + cv: int | str | None = "prefit", +) -> CalibratedClassifierCV: + """Fit a ``CalibratedClassifierCV`` on top of a pre-trained estimator.""" + + calibrator = CalibratedClassifierCV(estimator, method=method, cv=cv) + calibrator.fit(X, y) + return calibrator + diff --git a/polyreact/models/linear.py b/polyreact/models/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..a798b74482948ce41bba1b89692143fc6cc3a8a9 --- /dev/null +++ b/polyreact/models/linear.py @@ -0,0 +1,91 @@ +"""Linear classification heads for polyreactivity prediction.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.svm import LinearSVC + + +@dataclass(slots=True) +class LinearModelConfig: + """Configuration options for linear heads.""" + + head: str = "logreg" + C: float = 1.0 + class_weight: Any = "balanced" + max_iter: int = 1000 + + +@dataclass(slots=True) +class TrainedModel: + """Container for trained estimators and optional calibration.""" + + estimator: Any + calibrator: Any | None = None + vectorizer_name: str = "" + feature_meta: dict[str, Any] = field(default_factory=dict) + metrics_cv: dict[str, float] = field(default_factory=dict) + + def predict(self, X: np.ndarray) -> np.ndarray: + if self.calibrator is not None and hasattr(self.calibrator, "predict"): + return self.calibrator.predict(X) + return self.estimator.predict(X) + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + if self.calibrator is not None and hasattr(self.calibrator, "predict_proba"): + probs = self.calibrator.predict_proba(X) + return probs[:, 1] + if hasattr(self.estimator, "predict_proba"): + probs = self.estimator.predict_proba(X) + return probs[:, 1] + if hasattr(self.estimator, "decision_function"): + scores = self.estimator.decision_function(X) + return 1.0 / (1.0 + np.exp(-scores)) + msg = "Estimator does not support probability prediction" + raise AttributeError(msg) + + +def build_estimator( + *, config: LinearModelConfig, random_state: int | None = 42 +) -> Any: + """Construct an unfitted linear estimator based on configuration.""" + + if config.head == "logreg": + estimator = LogisticRegression( + C=config.C, + max_iter=config.max_iter, + class_weight=config.class_weight, + solver="liblinear", + random_state=random_state, + ) + elif config.head == "linear_svm": + estimator = LinearSVC( + C=config.C, + class_weight=config.class_weight, + max_iter=config.max_iter, + random_state=random_state, + ) + else: # pragma: no cover - defensive branch + msg = f"Unsupported head type: {config.head}" + raise ValueError(msg) + return estimator + + +def train_linear_model( + X: np.ndarray, + y: np.ndarray, + *, + config: LinearModelConfig, + random_state: int | None = 42, +) -> TrainedModel: + """Fit a linear classifier on the provided feature matrix.""" + + estimator = build_estimator(config=config, random_state=random_state) + if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: + estimator.set_params(solver="lbfgs") + estimator.fit(X, y) + return TrainedModel(estimator=estimator) diff --git a/polyreact/models/ordinal.py b/polyreact/models/ordinal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5208ca3b8b66072ac39df54aa987fea725fcf6 --- /dev/null +++ b/polyreact/models/ordinal.py @@ -0,0 +1,106 @@ +"""Ordinal/count modeling utilities for flag regression.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import statsmodels.api as sm +from statsmodels.discrete.discrete_model import NegativeBinomialResults +from sklearn.linear_model import PoissonRegressor +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score + +try: + from sklearn.metrics import root_mean_squared_error +except ImportError: # pragma: no cover - fallback for older sklearn + def root_mean_squared_error(y_true, y_pred): + return mean_squared_error(y_true, y_pred, squared=False) + + +@dataclass(slots=True) +class PoissonModel: + """Wrapper storing a fitted Poisson regression model.""" + + estimator: PoissonRegressor + + def predict(self, X: np.ndarray) -> np.ndarray: + return self.estimator.predict(X) + + +def fit_poisson_model( + X: np.ndarray, + y: np.ndarray, + *, + alpha: float = 1e-6, + max_iter: int = 1000, +) -> PoissonModel: + """Train a Poisson regression model on count targets.""" + + model = PoissonRegressor(alpha=alpha, max_iter=max_iter) + model.fit(X, y) + return PoissonModel(estimator=model) + + +@dataclass(slots=True) +class NegativeBinomialModel: + """Wrapper storing a fitted negative-binomial regression model.""" + + result: NegativeBinomialResults + + def predict(self, X: np.ndarray) -> np.ndarray: + X_const = sm.add_constant(X, has_constant="add") + return self.result.predict(X_const) + + @property + def alpha(self) -> float: + params = np.asarray(self.result.params, dtype=float) + exog_dim = self.result.model.exog.shape[1] + if params.size > exog_dim: + # statsmodels stores log(alpha) as the final coefficient + return float(np.exp(params[-1])) + model_alpha = getattr(self.result.model, "alpha", None) + if model_alpha is not None: + return float(model_alpha) + return float("nan") + + +def fit_negative_binomial_model( + X: np.ndarray, + y: np.ndarray, + *, + max_iter: int = 200, +) -> NegativeBinomialModel: + """Train a negative binomial regression model (NB2).""" + + X_const = sm.add_constant(X, has_constant="add") + model = sm.NegativeBinomial(y, X_const, loglike_method="nb2") + result = model.fit(maxiter=max_iter, disp=False) + return NegativeBinomialModel(result=result) + + +def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]: + """Return standard regression metrics for count predictions.""" + + mae = mean_absolute_error(y_true, y_pred) + rmse = root_mean_squared_error(y_true, y_pred) + r2 = r2_score(y_true, y_pred) + return { + "mae": float(mae), + "rmse": float(rmse), + "r2": float(r2), + } + + +def pearson_dispersion( + y_true: np.ndarray, + y_pred: np.ndarray, + *, + dof: int, +) -> float: + """Compute Pearson dispersion (chi-square / dof).""" + + eps = 1e-8 + adjusted = np.maximum(y_pred, eps) + resid = (y_true - y_pred) / np.sqrt(adjusted) + denom = max(dof, 1) + return float(np.sum(resid**2) / denom) diff --git a/polyreact/predict.py b/polyreact/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..83916a068b5447333d857d88d843fde61dd7a9d6 --- /dev/null +++ b/polyreact/predict.py @@ -0,0 +1,106 @@ +"""Command-line interface for polyreactivity predictions.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import pandas as pd + +from .api import predict_batch +from .config import load_config +from .utils.io import read_table, write_table + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Polyreactivity prediction CLI") + parser.add_argument( + "--input", + required=True, + help="Path to input CSV or JSONL file with sequences.", + ) + parser.add_argument( + "--output", + required=True, + help="Path to write predictions CSV.", + ) + parser.add_argument( + "--config", + default="configs/default.yaml", + help="Path to configuration YAML file.", + ) + parser.add_argument( + "--backend", + choices=["plm", "descriptors", "concat"], + help="Override feature backend from config.", + ) + parser.add_argument( + "--plm-model", + help="Override PLM model name.", + ) + parser.add_argument( + "--weights", + required=True, + help="Path to trained model artifact (joblib).", + ) + parser.add_argument( + "--heavy-only", + dest="heavy_only", + action="store_true", + default=True, + help="Use only heavy chains (default).", + ) + parser.add_argument( + "--paired", + dest="heavy_only", + action="store_false", + help="Use paired heavy/light chains if available.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=8, + help="Batch size for model inference (PLM backend).", + ) + parser.add_argument( + "--device", + choices=["auto", "cpu", "cuda"], + help="Computation device override.", + ) + parser.add_argument( + "--cache-dir", + help="Cache directory for embeddings.", + ) + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + config = load_config(args.config) + df = read_table(args.input) + + if "heavy_seq" not in df.columns and "heavy" not in df.columns: + parser.error("Input file must contain a 'heavy_seq' column (or 'heavy').") + if df.get("heavy_seq", df.get("heavy", "")).fillna("").str.len().eq(0).all(): + parser.error("At least one non-empty heavy sequence is required.") + + predictions = predict_batch( + df.to_dict("records"), + config=config, + backend=args.backend, + plm_model=args.plm_model, + weights=args.weights, + heavy_only=args.heavy_only, + batch_size=args.batch_size, + device=args.device, + cache_dir=args.cache_dir, + ) + + write_table(predictions, args.output) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/polyreact/train.py b/polyreact/train.py new file mode 100644 index 0000000000000000000000000000000000000000..0d70ad0676c81be48e18562b56554ca65ce310aa --- /dev/null +++ b/polyreact/train.py @@ -0,0 +1,619 @@ +"""Training entrypoint for the polyreactivity model.""" + +from __future__ import annotations + +import argparse +import json +import subprocess +from pathlib import Path +from typing import Any, Sequence + +import joblib +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score +from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold +from sklearn.linear_model import LogisticRegression + +from .config import Config, load_config +from .data_loaders import boughter, harvey, jain, shehata +from .data_loaders.utils import deduplicate_sequences +from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline +from .models.calibrate import fit_calibrator +from .models.linear import LinearModelConfig, TrainedModel, build_estimator, train_linear_model +from .utils.io import write_table +from .utils.logging import configure_logging +from .utils.metrics import bootstrap_metric_intervals, compute_metrics +from .utils.plots import plot_precision_recall, plot_reliability_curve, plot_roc_curve +from .utils.seeds import set_global_seeds + +DATASET_LOADERS = { + "boughter": boughter.load_dataframe, + "jain": jain.load_dataframe, + "shehata": shehata.load_dataframe, + "shehata_curated": shehata.load_dataframe, + "harvey": harvey.load_dataframe, +} + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Train polyreactivity model") + parser.add_argument("--config", default="configs/default.yaml", help="Config file") + parser.add_argument("--train", required=True, help="Training dataset path") + parser.add_argument( + "--eval", + nargs="*", + default=[], + help="Evaluation dataset paths", + ) + parser.add_argument( + "--save-to", + default="artifacts/model.joblib", + help="Path to save trained model artifact", + ) + parser.add_argument( + "--report-to", + default="artifacts", + help="Directory for metrics, predictions, and plots", + ) + parser.add_argument( + "--train-loader", + choices=list(DATASET_LOADERS.keys()), + help="Optional explicit loader for training dataset", + ) + parser.add_argument( + "--eval-loaders", + nargs="*", + help="Optional explicit loaders for evaluation datasets (aligned with --eval order)", + ) + parser.add_argument( + "--backend", + choices=["plm", "descriptors", "concat"], + help="Override feature backend", + ) + parser.add_argument("--plm-model", help="Override PLM model name") + parser.add_argument("--cache-dir", help="Override embedding cache directory") + parser.add_argument("--device", choices=["auto", "cpu", "cuda"], help="Device override") + parser.add_argument("--batch-size", type=int, default=8, help="Batch size for embeddings") + parser.add_argument( + "--heavy-only", + action="store_true", + default=True, + help="Use heavy chains only (default true)", + ) + parser.add_argument( + "--paired", + dest="heavy_only", + action="store_false", + help="Use paired heavy/light chains when available.", + ) + parser.add_argument( + "--include-families", + nargs="*", + help="Optional list of family names to retain in the training dataset", + ) + parser.add_argument( + "--exclude-families", + nargs="*", + help="Optional list of family names to drop from the training dataset", + ) + parser.add_argument( + "--include-species", + nargs="*", + help="Optional list of species (e.g. human, mouse) to retain", + ) + parser.add_argument( + "--cv-group-column", + default="lineage", + help="Column name used to group samples during cross-validation (default: lineage)", + ) + parser.add_argument( + "--no-group-cv", + action="store_true", + help="Disable group-aware cross-validation even if group column is present", + ) + parser.add_argument( + "--keep-train-duplicates", + action="store_true", + help="Keep duplicate keys within the training dataset when deduplicating across splits", + ) + parser.add_argument( + "--dedupe-key-columns", + nargs="*", + help="Columns used to detect duplicates across datasets (defaults to heavy/light sequences)", + ) + parser.add_argument( + "--bootstrap-samples", + type=int, + default=200, + help="Number of bootstrap resamples for confidence intervals (0 to disable).", + ) + parser.add_argument( + "--bootstrap-alpha", + type=float, + default=0.05, + help="Alpha for two-sided bootstrap confidence intervals (default 0.05 → 95% CI).", + ) + parser.add_argument( + "--write-train-in-sample", + action="store_true", + help=( + "Persist in-sample metrics on the full training set; disabled by default to avoid" + " over-optimistic reporting." + ), + ) + return parser + + +def _infer_loader(path: str, explicit: str | None) -> tuple[str, callable]: + if explicit: + return explicit, DATASET_LOADERS[explicit] + lower = Path(path).stem.lower() + for name, loader in DATASET_LOADERS.items(): + if name in lower: + return name, loader + msg = f"Could not infer loader for dataset: {path}. Provide --train-loader/--eval-loaders." + raise ValueError(msg) + + +def _load_dataset(path: str, loader_name: str, loader_fn, *, heavy_only: bool) -> pd.DataFrame: + frame = loader_fn(path, heavy_only=heavy_only) + frame["source"] = loader_name + return frame + + +def _apply_dataset_filters( + frame: pd.DataFrame, + *, + include_families: Sequence[str] | None, + exclude_families: Sequence[str] | None, + include_species: Sequence[str] | None, +) -> pd.DataFrame: + filtered = frame.copy() + if include_families: + families = {fam.lower() for fam in include_families} + if "family" in filtered.columns: + filtered = filtered[ + filtered["family"].astype(str).str.lower().isin(families) + ] + if exclude_families: + families_ex = {fam.lower() for fam in exclude_families} + if "family" in filtered.columns: + filtered = filtered[ + ~filtered["family"].astype(str).str.lower().isin(families_ex) + ] + if include_species: + species_set = {spec.lower() for spec in include_species} + if "species" in filtered.columns: + filtered = filtered[ + filtered["species"].astype(str).str.lower().isin(species_set) + ] + return filtered.reset_index(drop=True) + + +def main(argv: Sequence[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + config = load_config(args.config) + if args.device: + config.device = args.device + if args.backend: + config.feature_backend.type = args.backend + if args.cache_dir: + config.feature_backend.cache_dir = args.cache_dir + if args.plm_model: + config.feature_backend.plm_model_name = args.plm_model + + logger = configure_logging() + set_global_seeds(config.seed) + _log_environment(logger) + + heavy_only = args.heavy_only + + train_name, train_loader = _infer_loader(args.train, args.train_loader) + train_df = _load_dataset(args.train, train_name, train_loader, heavy_only=heavy_only) + train_df = _apply_dataset_filters( + train_df, + include_families=args.include_families, + exclude_families=args.exclude_families, + include_species=args.include_species, + ) + + eval_frames: list[pd.DataFrame] = [] + if args.eval: + loaders_iter = args.eval_loaders or [] + for idx, eval_path in enumerate(args.eval): + explicit = loaders_iter[idx] if idx < len(loaders_iter) else None + eval_name, eval_loader = _infer_loader(eval_path, explicit) + eval_df = _load_dataset(eval_path, eval_name, eval_loader, heavy_only=heavy_only) + eval_frames.append(eval_df) + + all_frames = [train_df, *eval_frames] + dedup_keep = {0} if args.keep_train_duplicates else set() + deduped_frames = deduplicate_sequences( + all_frames, + heavy_only=heavy_only, + key_columns=args.dedupe_key_columns, + keep_intra_frames=dedup_keep, + ) + train_df = deduped_frames[0] + eval_frames = deduped_frames[1:] + + pipeline_factory = lambda: build_feature_pipeline( # noqa: E731 + config, + backend_override=args.backend, + plm_model_override=args.plm_model, + cache_dir_override=args.cache_dir, + ) + + model_config = LinearModelConfig( + head=config.model.head, + C=config.model.C, + class_weight=config.model.class_weight, + ) + + groups = None + if not args.no_group_cv and args.cv_group_column: + if args.cv_group_column in train_df.columns: + groups = train_df[args.cv_group_column].fillna("").astype(str).to_numpy() + else: + logger.warning( + "Group column '%s' not found in training dataframe; falling back to standard CV", + args.cv_group_column, + ) + + cv_results = _cross_validate( + train_df, + pipeline_factory, + model_config, + config, + heavy_only=heavy_only, + batch_size=args.batch_size, + groups=groups, + ) + + trained_model, feature_pipeline = _fit_full_model( + train_df, + pipeline_factory, + model_config, + config, + heavy_only=heavy_only, + batch_size=args.batch_size, + ) + + outputs_dir = Path(args.report_to) + outputs_dir.mkdir(parents=True, exist_ok=True) + + metrics_df, preds_rows = _evaluate_datasets( + train_df, + eval_frames, + trained_model, + feature_pipeline, + config, + cv_results, + outputs_dir, + batch_size=args.batch_size, + heavy_only=heavy_only, + bootstrap_samples=args.bootstrap_samples, + bootstrap_alpha=args.bootstrap_alpha, + write_train_in_sample=args.write_train_in_sample, + ) + + write_table(metrics_df, outputs_dir / config.io.metrics_filename) + preds_df = pd.DataFrame(preds_rows) + write_table(preds_df, outputs_dir / config.io.preds_filename) + + artifact = { + "config": config, + "feature_state": feature_pipeline.get_state(), + "model": trained_model, + } + Path(args.save_to).parent.mkdir(parents=True, exist_ok=True) + joblib.dump(artifact, args.save_to) + + logger.info("Training complete. Metrics written to %s", outputs_dir) + return 0 + + +def _cross_validate( + train_df: pd.DataFrame, + pipeline_factory, + model_config: LinearModelConfig, + config: Config, + *, + heavy_only: bool, + batch_size: int, + groups: np.ndarray | None = None, +): + y = train_df["label"].to_numpy(dtype=int) + n_samples = len(y) + # Determine a safe number of folds for tiny fixtures; prefer the configured value + # but never exceed the number of samples. Fall back to non-stratified KFold when + # per-class counts are too small for stratification (e.g., 1 positive/1 negative). + n_splits = max(2, min(config.training.cv_folds, n_samples)) + + use_stratified = True + class_counts = np.bincount(y) if y.size else np.array([]) + if class_counts.size > 0 and (class_counts.min(initial=0) < n_splits): + use_stratified = False + + if groups is not None and use_stratified: + splitter = StratifiedGroupKFold( + n_splits=n_splits, + shuffle=True, + random_state=config.seed, + ) + split_iter = splitter.split(train_df, y, groups) + elif use_stratified: + splitter = StratifiedKFold( + n_splits=n_splits, + shuffle=True, + random_state=config.seed, + ) + split_iter = splitter.split(train_df, y) + else: + # Non-stratified fallback for extreme class imbalance / tiny datasets + from sklearn.model_selection import KFold # local import to limit surface + + splitter = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed) + split_iter = splitter.split(train_df) + oof_scores = np.zeros(len(train_df), dtype=float) + metrics_per_fold: list[dict[str, float]] = [] + + for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1): + train_slice = train_df.iloc[train_idx].reset_index(drop=True) + val_slice = train_df.iloc[val_idx].reset_index(drop=True) + + pipeline: FeaturePipeline = pipeline_factory() + X_train = pipeline.fit_transform(train_slice, heavy_only=heavy_only, batch_size=batch_size) + X_val = pipeline.transform(val_slice, heavy_only=heavy_only, batch_size=batch_size) + + y_train = y[train_idx] + y_val = y[val_idx] + + # Handle degenerate folds where training data contains a single class + if np.unique(y_train).size < 2: + fallback_prob = float(y.mean()) if y.size else 0.5 + y_scores = np.full(X_val.shape[0], fallback_prob, dtype=float) + else: + trained = train_linear_model( + X_train, y_train, config=model_config, random_state=config.seed + ) + calibrator = _fit_model_calibrator( + model_config, + config, + X_train, + y_train, + base_estimator=trained.estimator, + ) + trained.calibrator = calibrator + if calibrator is not None: + y_scores = calibrator.predict_proba(X_val)[:, 1] + else: + y_scores = trained.predict_proba(X_val) + oof_scores[val_idx] = y_scores + + fold_metrics = compute_metrics(y_val, y_scores) + try: + fold_metrics["roc_auc"] = float(roc_auc_score(y_val, y_scores)) + except ValueError: + # For tiny validation folds with a single class, ROC-AUC is undefined + pass + metrics_per_fold.append(fold_metrics) + + metrics_mean: dict[str, float] = {} + metrics_std: dict[str, float] = {} + metric_names = list(metrics_per_fold[0].keys()) if metrics_per_fold else [] + for metric in metric_names: + values = [fold[metric] for fold in metrics_per_fold] + metrics_mean[metric] = float(np.mean(values)) + metrics_std[metric] = float(np.std(values, ddof=1)) + + return { + "oof_scores": oof_scores, + "metrics_per_fold": metrics_per_fold, + "metrics_mean": metrics_mean, + "metrics_std": metrics_std, + } + + +def _fit_full_model( + train_df: pd.DataFrame, + pipeline_factory, + model_config: LinearModelConfig, + config: Config, + *, + heavy_only: bool, + batch_size: int, +) -> tuple[TrainedModel, FeaturePipeline]: + pipeline: FeaturePipeline = pipeline_factory() + X_train = pipeline.fit_transform(train_df, heavy_only=heavy_only, batch_size=batch_size) + y_train = train_df["label"].to_numpy(dtype=int) + + trained = train_linear_model(X_train, y_train, config=model_config, random_state=config.seed) + calibrator = _fit_model_calibrator( + model_config, + config, + X_train, + y_train, + base_estimator=trained.estimator, + ) + trained.calibrator = calibrator + + return trained, pipeline + + +def _evaluate_datasets( + train_df: pd.DataFrame, + eval_frames: list[pd.DataFrame], + trained_model: TrainedModel, + pipeline: FeaturePipeline, + config: Config, + cv_results: dict, + outputs_dir: Path, + *, + batch_size: int, + heavy_only: bool, + bootstrap_samples: int, + bootstrap_alpha: float, + write_train_in_sample: bool, +): + metrics_lookup: dict[str, dict[str, float]] = {} + preds_rows: list[dict[str, float]] = [] + + metrics_mean: dict[str, float] = cv_results["metrics_mean"] + metrics_std: dict[str, float] = cv_results["metrics_std"] + + for metric_name, value in metrics_mean.items(): + metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ + "train_cv_mean" + ] = value + for metric_name, value in metrics_std.items(): + metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ + "train_cv_std" + ] = value + + train_scores = cv_results["oof_scores"] + train_preds = train_df[["id", "source", "label"]].copy() + train_preds["y_true"] = train_preds["label"] + train_preds["y_score"] = train_scores + train_preds["y_pred"] = (train_scores >= 0.5).astype(int) + train_preds["split"] = "train_cv_oof" + preds_rows.extend( + train_preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records") + ) + + plot_reliability_curve( + train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "reliability_train.png" + ) + plot_precision_recall( + train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "pr_train.png" + ) + plot_roc_curve(train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "roc_train.png") + + if bootstrap_samples > 0: + ci_map = bootstrap_metric_intervals( + train_preds["y_true"], + train_preds["y_score"], + n_bootstrap=bootstrap_samples, + alpha=bootstrap_alpha, + random_state=config.seed, + ) + for metric_name, stats in ci_map.items(): + row = metrics_lookup.setdefault(metric_name, {"metric": metric_name}) + row["train_cv_ci_lower"] = stats.get("ci_lower") + row["train_cv_ci_upper"] = stats.get("ci_upper") + row["train_cv_ci_median"] = stats.get("ci_median") + + if write_train_in_sample: + train_features_full = pipeline.transform( + train_df, heavy_only=heavy_only, batch_size=batch_size + ) + train_full_scores = trained_model.predict_proba(train_features_full) + train_full_metrics = compute_metrics( + train_df["label"].to_numpy(dtype=int), train_full_scores + ) + (outputs_dir / "train_in_sample.json").write_text( + json.dumps(train_full_metrics, indent=2), + encoding="utf-8", + ) + + for frame in eval_frames: + if frame.empty: + continue + features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size) + scores = trained_model.predict_proba(features) + y_true = frame["label"].to_numpy(dtype=int) + metrics = compute_metrics(y_true, scores) + dataset_name = frame["source"].iloc[0] + for metric_name, value in metrics.items(): + metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ + dataset_name + ] = value + + preds = frame[["id", "source", "label"]].copy() + preds["y_true"] = preds["label"] + preds["y_score"] = scores + preds["y_pred"] = (scores >= 0.5).astype(int) + preds["split"] = dataset_name + preds_rows.extend( + preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records") + ) + + plot_reliability_curve( + preds["y_true"], + preds["y_score"], + path=outputs_dir / f"reliability_{dataset_name}.png", + ) + plot_precision_recall( + preds["y_true"], + preds["y_score"], + path=outputs_dir / f"pr_{dataset_name}.png", + ) + plot_roc_curve( + preds["y_true"], preds["y_score"], path=outputs_dir / f"roc_{dataset_name}.png" + ) + + if bootstrap_samples > 0: + ci_map = bootstrap_metric_intervals( + preds["y_true"], + preds["y_score"], + n_bootstrap=bootstrap_samples, + alpha=bootstrap_alpha, + random_state=config.seed, + ) + for metric_name, stats in ci_map.items(): + row = metrics_lookup.setdefault(metric_name, {"metric": metric_name}) + row[f"{dataset_name}_ci_lower"] = stats.get("ci_lower") + row[f"{dataset_name}_ci_upper"] = stats.get("ci_upper") + row[f"{dataset_name}_ci_median"] = stats.get("ci_median") + + metrics_df = pd.DataFrame(sorted(metrics_lookup.values(), key=lambda row: row["metric"])) + return metrics_df, preds_rows + + +def _fit_model_calibrator( + model_config: LinearModelConfig, + config: Config, + X: np.ndarray, + y: np.ndarray, + *, + base_estimator: Any | None = None, +): + method = config.calibration.method + if not method: + return None + if len(np.unique(y)) < 2: + return None + + if len(y) >= 4: + cv_cal = min(config.training.cv_folds, max(2, len(y) // 2)) + estimator = build_estimator(config=model_config, random_state=config.seed) + if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: + estimator.set_params(solver="lbfgs") + calibrator = fit_calibrator(estimator, X, y, method=method, cv=cv_cal) + else: + estimator = base_estimator or build_estimator(config=model_config, random_state=config.seed) + if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: + estimator.set_params(solver="lbfgs") + estimator.fit(X, y) + calibrator = fit_calibrator(estimator, X, y, method=method, cv="prefit") + return calibrator + + +def _log_environment(logger) -> None: + try: + git_head = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + except Exception: # pragma: no cover - best effort + git_head = "unknown" + try: + pip_freeze = subprocess.check_output(["pip", "freeze"], text=True) + except Exception: # pragma: no cover + pip_freeze = "" + logger.info("git_head=%s", git_head) + logger.info("pip_freeze=%s", pip_freeze) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/polyreact/utils/__pycache__/io.cpython-311.pyc b/polyreact/utils/__pycache__/io.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91ebafc84c3874c85ef7bdc73b0da35bff5aba98 Binary files /dev/null and b/polyreact/utils/__pycache__/io.cpython-311.pyc differ diff --git a/polyreact/utils/__pycache__/logging.cpython-311.pyc b/polyreact/utils/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5261b0db9649145aa87f6f49d61303c5bf9a5c56 Binary files /dev/null and b/polyreact/utils/__pycache__/logging.cpython-311.pyc differ diff --git a/polyreact/utils/__pycache__/metrics.cpython-311.pyc b/polyreact/utils/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9dacac89d473b641f30d5537bc374188e172a08 Binary files /dev/null and b/polyreact/utils/__pycache__/metrics.cpython-311.pyc differ diff --git a/polyreact/utils/__pycache__/plots.cpython-311.pyc b/polyreact/utils/__pycache__/plots.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e89062b3aa850abfc1b3957de392b60b3287398 Binary files /dev/null and b/polyreact/utils/__pycache__/plots.cpython-311.pyc differ diff --git a/polyreact/utils/__pycache__/seeds.cpython-311.pyc b/polyreact/utils/__pycache__/seeds.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5765096076f1aa74828ed02787677f60905492f Binary files /dev/null and b/polyreact/utils/__pycache__/seeds.cpython-311.pyc differ diff --git a/polyreact/utils/io.py b/polyreact/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..aae9c724426fcf3af47861d4238c30500c6a5a1a --- /dev/null +++ b/polyreact/utils/io.py @@ -0,0 +1,39 @@ +"""I/O helpers for reading and writing artifacts.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pandas as pd + + +def read_table(path: str | Path, **kwargs: Any) -> pd.DataFrame: + """Read CSV or JSONL files into a DataFrame.""" + + path = Path(path) + if not path.exists(): + msg = f"Input file does not exist: {path}" + raise FileNotFoundError(msg) + suffix = path.suffix.lower() + if suffix in {".jsonl", ".json"}: + return pd.read_json(path, lines=True, **kwargs) + if suffix in {".csv", ""}: + return pd.read_csv(path, **kwargs) + msg = f"Unsupported file extension: {suffix}" + raise ValueError(msg) + + +def write_table(frame: pd.DataFrame, path: str | Path, *, index: bool = False, **kwargs: Any) -> None: + """Persist a DataFrame as CSV or JSONL, creating directories as needed.""" + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + suffix = path.suffix.lower() + if suffix in {".jsonl", ".json"}: + frame.to_json(path, orient="records", lines=True, **kwargs) + elif suffix in {".csv", ""}: + frame.to_csv(path, index=index, **kwargs) + else: + msg = f"Unsupported file extension: {suffix}" + raise ValueError(msg) diff --git a/polyreact/utils/logging.py b/polyreact/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e9301675d8de984838e0a65142fcb6c16e5d9764 --- /dev/null +++ b/polyreact/utils/logging.py @@ -0,0 +1,41 @@ +"""Logging utilities.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +LOGGER_NAME = "polyreact" + + +def configure_logging(level: int = logging.INFO) -> logging.Logger: + """Configure and return the project logger.""" + + logger = logging.getLogger(LOGGER_NAME) + if logger.handlers: + logger.setLevel(level) + return logger + + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(level) + return logger + + +def write_logfile(path: str | Path, *, level: int = logging.INFO) -> logging.Logger: + """Add a file handler to the project logger.""" + + logger = configure_logging(level) + log_path = Path(path) + log_path.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(log_path, encoding="utf-8") + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s" + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + return logger diff --git a/polyreact/utils/metrics.py b/polyreact/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..def7c2be407b44afe71ac2fdc9f57e682f2b5c87 --- /dev/null +++ b/polyreact/utils/metrics.py @@ -0,0 +1,158 @@ +"""Metrics utilities.""" + +from __future__ import annotations + +from typing import Iterable + +import numpy as np +from numpy.typing import ArrayLike +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + brier_score_loss, + f1_score, + recall_score, + precision_score, + roc_auc_score, +) + + +def compute_metrics( + y_true: Iterable[float], + y_score: Iterable[float], + *, + threshold: float = 0.5, +) -> dict[str, float]: + """Compute classification metrics from scores and labels.""" + + y_true_arr = np.asarray(list(y_true), dtype=float) + y_score_arr = np.asarray(list(y_score), dtype=float) + y_pred = (y_score_arr >= threshold).astype(int) + + metrics: dict[str, float] = {} + + try: + metrics["roc_auc"] = float(roc_auc_score(y_true_arr, y_score_arr)) + except ValueError: + metrics["roc_auc"] = float("nan") + + try: + metrics["pr_auc"] = float(average_precision_score(y_true_arr, y_score_arr)) + except ValueError: + metrics["pr_auc"] = float("nan") + + metrics["accuracy"] = float(accuracy_score(y_true_arr, y_pred)) + metrics["f1"] = float(f1_score(y_true_arr, y_pred, zero_division=0)) + metrics["f1_positive"] = float(f1_score(y_true_arr, y_pred, pos_label=1, zero_division=0)) + metrics["f1_negative"] = float(f1_score(y_true_arr, y_pred, pos_label=0, zero_division=0)) + metrics["sensitivity"] = float(recall_score(y_true_arr, y_pred, zero_division=0)) + # Specificity is recall on the negative class + metrics["specificity"] = float( + recall_score(1 - y_true_arr, 1 - y_pred, zero_division=0) + ) + metrics["precision"] = float(precision_score(y_true_arr, y_pred, zero_division=0)) + metrics["positive_rate"] = float(y_true_arr.mean()) if y_true_arr.size else float("nan") + metrics["brier"] = float(brier_score_loss(y_true_arr, y_score_arr)) + ece, mce = _calibration_errors(y_true_arr, y_score_arr) + metrics["ece"] = float(ece) + metrics["mce"] = float(mce) + return metrics + + +def _calibration_errors( + y_true: np.ndarray, + y_score: np.ndarray, + n_bins: int = 15, +) -> tuple[float, float]: + if y_true.size == 0: + return float("nan"), float("nan") + + # Clamp scores to [0, 1] to avoid binning issues when calibrators overshoot + scores = np.clip(y_score, 0.0, 1.0) + bins = np.linspace(0.0, 1.0, n_bins + 1) + bin_indices = np.digitize(scores, bins[1:-1], right=True) + + total = y_true.size + ece = 0.0 + mce = 0.0 + for bin_idx in range(n_bins): + mask = bin_indices == bin_idx + if not np.any(mask): + continue + bin_scores = scores[mask] + bin_true = y_true[mask] + confidence = float(bin_scores.mean()) + accuracy = float(bin_true.mean()) + gap = abs(confidence - accuracy) + weight = float(mask.sum()) / float(total) + ece += weight * gap + mce = max(mce, gap) + + return ece, mce + + +def bootstrap_metric_intervals( + y_true: ArrayLike, + y_score: ArrayLike, + *, + n_bootstrap: int = 200, + alpha: float = 0.05, + threshold: float = 0.5, + random_state: int | None = 42, +) -> dict[str, dict[str, float]]: + """Estimate bootstrap confidence intervals for core metrics. + + Parameters + ---------- + y_true, y_score: + Arrays of ground-truth labels and probability scores. + n_bootstrap: + Number of bootstrap resamples; set to ``0`` to disable. + alpha: + Two-sided confidence level (default ``0.05`` gives 95% CI). + threshold: + Decision threshold passed to :func:`compute_metrics`. + random_state: + Seed controlling the bootstrap sampler. + """ + + if n_bootstrap <= 0: + return {} + + y_true_arr = np.asarray(y_true, dtype=float) + y_score_arr = np.asarray(y_score, dtype=float) + n = y_true_arr.size + if n == 0: + return {} + + rng = np.random.default_rng(random_state) + collected: dict[str, list[float]] = {} + + for _ in range(n_bootstrap): + indices = rng.integers(0, n, size=n) + resampled_true = y_true_arr[indices] + resampled_score = y_score_arr[indices] + if np.unique(resampled_true).size < 2: + continue + metrics = compute_metrics(resampled_true, resampled_score, threshold=threshold) + for metric_name, value in metrics.items(): + collected.setdefault(metric_name, []).append(value) + + lower_q = alpha / 2.0 + upper_q = 1.0 - lower_q + summary: dict[str, dict[str, float]] = {} + for metric_name, values in collected.items(): + arr = np.asarray(values, dtype=float) + valid = arr[~np.isnan(arr)] + if valid.size == 0: + continue + lower = float(np.nanquantile(valid, lower_q)) + upper = float(np.nanquantile(valid, upper_q)) + median = float(np.nanmedian(valid)) + summary[metric_name] = { + "ci_lower": lower, + "ci_upper": upper, + "ci_median": median, + } + + return summary diff --git a/polyreact/utils/plots.py b/polyreact/utils/plots.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea18367db1ca65f096ae0cfa1900c35615831f6 --- /dev/null +++ b/polyreact/utils/plots.py @@ -0,0 +1,79 @@ +"""Plotting helpers using Matplotlib.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.calibration import calibration_curve +from sklearn.metrics import precision_recall_curve, roc_curve + + +def plot_reliability_curve( + y_true: Iterable[float], + y_score: Iterable[float], + *, + path: str | Path, + n_bins: int = 10, +) -> None: + """Save a reliability curve plot.""" + + prob_true, prob_pred = calibration_curve(y_true, y_score, n_bins=n_bins) + fig, ax = plt.subplots(figsize=(4, 4)) + ax.plot(prob_pred, prob_true, marker="o", label="Model") + ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect") + ax.set_xlabel("Predicted probability") + ax.set_ylabel("Observed frequency") + ax.set_title("Reliability curve") + ax.legend() + fig.tight_layout() + Path(path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(path) + plt.close(fig) + + +def plot_precision_recall( + y_true: Iterable[float], + y_score: Iterable[float], + *, + path: str | Path, +) -> None: + """Save a precision-recall curve.""" + + precision, recall, _ = precision_recall_curve(y_true, y_score) + fig, ax = plt.subplots(figsize=(4, 4)) + ax.plot(recall, precision, label="Model") + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_title("Precision-Recall curve") + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1]) + fig.tight_layout() + Path(path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(path) + plt.close(fig) + + +def plot_roc_curve( + y_true: Iterable[float], + y_score: Iterable[float], + *, + path: str | Path, +) -> None: + """Save an ROC curve plot.""" + + fpr, tpr, _ = roc_curve(y_true, y_score) + fig, ax = plt.subplots(figsize=(4, 4)) + ax.plot(fpr, tpr, label="Model") + ax.plot([0, 1], [0, 1], linestyle="--", color="gray") + ax.set_xlabel("False positive rate") + ax.set_ylabel("True positive rate") + ax.set_title("ROC curve") + ax.set_xlim([0, 1]) + ax.set_ylim([0, 1]) + fig.tight_layout() + Path(path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(path) + plt.close(fig) diff --git a/polyreact/utils/seeds.py b/polyreact/utils/seeds.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc2af4b602b6b4a941a506f7d183d449937fb7d --- /dev/null +++ b/polyreact/utils/seeds.py @@ -0,0 +1,43 @@ +"""Random seed utilities.""" + +from __future__ import annotations + +import os +import random +from dataclasses import dataclass + +import numpy as np + +try: # pragma: no cover - optional dependency + import torch +except ImportError: # pragma: no cover - optional dependency + torch = None # type: ignore + + +@dataclass(slots=True) +class SeedState: + """Record of RNG seeds applied across libraries.""" + + python: int + numpy: int + torch: int | None = None + + +def set_global_seeds(seed: int) -> SeedState: + """Seed ``random``, ``numpy`` and ``torch`` (if available).""" + + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + + torch_seed: int | None = None + if "torch" in globals() and torch is not None: # pragma: no branch + torch.manual_seed(seed) + if torch.cuda.is_available(): # pragma: no cover - GPU specific + torch.cuda.manual_seed_all(seed) + if hasattr(torch.backends, "cudnn"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch_seed = seed + + return SeedState(python=seed, numpy=seed, torch=torch_seed)