makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
21.8 kB
"""Training entrypoint for the polyreactivity model."""
from __future__ import annotations
import argparse
import json
import subprocess
from pathlib import Path
from typing import Any, Sequence
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from sklearn.linear_model import LogisticRegression
from .config import Config, load_config
from .data_loaders import boughter, harvey, jain, shehata
from .data_loaders.utils import deduplicate_sequences
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
from .models.calibrate import fit_calibrator
from .models.linear import LinearModelConfig, TrainedModel, build_estimator, train_linear_model
from .utils.io import write_table
from .utils.logging import configure_logging
from .utils.metrics import bootstrap_metric_intervals, compute_metrics
from .utils.plots import plot_precision_recall, plot_reliability_curve, plot_roc_curve
from .utils.seeds import set_global_seeds
DATASET_LOADERS = {
"boughter": boughter.load_dataframe,
"jain": jain.load_dataframe,
"shehata": shehata.load_dataframe,
"shehata_curated": shehata.load_dataframe,
"harvey": harvey.load_dataframe,
}
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train polyreactivity model")
parser.add_argument("--config", default="configs/default.yaml", help="Config file")
parser.add_argument("--train", required=True, help="Training dataset path")
parser.add_argument(
"--eval",
nargs="*",
default=[],
help="Evaluation dataset paths",
)
parser.add_argument(
"--save-to",
default="artifacts/model.joblib",
help="Path to save trained model artifact",
)
parser.add_argument(
"--report-to",
default="artifacts",
help="Directory for metrics, predictions, and plots",
)
parser.add_argument(
"--train-loader",
choices=list(DATASET_LOADERS.keys()),
help="Optional explicit loader for training dataset",
)
parser.add_argument(
"--eval-loaders",
nargs="*",
help="Optional explicit loaders for evaluation datasets (aligned with --eval order)",
)
parser.add_argument(
"--backend",
choices=["plm", "descriptors", "concat"],
help="Override feature backend",
)
parser.add_argument("--plm-model", help="Override PLM model name")
parser.add_argument("--cache-dir", help="Override embedding cache directory")
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], help="Device override")
parser.add_argument("--batch-size", type=int, default=8, help="Batch size for embeddings")
parser.add_argument(
"--heavy-only",
action="store_true",
default=True,
help="Use heavy chains only (default true)",
)
parser.add_argument(
"--paired",
dest="heavy_only",
action="store_false",
help="Use paired heavy/light chains when available.",
)
parser.add_argument(
"--include-families",
nargs="*",
help="Optional list of family names to retain in the training dataset",
)
parser.add_argument(
"--exclude-families",
nargs="*",
help="Optional list of family names to drop from the training dataset",
)
parser.add_argument(
"--include-species",
nargs="*",
help="Optional list of species (e.g. human, mouse) to retain",
)
parser.add_argument(
"--cv-group-column",
default="lineage",
help="Column name used to group samples during cross-validation (default: lineage)",
)
parser.add_argument(
"--no-group-cv",
action="store_true",
help="Disable group-aware cross-validation even if group column is present",
)
parser.add_argument(
"--keep-train-duplicates",
action="store_true",
help="Keep duplicate keys within the training dataset when deduplicating across splits",
)
parser.add_argument(
"--dedupe-key-columns",
nargs="*",
help="Columns used to detect duplicates across datasets (defaults to heavy/light sequences)",
)
parser.add_argument(
"--bootstrap-samples",
type=int,
default=200,
help="Number of bootstrap resamples for confidence intervals (0 to disable).",
)
parser.add_argument(
"--bootstrap-alpha",
type=float,
default=0.05,
help="Alpha for two-sided bootstrap confidence intervals (default 0.05 → 95% CI).",
)
parser.add_argument(
"--write-train-in-sample",
action="store_true",
help=(
"Persist in-sample metrics on the full training set; disabled by default to avoid"
" over-optimistic reporting."
),
)
return parser
def _infer_loader(path: str, explicit: str | None) -> tuple[str, callable]:
if explicit:
return explicit, DATASET_LOADERS[explicit]
lower = Path(path).stem.lower()
for name, loader in DATASET_LOADERS.items():
if name in lower:
return name, loader
msg = f"Could not infer loader for dataset: {path}. Provide --train-loader/--eval-loaders."
raise ValueError(msg)
def _load_dataset(path: str, loader_name: str, loader_fn, *, heavy_only: bool) -> pd.DataFrame:
frame = loader_fn(path, heavy_only=heavy_only)
frame["source"] = loader_name
return frame
def _apply_dataset_filters(
frame: pd.DataFrame,
*,
include_families: Sequence[str] | None,
exclude_families: Sequence[str] | None,
include_species: Sequence[str] | None,
) -> pd.DataFrame:
filtered = frame.copy()
if include_families:
families = {fam.lower() for fam in include_families}
if "family" in filtered.columns:
filtered = filtered[
filtered["family"].astype(str).str.lower().isin(families)
]
if exclude_families:
families_ex = {fam.lower() for fam in exclude_families}
if "family" in filtered.columns:
filtered = filtered[
~filtered["family"].astype(str).str.lower().isin(families_ex)
]
if include_species:
species_set = {spec.lower() for spec in include_species}
if "species" in filtered.columns:
filtered = filtered[
filtered["species"].astype(str).str.lower().isin(species_set)
]
return filtered.reset_index(drop=True)
def main(argv: Sequence[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
config = load_config(args.config)
if args.device:
config.device = args.device
if args.backend:
config.feature_backend.type = args.backend
if args.cache_dir:
config.feature_backend.cache_dir = args.cache_dir
if args.plm_model:
config.feature_backend.plm_model_name = args.plm_model
logger = configure_logging()
set_global_seeds(config.seed)
_log_environment(logger)
heavy_only = args.heavy_only
train_name, train_loader = _infer_loader(args.train, args.train_loader)
train_df = _load_dataset(args.train, train_name, train_loader, heavy_only=heavy_only)
train_df = _apply_dataset_filters(
train_df,
include_families=args.include_families,
exclude_families=args.exclude_families,
include_species=args.include_species,
)
eval_frames: list[pd.DataFrame] = []
if args.eval:
loaders_iter = args.eval_loaders or []
for idx, eval_path in enumerate(args.eval):
explicit = loaders_iter[idx] if idx < len(loaders_iter) else None
eval_name, eval_loader = _infer_loader(eval_path, explicit)
eval_df = _load_dataset(eval_path, eval_name, eval_loader, heavy_only=heavy_only)
eval_frames.append(eval_df)
all_frames = [train_df, *eval_frames]
dedup_keep = {0} if args.keep_train_duplicates else set()
deduped_frames = deduplicate_sequences(
all_frames,
heavy_only=heavy_only,
key_columns=args.dedupe_key_columns,
keep_intra_frames=dedup_keep,
)
train_df = deduped_frames[0]
eval_frames = deduped_frames[1:]
pipeline_factory = lambda: build_feature_pipeline( # noqa: E731
config,
backend_override=args.backend,
plm_model_override=args.plm_model,
cache_dir_override=args.cache_dir,
)
model_config = LinearModelConfig(
head=config.model.head,
C=config.model.C,
class_weight=config.model.class_weight,
)
groups = None
if not args.no_group_cv and args.cv_group_column:
if args.cv_group_column in train_df.columns:
groups = train_df[args.cv_group_column].fillna("").astype(str).to_numpy()
else:
logger.warning(
"Group column '%s' not found in training dataframe; falling back to standard CV",
args.cv_group_column,
)
cv_results = _cross_validate(
train_df,
pipeline_factory,
model_config,
config,
heavy_only=heavy_only,
batch_size=args.batch_size,
groups=groups,
)
trained_model, feature_pipeline = _fit_full_model(
train_df,
pipeline_factory,
model_config,
config,
heavy_only=heavy_only,
batch_size=args.batch_size,
)
outputs_dir = Path(args.report_to)
outputs_dir.mkdir(parents=True, exist_ok=True)
metrics_df, preds_rows = _evaluate_datasets(
train_df,
eval_frames,
trained_model,
feature_pipeline,
config,
cv_results,
outputs_dir,
batch_size=args.batch_size,
heavy_only=heavy_only,
bootstrap_samples=args.bootstrap_samples,
bootstrap_alpha=args.bootstrap_alpha,
write_train_in_sample=args.write_train_in_sample,
)
write_table(metrics_df, outputs_dir / config.io.metrics_filename)
preds_df = pd.DataFrame(preds_rows)
write_table(preds_df, outputs_dir / config.io.preds_filename)
artifact = {
"config": config,
"feature_state": feature_pipeline.get_state(),
"model": trained_model,
}
Path(args.save_to).parent.mkdir(parents=True, exist_ok=True)
joblib.dump(artifact, args.save_to)
logger.info("Training complete. Metrics written to %s", outputs_dir)
return 0
def _cross_validate(
train_df: pd.DataFrame,
pipeline_factory,
model_config: LinearModelConfig,
config: Config,
*,
heavy_only: bool,
batch_size: int,
groups: np.ndarray | None = None,
):
y = train_df["label"].to_numpy(dtype=int)
n_samples = len(y)
# Determine a safe number of folds for tiny fixtures; prefer the configured value
# but never exceed the number of samples. Fall back to non-stratified KFold when
# per-class counts are too small for stratification (e.g., 1 positive/1 negative).
n_splits = max(2, min(config.training.cv_folds, n_samples))
use_stratified = True
class_counts = np.bincount(y) if y.size else np.array([])
if class_counts.size > 0 and (class_counts.min(initial=0) < n_splits):
use_stratified = False
if groups is not None and use_stratified:
splitter = StratifiedGroupKFold(
n_splits=n_splits,
shuffle=True,
random_state=config.seed,
)
split_iter = splitter.split(train_df, y, groups)
elif use_stratified:
splitter = StratifiedKFold(
n_splits=n_splits,
shuffle=True,
random_state=config.seed,
)
split_iter = splitter.split(train_df, y)
else:
# Non-stratified fallback for extreme class imbalance / tiny datasets
from sklearn.model_selection import KFold # local import to limit surface
splitter = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed)
split_iter = splitter.split(train_df)
oof_scores = np.zeros(len(train_df), dtype=float)
metrics_per_fold: list[dict[str, float]] = []
for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
train_slice = train_df.iloc[train_idx].reset_index(drop=True)
val_slice = train_df.iloc[val_idx].reset_index(drop=True)
pipeline: FeaturePipeline = pipeline_factory()
X_train = pipeline.fit_transform(train_slice, heavy_only=heavy_only, batch_size=batch_size)
X_val = pipeline.transform(val_slice, heavy_only=heavy_only, batch_size=batch_size)
y_train = y[train_idx]
y_val = y[val_idx]
# Handle degenerate folds where training data contains a single class
if np.unique(y_train).size < 2:
fallback_prob = float(y.mean()) if y.size else 0.5
y_scores = np.full(X_val.shape[0], fallback_prob, dtype=float)
else:
trained = train_linear_model(
X_train, y_train, config=model_config, random_state=config.seed
)
calibrator = _fit_model_calibrator(
model_config,
config,
X_train,
y_train,
base_estimator=trained.estimator,
)
trained.calibrator = calibrator
if calibrator is not None:
y_scores = calibrator.predict_proba(X_val)[:, 1]
else:
y_scores = trained.predict_proba(X_val)
oof_scores[val_idx] = y_scores
fold_metrics = compute_metrics(y_val, y_scores)
try:
fold_metrics["roc_auc"] = float(roc_auc_score(y_val, y_scores))
except ValueError:
# For tiny validation folds with a single class, ROC-AUC is undefined
pass
metrics_per_fold.append(fold_metrics)
metrics_mean: dict[str, float] = {}
metrics_std: dict[str, float] = {}
metric_names = list(metrics_per_fold[0].keys()) if metrics_per_fold else []
for metric in metric_names:
values = [fold[metric] for fold in metrics_per_fold]
metrics_mean[metric] = float(np.mean(values))
metrics_std[metric] = float(np.std(values, ddof=1))
return {
"oof_scores": oof_scores,
"metrics_per_fold": metrics_per_fold,
"metrics_mean": metrics_mean,
"metrics_std": metrics_std,
}
def _fit_full_model(
train_df: pd.DataFrame,
pipeline_factory,
model_config: LinearModelConfig,
config: Config,
*,
heavy_only: bool,
batch_size: int,
) -> tuple[TrainedModel, FeaturePipeline]:
pipeline: FeaturePipeline = pipeline_factory()
X_train = pipeline.fit_transform(train_df, heavy_only=heavy_only, batch_size=batch_size)
y_train = train_df["label"].to_numpy(dtype=int)
trained = train_linear_model(X_train, y_train, config=model_config, random_state=config.seed)
calibrator = _fit_model_calibrator(
model_config,
config,
X_train,
y_train,
base_estimator=trained.estimator,
)
trained.calibrator = calibrator
return trained, pipeline
def _evaluate_datasets(
train_df: pd.DataFrame,
eval_frames: list[pd.DataFrame],
trained_model: TrainedModel,
pipeline: FeaturePipeline,
config: Config,
cv_results: dict,
outputs_dir: Path,
*,
batch_size: int,
heavy_only: bool,
bootstrap_samples: int,
bootstrap_alpha: float,
write_train_in_sample: bool,
):
metrics_lookup: dict[str, dict[str, float]] = {}
preds_rows: list[dict[str, float]] = []
metrics_mean: dict[str, float] = cv_results["metrics_mean"]
metrics_std: dict[str, float] = cv_results["metrics_std"]
for metric_name, value in metrics_mean.items():
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
"train_cv_mean"
] = value
for metric_name, value in metrics_std.items():
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
"train_cv_std"
] = value
train_scores = cv_results["oof_scores"]
train_preds = train_df[["id", "source", "label"]].copy()
train_preds["y_true"] = train_preds["label"]
train_preds["y_score"] = train_scores
train_preds["y_pred"] = (train_scores >= 0.5).astype(int)
train_preds["split"] = "train_cv_oof"
preds_rows.extend(
train_preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
)
plot_reliability_curve(
train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "reliability_train.png"
)
plot_precision_recall(
train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "pr_train.png"
)
plot_roc_curve(train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "roc_train.png")
if bootstrap_samples > 0:
ci_map = bootstrap_metric_intervals(
train_preds["y_true"],
train_preds["y_score"],
n_bootstrap=bootstrap_samples,
alpha=bootstrap_alpha,
random_state=config.seed,
)
for metric_name, stats in ci_map.items():
row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
row["train_cv_ci_lower"] = stats.get("ci_lower")
row["train_cv_ci_upper"] = stats.get("ci_upper")
row["train_cv_ci_median"] = stats.get("ci_median")
if write_train_in_sample:
train_features_full = pipeline.transform(
train_df, heavy_only=heavy_only, batch_size=batch_size
)
train_full_scores = trained_model.predict_proba(train_features_full)
train_full_metrics = compute_metrics(
train_df["label"].to_numpy(dtype=int), train_full_scores
)
(outputs_dir / "train_in_sample.json").write_text(
json.dumps(train_full_metrics, indent=2),
encoding="utf-8",
)
for frame in eval_frames:
if frame.empty:
continue
features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
scores = trained_model.predict_proba(features)
y_true = frame["label"].to_numpy(dtype=int)
metrics = compute_metrics(y_true, scores)
dataset_name = frame["source"].iloc[0]
for metric_name, value in metrics.items():
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
dataset_name
] = value
preds = frame[["id", "source", "label"]].copy()
preds["y_true"] = preds["label"]
preds["y_score"] = scores
preds["y_pred"] = (scores >= 0.5).astype(int)
preds["split"] = dataset_name
preds_rows.extend(
preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
)
plot_reliability_curve(
preds["y_true"],
preds["y_score"],
path=outputs_dir / f"reliability_{dataset_name}.png",
)
plot_precision_recall(
preds["y_true"],
preds["y_score"],
path=outputs_dir / f"pr_{dataset_name}.png",
)
plot_roc_curve(
preds["y_true"], preds["y_score"], path=outputs_dir / f"roc_{dataset_name}.png"
)
if bootstrap_samples > 0:
ci_map = bootstrap_metric_intervals(
preds["y_true"],
preds["y_score"],
n_bootstrap=bootstrap_samples,
alpha=bootstrap_alpha,
random_state=config.seed,
)
for metric_name, stats in ci_map.items():
row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
row[f"{dataset_name}_ci_lower"] = stats.get("ci_lower")
row[f"{dataset_name}_ci_upper"] = stats.get("ci_upper")
row[f"{dataset_name}_ci_median"] = stats.get("ci_median")
metrics_df = pd.DataFrame(sorted(metrics_lookup.values(), key=lambda row: row["metric"]))
return metrics_df, preds_rows
def _fit_model_calibrator(
model_config: LinearModelConfig,
config: Config,
X: np.ndarray,
y: np.ndarray,
*,
base_estimator: Any | None = None,
):
method = config.calibration.method
if not method:
return None
if len(np.unique(y)) < 2:
return None
if len(y) >= 4:
cv_cal = min(config.training.cv_folds, max(2, len(y) // 2))
estimator = build_estimator(config=model_config, random_state=config.seed)
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
estimator.set_params(solver="lbfgs")
calibrator = fit_calibrator(estimator, X, y, method=method, cv=cv_cal)
else:
estimator = base_estimator or build_estimator(config=model_config, random_state=config.seed)
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
estimator.set_params(solver="lbfgs")
estimator.fit(X, y)
calibrator = fit_calibrator(estimator, X, y, method=method, cv="prefit")
return calibrator
def _log_environment(logger) -> None:
try:
git_head = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
except Exception: # pragma: no cover - best effort
git_head = "unknown"
try:
pip_freeze = subprocess.check_output(["pip", "freeze"], text=True)
except Exception: # pragma: no cover
pip_freeze = ""
logger.info("git_head=%s", git_head)
logger.info("pip_freeze=%s", pip_freeze)
if __name__ == "__main__":
raise SystemExit(main())