Spaces:
Sleeping
Sleeping
| """Reproduce key metrics and visualisations for the polyreactivity model.""" | |
| from __future__ import annotations | |
| import argparse | |
| import copy | |
| import json | |
| import subprocess | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Sequence | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import yaml | |
| from scipy.stats import pearsonr, spearmanr | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import roc_curve | |
| from sklearn.model_selection import KFold | |
| from sklearn.preprocessing import StandardScaler | |
| from polyreact import train as train_module | |
| from polyreact.config import load_config | |
| from polyreact.features.anarsi import AnarciNumberer | |
| from polyreact.features.pipeline import FeaturePipeline | |
| from polyreact.models.ordinal import ( | |
| fit_negative_binomial_model, | |
| fit_poisson_model, | |
| pearson_dispersion, | |
| regression_metrics, | |
| ) | |
| class DatasetSpec: | |
| name: str | |
| path: Path | |
| display: str | |
| DISPLAY_LABELS = { | |
| "jain": "Jain (2017)", | |
| "shehata": "Shehata PSR (398)", | |
| "shehata_curated": "Shehata curated (88)", | |
| "harvey": "Harvey (2022)", | |
| } | |
| RAW_LABELS = { | |
| "jain": "jain2017", | |
| "shehata": "shehata2019", | |
| "shehata_curated": "shehata2019_curated", | |
| "harvey": "harvey2022", | |
| } | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Reproduce paper-style metrics and plots") | |
| parser.add_argument( | |
| "--train-data", | |
| default="data/processed/boughter_counts_rebuilt.csv", | |
| help="Reconstructed Boughter dataset path.", | |
| ) | |
| parser.add_argument( | |
| "--full-data", | |
| default="data/processed/boughter_counts_rebuilt.csv", | |
| help="Dataset (including mild flags) for correlation analysis.", | |
| ) | |
| parser.add_argument("--jain", default="data/processed/jain.csv") | |
| parser.add_argument( | |
| "--shehata", | |
| default="data/processed/shehata_full.csv", | |
| help="Full Shehata PSR panel (398 sequences) in processed CSV form.", | |
| ) | |
| parser.add_argument( | |
| "--shehata-curated", | |
| default="data/processed/shehata_curated.csv", | |
| help="Optional curated subset of Shehata et al. (88 sequences).", | |
| ) | |
| parser.add_argument("--harvey", default="data/processed/harvey.csv") | |
| parser.add_argument("--output-dir", default="artifacts/paper") | |
| parser.add_argument("--config", default="configs/default.yaml") | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--rebuild", action="store_true") | |
| parser.add_argument( | |
| "--bootstrap-samples", | |
| type=int, | |
| default=1000, | |
| help="Bootstrap resamples for metrics confidence intervals.", | |
| ) | |
| parser.add_argument( | |
| "--bootstrap-alpha", | |
| type=float, | |
| default=0.05, | |
| help="Alpha for bootstrap confidence intervals (default 0.05 → 95%).", | |
| ) | |
| parser.add_argument( | |
| "--human-only", | |
| action="store_true", | |
| help=( | |
| "Restrict the main cross-validation run to human HIV and influenza families" | |
| " (legacy behaviour). By default all Boughter families, including mouse IgA," | |
| " participate in CV as in Sakhnini et al." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--skip-flag-regression", | |
| action="store_true", | |
| help="Skip ELISA flag regression diagnostics (Poisson/NB).", | |
| ) | |
| parser.add_argument( | |
| "--skip-lofo", | |
| action="store_true", | |
| help="Skip leave-one-family-out experiments.", | |
| ) | |
| parser.add_argument( | |
| "--skip-descriptor-variants", | |
| action="store_true", | |
| help="Skip descriptor-only benchmark variants.", | |
| ) | |
| parser.add_argument( | |
| "--skip-fragment-variants", | |
| action="store_true", | |
| help="Skip CDR fragment ablation benchmarks.", | |
| ) | |
| return parser | |
| def _config_to_dict(config) -> Dict[str, Any]: | |
| data = asdict(config) | |
| data.pop("raw", None) | |
| return data | |
| def _deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: | |
| result = copy.deepcopy(base) | |
| for key, value in overrides.items(): | |
| if isinstance(value, dict) and isinstance(result.get(key), dict): | |
| result[key] = _deep_merge(result.get(key, {}), value) | |
| else: | |
| result[key] = value | |
| return result | |
| def _write_variant_config( | |
| base_config: Dict[str, Any], | |
| overrides: Dict[str, Any], | |
| target_path: Path, | |
| ) -> Path: | |
| merged = _deep_merge(base_config, overrides) | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| with target_path.open("w", encoding="utf-8") as handle: | |
| yaml.safe_dump(merged, handle, sort_keys=False) | |
| return target_path | |
| def _collect_metric_records(variant: str, metrics: pd.DataFrame) -> list[dict[str, Any]]: | |
| tracked = { | |
| "roc_auc", | |
| "pr_auc", | |
| "accuracy", | |
| "f1", | |
| "f1_positive", | |
| "f1_negative", | |
| "precision", | |
| "sensitivity", | |
| "specificity", | |
| "brier", | |
| "ece", | |
| "mce", | |
| } | |
| records: list[dict[str, Any]] = [] | |
| for _, row in metrics.iterrows(): | |
| metric_name = row["metric"] | |
| if metric_name not in tracked: | |
| continue | |
| record = {"variant": variant, "metric": metric_name} | |
| for column in metrics.columns: | |
| if column == "metric": | |
| continue | |
| record[column] = float(row[column]) if pd.notna(row[column]) else np.nan | |
| records.append(record) | |
| return records | |
| def _dump_coefficients(model_path: Path, output_path: Path) -> None: | |
| artifact = joblib.load(model_path) | |
| trained = artifact["model"] | |
| estimator = getattr(trained, "estimator", None) | |
| if estimator is None or not hasattr(estimator, "coef_"): | |
| return | |
| coefs = estimator.coef_[0] | |
| feature_state = artifact.get("feature_state") | |
| feature_names: list[str] | |
| if feature_state is not None and getattr(feature_state, "feature_names", None): | |
| feature_names = list(feature_state.feature_names) | |
| else: | |
| feature_names = [f"f{i}" for i in range(len(coefs))] | |
| coeff_df = pd.DataFrame( | |
| { | |
| "feature": feature_names, | |
| "coef": coefs, | |
| "abs_coef": np.abs(coefs), | |
| } | |
| ).sort_values("abs_coef", ascending=False) | |
| coeff_df.to_csv(output_path, index=False) | |
| def _summarise_predictions(preds: pd.DataFrame) -> pd.DataFrame: | |
| records: list[dict[str, Any]] = [] | |
| for split, group in preds.groupby("split"): | |
| stats = { | |
| "split": split, | |
| "n_samples": int(len(group)), | |
| "positives": int(group["y_true"].sum()), | |
| "positive_rate": float(group["y_true"].mean()) if len(group) else np.nan, | |
| "score_mean": float(group["y_score"].mean()) if len(group) else np.nan, | |
| "score_std": float(group["y_score"].std(ddof=1)) if len(group) > 1 else np.nan, | |
| } | |
| records.append(stats) | |
| return pd.DataFrame(records) | |
| def _summarise_raw_dataset(path: Path, name: str) -> dict[str, Any]: | |
| df = pd.read_csv(path) | |
| summary: dict[str, Any] = { | |
| "dataset": name, | |
| "path": str(path), | |
| "rows": int(len(df)), | |
| } | |
| if "label" in df.columns: | |
| positives = int(df["label"].sum()) | |
| summary["positives"] = positives | |
| summary["positive_rate"] = float(df["label"].mean()) if len(df) else np.nan | |
| if "reactivity_count" in df.columns: | |
| summary["reactivity_count_mean"] = float(df["reactivity_count"].mean()) | |
| summary["reactivity_count_median"] = float(df["reactivity_count"].median()) | |
| summary["reactivity_count_max"] = int(df["reactivity_count"].max()) | |
| if "smp" in df.columns: | |
| summary["smp_mean"] = float(df["smp"].mean()) | |
| summary["smp_median"] = float(df["smp"].median()) | |
| summary["smp_max"] = float(df["smp"].max()) | |
| summary["smp_min"] = float(df["smp"].min()) | |
| summary["unique_heavy"] = int(df["heavy_seq"].nunique()) if "heavy_seq" in df.columns else np.nan | |
| return summary | |
| def _extract_region_sequence(sequence: str, regions: List[str], numberer: AnarciNumberer) -> str: | |
| if not sequence: | |
| return "" | |
| upper_regions = [region.upper() for region in regions] | |
| if upper_regions == ["VH"]: | |
| return sequence | |
| try: | |
| numbered = numberer.number_sequence(sequence) | |
| except Exception: | |
| return "" | |
| fragments: list[str] = [] | |
| for region in upper_regions: | |
| if region == "VH": | |
| return sequence | |
| fragment = numbered.regions.get(region) | |
| if not fragment: | |
| return "" | |
| fragments.append(fragment) | |
| return "".join(fragments) | |
| def _make_region_dataset( | |
| frame: pd.DataFrame, regions: List[str], numberer: AnarciNumberer | |
| ) -> tuple[pd.DataFrame, dict[str, Any]]: | |
| records: list[dict[str, Any]] = [] | |
| dropped = 0 | |
| for record in frame.to_dict(orient="records"): | |
| new_seq = _extract_region_sequence(record.get("heavy_seq", ""), regions, numberer) | |
| if not new_seq: | |
| dropped += 1 | |
| continue | |
| updated = record.copy() | |
| updated["heavy_seq"] = new_seq | |
| updated["light_seq"] = "" | |
| records.append(updated) | |
| result = pd.DataFrame(records, columns=frame.columns) | |
| summary = { | |
| "regions": "+".join(regions), | |
| "input_rows": int(len(frame)), | |
| "retained_rows": int(len(result)), | |
| "dropped_rows": int(dropped), | |
| } | |
| return result, summary | |
| def run_train( | |
| *, | |
| train_path: Path, | |
| eval_specs: Sequence[DatasetSpec], | |
| output_dir: Path, | |
| model_path: Path, | |
| config: str, | |
| batch_size: int, | |
| include_species: list[str] | None = None, | |
| include_families: list[str] | None = None, | |
| exclude_families: list[str] | None = None, | |
| keep_duplicates: bool = False, | |
| group_column: str | None = "lineage", | |
| train_loader: str | None = None, | |
| bootstrap_samples: int = 200, | |
| bootstrap_alpha: float = 0.05, | |
| ) -> None: | |
| args: list[str] = [ | |
| "--config", | |
| str(config), | |
| "--train", | |
| str(train_path), | |
| "--report-to", | |
| str(output_dir), | |
| "--save-to", | |
| str(model_path), | |
| "--batch-size", | |
| str(batch_size), | |
| ] | |
| if eval_specs: | |
| args.append("--eval") | |
| args.extend(str(spec.path) for spec in eval_specs) | |
| if train_loader: | |
| args.extend(["--train-loader", train_loader]) | |
| if eval_specs: | |
| args.append("--eval-loaders") | |
| args.extend(spec.name for spec in eval_specs) | |
| if include_species: | |
| args.append("--include-species") | |
| args.extend(include_species) | |
| if include_families: | |
| args.append("--include-families") | |
| args.extend(include_families) | |
| if exclude_families: | |
| args.append("--exclude-families") | |
| args.extend(exclude_families) | |
| if keep_duplicates: | |
| args.append("--keep-train-duplicates") | |
| if group_column: | |
| args.extend(["--cv-group-column", group_column]) | |
| else: | |
| args.append("--no-group-cv") | |
| args.extend(["--bootstrap-samples", str(bootstrap_samples)]) | |
| args.extend(["--bootstrap-alpha", str(bootstrap_alpha)]) | |
| exit_code = train_module.main(args) | |
| if exit_code != 0: | |
| raise RuntimeError(f"Training command failed with exit code {exit_code}") | |
| def compute_spearman(model_path: Path, dataset_path: Path, batch_size: int) -> tuple[float, float, pd.DataFrame]: | |
| artifact = joblib.load(model_path) | |
| config = artifact["config"] | |
| pipeline_state = artifact["feature_state"] | |
| trained_model = artifact["model"] | |
| pipeline = FeaturePipeline(backend=config.feature_backend, descriptors=config.descriptors, device=config.device) | |
| pipeline.load_state(pipeline_state) | |
| dataset = pd.read_csv(dataset_path) | |
| features = pipeline.transform(dataset, heavy_only=True, batch_size=batch_size) | |
| scores = trained_model.predict_proba(features) | |
| dataset = dataset.copy() | |
| dataset["score"] = scores | |
| stat, pvalue = spearmanr(dataset["reactivity_count"], dataset["score"]) | |
| return float(stat), float(pvalue), dataset | |
| def plot_accuracy( | |
| metrics: pd.DataFrame, | |
| output_path: Path, | |
| eval_specs: Sequence[DatasetSpec], | |
| ) -> None: | |
| row = metrics.loc[metrics["metric"] == "accuracy"].iloc[0] | |
| labels = ["Train CV"] + [spec.display for spec in eval_specs] | |
| values = [row.get("train_cv_mean", np.nan)] + [row.get(spec.name, np.nan) for spec in eval_specs] | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| xs = np.arange(len(labels)) | |
| ax.bar(xs, values, color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]) | |
| ax.set_xticks(xs, labels) | |
| ax.set_ylim(0.0, 1.05) | |
| ax.set_ylabel("Accuracy") | |
| ax.set_title("Polyreactivity accuracy overview") | |
| for x, val in zip(xs, values, strict=False): | |
| if np.isnan(val): | |
| continue | |
| ax.text(x, val + 0.02, f"{val:.3f}", ha="center", va="bottom") | |
| fig.tight_layout() | |
| fig.savefig(output_path, dpi=300) | |
| plt.close(fig) | |
| def plot_rocs( | |
| preds: pd.DataFrame, | |
| output_path: Path, | |
| eval_specs: Sequence[DatasetSpec], | |
| ) -> None: | |
| mapping = {"train_cv_oof": "Train CV"} | |
| for spec in eval_specs: | |
| mapping[spec.name] = spec.display | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| for split, label in mapping.items(): | |
| subset = preds[preds["split"] == split] | |
| if subset.empty: | |
| continue | |
| fpr, tpr, _ = roc_curve(subset["y_true"], subset["y_score"]) | |
| ax.plot(fpr, tpr, label=label) | |
| ax.plot([0, 1], [0, 1], linestyle="--", color="gray") | |
| ax.set_xlabel("False positive rate") | |
| ax.set_ylabel("True positive rate") | |
| ax.set_title("ROC curves") | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(output_path, dpi=300) | |
| plt.close(fig) | |
| def plot_flags_scatter(data: pd.DataFrame, spearman_stat: float, output_path: Path) -> None: | |
| rng = np.random.default_rng(42) | |
| jitter = rng.uniform(-0.1, 0.1, size=len(data)) | |
| x = data["reactivity_count"].to_numpy(dtype=float) + jitter | |
| y = data["score"].to_numpy(dtype=float) | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.scatter(x, y, alpha=0.5, s=10) | |
| ax.set_xlabel("ELISA flag count") | |
| ax.set_ylabel("Predicted probability") | |
| ax.set_title(f"Prediction vs flag count (Spearman={spearman_stat:.2f})") | |
| fig.tight_layout() | |
| fig.savefig(output_path, dpi=300) | |
| plt.close(fig) | |
| def run_lofo( | |
| full_df: pd.DataFrame, | |
| *, | |
| families: list[str], | |
| config: str, | |
| batch_size: int, | |
| output_dir: Path, | |
| bootstrap_samples: int, | |
| bootstrap_alpha: float, | |
| ) -> pd.DataFrame: | |
| results: list[dict[str, float]] = [] | |
| for family in families: | |
| family_lower = family.lower() | |
| holdout = full_df[full_df["family"].str.lower() == family_lower].copy() | |
| train = full_df[full_df["family"].str.lower() != family_lower].copy() | |
| if holdout.empty or train.empty: | |
| continue | |
| train_path = output_dir / f"train_lofo_{family_lower}.csv" | |
| holdout_path = output_dir / f"eval_lofo_{family_lower}.csv" | |
| train.to_csv(train_path, index=False) | |
| holdout.to_csv(holdout_path, index=False) | |
| run_dir = output_dir / f"lofo_{family_lower}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = run_dir / "model.joblib" | |
| run_train( | |
| train_path=train_path, | |
| eval_specs=[ | |
| DatasetSpec( | |
| name="boughter", | |
| path=holdout_path, | |
| display=f"{family.title()} holdout", | |
| ) | |
| ], | |
| output_dir=run_dir, | |
| model_path=model_path, | |
| config=config, | |
| batch_size=batch_size, | |
| keep_duplicates=True, | |
| include_species=None, | |
| include_families=None, | |
| exclude_families=None, | |
| group_column="lineage", | |
| train_loader="boughter", | |
| bootstrap_samples=bootstrap_samples, | |
| bootstrap_alpha=bootstrap_alpha, | |
| ) | |
| metrics = pd.read_csv(run_dir / "metrics.csv") | |
| evaluation_cols = [ | |
| col | |
| for col in metrics.columns | |
| if col not in {"metric", "train_cv_mean", "train_cv_std"} | |
| ] | |
| if not evaluation_cols: | |
| continue | |
| eval_col = evaluation_cols[0] | |
| def _metric_value(name: str) -> float: | |
| series = metrics.loc[metrics["metric"] == name, eval_col] | |
| return float(series.values[0]) if not series.empty else float("nan") | |
| results.append( | |
| { | |
| "family": family, | |
| "accuracy": _metric_value("accuracy"), | |
| "roc_auc": _metric_value("roc_auc"), | |
| "pr_auc": _metric_value("pr_auc"), | |
| "sensitivity": _metric_value("sensitivity"), | |
| "specificity": _metric_value("specificity"), | |
| } | |
| ) | |
| return pd.DataFrame(results) | |
| def run_flag_regression( | |
| train_path: Path, | |
| *, | |
| output_dir: Path, | |
| config_path: str, | |
| batch_size: int, | |
| n_splits: int = 5, | |
| ) -> None: | |
| df = pd.read_csv(train_path) | |
| if "reactivity_count" not in df.columns: | |
| return | |
| config = load_config(config_path) | |
| kfold = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed) | |
| metrics_rows: list[dict[str, float]] = [] | |
| preds_rows: list[dict[str, float]] = [] | |
| for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(df), start=1): | |
| train_split = df.iloc[train_idx].reset_index(drop=True) | |
| val_split = df.iloc[val_idx].reset_index(drop=True) | |
| pipeline = FeaturePipeline( | |
| backend=config.feature_backend, | |
| descriptors=config.descriptors, | |
| device=config.device, | |
| ) | |
| X_train = pipeline.fit_transform(train_split, heavy_only=True, batch_size=batch_size) | |
| scaler = StandardScaler() | |
| X_train_scaled = scaler.fit_transform(X_train) | |
| y_train = train_split["reactivity_count"].to_numpy(dtype=float) | |
| # Train a logistic head to obtain probabilities as a 1-D feature | |
| clf = LogisticRegression( | |
| C=config.model.C, | |
| class_weight=config.model.class_weight, | |
| max_iter=2000, | |
| solver="lbfgs", | |
| ) | |
| clf.fit(X_train_scaled, train_split["label"].to_numpy(dtype=int)) | |
| prob_train = clf.predict_proba(X_train_scaled)[:, 1] | |
| X_val = pipeline.transform(val_split, heavy_only=True, batch_size=batch_size) | |
| X_val_scaled = scaler.transform(X_val) | |
| y_val = val_split["reactivity_count"].to_numpy(dtype=float) | |
| prob_val = clf.predict_proba(X_val_scaled)[:, 1] | |
| poisson_X_train = prob_train.reshape(-1, 1) | |
| poisson_X_val = prob_val.reshape(-1, 1) | |
| model = fit_poisson_model(poisson_X_train, y_train) | |
| poisson_preds = model.predict(poisson_X_val) | |
| n_params = poisson_X_train.shape[1] + 1 # include intercept | |
| dof = max(len(y_val) - n_params, 1) | |
| variance_to_mean = float(np.var(y_val, ddof=1) / np.mean(y_val)) if np.mean(y_val) else float("nan") | |
| spearman_val = float(spearmanr(y_val, poisson_preds).statistic) | |
| try: | |
| pearson_val = float(pearsonr(y_val, poisson_preds)[0]) | |
| except Exception: # pragma: no cover - fallback if correlation fails | |
| pearson_val = float("nan") | |
| poisson_metrics = regression_metrics(y_val, poisson_preds) | |
| poisson_metrics.update( | |
| { | |
| "spearman": spearman_val, | |
| "pearson": pearson_val, | |
| "pearson_dispersion": pearson_dispersion(y_val, poisson_preds, dof=dof), | |
| "variance_to_mean": variance_to_mean, | |
| "fold": fold_idx, | |
| "model": "poisson", | |
| "status": "ok", | |
| } | |
| ) | |
| metrics_rows.append(poisson_metrics) | |
| nb_preds: np.ndarray | None = None | |
| nb_model = None | |
| try: | |
| nb_model = fit_negative_binomial_model(poisson_X_train, y_train) | |
| nb_preds = nb_model.predict(poisson_X_val) | |
| if not np.all(np.isfinite(nb_preds)): | |
| raise ValueError("negative binomial produced non-finite predictions") | |
| except Exception: | |
| nb_metrics = { | |
| "spearman": float("nan"), | |
| "pearson": float("nan"), | |
| "pearson_dispersion": float("nan"), | |
| "variance_to_mean": variance_to_mean, | |
| "alpha": float("nan"), | |
| "fold": fold_idx, | |
| "model": "negative_binomial", | |
| "status": "failed", | |
| } | |
| metrics_rows.append(nb_metrics) | |
| else: | |
| spearman_nb = float(spearmanr(y_val, nb_preds).statistic) | |
| try: | |
| pearson_nb = float(pearsonr(y_val, nb_preds)[0]) | |
| except Exception: # pragma: no cover | |
| pearson_nb = float("nan") | |
| nb_metrics = regression_metrics(y_val, nb_preds) | |
| nb_metrics.update( | |
| { | |
| "spearman": spearman_nb, | |
| "pearson": pearson_nb, | |
| "pearson_dispersion": pearson_dispersion(y_val, nb_preds, dof=dof), | |
| "variance_to_mean": variance_to_mean, | |
| "alpha": nb_model.alpha, | |
| "fold": fold_idx, | |
| "model": "negative_binomial", | |
| "status": "ok", | |
| } | |
| ) | |
| metrics_rows.append(nb_metrics) | |
| records = list(val_split.itertuples(index=False)) | |
| for idx, row in enumerate(records): | |
| row_id = getattr(row, "id", idx) | |
| y_true_val = float(getattr(row, "reactivity_count")) | |
| preds_rows.append( | |
| { | |
| "fold": fold_idx, | |
| "model": "poisson", | |
| "id": row_id, | |
| "y_true": y_true_val, | |
| "y_pred": float(poisson_preds[idx]), | |
| } | |
| ) | |
| if nb_preds is not None: | |
| preds_rows.append( | |
| { | |
| "fold": fold_idx, | |
| "model": "negative_binomial", | |
| "id": row_id, | |
| "y_true": y_true_val, | |
| "y_pred": float(nb_preds[idx]), | |
| } | |
| ) | |
| metrics_df = pd.DataFrame(metrics_rows) | |
| metrics_df.to_csv(output_dir / "flag_regression_folds.csv", index=False) | |
| summary_records: list[dict[str, float]] = [] | |
| for model_name, group in metrics_df.groupby("model"): | |
| for column in group.columns: | |
| if column in {"fold", "model", "status"}: | |
| continue | |
| values = group[column].dropna() | |
| if values.empty: | |
| continue | |
| summary_records.append( | |
| { | |
| "model": model_name, | |
| "metric": column, | |
| "mean": float(values.mean()), | |
| "std": float(values.std(ddof=1)) if len(values) > 1 else float("nan"), | |
| } | |
| ) | |
| if summary_records: | |
| pd.DataFrame(summary_records).to_csv( | |
| output_dir / "flag_regression_metrics.csv", index=False | |
| ) | |
| if preds_rows: | |
| pd.DataFrame(preds_rows).to_csv(output_dir / "flag_regression_preds.csv", index=False) | |
| def run_descriptor_variants( | |
| base_config: Dict[str, Any], | |
| *, | |
| train_path: Path, | |
| eval_specs: Sequence[DatasetSpec], | |
| output_dir: Path, | |
| batch_size: int, | |
| include_species: List[str] | None, | |
| include_families: List[str] | None, | |
| bootstrap_samples: int, | |
| bootstrap_alpha: float, | |
| ) -> None: | |
| variants = [ | |
| ( | |
| "descriptors_full_vh", | |
| { | |
| "feature_backend": {"type": "descriptors"}, | |
| "descriptors": { | |
| "use_anarci": True, | |
| "regions": ["CDRH1", "CDRH2", "CDRH3"], | |
| "features": [ | |
| "length", | |
| "charge", | |
| "hydropathy", | |
| "aromaticity", | |
| "pI", | |
| "net_charge", | |
| ], | |
| }, | |
| }, | |
| ), | |
| ( | |
| "descriptors_cdrh3_pi", | |
| { | |
| "feature_backend": {"type": "descriptors"}, | |
| "descriptors": { | |
| "use_anarci": True, | |
| "regions": ["CDRH3"], | |
| "features": ["pI"], | |
| }, | |
| }, | |
| ), | |
| ( | |
| "descriptors_cdrh3_top5", | |
| { | |
| "feature_backend": {"type": "descriptors"}, | |
| "descriptors": { | |
| "use_anarci": True, | |
| "regions": ["CDRH3"], | |
| "features": [ | |
| "pI", | |
| "net_charge", | |
| "charge", | |
| "hydropathy", | |
| "length", | |
| ], | |
| }, | |
| }, | |
| ), | |
| ] | |
| configs_dir = output_dir / "configs" | |
| configs_dir.mkdir(parents=True, exist_ok=True) | |
| summary_records: list[dict[str, Any]] = [] | |
| for name, overrides in variants: | |
| variant_config_path = _write_variant_config( | |
| base_config, | |
| overrides, | |
| configs_dir / f"{name}.yaml", | |
| ) | |
| variant_output = output_dir / name | |
| variant_output.mkdir(parents=True, exist_ok=True) | |
| model_path = variant_output / "model.joblib" | |
| run_train( | |
| train_path=train_path, | |
| eval_specs=eval_specs, | |
| output_dir=variant_output, | |
| model_path=model_path, | |
| config=str(variant_config_path), | |
| batch_size=batch_size, | |
| include_species=include_species, | |
| include_families=include_families, | |
| keep_duplicates=True, | |
| group_column="lineage", | |
| train_loader="boughter", | |
| bootstrap_samples=bootstrap_samples, | |
| bootstrap_alpha=bootstrap_alpha, | |
| ) | |
| metrics_path = variant_output / "metrics.csv" | |
| if metrics_path.exists(): | |
| metrics_df = pd.read_csv(metrics_path) | |
| summary_records.extend(_collect_metric_records(name, metrics_df)) | |
| _dump_coefficients(model_path, variant_output / "coefficients.csv") | |
| if summary_records: | |
| pd.DataFrame(summary_records).to_csv(output_dir / "summary.csv", index=False) | |
| def run_fragment_variants( | |
| config_path: str, | |
| *, | |
| train_path: Path, | |
| eval_specs: Sequence[DatasetSpec], | |
| output_dir: Path, | |
| batch_size: int, | |
| include_species: List[str] | None, | |
| include_families: List[str] | None, | |
| bootstrap_samples: int, | |
| bootstrap_alpha: float, | |
| ) -> None: | |
| numberer = AnarciNumberer() | |
| specs = [ | |
| ("vh_full", ["VH"]), | |
| ("cdrh1", ["CDRH1"]), | |
| ("cdrh2", ["CDRH2"]), | |
| ("cdrh3", ["CDRH3"]), | |
| ("cdrh123", ["CDRH1", "CDRH2", "CDRH3"]), | |
| ] | |
| summary_rows: list[dict[str, Any]] = [] | |
| metric_summary_rows: list[dict[str, Any]] = [] | |
| for name, regions in specs: | |
| variant_dir = output_dir / name | |
| variant_dir.mkdir(parents=True, exist_ok=True) | |
| dataset_dir = variant_dir / "datasets" | |
| dataset_dir.mkdir(parents=True, exist_ok=True) | |
| train_df = pd.read_csv(train_path) | |
| train_variant, train_summary = _make_region_dataset(train_df, regions, numberer) | |
| train_variant_path = dataset_dir / "train.csv" | |
| train_variant.to_csv(train_variant_path, index=False) | |
| eval_variant_specs: list[DatasetSpec] = [] | |
| for spec in eval_specs: | |
| eval_df = pd.read_csv(spec.path) | |
| transformed, eval_summary = _make_region_dataset(eval_df, regions, numberer) | |
| eval_path = dataset_dir / f"{spec.name}.csv" | |
| transformed.to_csv(eval_path, index=False) | |
| eval_variant_specs.append( | |
| DatasetSpec(name=spec.name, path=eval_path, display=spec.display) | |
| ) | |
| eval_summary.update({"variant": name, "dataset": spec.name}) | |
| summary_rows.append(eval_summary) | |
| train_summary.update({"variant": name, "dataset": "train"}) | |
| summary_rows.append(train_summary) | |
| run_train( | |
| train_path=train_variant_path, | |
| eval_specs=eval_variant_specs, | |
| output_dir=variant_dir, | |
| model_path=variant_dir / "model.joblib", | |
| config=config_path, | |
| batch_size=batch_size, | |
| include_species=include_species, | |
| include_families=include_families, | |
| keep_duplicates=True, | |
| group_column="lineage", | |
| train_loader="boughter", | |
| bootstrap_samples=bootstrap_samples, | |
| bootstrap_alpha=bootstrap_alpha, | |
| ) | |
| metrics_path = variant_dir / "metrics.csv" | |
| if metrics_path.exists(): | |
| metrics_df = pd.read_csv(metrics_path) | |
| metric_records = _collect_metric_records(name, metrics_df) | |
| for record in metric_records: | |
| record["variant_type"] = "fragment" | |
| metric_summary_rows.extend(metric_records) | |
| if summary_rows: | |
| pd.DataFrame(summary_rows).to_csv(output_dir / "fragment_dataset_summary.csv", index=False) | |
| if metric_summary_rows: | |
| pd.DataFrame(metric_summary_rows).to_csv(output_dir / "fragment_metrics_summary.csv", index=False) | |
| def main(argv: list[str] | None = None) -> int: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if args.rebuild: | |
| rebuild_cmd = [ | |
| "python", | |
| "scripts/rebuild_boughter_from_counts.py", | |
| "--output", | |
| str(args.train_data), | |
| ] | |
| if subprocess.run(rebuild_cmd, check=False).returncode != 0: | |
| raise RuntimeError("Dataset rebuild failed") | |
| train_path = Path(args.train_data) | |
| def _make_spec(name: str, path_str: str) -> DatasetSpec | None: | |
| path = Path(path_str) | |
| if not path.exists(): | |
| return None | |
| display = DISPLAY_LABELS.get(name, name.replace("_", " ").title()) | |
| return DatasetSpec(name=name, path=path, display=display) | |
| eval_specs: list[DatasetSpec] = [] | |
| seen_paths: set[Path] = set() | |
| for name, path_str in [ | |
| ("jain", args.jain), | |
| ("shehata", args.shehata), | |
| ("shehata_curated", args.shehata_curated), | |
| ("harvey", args.harvey), | |
| ]: | |
| spec = _make_spec(name, path_str) | |
| if spec is not None: | |
| resolved = spec.path.resolve() | |
| if resolved in seen_paths: | |
| continue | |
| seen_paths.add(resolved) | |
| eval_specs.append(spec) | |
| base_config = load_config(args.config) | |
| base_config_dict = _config_to_dict(base_config) | |
| main_output = output_dir / "main" | |
| main_output.mkdir(parents=True, exist_ok=True) | |
| model_path = main_output / "model.joblib" | |
| main_include_species = ["human"] if args.human_only else None | |
| main_include_families = ["hiv", "influenza"] if args.human_only else None | |
| run_train( | |
| train_path=train_path, | |
| eval_specs=eval_specs, | |
| output_dir=main_output, | |
| model_path=model_path, | |
| config=args.config, | |
| batch_size=args.batch_size, | |
| include_species=main_include_species, | |
| include_families=main_include_families, | |
| keep_duplicates=True, | |
| group_column="lineage", | |
| train_loader="boughter", | |
| bootstrap_samples=args.bootstrap_samples, | |
| bootstrap_alpha=args.bootstrap_alpha, | |
| ) | |
| metrics = pd.read_csv(main_output / "metrics.csv") | |
| preds = pd.read_csv(main_output / "preds.csv") | |
| plot_accuracy(metrics, main_output / "accuracy_overview.png", eval_specs) | |
| plot_rocs(preds, main_output / "roc_overview.png", eval_specs) | |
| if not args.skip_flag_regression: | |
| run_flag_regression( | |
| train_path=train_path, | |
| output_dir=main_output, | |
| config_path=args.config, | |
| batch_size=args.batch_size, | |
| ) | |
| split_summary = _summarise_predictions(preds) | |
| split_summary.to_csv(main_output / "dataset_split_summary.csv", index=False) | |
| spearman_stat, spearman_p, corr_df = compute_spearman( | |
| model_path=model_path, | |
| dataset_path=Path(args.full_data), | |
| batch_size=args.batch_size, | |
| ) | |
| plot_flags_scatter(corr_df, spearman_stat, main_output / "prob_vs_flags.png") | |
| (main_output / "spearman_flags.json").write_text( | |
| json.dumps({"spearman": spearman_stat, "p_value": spearman_p}, indent=2) | |
| ) | |
| corr_df.to_csv(main_output / "prob_vs_flags.csv", index=False) | |
| if not args.skip_lofo: | |
| full_df = pd.read_csv(args.train_data) | |
| lofo_dir = output_dir / "lofo_runs" | |
| lofo_dir.mkdir(parents=True, exist_ok=True) | |
| lofo_df = run_lofo( | |
| full_df, | |
| families=["influenza", "hiv", "mouse_iga"], | |
| config=args.config, | |
| batch_size=args.batch_size, | |
| output_dir=lofo_dir, | |
| bootstrap_samples=args.bootstrap_samples, | |
| bootstrap_alpha=args.bootstrap_alpha, | |
| ) | |
| lofo_df.to_csv(output_dir / "lofo_metrics.csv", index=False) | |
| if not args.skip_descriptor_variants: | |
| descriptor_dir = output_dir / "descriptor_variants" | |
| descriptor_dir.mkdir(parents=True, exist_ok=True) | |
| run_descriptor_variants( | |
| base_config_dict, | |
| train_path=train_path, | |
| eval_specs=eval_specs, | |
| output_dir=descriptor_dir, | |
| batch_size=args.batch_size, | |
| include_species=main_include_species, | |
| include_families=main_include_families, | |
| bootstrap_samples=args.bootstrap_samples, | |
| bootstrap_alpha=args.bootstrap_alpha, | |
| ) | |
| if not args.skip_fragment_variants: | |
| fragment_dir = output_dir / "fragment_variants" | |
| fragment_dir.mkdir(parents=True, exist_ok=True) | |
| run_fragment_variants( | |
| args.config, | |
| train_path=train_path, | |
| eval_specs=eval_specs, | |
| output_dir=fragment_dir, | |
| batch_size=args.batch_size, | |
| include_species=main_include_species, | |
| include_families=main_include_families, | |
| bootstrap_samples=args.bootstrap_samples, | |
| bootstrap_alpha=args.bootstrap_alpha, | |
| ) | |
| raw_summaries = [] | |
| raw_summaries.append(_summarise_raw_dataset(train_path, "boughter_rebuilt")) | |
| for spec in eval_specs: | |
| summary_name = RAW_LABELS.get(spec.name, spec.name) | |
| raw_summaries.append(_summarise_raw_dataset(spec.path, summary_name)) | |
| pd.DataFrame(raw_summaries).to_csv(output_dir / "raw_dataset_summary.csv", index=False) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |