Spaces:
Sleeping
Sleeping
| """Training entrypoint for the polyreactivity model.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Any, Sequence | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics import roc_auc_score | |
| from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold | |
| from sklearn.linear_model import LogisticRegression | |
| from .config import Config, load_config | |
| from .data_loaders import boughter, harvey, jain, shehata | |
| from .data_loaders.utils import deduplicate_sequences | |
| from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline | |
| from .models.calibrate import fit_calibrator | |
| from .models.linear import LinearModelConfig, TrainedModel, build_estimator, train_linear_model | |
| from .utils.io import write_table | |
| from .utils.logging import configure_logging | |
| from .utils.metrics import bootstrap_metric_intervals, compute_metrics | |
| from .utils.plots import plot_precision_recall, plot_reliability_curve, plot_roc_curve | |
| from .utils.seeds import set_global_seeds | |
| DATASET_LOADERS = { | |
| "boughter": boughter.load_dataframe, | |
| "jain": jain.load_dataframe, | |
| "shehata": shehata.load_dataframe, | |
| "shehata_curated": shehata.load_dataframe, | |
| "harvey": harvey.load_dataframe, | |
| } | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Train polyreactivity model") | |
| parser.add_argument("--config", default="configs/default.yaml", help="Config file") | |
| parser.add_argument("--train", required=True, help="Training dataset path") | |
| parser.add_argument( | |
| "--eval", | |
| nargs="*", | |
| default=[], | |
| help="Evaluation dataset paths", | |
| ) | |
| parser.add_argument( | |
| "--save-to", | |
| default="artifacts/model.joblib", | |
| help="Path to save trained model artifact", | |
| ) | |
| parser.add_argument( | |
| "--report-to", | |
| default="artifacts", | |
| help="Directory for metrics, predictions, and plots", | |
| ) | |
| parser.add_argument( | |
| "--train-loader", | |
| choices=list(DATASET_LOADERS.keys()), | |
| help="Optional explicit loader for training dataset", | |
| ) | |
| parser.add_argument( | |
| "--eval-loaders", | |
| nargs="*", | |
| help="Optional explicit loaders for evaluation datasets (aligned with --eval order)", | |
| ) | |
| parser.add_argument( | |
| "--backend", | |
| choices=["plm", "descriptors", "concat"], | |
| help="Override feature backend", | |
| ) | |
| parser.add_argument("--plm-model", help="Override PLM model name") | |
| parser.add_argument("--cache-dir", help="Override embedding cache directory") | |
| parser.add_argument("--device", choices=["auto", "cpu", "cuda"], help="Device override") | |
| parser.add_argument("--batch-size", type=int, default=8, help="Batch size for embeddings") | |
| parser.add_argument( | |
| "--heavy-only", | |
| action="store_true", | |
| default=True, | |
| help="Use heavy chains only (default true)", | |
| ) | |
| parser.add_argument( | |
| "--paired", | |
| dest="heavy_only", | |
| action="store_false", | |
| help="Use paired heavy/light chains when available.", | |
| ) | |
| parser.add_argument( | |
| "--include-families", | |
| nargs="*", | |
| help="Optional list of family names to retain in the training dataset", | |
| ) | |
| parser.add_argument( | |
| "--exclude-families", | |
| nargs="*", | |
| help="Optional list of family names to drop from the training dataset", | |
| ) | |
| parser.add_argument( | |
| "--include-species", | |
| nargs="*", | |
| help="Optional list of species (e.g. human, mouse) to retain", | |
| ) | |
| parser.add_argument( | |
| "--cv-group-column", | |
| default="lineage", | |
| help="Column name used to group samples during cross-validation (default: lineage)", | |
| ) | |
| parser.add_argument( | |
| "--no-group-cv", | |
| action="store_true", | |
| help="Disable group-aware cross-validation even if group column is present", | |
| ) | |
| parser.add_argument( | |
| "--keep-train-duplicates", | |
| action="store_true", | |
| help="Keep duplicate keys within the training dataset when deduplicating across splits", | |
| ) | |
| parser.add_argument( | |
| "--dedupe-key-columns", | |
| nargs="*", | |
| help="Columns used to detect duplicates across datasets (defaults to heavy/light sequences)", | |
| ) | |
| parser.add_argument( | |
| "--bootstrap-samples", | |
| type=int, | |
| default=200, | |
| help="Number of bootstrap resamples for confidence intervals (0 to disable).", | |
| ) | |
| parser.add_argument( | |
| "--bootstrap-alpha", | |
| type=float, | |
| default=0.05, | |
| help="Alpha for two-sided bootstrap confidence intervals (default 0.05 → 95% CI).", | |
| ) | |
| parser.add_argument( | |
| "--write-train-in-sample", | |
| action="store_true", | |
| help=( | |
| "Persist in-sample metrics on the full training set; disabled by default to avoid" | |
| " over-optimistic reporting." | |
| ), | |
| ) | |
| return parser | |
| def _infer_loader(path: str, explicit: str | None) -> tuple[str, callable]: | |
| if explicit: | |
| return explicit, DATASET_LOADERS[explicit] | |
| lower = Path(path).stem.lower() | |
| for name, loader in DATASET_LOADERS.items(): | |
| if name in lower: | |
| return name, loader | |
| msg = f"Could not infer loader for dataset: {path}. Provide --train-loader/--eval-loaders." | |
| raise ValueError(msg) | |
| def _load_dataset(path: str, loader_name: str, loader_fn, *, heavy_only: bool) -> pd.DataFrame: | |
| frame = loader_fn(path, heavy_only=heavy_only) | |
| frame["source"] = loader_name | |
| return frame | |
| def _apply_dataset_filters( | |
| frame: pd.DataFrame, | |
| *, | |
| include_families: Sequence[str] | None, | |
| exclude_families: Sequence[str] | None, | |
| include_species: Sequence[str] | None, | |
| ) -> pd.DataFrame: | |
| filtered = frame.copy() | |
| if include_families: | |
| families = {fam.lower() for fam in include_families} | |
| if "family" in filtered.columns: | |
| filtered = filtered[ | |
| filtered["family"].astype(str).str.lower().isin(families) | |
| ] | |
| if exclude_families: | |
| families_ex = {fam.lower() for fam in exclude_families} | |
| if "family" in filtered.columns: | |
| filtered = filtered[ | |
| ~filtered["family"].astype(str).str.lower().isin(families_ex) | |
| ] | |
| if include_species: | |
| species_set = {spec.lower() for spec in include_species} | |
| if "species" in filtered.columns: | |
| filtered = filtered[ | |
| filtered["species"].astype(str).str.lower().isin(species_set) | |
| ] | |
| return filtered.reset_index(drop=True) | |
| def main(argv: Sequence[str] | None = None) -> int: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| config = load_config(args.config) | |
| if args.device: | |
| config.device = args.device | |
| if args.backend: | |
| config.feature_backend.type = args.backend | |
| if args.cache_dir: | |
| config.feature_backend.cache_dir = args.cache_dir | |
| if args.plm_model: | |
| config.feature_backend.plm_model_name = args.plm_model | |
| logger = configure_logging() | |
| set_global_seeds(config.seed) | |
| _log_environment(logger) | |
| heavy_only = args.heavy_only | |
| train_name, train_loader = _infer_loader(args.train, args.train_loader) | |
| train_df = _load_dataset(args.train, train_name, train_loader, heavy_only=heavy_only) | |
| train_df = _apply_dataset_filters( | |
| train_df, | |
| include_families=args.include_families, | |
| exclude_families=args.exclude_families, | |
| include_species=args.include_species, | |
| ) | |
| eval_frames: list[pd.DataFrame] = [] | |
| if args.eval: | |
| loaders_iter = args.eval_loaders or [] | |
| for idx, eval_path in enumerate(args.eval): | |
| explicit = loaders_iter[idx] if idx < len(loaders_iter) else None | |
| eval_name, eval_loader = _infer_loader(eval_path, explicit) | |
| eval_df = _load_dataset(eval_path, eval_name, eval_loader, heavy_only=heavy_only) | |
| eval_frames.append(eval_df) | |
| all_frames = [train_df, *eval_frames] | |
| dedup_keep = {0} if args.keep_train_duplicates else set() | |
| deduped_frames = deduplicate_sequences( | |
| all_frames, | |
| heavy_only=heavy_only, | |
| key_columns=args.dedupe_key_columns, | |
| keep_intra_frames=dedup_keep, | |
| ) | |
| train_df = deduped_frames[0] | |
| eval_frames = deduped_frames[1:] | |
| pipeline_factory = lambda: build_feature_pipeline( # noqa: E731 | |
| config, | |
| backend_override=args.backend, | |
| plm_model_override=args.plm_model, | |
| cache_dir_override=args.cache_dir, | |
| ) | |
| model_config = LinearModelConfig( | |
| head=config.model.head, | |
| C=config.model.C, | |
| class_weight=config.model.class_weight, | |
| ) | |
| groups = None | |
| if not args.no_group_cv and args.cv_group_column: | |
| if args.cv_group_column in train_df.columns: | |
| groups = train_df[args.cv_group_column].fillna("").astype(str).to_numpy() | |
| else: | |
| logger.warning( | |
| "Group column '%s' not found in training dataframe; falling back to standard CV", | |
| args.cv_group_column, | |
| ) | |
| cv_results = _cross_validate( | |
| train_df, | |
| pipeline_factory, | |
| model_config, | |
| config, | |
| heavy_only=heavy_only, | |
| batch_size=args.batch_size, | |
| groups=groups, | |
| ) | |
| trained_model, feature_pipeline = _fit_full_model( | |
| train_df, | |
| pipeline_factory, | |
| model_config, | |
| config, | |
| heavy_only=heavy_only, | |
| batch_size=args.batch_size, | |
| ) | |
| outputs_dir = Path(args.report_to) | |
| outputs_dir.mkdir(parents=True, exist_ok=True) | |
| metrics_df, preds_rows = _evaluate_datasets( | |
| train_df, | |
| eval_frames, | |
| trained_model, | |
| feature_pipeline, | |
| config, | |
| cv_results, | |
| outputs_dir, | |
| batch_size=args.batch_size, | |
| heavy_only=heavy_only, | |
| bootstrap_samples=args.bootstrap_samples, | |
| bootstrap_alpha=args.bootstrap_alpha, | |
| write_train_in_sample=args.write_train_in_sample, | |
| ) | |
| write_table(metrics_df, outputs_dir / config.io.metrics_filename) | |
| preds_df = pd.DataFrame(preds_rows) | |
| write_table(preds_df, outputs_dir / config.io.preds_filename) | |
| artifact = { | |
| "config": config, | |
| "feature_state": feature_pipeline.get_state(), | |
| "model": trained_model, | |
| } | |
| Path(args.save_to).parent.mkdir(parents=True, exist_ok=True) | |
| joblib.dump(artifact, args.save_to) | |
| logger.info("Training complete. Metrics written to %s", outputs_dir) | |
| return 0 | |
| def _cross_validate( | |
| train_df: pd.DataFrame, | |
| pipeline_factory, | |
| model_config: LinearModelConfig, | |
| config: Config, | |
| *, | |
| heavy_only: bool, | |
| batch_size: int, | |
| groups: np.ndarray | None = None, | |
| ): | |
| y = train_df["label"].to_numpy(dtype=int) | |
| n_samples = len(y) | |
| # Determine a safe number of folds for tiny fixtures; prefer the configured value | |
| # but never exceed the number of samples. Fall back to non-stratified KFold when | |
| # per-class counts are too small for stratification (e.g., 1 positive/1 negative). | |
| n_splits = max(2, min(config.training.cv_folds, n_samples)) | |
| use_stratified = True | |
| class_counts = np.bincount(y) if y.size else np.array([]) | |
| if class_counts.size > 0 and (class_counts.min(initial=0) < n_splits): | |
| use_stratified = False | |
| if groups is not None and use_stratified: | |
| splitter = StratifiedGroupKFold( | |
| n_splits=n_splits, | |
| shuffle=True, | |
| random_state=config.seed, | |
| ) | |
| split_iter = splitter.split(train_df, y, groups) | |
| elif use_stratified: | |
| splitter = StratifiedKFold( | |
| n_splits=n_splits, | |
| shuffle=True, | |
| random_state=config.seed, | |
| ) | |
| split_iter = splitter.split(train_df, y) | |
| else: | |
| # Non-stratified fallback for extreme class imbalance / tiny datasets | |
| from sklearn.model_selection import KFold # local import to limit surface | |
| splitter = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed) | |
| split_iter = splitter.split(train_df) | |
| oof_scores = np.zeros(len(train_df), dtype=float) | |
| metrics_per_fold: list[dict[str, float]] = [] | |
| for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1): | |
| train_slice = train_df.iloc[train_idx].reset_index(drop=True) | |
| val_slice = train_df.iloc[val_idx].reset_index(drop=True) | |
| pipeline: FeaturePipeline = pipeline_factory() | |
| X_train = pipeline.fit_transform(train_slice, heavy_only=heavy_only, batch_size=batch_size) | |
| X_val = pipeline.transform(val_slice, heavy_only=heavy_only, batch_size=batch_size) | |
| y_train = y[train_idx] | |
| y_val = y[val_idx] | |
| # Handle degenerate folds where training data contains a single class | |
| if np.unique(y_train).size < 2: | |
| fallback_prob = float(y.mean()) if y.size else 0.5 | |
| y_scores = np.full(X_val.shape[0], fallback_prob, dtype=float) | |
| else: | |
| trained = train_linear_model( | |
| X_train, y_train, config=model_config, random_state=config.seed | |
| ) | |
| calibrator = _fit_model_calibrator( | |
| model_config, | |
| config, | |
| X_train, | |
| y_train, | |
| base_estimator=trained.estimator, | |
| ) | |
| trained.calibrator = calibrator | |
| if calibrator is not None: | |
| y_scores = calibrator.predict_proba(X_val)[:, 1] | |
| else: | |
| y_scores = trained.predict_proba(X_val) | |
| oof_scores[val_idx] = y_scores | |
| fold_metrics = compute_metrics(y_val, y_scores) | |
| try: | |
| fold_metrics["roc_auc"] = float(roc_auc_score(y_val, y_scores)) | |
| except ValueError: | |
| # For tiny validation folds with a single class, ROC-AUC is undefined | |
| pass | |
| metrics_per_fold.append(fold_metrics) | |
| metrics_mean: dict[str, float] = {} | |
| metrics_std: dict[str, float] = {} | |
| metric_names = list(metrics_per_fold[0].keys()) if metrics_per_fold else [] | |
| for metric in metric_names: | |
| values = [fold[metric] for fold in metrics_per_fold] | |
| metrics_mean[metric] = float(np.mean(values)) | |
| metrics_std[metric] = float(np.std(values, ddof=1)) | |
| return { | |
| "oof_scores": oof_scores, | |
| "metrics_per_fold": metrics_per_fold, | |
| "metrics_mean": metrics_mean, | |
| "metrics_std": metrics_std, | |
| } | |
| def _fit_full_model( | |
| train_df: pd.DataFrame, | |
| pipeline_factory, | |
| model_config: LinearModelConfig, | |
| config: Config, | |
| *, | |
| heavy_only: bool, | |
| batch_size: int, | |
| ) -> tuple[TrainedModel, FeaturePipeline]: | |
| pipeline: FeaturePipeline = pipeline_factory() | |
| X_train = pipeline.fit_transform(train_df, heavy_only=heavy_only, batch_size=batch_size) | |
| y_train = train_df["label"].to_numpy(dtype=int) | |
| trained = train_linear_model(X_train, y_train, config=model_config, random_state=config.seed) | |
| calibrator = _fit_model_calibrator( | |
| model_config, | |
| config, | |
| X_train, | |
| y_train, | |
| base_estimator=trained.estimator, | |
| ) | |
| trained.calibrator = calibrator | |
| return trained, pipeline | |
| def _evaluate_datasets( | |
| train_df: pd.DataFrame, | |
| eval_frames: list[pd.DataFrame], | |
| trained_model: TrainedModel, | |
| pipeline: FeaturePipeline, | |
| config: Config, | |
| cv_results: dict, | |
| outputs_dir: Path, | |
| *, | |
| batch_size: int, | |
| heavy_only: bool, | |
| bootstrap_samples: int, | |
| bootstrap_alpha: float, | |
| write_train_in_sample: bool, | |
| ): | |
| metrics_lookup: dict[str, dict[str, float]] = {} | |
| preds_rows: list[dict[str, float]] = [] | |
| metrics_mean: dict[str, float] = cv_results["metrics_mean"] | |
| metrics_std: dict[str, float] = cv_results["metrics_std"] | |
| for metric_name, value in metrics_mean.items(): | |
| metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ | |
| "train_cv_mean" | |
| ] = value | |
| for metric_name, value in metrics_std.items(): | |
| metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ | |
| "train_cv_std" | |
| ] = value | |
| train_scores = cv_results["oof_scores"] | |
| train_preds = train_df[["id", "source", "label"]].copy() | |
| train_preds["y_true"] = train_preds["label"] | |
| train_preds["y_score"] = train_scores | |
| train_preds["y_pred"] = (train_scores >= 0.5).astype(int) | |
| train_preds["split"] = "train_cv_oof" | |
| preds_rows.extend( | |
| train_preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records") | |
| ) | |
| plot_reliability_curve( | |
| train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "reliability_train.png" | |
| ) | |
| plot_precision_recall( | |
| train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "pr_train.png" | |
| ) | |
| plot_roc_curve(train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "roc_train.png") | |
| if bootstrap_samples > 0: | |
| ci_map = bootstrap_metric_intervals( | |
| train_preds["y_true"], | |
| train_preds["y_score"], | |
| n_bootstrap=bootstrap_samples, | |
| alpha=bootstrap_alpha, | |
| random_state=config.seed, | |
| ) | |
| for metric_name, stats in ci_map.items(): | |
| row = metrics_lookup.setdefault(metric_name, {"metric": metric_name}) | |
| row["train_cv_ci_lower"] = stats.get("ci_lower") | |
| row["train_cv_ci_upper"] = stats.get("ci_upper") | |
| row["train_cv_ci_median"] = stats.get("ci_median") | |
| if write_train_in_sample: | |
| train_features_full = pipeline.transform( | |
| train_df, heavy_only=heavy_only, batch_size=batch_size | |
| ) | |
| train_full_scores = trained_model.predict_proba(train_features_full) | |
| train_full_metrics = compute_metrics( | |
| train_df["label"].to_numpy(dtype=int), train_full_scores | |
| ) | |
| (outputs_dir / "train_in_sample.json").write_text( | |
| json.dumps(train_full_metrics, indent=2), | |
| encoding="utf-8", | |
| ) | |
| for frame in eval_frames: | |
| if frame.empty: | |
| continue | |
| features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size) | |
| scores = trained_model.predict_proba(features) | |
| y_true = frame["label"].to_numpy(dtype=int) | |
| metrics = compute_metrics(y_true, scores) | |
| dataset_name = frame["source"].iloc[0] | |
| for metric_name, value in metrics.items(): | |
| metrics_lookup.setdefault(metric_name, {"metric": metric_name})[ | |
| dataset_name | |
| ] = value | |
| preds = frame[["id", "source", "label"]].copy() | |
| preds["y_true"] = preds["label"] | |
| preds["y_score"] = scores | |
| preds["y_pred"] = (scores >= 0.5).astype(int) | |
| preds["split"] = dataset_name | |
| preds_rows.extend( | |
| preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records") | |
| ) | |
| plot_reliability_curve( | |
| preds["y_true"], | |
| preds["y_score"], | |
| path=outputs_dir / f"reliability_{dataset_name}.png", | |
| ) | |
| plot_precision_recall( | |
| preds["y_true"], | |
| preds["y_score"], | |
| path=outputs_dir / f"pr_{dataset_name}.png", | |
| ) | |
| plot_roc_curve( | |
| preds["y_true"], preds["y_score"], path=outputs_dir / f"roc_{dataset_name}.png" | |
| ) | |
| if bootstrap_samples > 0: | |
| ci_map = bootstrap_metric_intervals( | |
| preds["y_true"], | |
| preds["y_score"], | |
| n_bootstrap=bootstrap_samples, | |
| alpha=bootstrap_alpha, | |
| random_state=config.seed, | |
| ) | |
| for metric_name, stats in ci_map.items(): | |
| row = metrics_lookup.setdefault(metric_name, {"metric": metric_name}) | |
| row[f"{dataset_name}_ci_lower"] = stats.get("ci_lower") | |
| row[f"{dataset_name}_ci_upper"] = stats.get("ci_upper") | |
| row[f"{dataset_name}_ci_median"] = stats.get("ci_median") | |
| metrics_df = pd.DataFrame(sorted(metrics_lookup.values(), key=lambda row: row["metric"])) | |
| return metrics_df, preds_rows | |
| def _fit_model_calibrator( | |
| model_config: LinearModelConfig, | |
| config: Config, | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| *, | |
| base_estimator: Any | None = None, | |
| ): | |
| method = config.calibration.method | |
| if not method: | |
| return None | |
| if len(np.unique(y)) < 2: | |
| return None | |
| if len(y) >= 4: | |
| cv_cal = min(config.training.cv_folds, max(2, len(y) // 2)) | |
| estimator = build_estimator(config=model_config, random_state=config.seed) | |
| if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: | |
| estimator.set_params(solver="lbfgs") | |
| calibrator = fit_calibrator(estimator, X, y, method=method, cv=cv_cal) | |
| else: | |
| estimator = base_estimator or build_estimator(config=model_config, random_state=config.seed) | |
| if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000: | |
| estimator.set_params(solver="lbfgs") | |
| estimator.fit(X, y) | |
| calibrator = fit_calibrator(estimator, X, y, method=method, cv="prefit") | |
| return calibrator | |
| def _log_environment(logger) -> None: | |
| try: | |
| git_head = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() | |
| except Exception: # pragma: no cover - best effort | |
| git_head = "unknown" | |
| try: | |
| pip_freeze = subprocess.check_output(["pip", "freeze"], text=True) | |
| except Exception: # pragma: no cover | |
| pip_freeze = "" | |
| logger.info("git_head=%s", git_head) | |
| logger.info("pip_freeze=%s", pip_freeze) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |