"""Reproduce key metrics and visualisations for the polyreactivity model.""" from __future__ import annotations import argparse import copy import json import subprocess from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Sequence import joblib import matplotlib.pyplot as plt import numpy as np import pandas as pd import yaml from scipy.stats import pearsonr, spearmanr from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_curve from sklearn.model_selection import KFold from sklearn.preprocessing import StandardScaler from polyreact import train as train_module from polyreact.config import load_config from polyreact.features.anarsi import AnarciNumberer from polyreact.features.pipeline import FeaturePipeline from polyreact.models.ordinal import ( fit_negative_binomial_model, fit_poisson_model, pearson_dispersion, regression_metrics, ) @dataclass(slots=True) class DatasetSpec: name: str path: Path display: str DISPLAY_LABELS = { "jain": "Jain (2017)", "shehata": "Shehata PSR (398)", "shehata_curated": "Shehata curated (88)", "harvey": "Harvey (2022)", } RAW_LABELS = { "jain": "jain2017", "shehata": "shehata2019", "shehata_curated": "shehata2019_curated", "harvey": "harvey2022", } def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Reproduce paper-style metrics and plots") parser.add_argument( "--train-data", default="data/processed/boughter_counts_rebuilt.csv", help="Reconstructed Boughter dataset path.", ) parser.add_argument( "--full-data", default="data/processed/boughter_counts_rebuilt.csv", help="Dataset (including mild flags) for correlation analysis.", ) parser.add_argument("--jain", default="data/processed/jain.csv") parser.add_argument( "--shehata", default="data/processed/shehata_full.csv", help="Full Shehata PSR panel (398 sequences) in processed CSV form.", ) parser.add_argument( "--shehata-curated", default="data/processed/shehata_curated.csv", help="Optional curated subset of Shehata et al. (88 sequences).", ) parser.add_argument("--harvey", default="data/processed/harvey.csv") parser.add_argument("--output-dir", default="artifacts/paper") parser.add_argument("--config", default="configs/default.yaml") parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--rebuild", action="store_true") parser.add_argument( "--bootstrap-samples", type=int, default=1000, help="Bootstrap resamples for metrics confidence intervals.", ) parser.add_argument( "--bootstrap-alpha", type=float, default=0.05, help="Alpha for bootstrap confidence intervals (default 0.05 → 95%).", ) parser.add_argument( "--human-only", action="store_true", help=( "Restrict the main cross-validation run to human HIV and influenza families" " (legacy behaviour). By default all Boughter families, including mouse IgA," " participate in CV as in Sakhnini et al." ), ) parser.add_argument( "--skip-flag-regression", action="store_true", help="Skip ELISA flag regression diagnostics (Poisson/NB).", ) parser.add_argument( "--skip-lofo", action="store_true", help="Skip leave-one-family-out experiments.", ) parser.add_argument( "--skip-descriptor-variants", action="store_true", help="Skip descriptor-only benchmark variants.", ) parser.add_argument( "--skip-fragment-variants", action="store_true", help="Skip CDR fragment ablation benchmarks.", ) return parser def _config_to_dict(config) -> Dict[str, Any]: data = asdict(config) data.pop("raw", None) return data def _deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: result = copy.deepcopy(base) for key, value in overrides.items(): if isinstance(value, dict) and isinstance(result.get(key), dict): result[key] = _deep_merge(result.get(key, {}), value) else: result[key] = value return result def _write_variant_config( base_config: Dict[str, Any], overrides: Dict[str, Any], target_path: Path, ) -> Path: merged = _deep_merge(base_config, overrides) target_path.parent.mkdir(parents=True, exist_ok=True) with target_path.open("w", encoding="utf-8") as handle: yaml.safe_dump(merged, handle, sort_keys=False) return target_path def _collect_metric_records(variant: str, metrics: pd.DataFrame) -> list[dict[str, Any]]: tracked = { "roc_auc", "pr_auc", "accuracy", "f1", "f1_positive", "f1_negative", "precision", "sensitivity", "specificity", "brier", "ece", "mce", } records: list[dict[str, Any]] = [] for _, row in metrics.iterrows(): metric_name = row["metric"] if metric_name not in tracked: continue record = {"variant": variant, "metric": metric_name} for column in metrics.columns: if column == "metric": continue record[column] = float(row[column]) if pd.notna(row[column]) else np.nan records.append(record) return records def _dump_coefficients(model_path: Path, output_path: Path) -> None: artifact = joblib.load(model_path) trained = artifact["model"] estimator = getattr(trained, "estimator", None) if estimator is None or not hasattr(estimator, "coef_"): return coefs = estimator.coef_[0] feature_state = artifact.get("feature_state") feature_names: list[str] if feature_state is not None and getattr(feature_state, "feature_names", None): feature_names = list(feature_state.feature_names) else: feature_names = [f"f{i}" for i in range(len(coefs))] coeff_df = pd.DataFrame( { "feature": feature_names, "coef": coefs, "abs_coef": np.abs(coefs), } ).sort_values("abs_coef", ascending=False) coeff_df.to_csv(output_path, index=False) def _summarise_predictions(preds: pd.DataFrame) -> pd.DataFrame: records: list[dict[str, Any]] = [] for split, group in preds.groupby("split"): stats = { "split": split, "n_samples": int(len(group)), "positives": int(group["y_true"].sum()), "positive_rate": float(group["y_true"].mean()) if len(group) else np.nan, "score_mean": float(group["y_score"].mean()) if len(group) else np.nan, "score_std": float(group["y_score"].std(ddof=1)) if len(group) > 1 else np.nan, } records.append(stats) return pd.DataFrame(records) def _summarise_raw_dataset(path: Path, name: str) -> dict[str, Any]: df = pd.read_csv(path) summary: dict[str, Any] = { "dataset": name, "path": str(path), "rows": int(len(df)), } if "label" in df.columns: positives = int(df["label"].sum()) summary["positives"] = positives summary["positive_rate"] = float(df["label"].mean()) if len(df) else np.nan if "reactivity_count" in df.columns: summary["reactivity_count_mean"] = float(df["reactivity_count"].mean()) summary["reactivity_count_median"] = float(df["reactivity_count"].median()) summary["reactivity_count_max"] = int(df["reactivity_count"].max()) if "smp" in df.columns: summary["smp_mean"] = float(df["smp"].mean()) summary["smp_median"] = float(df["smp"].median()) summary["smp_max"] = float(df["smp"].max()) summary["smp_min"] = float(df["smp"].min()) summary["unique_heavy"] = int(df["heavy_seq"].nunique()) if "heavy_seq" in df.columns else np.nan return summary def _extract_region_sequence(sequence: str, regions: List[str], numberer: AnarciNumberer) -> str: if not sequence: return "" upper_regions = [region.upper() for region in regions] if upper_regions == ["VH"]: return sequence try: numbered = numberer.number_sequence(sequence) except Exception: return "" fragments: list[str] = [] for region in upper_regions: if region == "VH": return sequence fragment = numbered.regions.get(region) if not fragment: return "" fragments.append(fragment) return "".join(fragments) def _make_region_dataset( frame: pd.DataFrame, regions: List[str], numberer: AnarciNumberer ) -> tuple[pd.DataFrame, dict[str, Any]]: records: list[dict[str, Any]] = [] dropped = 0 for record in frame.to_dict(orient="records"): new_seq = _extract_region_sequence(record.get("heavy_seq", ""), regions, numberer) if not new_seq: dropped += 1 continue updated = record.copy() updated["heavy_seq"] = new_seq updated["light_seq"] = "" records.append(updated) result = pd.DataFrame(records, columns=frame.columns) summary = { "regions": "+".join(regions), "input_rows": int(len(frame)), "retained_rows": int(len(result)), "dropped_rows": int(dropped), } return result, summary def run_train( *, train_path: Path, eval_specs: Sequence[DatasetSpec], output_dir: Path, model_path: Path, config: str, batch_size: int, include_species: list[str] | None = None, include_families: list[str] | None = None, exclude_families: list[str] | None = None, keep_duplicates: bool = False, group_column: str | None = "lineage", train_loader: str | None = None, bootstrap_samples: int = 200, bootstrap_alpha: float = 0.05, ) -> None: args: list[str] = [ "--config", str(config), "--train", str(train_path), "--report-to", str(output_dir), "--save-to", str(model_path), "--batch-size", str(batch_size), ] if eval_specs: args.append("--eval") args.extend(str(spec.path) for spec in eval_specs) if train_loader: args.extend(["--train-loader", train_loader]) if eval_specs: args.append("--eval-loaders") args.extend(spec.name for spec in eval_specs) if include_species: args.append("--include-species") args.extend(include_species) if include_families: args.append("--include-families") args.extend(include_families) if exclude_families: args.append("--exclude-families") args.extend(exclude_families) if keep_duplicates: args.append("--keep-train-duplicates") if group_column: args.extend(["--cv-group-column", group_column]) else: args.append("--no-group-cv") args.extend(["--bootstrap-samples", str(bootstrap_samples)]) args.extend(["--bootstrap-alpha", str(bootstrap_alpha)]) exit_code = train_module.main(args) if exit_code != 0: raise RuntimeError(f"Training command failed with exit code {exit_code}") def compute_spearman(model_path: Path, dataset_path: Path, batch_size: int) -> tuple[float, float, pd.DataFrame]: artifact = joblib.load(model_path) config = artifact["config"] pipeline_state = artifact["feature_state"] trained_model = artifact["model"] pipeline = FeaturePipeline(backend=config.feature_backend, descriptors=config.descriptors, device=config.device) pipeline.load_state(pipeline_state) dataset = pd.read_csv(dataset_path) features = pipeline.transform(dataset, heavy_only=True, batch_size=batch_size) scores = trained_model.predict_proba(features) dataset = dataset.copy() dataset["score"] = scores stat, pvalue = spearmanr(dataset["reactivity_count"], dataset["score"]) return float(stat), float(pvalue), dataset def plot_accuracy( metrics: pd.DataFrame, output_path: Path, eval_specs: Sequence[DatasetSpec], ) -> None: row = metrics.loc[metrics["metric"] == "accuracy"].iloc[0] labels = ["Train CV"] + [spec.display for spec in eval_specs] values = [row.get("train_cv_mean", np.nan)] + [row.get(spec.name, np.nan) for spec in eval_specs] fig, ax = plt.subplots(figsize=(6, 4)) xs = np.arange(len(labels)) ax.bar(xs, values, color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]) ax.set_xticks(xs, labels) ax.set_ylim(0.0, 1.05) ax.set_ylabel("Accuracy") ax.set_title("Polyreactivity accuracy overview") for x, val in zip(xs, values, strict=False): if np.isnan(val): continue ax.text(x, val + 0.02, f"{val:.3f}", ha="center", va="bottom") fig.tight_layout() fig.savefig(output_path, dpi=300) plt.close(fig) def plot_rocs( preds: pd.DataFrame, output_path: Path, eval_specs: Sequence[DatasetSpec], ) -> None: mapping = {"train_cv_oof": "Train CV"} for spec in eval_specs: mapping[spec.name] = spec.display fig, ax = plt.subplots(figsize=(6, 6)) for split, label in mapping.items(): subset = preds[preds["split"] == split] if subset.empty: continue fpr, tpr, _ = roc_curve(subset["y_true"], subset["y_score"]) ax.plot(fpr, tpr, label=label) ax.plot([0, 1], [0, 1], linestyle="--", color="gray") ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate") ax.set_title("ROC curves") ax.legend() fig.tight_layout() fig.savefig(output_path, dpi=300) plt.close(fig) def plot_flags_scatter(data: pd.DataFrame, spearman_stat: float, output_path: Path) -> None: rng = np.random.default_rng(42) jitter = rng.uniform(-0.1, 0.1, size=len(data)) x = data["reactivity_count"].to_numpy(dtype=float) + jitter y = data["score"].to_numpy(dtype=float) fig, ax = plt.subplots(figsize=(6, 4)) ax.scatter(x, y, alpha=0.5, s=10) ax.set_xlabel("ELISA flag count") ax.set_ylabel("Predicted probability") ax.set_title(f"Prediction vs flag count (Spearman={spearman_stat:.2f})") fig.tight_layout() fig.savefig(output_path, dpi=300) plt.close(fig) def run_lofo( full_df: pd.DataFrame, *, families: list[str], config: str, batch_size: int, output_dir: Path, bootstrap_samples: int, bootstrap_alpha: float, ) -> pd.DataFrame: results: list[dict[str, float]] = [] for family in families: family_lower = family.lower() holdout = full_df[full_df["family"].str.lower() == family_lower].copy() train = full_df[full_df["family"].str.lower() != family_lower].copy() if holdout.empty or train.empty: continue train_path = output_dir / f"train_lofo_{family_lower}.csv" holdout_path = output_dir / f"eval_lofo_{family_lower}.csv" train.to_csv(train_path, index=False) holdout.to_csv(holdout_path, index=False) run_dir = output_dir / f"lofo_{family_lower}" run_dir.mkdir(parents=True, exist_ok=True) model_path = run_dir / "model.joblib" run_train( train_path=train_path, eval_specs=[ DatasetSpec( name="boughter", path=holdout_path, display=f"{family.title()} holdout", ) ], output_dir=run_dir, model_path=model_path, config=config, batch_size=batch_size, keep_duplicates=True, include_species=None, include_families=None, exclude_families=None, group_column="lineage", train_loader="boughter", bootstrap_samples=bootstrap_samples, bootstrap_alpha=bootstrap_alpha, ) metrics = pd.read_csv(run_dir / "metrics.csv") evaluation_cols = [ col for col in metrics.columns if col not in {"metric", "train_cv_mean", "train_cv_std"} ] if not evaluation_cols: continue eval_col = evaluation_cols[0] def _metric_value(name: str) -> float: series = metrics.loc[metrics["metric"] == name, eval_col] return float(series.values[0]) if not series.empty else float("nan") results.append( { "family": family, "accuracy": _metric_value("accuracy"), "roc_auc": _metric_value("roc_auc"), "pr_auc": _metric_value("pr_auc"), "sensitivity": _metric_value("sensitivity"), "specificity": _metric_value("specificity"), } ) return pd.DataFrame(results) def run_flag_regression( train_path: Path, *, output_dir: Path, config_path: str, batch_size: int, n_splits: int = 5, ) -> None: df = pd.read_csv(train_path) if "reactivity_count" not in df.columns: return config = load_config(config_path) kfold = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed) metrics_rows: list[dict[str, float]] = [] preds_rows: list[dict[str, float]] = [] for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(df), start=1): train_split = df.iloc[train_idx].reset_index(drop=True) val_split = df.iloc[val_idx].reset_index(drop=True) pipeline = FeaturePipeline( backend=config.feature_backend, descriptors=config.descriptors, device=config.device, ) X_train = pipeline.fit_transform(train_split, heavy_only=True, batch_size=batch_size) scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) y_train = train_split["reactivity_count"].to_numpy(dtype=float) # Train a logistic head to obtain probabilities as a 1-D feature clf = LogisticRegression( C=config.model.C, class_weight=config.model.class_weight, max_iter=2000, solver="lbfgs", ) clf.fit(X_train_scaled, train_split["label"].to_numpy(dtype=int)) prob_train = clf.predict_proba(X_train_scaled)[:, 1] X_val = pipeline.transform(val_split, heavy_only=True, batch_size=batch_size) X_val_scaled = scaler.transform(X_val) y_val = val_split["reactivity_count"].to_numpy(dtype=float) prob_val = clf.predict_proba(X_val_scaled)[:, 1] poisson_X_train = prob_train.reshape(-1, 1) poisson_X_val = prob_val.reshape(-1, 1) model = fit_poisson_model(poisson_X_train, y_train) poisson_preds = model.predict(poisson_X_val) n_params = poisson_X_train.shape[1] + 1 # include intercept dof = max(len(y_val) - n_params, 1) variance_to_mean = float(np.var(y_val, ddof=1) / np.mean(y_val)) if np.mean(y_val) else float("nan") spearman_val = float(spearmanr(y_val, poisson_preds).statistic) try: pearson_val = float(pearsonr(y_val, poisson_preds)[0]) except Exception: # pragma: no cover - fallback if correlation fails pearson_val = float("nan") poisson_metrics = regression_metrics(y_val, poisson_preds) poisson_metrics.update( { "spearman": spearman_val, "pearson": pearson_val, "pearson_dispersion": pearson_dispersion(y_val, poisson_preds, dof=dof), "variance_to_mean": variance_to_mean, "fold": fold_idx, "model": "poisson", "status": "ok", } ) metrics_rows.append(poisson_metrics) nb_preds: np.ndarray | None = None nb_model = None try: nb_model = fit_negative_binomial_model(poisson_X_train, y_train) nb_preds = nb_model.predict(poisson_X_val) if not np.all(np.isfinite(nb_preds)): raise ValueError("negative binomial produced non-finite predictions") except Exception: nb_metrics = { "spearman": float("nan"), "pearson": float("nan"), "pearson_dispersion": float("nan"), "variance_to_mean": variance_to_mean, "alpha": float("nan"), "fold": fold_idx, "model": "negative_binomial", "status": "failed", } metrics_rows.append(nb_metrics) else: spearman_nb = float(spearmanr(y_val, nb_preds).statistic) try: pearson_nb = float(pearsonr(y_val, nb_preds)[0]) except Exception: # pragma: no cover pearson_nb = float("nan") nb_metrics = regression_metrics(y_val, nb_preds) nb_metrics.update( { "spearman": spearman_nb, "pearson": pearson_nb, "pearson_dispersion": pearson_dispersion(y_val, nb_preds, dof=dof), "variance_to_mean": variance_to_mean, "alpha": nb_model.alpha, "fold": fold_idx, "model": "negative_binomial", "status": "ok", } ) metrics_rows.append(nb_metrics) records = list(val_split.itertuples(index=False)) for idx, row in enumerate(records): row_id = getattr(row, "id", idx) y_true_val = float(getattr(row, "reactivity_count")) preds_rows.append( { "fold": fold_idx, "model": "poisson", "id": row_id, "y_true": y_true_val, "y_pred": float(poisson_preds[idx]), } ) if nb_preds is not None: preds_rows.append( { "fold": fold_idx, "model": "negative_binomial", "id": row_id, "y_true": y_true_val, "y_pred": float(nb_preds[idx]), } ) metrics_df = pd.DataFrame(metrics_rows) metrics_df.to_csv(output_dir / "flag_regression_folds.csv", index=False) summary_records: list[dict[str, float]] = [] for model_name, group in metrics_df.groupby("model"): for column in group.columns: if column in {"fold", "model", "status"}: continue values = group[column].dropna() if values.empty: continue summary_records.append( { "model": model_name, "metric": column, "mean": float(values.mean()), "std": float(values.std(ddof=1)) if len(values) > 1 else float("nan"), } ) if summary_records: pd.DataFrame(summary_records).to_csv( output_dir / "flag_regression_metrics.csv", index=False ) if preds_rows: pd.DataFrame(preds_rows).to_csv(output_dir / "flag_regression_preds.csv", index=False) def run_descriptor_variants( base_config: Dict[str, Any], *, train_path: Path, eval_specs: Sequence[DatasetSpec], output_dir: Path, batch_size: int, include_species: List[str] | None, include_families: List[str] | None, bootstrap_samples: int, bootstrap_alpha: float, ) -> None: variants = [ ( "descriptors_full_vh", { "feature_backend": {"type": "descriptors"}, "descriptors": { "use_anarci": True, "regions": ["CDRH1", "CDRH2", "CDRH3"], "features": [ "length", "charge", "hydropathy", "aromaticity", "pI", "net_charge", ], }, }, ), ( "descriptors_cdrh3_pi", { "feature_backend": {"type": "descriptors"}, "descriptors": { "use_anarci": True, "regions": ["CDRH3"], "features": ["pI"], }, }, ), ( "descriptors_cdrh3_top5", { "feature_backend": {"type": "descriptors"}, "descriptors": { "use_anarci": True, "regions": ["CDRH3"], "features": [ "pI", "net_charge", "charge", "hydropathy", "length", ], }, }, ), ] configs_dir = output_dir / "configs" configs_dir.mkdir(parents=True, exist_ok=True) summary_records: list[dict[str, Any]] = [] for name, overrides in variants: variant_config_path = _write_variant_config( base_config, overrides, configs_dir / f"{name}.yaml", ) variant_output = output_dir / name variant_output.mkdir(parents=True, exist_ok=True) model_path = variant_output / "model.joblib" run_train( train_path=train_path, eval_specs=eval_specs, output_dir=variant_output, model_path=model_path, config=str(variant_config_path), batch_size=batch_size, include_species=include_species, include_families=include_families, keep_duplicates=True, group_column="lineage", train_loader="boughter", bootstrap_samples=bootstrap_samples, bootstrap_alpha=bootstrap_alpha, ) metrics_path = variant_output / "metrics.csv" if metrics_path.exists(): metrics_df = pd.read_csv(metrics_path) summary_records.extend(_collect_metric_records(name, metrics_df)) _dump_coefficients(model_path, variant_output / "coefficients.csv") if summary_records: pd.DataFrame(summary_records).to_csv(output_dir / "summary.csv", index=False) def run_fragment_variants( config_path: str, *, train_path: Path, eval_specs: Sequence[DatasetSpec], output_dir: Path, batch_size: int, include_species: List[str] | None, include_families: List[str] | None, bootstrap_samples: int, bootstrap_alpha: float, ) -> None: numberer = AnarciNumberer() specs = [ ("vh_full", ["VH"]), ("cdrh1", ["CDRH1"]), ("cdrh2", ["CDRH2"]), ("cdrh3", ["CDRH3"]), ("cdrh123", ["CDRH1", "CDRH2", "CDRH3"]), ] summary_rows: list[dict[str, Any]] = [] metric_summary_rows: list[dict[str, Any]] = [] for name, regions in specs: variant_dir = output_dir / name variant_dir.mkdir(parents=True, exist_ok=True) dataset_dir = variant_dir / "datasets" dataset_dir.mkdir(parents=True, exist_ok=True) train_df = pd.read_csv(train_path) train_variant, train_summary = _make_region_dataset(train_df, regions, numberer) train_variant_path = dataset_dir / "train.csv" train_variant.to_csv(train_variant_path, index=False) eval_variant_specs: list[DatasetSpec] = [] for spec in eval_specs: eval_df = pd.read_csv(spec.path) transformed, eval_summary = _make_region_dataset(eval_df, regions, numberer) eval_path = dataset_dir / f"{spec.name}.csv" transformed.to_csv(eval_path, index=False) eval_variant_specs.append( DatasetSpec(name=spec.name, path=eval_path, display=spec.display) ) eval_summary.update({"variant": name, "dataset": spec.name}) summary_rows.append(eval_summary) train_summary.update({"variant": name, "dataset": "train"}) summary_rows.append(train_summary) run_train( train_path=train_variant_path, eval_specs=eval_variant_specs, output_dir=variant_dir, model_path=variant_dir / "model.joblib", config=config_path, batch_size=batch_size, include_species=include_species, include_families=include_families, keep_duplicates=True, group_column="lineage", train_loader="boughter", bootstrap_samples=bootstrap_samples, bootstrap_alpha=bootstrap_alpha, ) metrics_path = variant_dir / "metrics.csv" if metrics_path.exists(): metrics_df = pd.read_csv(metrics_path) metric_records = _collect_metric_records(name, metrics_df) for record in metric_records: record["variant_type"] = "fragment" metric_summary_rows.extend(metric_records) if summary_rows: pd.DataFrame(summary_rows).to_csv(output_dir / "fragment_dataset_summary.csv", index=False) if metric_summary_rows: pd.DataFrame(metric_summary_rows).to_csv(output_dir / "fragment_metrics_summary.csv", index=False) def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) if args.rebuild: rebuild_cmd = [ "python", "scripts/rebuild_boughter_from_counts.py", "--output", str(args.train_data), ] if subprocess.run(rebuild_cmd, check=False).returncode != 0: raise RuntimeError("Dataset rebuild failed") train_path = Path(args.train_data) def _make_spec(name: str, path_str: str) -> DatasetSpec | None: path = Path(path_str) if not path.exists(): return None display = DISPLAY_LABELS.get(name, name.replace("_", " ").title()) return DatasetSpec(name=name, path=path, display=display) eval_specs: list[DatasetSpec] = [] seen_paths: set[Path] = set() for name, path_str in [ ("jain", args.jain), ("shehata", args.shehata), ("shehata_curated", args.shehata_curated), ("harvey", args.harvey), ]: spec = _make_spec(name, path_str) if spec is not None: resolved = spec.path.resolve() if resolved in seen_paths: continue seen_paths.add(resolved) eval_specs.append(spec) base_config = load_config(args.config) base_config_dict = _config_to_dict(base_config) main_output = output_dir / "main" main_output.mkdir(parents=True, exist_ok=True) model_path = main_output / "model.joblib" main_include_species = ["human"] if args.human_only else None main_include_families = ["hiv", "influenza"] if args.human_only else None run_train( train_path=train_path, eval_specs=eval_specs, output_dir=main_output, model_path=model_path, config=args.config, batch_size=args.batch_size, include_species=main_include_species, include_families=main_include_families, keep_duplicates=True, group_column="lineage", train_loader="boughter", bootstrap_samples=args.bootstrap_samples, bootstrap_alpha=args.bootstrap_alpha, ) metrics = pd.read_csv(main_output / "metrics.csv") preds = pd.read_csv(main_output / "preds.csv") plot_accuracy(metrics, main_output / "accuracy_overview.png", eval_specs) plot_rocs(preds, main_output / "roc_overview.png", eval_specs) if not args.skip_flag_regression: run_flag_regression( train_path=train_path, output_dir=main_output, config_path=args.config, batch_size=args.batch_size, ) split_summary = _summarise_predictions(preds) split_summary.to_csv(main_output / "dataset_split_summary.csv", index=False) spearman_stat, spearman_p, corr_df = compute_spearman( model_path=model_path, dataset_path=Path(args.full_data), batch_size=args.batch_size, ) plot_flags_scatter(corr_df, spearman_stat, main_output / "prob_vs_flags.png") (main_output / "spearman_flags.json").write_text( json.dumps({"spearman": spearman_stat, "p_value": spearman_p}, indent=2) ) corr_df.to_csv(main_output / "prob_vs_flags.csv", index=False) if not args.skip_lofo: full_df = pd.read_csv(args.train_data) lofo_dir = output_dir / "lofo_runs" lofo_dir.mkdir(parents=True, exist_ok=True) lofo_df = run_lofo( full_df, families=["influenza", "hiv", "mouse_iga"], config=args.config, batch_size=args.batch_size, output_dir=lofo_dir, bootstrap_samples=args.bootstrap_samples, bootstrap_alpha=args.bootstrap_alpha, ) lofo_df.to_csv(output_dir / "lofo_metrics.csv", index=False) if not args.skip_descriptor_variants: descriptor_dir = output_dir / "descriptor_variants" descriptor_dir.mkdir(parents=True, exist_ok=True) run_descriptor_variants( base_config_dict, train_path=train_path, eval_specs=eval_specs, output_dir=descriptor_dir, batch_size=args.batch_size, include_species=main_include_species, include_families=main_include_families, bootstrap_samples=args.bootstrap_samples, bootstrap_alpha=args.bootstrap_alpha, ) if not args.skip_fragment_variants: fragment_dir = output_dir / "fragment_variants" fragment_dir.mkdir(parents=True, exist_ok=True) run_fragment_variants( args.config, train_path=train_path, eval_specs=eval_specs, output_dir=fragment_dir, batch_size=args.batch_size, include_species=main_include_species, include_families=main_include_families, bootstrap_samples=args.bootstrap_samples, bootstrap_alpha=args.bootstrap_alpha, ) raw_summaries = [] raw_summaries.append(_summarise_raw_dataset(train_path, "boughter_rebuilt")) for spec in eval_specs: summary_name = RAW_LABELS.get(spec.name, spec.name) raw_summaries.append(_summarise_raw_dataset(spec.path, summary_name)) pd.DataFrame(raw_summaries).to_csv(output_dir / "raw_dataset_summary.csv", index=False) return 0 if __name__ == "__main__": raise SystemExit(main())