Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- polyreact/__init__.py +10 -0
- polyreact/__pycache__/__init__.cpython-311.pyc +0 -0
- polyreact/__pycache__/api.cpython-311.pyc +0 -0
- polyreact/__pycache__/config.cpython-311.pyc +0 -0
- polyreact/__pycache__/predict.cpython-311.pyc +0 -0
- polyreact/__pycache__/train.cpython-311.pyc +0 -0
- polyreact/api.py +121 -0
- polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc +0 -0
- polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc +0 -0
- polyreact/benchmarks/reproduce_paper.ipynb +25 -0
- polyreact/benchmarks/reproduce_paper.py +1020 -0
- polyreact/benchmarks/run_benchmarks.py +114 -0
- polyreact/config.py +160 -0
- polyreact/configs/__init__.py +1 -0
- polyreact/configs/default.yaml +34 -0
- polyreact/data_loaders/__init__.py +3 -0
- polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc +0 -0
- polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc +0 -0
- polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc +0 -0
- polyreact/data_loaders/__pycache__/jain.cpython-311.pyc +0 -0
- polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc +0 -0
- polyreact/data_loaders/__pycache__/utils.cpython-311.pyc +0 -0
- polyreact/data_loaders/boughter.py +76 -0
- polyreact/data_loaders/harvey.py +39 -0
- polyreact/data_loaders/jain.py +41 -0
- polyreact/data_loaders/shehata.py +202 -0
- polyreact/data_loaders/utils.py +186 -0
- polyreact/features/__init__.py +13 -0
- polyreact/features/__pycache__/__init__.cpython-311.pyc +0 -0
- polyreact/features/__pycache__/anarsi.cpython-311.pyc +0 -0
- polyreact/features/__pycache__/descriptors.cpython-311.pyc +0 -0
- polyreact/features/__pycache__/pipeline.cpython-311.pyc +0 -0
- polyreact/features/__pycache__/plm.cpython-311.pyc +0 -0
- polyreact/features/anarsi.py +222 -0
- polyreact/features/descriptors.py +146 -0
- polyreact/features/pipeline.py +343 -0
- polyreact/features/plm.py +378 -0
- polyreact/models/__init__.py +3 -0
- polyreact/models/__pycache__/__init__.cpython-311.pyc +0 -0
- polyreact/models/__pycache__/calibrate.cpython-311.pyc +0 -0
- polyreact/models/__pycache__/linear.cpython-311.pyc +0 -0
- polyreact/models/__pycache__/ordinal.cpython-311.pyc +0 -0
- polyreact/models/calibrate.py +24 -0
- polyreact/models/linear.py +91 -0
- polyreact/models/ordinal.py +106 -0
- polyreact/predict.py +106 -0
- polyreact/train.py +619 -0
- polyreact/utils/__pycache__/io.cpython-311.pyc +0 -0
- polyreact/utils/__pycache__/logging.cpython-311.pyc +0 -0
- polyreact/utils/__pycache__/metrics.cpython-311.pyc +0 -0
polyreact/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Polyreactivity prediction package."""
|
| 2 |
+
|
| 3 |
+
from importlib import metadata
|
| 4 |
+
|
| 5 |
+
__all__ = ["__version__"]
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
__version__ = metadata.version("polyreact")
|
| 9 |
+
except metadata.PackageNotFoundError: # pragma: no cover
|
| 10 |
+
__version__ = "0.0.0"
|
polyreact/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (502 Bytes). View file
|
|
|
polyreact/__pycache__/api.cpython-311.pyc
ADDED
|
Binary file (6.1 kB). View file
|
|
|
polyreact/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (9.9 kB). View file
|
|
|
polyreact/__pycache__/predict.cpython-311.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
polyreact/__pycache__/train.cpython-311.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
polyreact/api.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public Python API for polyreactivity prediction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
|
| 10 |
+
import joblib
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from sklearn.preprocessing import StandardScaler
|
| 13 |
+
|
| 14 |
+
from .config import Config, load_config
|
| 15 |
+
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def predict_batch( # noqa: ANN003
|
| 19 |
+
records: Iterable[dict],
|
| 20 |
+
*,
|
| 21 |
+
config: Config | str | Path | None = None,
|
| 22 |
+
backend: str | None = None,
|
| 23 |
+
plm_model: str | None = None,
|
| 24 |
+
weights: str | Path | None = None,
|
| 25 |
+
heavy_only: bool = True,
|
| 26 |
+
batch_size: int = 8,
|
| 27 |
+
device: str | None = None,
|
| 28 |
+
cache_dir: str | None = None,
|
| 29 |
+
) -> pd.DataFrame:
|
| 30 |
+
"""Predict polyreactivity scores for a batch of sequences."""
|
| 31 |
+
|
| 32 |
+
records_list = list(records)
|
| 33 |
+
if not records_list:
|
| 34 |
+
return pd.DataFrame(columns=["id", "score", "pred"])
|
| 35 |
+
|
| 36 |
+
artifact = _load_artifact(weights)
|
| 37 |
+
|
| 38 |
+
if config is None:
|
| 39 |
+
artifact_config = artifact.get("config")
|
| 40 |
+
if isinstance(artifact_config, Config):
|
| 41 |
+
config = copy.deepcopy(artifact_config)
|
| 42 |
+
else:
|
| 43 |
+
config = load_config("configs/default.yaml")
|
| 44 |
+
elif isinstance(config, (str, Path)):
|
| 45 |
+
config = load_config(config)
|
| 46 |
+
else:
|
| 47 |
+
config = copy.deepcopy(config)
|
| 48 |
+
|
| 49 |
+
if backend:
|
| 50 |
+
config.feature_backend.type = backend
|
| 51 |
+
if plm_model:
|
| 52 |
+
config.feature_backend.plm_model_name = plm_model
|
| 53 |
+
if device:
|
| 54 |
+
config.device = device
|
| 55 |
+
if cache_dir:
|
| 56 |
+
config.feature_backend.cache_dir = cache_dir
|
| 57 |
+
|
| 58 |
+
pipeline = _restore_pipeline(config, artifact)
|
| 59 |
+
trained_model = artifact["model"]
|
| 60 |
+
|
| 61 |
+
frame = pd.DataFrame(records_list)
|
| 62 |
+
if frame.empty:
|
| 63 |
+
raise ValueError("Prediction requires at least one record.")
|
| 64 |
+
if "id" not in frame.columns:
|
| 65 |
+
frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str)
|
| 66 |
+
if "heavy_seq" in frame.columns:
|
| 67 |
+
frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
|
| 68 |
+
else:
|
| 69 |
+
heavy_series = frame.get("heavy")
|
| 70 |
+
if heavy_series is None:
|
| 71 |
+
heavy_series = pd.Series([""] * len(frame))
|
| 72 |
+
frame["heavy_seq"] = heavy_series.fillna("").astype(str)
|
| 73 |
+
|
| 74 |
+
if "light_seq" in frame.columns:
|
| 75 |
+
frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
|
| 76 |
+
else:
|
| 77 |
+
light_series = frame.get("light")
|
| 78 |
+
if light_series is None:
|
| 79 |
+
light_series = pd.Series([""] * len(frame))
|
| 80 |
+
frame["light_seq"] = light_series.fillna("").astype(str)
|
| 81 |
+
|
| 82 |
+
if heavy_only:
|
| 83 |
+
frame["light_seq"] = ""
|
| 84 |
+
if frame["heavy_seq"].str.len().eq(0).all():
|
| 85 |
+
raise ValueError("No heavy chain sequences provided for prediction.")
|
| 86 |
+
|
| 87 |
+
features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
|
| 88 |
+
scores = trained_model.predict_proba(features)
|
| 89 |
+
preds = (scores >= 0.5).astype(int)
|
| 90 |
+
|
| 91 |
+
return pd.DataFrame(
|
| 92 |
+
{
|
| 93 |
+
"id": frame["id"].astype(str),
|
| 94 |
+
"score": scores,
|
| 95 |
+
"pred": preds,
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _load_artifact(weights: str | Path | None) -> dict:
|
| 101 |
+
if weights is None:
|
| 102 |
+
msg = "Prediction requires a path to model weights"
|
| 103 |
+
raise ValueError(msg)
|
| 104 |
+
artifact = joblib.load(weights)
|
| 105 |
+
if not isinstance(artifact, dict):
|
| 106 |
+
msg = "Model artifact must be a dictionary"
|
| 107 |
+
raise ValueError(msg)
|
| 108 |
+
return artifact
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline:
|
| 112 |
+
pipeline = build_feature_pipeline(config)
|
| 113 |
+
state = artifact.get("feature_state")
|
| 114 |
+
if isinstance(state, FeaturePipelineState):
|
| 115 |
+
pipeline.load_state(state)
|
| 116 |
+
if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None:
|
| 117 |
+
pipeline._plm_scaler = StandardScaler()
|
| 118 |
+
return pipeline
|
| 119 |
+
|
| 120 |
+
msg = "Model artifact is missing feature pipeline state"
|
| 121 |
+
raise ValueError(msg)
|
polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc
ADDED
|
Binary file (46.4 kB). View file
|
|
|
polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc
ADDED
|
Binary file (5.21 kB). View file
|
|
|
polyreact/benchmarks/reproduce_paper.ipynb
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Polyreactivity Benchmark Notebook\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook will reproduce paper results once the pipeline is implemented." ]
|
| 10 |
+
}
|
| 11 |
+
],
|
| 12 |
+
"metadata": {
|
| 13 |
+
"kernelspec": {
|
| 14 |
+
"display_name": "Python 3",
|
| 15 |
+
"language": "python",
|
| 16 |
+
"name": "python3"
|
| 17 |
+
},
|
| 18 |
+
"language_info": {
|
| 19 |
+
"name": "python",
|
| 20 |
+
"version": "3.11"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"nbformat": 4,
|
| 24 |
+
"nbformat_minor": 5
|
| 25 |
+
}
|
polyreact/benchmarks/reproduce_paper.py
ADDED
|
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reproduce key metrics and visualisations for the polyreactivity model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import copy
|
| 7 |
+
import json
|
| 8 |
+
import subprocess
|
| 9 |
+
from dataclasses import asdict, dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, Iterable, List, Sequence
|
| 12 |
+
|
| 13 |
+
import joblib
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import yaml
|
| 18 |
+
from scipy.stats import pearsonr, spearmanr
|
| 19 |
+
from sklearn.linear_model import LogisticRegression
|
| 20 |
+
from sklearn.metrics import roc_curve
|
| 21 |
+
from sklearn.model_selection import KFold
|
| 22 |
+
from sklearn.preprocessing import StandardScaler
|
| 23 |
+
|
| 24 |
+
from polyreact import train as train_module
|
| 25 |
+
from polyreact.config import load_config
|
| 26 |
+
from polyreact.features.anarsi import AnarciNumberer
|
| 27 |
+
from polyreact.features.pipeline import FeaturePipeline
|
| 28 |
+
from polyreact.models.ordinal import (
|
| 29 |
+
fit_negative_binomial_model,
|
| 30 |
+
fit_poisson_model,
|
| 31 |
+
pearson_dispersion,
|
| 32 |
+
regression_metrics,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass(slots=True)
|
| 37 |
+
class DatasetSpec:
|
| 38 |
+
name: str
|
| 39 |
+
path: Path
|
| 40 |
+
display: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
DISPLAY_LABELS = {
|
| 44 |
+
"jain": "Jain (2017)",
|
| 45 |
+
"shehata": "Shehata PSR (398)",
|
| 46 |
+
"shehata_curated": "Shehata curated (88)",
|
| 47 |
+
"harvey": "Harvey (2022)",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
RAW_LABELS = {
|
| 51 |
+
"jain": "jain2017",
|
| 52 |
+
"shehata": "shehata2019",
|
| 53 |
+
"shehata_curated": "shehata2019_curated",
|
| 54 |
+
"harvey": "harvey2022",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 59 |
+
parser = argparse.ArgumentParser(description="Reproduce paper-style metrics and plots")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--train-data",
|
| 62 |
+
default="data/processed/boughter_counts_rebuilt.csv",
|
| 63 |
+
help="Reconstructed Boughter dataset path.",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--full-data",
|
| 67 |
+
default="data/processed/boughter_counts_rebuilt.csv",
|
| 68 |
+
help="Dataset (including mild flags) for correlation analysis.",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument("--jain", default="data/processed/jain.csv")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--shehata",
|
| 73 |
+
default="data/processed/shehata_full.csv",
|
| 74 |
+
help="Full Shehata PSR panel (398 sequences) in processed CSV form.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--shehata-curated",
|
| 78 |
+
default="data/processed/shehata_curated.csv",
|
| 79 |
+
help="Optional curated subset of Shehata et al. (88 sequences).",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument("--harvey", default="data/processed/harvey.csv")
|
| 82 |
+
parser.add_argument("--output-dir", default="artifacts/paper")
|
| 83 |
+
parser.add_argument("--config", default="configs/default.yaml")
|
| 84 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 85 |
+
parser.add_argument("--rebuild", action="store_true")
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--bootstrap-samples",
|
| 88 |
+
type=int,
|
| 89 |
+
default=1000,
|
| 90 |
+
help="Bootstrap resamples for metrics confidence intervals.",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--bootstrap-alpha",
|
| 94 |
+
type=float,
|
| 95 |
+
default=0.05,
|
| 96 |
+
help="Alpha for bootstrap confidence intervals (default 0.05 → 95%).",
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--human-only",
|
| 100 |
+
action="store_true",
|
| 101 |
+
help=(
|
| 102 |
+
"Restrict the main cross-validation run to human HIV and influenza families"
|
| 103 |
+
" (legacy behaviour). By default all Boughter families, including mouse IgA,"
|
| 104 |
+
" participate in CV as in Sakhnini et al."
|
| 105 |
+
),
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--skip-flag-regression",
|
| 109 |
+
action="store_true",
|
| 110 |
+
help="Skip ELISA flag regression diagnostics (Poisson/NB).",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--skip-lofo",
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="Skip leave-one-family-out experiments.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--skip-descriptor-variants",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Skip descriptor-only benchmark variants.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--skip-fragment-variants",
|
| 124 |
+
action="store_true",
|
| 125 |
+
help="Skip CDR fragment ablation benchmarks.",
|
| 126 |
+
)
|
| 127 |
+
return parser
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _config_to_dict(config) -> Dict[str, Any]:
|
| 131 |
+
data = asdict(config)
|
| 132 |
+
data.pop("raw", None)
|
| 133 |
+
return data
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
|
| 137 |
+
result = copy.deepcopy(base)
|
| 138 |
+
for key, value in overrides.items():
|
| 139 |
+
if isinstance(value, dict) and isinstance(result.get(key), dict):
|
| 140 |
+
result[key] = _deep_merge(result.get(key, {}), value)
|
| 141 |
+
else:
|
| 142 |
+
result[key] = value
|
| 143 |
+
return result
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _write_variant_config(
|
| 147 |
+
base_config: Dict[str, Any],
|
| 148 |
+
overrides: Dict[str, Any],
|
| 149 |
+
target_path: Path,
|
| 150 |
+
) -> Path:
|
| 151 |
+
merged = _deep_merge(base_config, overrides)
|
| 152 |
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
| 153 |
+
with target_path.open("w", encoding="utf-8") as handle:
|
| 154 |
+
yaml.safe_dump(merged, handle, sort_keys=False)
|
| 155 |
+
return target_path
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _collect_metric_records(variant: str, metrics: pd.DataFrame) -> list[dict[str, Any]]:
|
| 159 |
+
tracked = {
|
| 160 |
+
"roc_auc",
|
| 161 |
+
"pr_auc",
|
| 162 |
+
"accuracy",
|
| 163 |
+
"f1",
|
| 164 |
+
"f1_positive",
|
| 165 |
+
"f1_negative",
|
| 166 |
+
"precision",
|
| 167 |
+
"sensitivity",
|
| 168 |
+
"specificity",
|
| 169 |
+
"brier",
|
| 170 |
+
"ece",
|
| 171 |
+
"mce",
|
| 172 |
+
}
|
| 173 |
+
records: list[dict[str, Any]] = []
|
| 174 |
+
for _, row in metrics.iterrows():
|
| 175 |
+
metric_name = row["metric"]
|
| 176 |
+
if metric_name not in tracked:
|
| 177 |
+
continue
|
| 178 |
+
record = {"variant": variant, "metric": metric_name}
|
| 179 |
+
for column in metrics.columns:
|
| 180 |
+
if column == "metric":
|
| 181 |
+
continue
|
| 182 |
+
record[column] = float(row[column]) if pd.notna(row[column]) else np.nan
|
| 183 |
+
records.append(record)
|
| 184 |
+
return records
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _dump_coefficients(model_path: Path, output_path: Path) -> None:
|
| 188 |
+
artifact = joblib.load(model_path)
|
| 189 |
+
trained = artifact["model"]
|
| 190 |
+
estimator = getattr(trained, "estimator", None)
|
| 191 |
+
if estimator is None or not hasattr(estimator, "coef_"):
|
| 192 |
+
return
|
| 193 |
+
coefs = estimator.coef_[0]
|
| 194 |
+
feature_state = artifact.get("feature_state")
|
| 195 |
+
feature_names: list[str]
|
| 196 |
+
if feature_state is not None and getattr(feature_state, "feature_names", None):
|
| 197 |
+
feature_names = list(feature_state.feature_names)
|
| 198 |
+
else:
|
| 199 |
+
feature_names = [f"f{i}" for i in range(len(coefs))]
|
| 200 |
+
coeff_df = pd.DataFrame(
|
| 201 |
+
{
|
| 202 |
+
"feature": feature_names,
|
| 203 |
+
"coef": coefs,
|
| 204 |
+
"abs_coef": np.abs(coefs),
|
| 205 |
+
}
|
| 206 |
+
).sort_values("abs_coef", ascending=False)
|
| 207 |
+
coeff_df.to_csv(output_path, index=False)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _summarise_predictions(preds: pd.DataFrame) -> pd.DataFrame:
|
| 211 |
+
records: list[dict[str, Any]] = []
|
| 212 |
+
for split, group in preds.groupby("split"):
|
| 213 |
+
stats = {
|
| 214 |
+
"split": split,
|
| 215 |
+
"n_samples": int(len(group)),
|
| 216 |
+
"positives": int(group["y_true"].sum()),
|
| 217 |
+
"positive_rate": float(group["y_true"].mean()) if len(group) else np.nan,
|
| 218 |
+
"score_mean": float(group["y_score"].mean()) if len(group) else np.nan,
|
| 219 |
+
"score_std": float(group["y_score"].std(ddof=1)) if len(group) > 1 else np.nan,
|
| 220 |
+
}
|
| 221 |
+
records.append(stats)
|
| 222 |
+
return pd.DataFrame(records)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def _summarise_raw_dataset(path: Path, name: str) -> dict[str, Any]:
|
| 226 |
+
df = pd.read_csv(path)
|
| 227 |
+
summary: dict[str, Any] = {
|
| 228 |
+
"dataset": name,
|
| 229 |
+
"path": str(path),
|
| 230 |
+
"rows": int(len(df)),
|
| 231 |
+
}
|
| 232 |
+
if "label" in df.columns:
|
| 233 |
+
positives = int(df["label"].sum())
|
| 234 |
+
summary["positives"] = positives
|
| 235 |
+
summary["positive_rate"] = float(df["label"].mean()) if len(df) else np.nan
|
| 236 |
+
if "reactivity_count" in df.columns:
|
| 237 |
+
summary["reactivity_count_mean"] = float(df["reactivity_count"].mean())
|
| 238 |
+
summary["reactivity_count_median"] = float(df["reactivity_count"].median())
|
| 239 |
+
summary["reactivity_count_max"] = int(df["reactivity_count"].max())
|
| 240 |
+
if "smp" in df.columns:
|
| 241 |
+
summary["smp_mean"] = float(df["smp"].mean())
|
| 242 |
+
summary["smp_median"] = float(df["smp"].median())
|
| 243 |
+
summary["smp_max"] = float(df["smp"].max())
|
| 244 |
+
summary["smp_min"] = float(df["smp"].min())
|
| 245 |
+
summary["unique_heavy"] = int(df["heavy_seq"].nunique()) if "heavy_seq" in df.columns else np.nan
|
| 246 |
+
return summary
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _extract_region_sequence(sequence: str, regions: List[str], numberer: AnarciNumberer) -> str:
|
| 250 |
+
if not sequence:
|
| 251 |
+
return ""
|
| 252 |
+
upper_regions = [region.upper() for region in regions]
|
| 253 |
+
if upper_regions == ["VH"]:
|
| 254 |
+
return sequence
|
| 255 |
+
try:
|
| 256 |
+
numbered = numberer.number_sequence(sequence)
|
| 257 |
+
except Exception:
|
| 258 |
+
return ""
|
| 259 |
+
fragments: list[str] = []
|
| 260 |
+
for region in upper_regions:
|
| 261 |
+
if region == "VH":
|
| 262 |
+
return sequence
|
| 263 |
+
fragment = numbered.regions.get(region)
|
| 264 |
+
if not fragment:
|
| 265 |
+
return ""
|
| 266 |
+
fragments.append(fragment)
|
| 267 |
+
return "".join(fragments)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _make_region_dataset(
|
| 271 |
+
frame: pd.DataFrame, regions: List[str], numberer: AnarciNumberer
|
| 272 |
+
) -> tuple[pd.DataFrame, dict[str, Any]]:
|
| 273 |
+
records: list[dict[str, Any]] = []
|
| 274 |
+
dropped = 0
|
| 275 |
+
for record in frame.to_dict(orient="records"):
|
| 276 |
+
new_seq = _extract_region_sequence(record.get("heavy_seq", ""), regions, numberer)
|
| 277 |
+
if not new_seq:
|
| 278 |
+
dropped += 1
|
| 279 |
+
continue
|
| 280 |
+
updated = record.copy()
|
| 281 |
+
updated["heavy_seq"] = new_seq
|
| 282 |
+
updated["light_seq"] = ""
|
| 283 |
+
records.append(updated)
|
| 284 |
+
result = pd.DataFrame(records, columns=frame.columns)
|
| 285 |
+
summary = {
|
| 286 |
+
"regions": "+".join(regions),
|
| 287 |
+
"input_rows": int(len(frame)),
|
| 288 |
+
"retained_rows": int(len(result)),
|
| 289 |
+
"dropped_rows": int(dropped),
|
| 290 |
+
}
|
| 291 |
+
return result, summary
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def run_train(
|
| 295 |
+
*,
|
| 296 |
+
train_path: Path,
|
| 297 |
+
eval_specs: Sequence[DatasetSpec],
|
| 298 |
+
output_dir: Path,
|
| 299 |
+
model_path: Path,
|
| 300 |
+
config: str,
|
| 301 |
+
batch_size: int,
|
| 302 |
+
include_species: list[str] | None = None,
|
| 303 |
+
include_families: list[str] | None = None,
|
| 304 |
+
exclude_families: list[str] | None = None,
|
| 305 |
+
keep_duplicates: bool = False,
|
| 306 |
+
group_column: str | None = "lineage",
|
| 307 |
+
train_loader: str | None = None,
|
| 308 |
+
bootstrap_samples: int = 200,
|
| 309 |
+
bootstrap_alpha: float = 0.05,
|
| 310 |
+
) -> None:
|
| 311 |
+
args: list[str] = [
|
| 312 |
+
"--config",
|
| 313 |
+
str(config),
|
| 314 |
+
"--train",
|
| 315 |
+
str(train_path),
|
| 316 |
+
"--report-to",
|
| 317 |
+
str(output_dir),
|
| 318 |
+
"--save-to",
|
| 319 |
+
str(model_path),
|
| 320 |
+
"--batch-size",
|
| 321 |
+
str(batch_size),
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
if eval_specs:
|
| 325 |
+
args.append("--eval")
|
| 326 |
+
args.extend(str(spec.path) for spec in eval_specs)
|
| 327 |
+
|
| 328 |
+
if train_loader:
|
| 329 |
+
args.extend(["--train-loader", train_loader])
|
| 330 |
+
if eval_specs:
|
| 331 |
+
args.append("--eval-loaders")
|
| 332 |
+
args.extend(spec.name for spec in eval_specs)
|
| 333 |
+
if include_species:
|
| 334 |
+
args.append("--include-species")
|
| 335 |
+
args.extend(include_species)
|
| 336 |
+
if include_families:
|
| 337 |
+
args.append("--include-families")
|
| 338 |
+
args.extend(include_families)
|
| 339 |
+
if exclude_families:
|
| 340 |
+
args.append("--exclude-families")
|
| 341 |
+
args.extend(exclude_families)
|
| 342 |
+
if keep_duplicates:
|
| 343 |
+
args.append("--keep-train-duplicates")
|
| 344 |
+
if group_column:
|
| 345 |
+
args.extend(["--cv-group-column", group_column])
|
| 346 |
+
else:
|
| 347 |
+
args.append("--no-group-cv")
|
| 348 |
+
args.extend(["--bootstrap-samples", str(bootstrap_samples)])
|
| 349 |
+
args.extend(["--bootstrap-alpha", str(bootstrap_alpha)])
|
| 350 |
+
|
| 351 |
+
exit_code = train_module.main(args)
|
| 352 |
+
if exit_code != 0:
|
| 353 |
+
raise RuntimeError(f"Training command failed with exit code {exit_code}")
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def compute_spearman(model_path: Path, dataset_path: Path, batch_size: int) -> tuple[float, float, pd.DataFrame]:
|
| 357 |
+
artifact = joblib.load(model_path)
|
| 358 |
+
config = artifact["config"]
|
| 359 |
+
pipeline_state = artifact["feature_state"]
|
| 360 |
+
trained_model = artifact["model"]
|
| 361 |
+
|
| 362 |
+
pipeline = FeaturePipeline(backend=config.feature_backend, descriptors=config.descriptors, device=config.device)
|
| 363 |
+
pipeline.load_state(pipeline_state)
|
| 364 |
+
|
| 365 |
+
dataset = pd.read_csv(dataset_path)
|
| 366 |
+
features = pipeline.transform(dataset, heavy_only=True, batch_size=batch_size)
|
| 367 |
+
scores = trained_model.predict_proba(features)
|
| 368 |
+
dataset = dataset.copy()
|
| 369 |
+
dataset["score"] = scores
|
| 370 |
+
|
| 371 |
+
stat, pvalue = spearmanr(dataset["reactivity_count"], dataset["score"])
|
| 372 |
+
return float(stat), float(pvalue), dataset
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def plot_accuracy(
|
| 376 |
+
metrics: pd.DataFrame,
|
| 377 |
+
output_path: Path,
|
| 378 |
+
eval_specs: Sequence[DatasetSpec],
|
| 379 |
+
) -> None:
|
| 380 |
+
row = metrics.loc[metrics["metric"] == "accuracy"].iloc[0]
|
| 381 |
+
labels = ["Train CV"] + [spec.display for spec in eval_specs]
|
| 382 |
+
values = [row.get("train_cv_mean", np.nan)] + [row.get(spec.name, np.nan) for spec in eval_specs]
|
| 383 |
+
|
| 384 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 385 |
+
xs = np.arange(len(labels))
|
| 386 |
+
ax.bar(xs, values, color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"])
|
| 387 |
+
ax.set_xticks(xs, labels)
|
| 388 |
+
ax.set_ylim(0.0, 1.05)
|
| 389 |
+
ax.set_ylabel("Accuracy")
|
| 390 |
+
ax.set_title("Polyreactivity accuracy overview")
|
| 391 |
+
for x, val in zip(xs, values, strict=False):
|
| 392 |
+
if np.isnan(val):
|
| 393 |
+
continue
|
| 394 |
+
ax.text(x, val + 0.02, f"{val:.3f}", ha="center", va="bottom")
|
| 395 |
+
fig.tight_layout()
|
| 396 |
+
fig.savefig(output_path, dpi=300)
|
| 397 |
+
plt.close(fig)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def plot_rocs(
|
| 401 |
+
preds: pd.DataFrame,
|
| 402 |
+
output_path: Path,
|
| 403 |
+
eval_specs: Sequence[DatasetSpec],
|
| 404 |
+
) -> None:
|
| 405 |
+
mapping = {"train_cv_oof": "Train CV"}
|
| 406 |
+
for spec in eval_specs:
|
| 407 |
+
mapping[spec.name] = spec.display
|
| 408 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 409 |
+
for split, label in mapping.items():
|
| 410 |
+
subset = preds[preds["split"] == split]
|
| 411 |
+
if subset.empty:
|
| 412 |
+
continue
|
| 413 |
+
fpr, tpr, _ = roc_curve(subset["y_true"], subset["y_score"])
|
| 414 |
+
ax.plot(fpr, tpr, label=label)
|
| 415 |
+
ax.plot([0, 1], [0, 1], linestyle="--", color="gray")
|
| 416 |
+
ax.set_xlabel("False positive rate")
|
| 417 |
+
ax.set_ylabel("True positive rate")
|
| 418 |
+
ax.set_title("ROC curves")
|
| 419 |
+
ax.legend()
|
| 420 |
+
fig.tight_layout()
|
| 421 |
+
fig.savefig(output_path, dpi=300)
|
| 422 |
+
plt.close(fig)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def plot_flags_scatter(data: pd.DataFrame, spearman_stat: float, output_path: Path) -> None:
|
| 426 |
+
rng = np.random.default_rng(42)
|
| 427 |
+
jitter = rng.uniform(-0.1, 0.1, size=len(data))
|
| 428 |
+
x = data["reactivity_count"].to_numpy(dtype=float) + jitter
|
| 429 |
+
y = data["score"].to_numpy(dtype=float)
|
| 430 |
+
|
| 431 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 432 |
+
ax.scatter(x, y, alpha=0.5, s=10)
|
| 433 |
+
ax.set_xlabel("ELISA flag count")
|
| 434 |
+
ax.set_ylabel("Predicted probability")
|
| 435 |
+
ax.set_title(f"Prediction vs flag count (Spearman={spearman_stat:.2f})")
|
| 436 |
+
fig.tight_layout()
|
| 437 |
+
fig.savefig(output_path, dpi=300)
|
| 438 |
+
plt.close(fig)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def run_lofo(
|
| 442 |
+
full_df: pd.DataFrame,
|
| 443 |
+
*,
|
| 444 |
+
families: list[str],
|
| 445 |
+
config: str,
|
| 446 |
+
batch_size: int,
|
| 447 |
+
output_dir: Path,
|
| 448 |
+
bootstrap_samples: int,
|
| 449 |
+
bootstrap_alpha: float,
|
| 450 |
+
) -> pd.DataFrame:
|
| 451 |
+
results: list[dict[str, float]] = []
|
| 452 |
+
for family in families:
|
| 453 |
+
family_lower = family.lower()
|
| 454 |
+
holdout = full_df[full_df["family"].str.lower() == family_lower].copy()
|
| 455 |
+
train = full_df[full_df["family"].str.lower() != family_lower].copy()
|
| 456 |
+
if holdout.empty or train.empty:
|
| 457 |
+
continue
|
| 458 |
+
|
| 459 |
+
train_path = output_dir / f"train_lofo_{family_lower}.csv"
|
| 460 |
+
holdout_path = output_dir / f"eval_lofo_{family_lower}.csv"
|
| 461 |
+
train.to_csv(train_path, index=False)
|
| 462 |
+
holdout.to_csv(holdout_path, index=False)
|
| 463 |
+
|
| 464 |
+
run_dir = output_dir / f"lofo_{family_lower}"
|
| 465 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 466 |
+
model_path = run_dir / "model.joblib"
|
| 467 |
+
|
| 468 |
+
run_train(
|
| 469 |
+
train_path=train_path,
|
| 470 |
+
eval_specs=[
|
| 471 |
+
DatasetSpec(
|
| 472 |
+
name="boughter",
|
| 473 |
+
path=holdout_path,
|
| 474 |
+
display=f"{family.title()} holdout",
|
| 475 |
+
)
|
| 476 |
+
],
|
| 477 |
+
output_dir=run_dir,
|
| 478 |
+
model_path=model_path,
|
| 479 |
+
config=config,
|
| 480 |
+
batch_size=batch_size,
|
| 481 |
+
keep_duplicates=True,
|
| 482 |
+
include_species=None,
|
| 483 |
+
include_families=None,
|
| 484 |
+
exclude_families=None,
|
| 485 |
+
group_column="lineage",
|
| 486 |
+
train_loader="boughter",
|
| 487 |
+
bootstrap_samples=bootstrap_samples,
|
| 488 |
+
bootstrap_alpha=bootstrap_alpha,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
metrics = pd.read_csv(run_dir / "metrics.csv")
|
| 492 |
+
evaluation_cols = [
|
| 493 |
+
col
|
| 494 |
+
for col in metrics.columns
|
| 495 |
+
if col not in {"metric", "train_cv_mean", "train_cv_std"}
|
| 496 |
+
]
|
| 497 |
+
if not evaluation_cols:
|
| 498 |
+
continue
|
| 499 |
+
eval_col = evaluation_cols[0]
|
| 500 |
+
def _metric_value(name: str) -> float:
|
| 501 |
+
series = metrics.loc[metrics["metric"] == name, eval_col]
|
| 502 |
+
return float(series.values[0]) if not series.empty else float("nan")
|
| 503 |
+
|
| 504 |
+
results.append(
|
| 505 |
+
{
|
| 506 |
+
"family": family,
|
| 507 |
+
"accuracy": _metric_value("accuracy"),
|
| 508 |
+
"roc_auc": _metric_value("roc_auc"),
|
| 509 |
+
"pr_auc": _metric_value("pr_auc"),
|
| 510 |
+
"sensitivity": _metric_value("sensitivity"),
|
| 511 |
+
"specificity": _metric_value("specificity"),
|
| 512 |
+
}
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
return pd.DataFrame(results)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def run_flag_regression(
|
| 520 |
+
train_path: Path,
|
| 521 |
+
*,
|
| 522 |
+
output_dir: Path,
|
| 523 |
+
config_path: str,
|
| 524 |
+
batch_size: int,
|
| 525 |
+
n_splits: int = 5,
|
| 526 |
+
) -> None:
|
| 527 |
+
df = pd.read_csv(train_path)
|
| 528 |
+
if "reactivity_count" not in df.columns:
|
| 529 |
+
return
|
| 530 |
+
|
| 531 |
+
config = load_config(config_path)
|
| 532 |
+
kfold = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed)
|
| 533 |
+
|
| 534 |
+
metrics_rows: list[dict[str, float]] = []
|
| 535 |
+
preds_rows: list[dict[str, float]] = []
|
| 536 |
+
|
| 537 |
+
for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(df), start=1):
|
| 538 |
+
train_split = df.iloc[train_idx].reset_index(drop=True)
|
| 539 |
+
val_split = df.iloc[val_idx].reset_index(drop=True)
|
| 540 |
+
|
| 541 |
+
pipeline = FeaturePipeline(
|
| 542 |
+
backend=config.feature_backend,
|
| 543 |
+
descriptors=config.descriptors,
|
| 544 |
+
device=config.device,
|
| 545 |
+
)
|
| 546 |
+
X_train = pipeline.fit_transform(train_split, heavy_only=True, batch_size=batch_size)
|
| 547 |
+
scaler = StandardScaler()
|
| 548 |
+
X_train_scaled = scaler.fit_transform(X_train)
|
| 549 |
+
y_train = train_split["reactivity_count"].to_numpy(dtype=float)
|
| 550 |
+
# Train a logistic head to obtain probabilities as a 1-D feature
|
| 551 |
+
clf = LogisticRegression(
|
| 552 |
+
C=config.model.C,
|
| 553 |
+
class_weight=config.model.class_weight,
|
| 554 |
+
max_iter=2000,
|
| 555 |
+
solver="lbfgs",
|
| 556 |
+
)
|
| 557 |
+
clf.fit(X_train_scaled, train_split["label"].to_numpy(dtype=int))
|
| 558 |
+
prob_train = clf.predict_proba(X_train_scaled)[:, 1]
|
| 559 |
+
|
| 560 |
+
X_val = pipeline.transform(val_split, heavy_only=True, batch_size=batch_size)
|
| 561 |
+
X_val_scaled = scaler.transform(X_val)
|
| 562 |
+
y_val = val_split["reactivity_count"].to_numpy(dtype=float)
|
| 563 |
+
prob_val = clf.predict_proba(X_val_scaled)[:, 1]
|
| 564 |
+
|
| 565 |
+
poisson_X_train = prob_train.reshape(-1, 1)
|
| 566 |
+
poisson_X_val = prob_val.reshape(-1, 1)
|
| 567 |
+
model = fit_poisson_model(poisson_X_train, y_train)
|
| 568 |
+
poisson_preds = model.predict(poisson_X_val)
|
| 569 |
+
|
| 570 |
+
n_params = poisson_X_train.shape[1] + 1 # include intercept
|
| 571 |
+
dof = max(len(y_val) - n_params, 1)
|
| 572 |
+
variance_to_mean = float(np.var(y_val, ddof=1) / np.mean(y_val)) if np.mean(y_val) else float("nan")
|
| 573 |
+
|
| 574 |
+
spearman_val = float(spearmanr(y_val, poisson_preds).statistic)
|
| 575 |
+
try:
|
| 576 |
+
pearson_val = float(pearsonr(y_val, poisson_preds)[0])
|
| 577 |
+
except Exception: # pragma: no cover - fallback if correlation fails
|
| 578 |
+
pearson_val = float("nan")
|
| 579 |
+
|
| 580 |
+
poisson_metrics = regression_metrics(y_val, poisson_preds)
|
| 581 |
+
poisson_metrics.update(
|
| 582 |
+
{
|
| 583 |
+
"spearman": spearman_val,
|
| 584 |
+
"pearson": pearson_val,
|
| 585 |
+
"pearson_dispersion": pearson_dispersion(y_val, poisson_preds, dof=dof),
|
| 586 |
+
"variance_to_mean": variance_to_mean,
|
| 587 |
+
"fold": fold_idx,
|
| 588 |
+
"model": "poisson",
|
| 589 |
+
"status": "ok",
|
| 590 |
+
}
|
| 591 |
+
)
|
| 592 |
+
metrics_rows.append(poisson_metrics)
|
| 593 |
+
|
| 594 |
+
nb_preds: np.ndarray | None = None
|
| 595 |
+
nb_model = None
|
| 596 |
+
try:
|
| 597 |
+
nb_model = fit_negative_binomial_model(poisson_X_train, y_train)
|
| 598 |
+
nb_preds = nb_model.predict(poisson_X_val)
|
| 599 |
+
if not np.all(np.isfinite(nb_preds)):
|
| 600 |
+
raise ValueError("negative binomial produced non-finite predictions")
|
| 601 |
+
except Exception:
|
| 602 |
+
nb_metrics = {
|
| 603 |
+
"spearman": float("nan"),
|
| 604 |
+
"pearson": float("nan"),
|
| 605 |
+
"pearson_dispersion": float("nan"),
|
| 606 |
+
"variance_to_mean": variance_to_mean,
|
| 607 |
+
"alpha": float("nan"),
|
| 608 |
+
"fold": fold_idx,
|
| 609 |
+
"model": "negative_binomial",
|
| 610 |
+
"status": "failed",
|
| 611 |
+
}
|
| 612 |
+
metrics_rows.append(nb_metrics)
|
| 613 |
+
else:
|
| 614 |
+
spearman_nb = float(spearmanr(y_val, nb_preds).statistic)
|
| 615 |
+
try:
|
| 616 |
+
pearson_nb = float(pearsonr(y_val, nb_preds)[0])
|
| 617 |
+
except Exception: # pragma: no cover
|
| 618 |
+
pearson_nb = float("nan")
|
| 619 |
+
|
| 620 |
+
nb_metrics = regression_metrics(y_val, nb_preds)
|
| 621 |
+
nb_metrics.update(
|
| 622 |
+
{
|
| 623 |
+
"spearman": spearman_nb,
|
| 624 |
+
"pearson": pearson_nb,
|
| 625 |
+
"pearson_dispersion": pearson_dispersion(y_val, nb_preds, dof=dof),
|
| 626 |
+
"variance_to_mean": variance_to_mean,
|
| 627 |
+
"alpha": nb_model.alpha,
|
| 628 |
+
"fold": fold_idx,
|
| 629 |
+
"model": "negative_binomial",
|
| 630 |
+
"status": "ok",
|
| 631 |
+
}
|
| 632 |
+
)
|
| 633 |
+
metrics_rows.append(nb_metrics)
|
| 634 |
+
|
| 635 |
+
records = list(val_split.itertuples(index=False))
|
| 636 |
+
for idx, row in enumerate(records):
|
| 637 |
+
row_id = getattr(row, "id", idx)
|
| 638 |
+
y_true_val = float(getattr(row, "reactivity_count"))
|
| 639 |
+
preds_rows.append(
|
| 640 |
+
{
|
| 641 |
+
"fold": fold_idx,
|
| 642 |
+
"model": "poisson",
|
| 643 |
+
"id": row_id,
|
| 644 |
+
"y_true": y_true_val,
|
| 645 |
+
"y_pred": float(poisson_preds[idx]),
|
| 646 |
+
}
|
| 647 |
+
)
|
| 648 |
+
if nb_preds is not None:
|
| 649 |
+
preds_rows.append(
|
| 650 |
+
{
|
| 651 |
+
"fold": fold_idx,
|
| 652 |
+
"model": "negative_binomial",
|
| 653 |
+
"id": row_id,
|
| 654 |
+
"y_true": y_true_val,
|
| 655 |
+
"y_pred": float(nb_preds[idx]),
|
| 656 |
+
}
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
metrics_df = pd.DataFrame(metrics_rows)
|
| 660 |
+
metrics_df.to_csv(output_dir / "flag_regression_folds.csv", index=False)
|
| 661 |
+
|
| 662 |
+
summary_records: list[dict[str, float]] = []
|
| 663 |
+
for model_name, group in metrics_df.groupby("model"):
|
| 664 |
+
for column in group.columns:
|
| 665 |
+
if column in {"fold", "model", "status"}:
|
| 666 |
+
continue
|
| 667 |
+
values = group[column].dropna()
|
| 668 |
+
if values.empty:
|
| 669 |
+
continue
|
| 670 |
+
summary_records.append(
|
| 671 |
+
{
|
| 672 |
+
"model": model_name,
|
| 673 |
+
"metric": column,
|
| 674 |
+
"mean": float(values.mean()),
|
| 675 |
+
"std": float(values.std(ddof=1)) if len(values) > 1 else float("nan"),
|
| 676 |
+
}
|
| 677 |
+
)
|
| 678 |
+
if summary_records:
|
| 679 |
+
pd.DataFrame(summary_records).to_csv(
|
| 680 |
+
output_dir / "flag_regression_metrics.csv", index=False
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
if preds_rows:
|
| 684 |
+
pd.DataFrame(preds_rows).to_csv(output_dir / "flag_regression_preds.csv", index=False)
|
| 685 |
+
|
| 686 |
+
def run_descriptor_variants(
|
| 687 |
+
base_config: Dict[str, Any],
|
| 688 |
+
*,
|
| 689 |
+
train_path: Path,
|
| 690 |
+
eval_specs: Sequence[DatasetSpec],
|
| 691 |
+
output_dir: Path,
|
| 692 |
+
batch_size: int,
|
| 693 |
+
include_species: List[str] | None,
|
| 694 |
+
include_families: List[str] | None,
|
| 695 |
+
bootstrap_samples: int,
|
| 696 |
+
bootstrap_alpha: float,
|
| 697 |
+
) -> None:
|
| 698 |
+
variants = [
|
| 699 |
+
(
|
| 700 |
+
"descriptors_full_vh",
|
| 701 |
+
{
|
| 702 |
+
"feature_backend": {"type": "descriptors"},
|
| 703 |
+
"descriptors": {
|
| 704 |
+
"use_anarci": True,
|
| 705 |
+
"regions": ["CDRH1", "CDRH2", "CDRH3"],
|
| 706 |
+
"features": [
|
| 707 |
+
"length",
|
| 708 |
+
"charge",
|
| 709 |
+
"hydropathy",
|
| 710 |
+
"aromaticity",
|
| 711 |
+
"pI",
|
| 712 |
+
"net_charge",
|
| 713 |
+
],
|
| 714 |
+
},
|
| 715 |
+
},
|
| 716 |
+
),
|
| 717 |
+
(
|
| 718 |
+
"descriptors_cdrh3_pi",
|
| 719 |
+
{
|
| 720 |
+
"feature_backend": {"type": "descriptors"},
|
| 721 |
+
"descriptors": {
|
| 722 |
+
"use_anarci": True,
|
| 723 |
+
"regions": ["CDRH3"],
|
| 724 |
+
"features": ["pI"],
|
| 725 |
+
},
|
| 726 |
+
},
|
| 727 |
+
),
|
| 728 |
+
(
|
| 729 |
+
"descriptors_cdrh3_top5",
|
| 730 |
+
{
|
| 731 |
+
"feature_backend": {"type": "descriptors"},
|
| 732 |
+
"descriptors": {
|
| 733 |
+
"use_anarci": True,
|
| 734 |
+
"regions": ["CDRH3"],
|
| 735 |
+
"features": [
|
| 736 |
+
"pI",
|
| 737 |
+
"net_charge",
|
| 738 |
+
"charge",
|
| 739 |
+
"hydropathy",
|
| 740 |
+
"length",
|
| 741 |
+
],
|
| 742 |
+
},
|
| 743 |
+
},
|
| 744 |
+
),
|
| 745 |
+
]
|
| 746 |
+
|
| 747 |
+
configs_dir = output_dir / "configs"
|
| 748 |
+
configs_dir.mkdir(parents=True, exist_ok=True)
|
| 749 |
+
summary_records: list[dict[str, Any]] = []
|
| 750 |
+
|
| 751 |
+
for name, overrides in variants:
|
| 752 |
+
variant_config_path = _write_variant_config(
|
| 753 |
+
base_config,
|
| 754 |
+
overrides,
|
| 755 |
+
configs_dir / f"{name}.yaml",
|
| 756 |
+
)
|
| 757 |
+
variant_output = output_dir / name
|
| 758 |
+
variant_output.mkdir(parents=True, exist_ok=True)
|
| 759 |
+
model_path = variant_output / "model.joblib"
|
| 760 |
+
|
| 761 |
+
run_train(
|
| 762 |
+
train_path=train_path,
|
| 763 |
+
eval_specs=eval_specs,
|
| 764 |
+
output_dir=variant_output,
|
| 765 |
+
model_path=model_path,
|
| 766 |
+
config=str(variant_config_path),
|
| 767 |
+
batch_size=batch_size,
|
| 768 |
+
include_species=include_species,
|
| 769 |
+
include_families=include_families,
|
| 770 |
+
keep_duplicates=True,
|
| 771 |
+
group_column="lineage",
|
| 772 |
+
train_loader="boughter",
|
| 773 |
+
bootstrap_samples=bootstrap_samples,
|
| 774 |
+
bootstrap_alpha=bootstrap_alpha,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
metrics_path = variant_output / "metrics.csv"
|
| 778 |
+
if metrics_path.exists():
|
| 779 |
+
metrics_df = pd.read_csv(metrics_path)
|
| 780 |
+
summary_records.extend(_collect_metric_records(name, metrics_df))
|
| 781 |
+
|
| 782 |
+
_dump_coefficients(model_path, variant_output / "coefficients.csv")
|
| 783 |
+
|
| 784 |
+
if summary_records:
|
| 785 |
+
pd.DataFrame(summary_records).to_csv(output_dir / "summary.csv", index=False)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def run_fragment_variants(
|
| 789 |
+
config_path: str,
|
| 790 |
+
*,
|
| 791 |
+
train_path: Path,
|
| 792 |
+
eval_specs: Sequence[DatasetSpec],
|
| 793 |
+
output_dir: Path,
|
| 794 |
+
batch_size: int,
|
| 795 |
+
include_species: List[str] | None,
|
| 796 |
+
include_families: List[str] | None,
|
| 797 |
+
bootstrap_samples: int,
|
| 798 |
+
bootstrap_alpha: float,
|
| 799 |
+
) -> None:
|
| 800 |
+
numberer = AnarciNumberer()
|
| 801 |
+
specs = [
|
| 802 |
+
("vh_full", ["VH"]),
|
| 803 |
+
("cdrh1", ["CDRH1"]),
|
| 804 |
+
("cdrh2", ["CDRH2"]),
|
| 805 |
+
("cdrh3", ["CDRH3"]),
|
| 806 |
+
("cdrh123", ["CDRH1", "CDRH2", "CDRH3"]),
|
| 807 |
+
]
|
| 808 |
+
|
| 809 |
+
summary_rows: list[dict[str, Any]] = []
|
| 810 |
+
metric_summary_rows: list[dict[str, Any]] = []
|
| 811 |
+
|
| 812 |
+
for name, regions in specs:
|
| 813 |
+
variant_dir = output_dir / name
|
| 814 |
+
variant_dir.mkdir(parents=True, exist_ok=True)
|
| 815 |
+
dataset_dir = variant_dir / "datasets"
|
| 816 |
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 817 |
+
|
| 818 |
+
train_df = pd.read_csv(train_path)
|
| 819 |
+
train_variant, train_summary = _make_region_dataset(train_df, regions, numberer)
|
| 820 |
+
train_variant_path = dataset_dir / "train.csv"
|
| 821 |
+
train_variant.to_csv(train_variant_path, index=False)
|
| 822 |
+
|
| 823 |
+
eval_variant_specs: list[DatasetSpec] = []
|
| 824 |
+
for spec in eval_specs:
|
| 825 |
+
eval_df = pd.read_csv(spec.path)
|
| 826 |
+
transformed, eval_summary = _make_region_dataset(eval_df, regions, numberer)
|
| 827 |
+
eval_path = dataset_dir / f"{spec.name}.csv"
|
| 828 |
+
transformed.to_csv(eval_path, index=False)
|
| 829 |
+
eval_variant_specs.append(
|
| 830 |
+
DatasetSpec(name=spec.name, path=eval_path, display=spec.display)
|
| 831 |
+
)
|
| 832 |
+
eval_summary.update({"variant": name, "dataset": spec.name})
|
| 833 |
+
summary_rows.append(eval_summary)
|
| 834 |
+
|
| 835 |
+
train_summary.update({"variant": name, "dataset": "train"})
|
| 836 |
+
summary_rows.append(train_summary)
|
| 837 |
+
|
| 838 |
+
run_train(
|
| 839 |
+
train_path=train_variant_path,
|
| 840 |
+
eval_specs=eval_variant_specs,
|
| 841 |
+
output_dir=variant_dir,
|
| 842 |
+
model_path=variant_dir / "model.joblib",
|
| 843 |
+
config=config_path,
|
| 844 |
+
batch_size=batch_size,
|
| 845 |
+
include_species=include_species,
|
| 846 |
+
include_families=include_families,
|
| 847 |
+
keep_duplicates=True,
|
| 848 |
+
group_column="lineage",
|
| 849 |
+
train_loader="boughter",
|
| 850 |
+
bootstrap_samples=bootstrap_samples,
|
| 851 |
+
bootstrap_alpha=bootstrap_alpha,
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
metrics_path = variant_dir / "metrics.csv"
|
| 855 |
+
if metrics_path.exists():
|
| 856 |
+
metrics_df = pd.read_csv(metrics_path)
|
| 857 |
+
metric_records = _collect_metric_records(name, metrics_df)
|
| 858 |
+
for record in metric_records:
|
| 859 |
+
record["variant_type"] = "fragment"
|
| 860 |
+
metric_summary_rows.extend(metric_records)
|
| 861 |
+
|
| 862 |
+
if summary_rows:
|
| 863 |
+
pd.DataFrame(summary_rows).to_csv(output_dir / "fragment_dataset_summary.csv", index=False)
|
| 864 |
+
if metric_summary_rows:
|
| 865 |
+
pd.DataFrame(metric_summary_rows).to_csv(output_dir / "fragment_metrics_summary.csv", index=False)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def main(argv: list[str] | None = None) -> int:
|
| 869 |
+
parser = build_parser()
|
| 870 |
+
args = parser.parse_args(argv)
|
| 871 |
+
|
| 872 |
+
output_dir = Path(args.output_dir)
|
| 873 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 874 |
+
|
| 875 |
+
if args.rebuild:
|
| 876 |
+
rebuild_cmd = [
|
| 877 |
+
"python",
|
| 878 |
+
"scripts/rebuild_boughter_from_counts.py",
|
| 879 |
+
"--output",
|
| 880 |
+
str(args.train_data),
|
| 881 |
+
]
|
| 882 |
+
if subprocess.run(rebuild_cmd, check=False).returncode != 0:
|
| 883 |
+
raise RuntimeError("Dataset rebuild failed")
|
| 884 |
+
|
| 885 |
+
train_path = Path(args.train_data)
|
| 886 |
+
|
| 887 |
+
def _make_spec(name: str, path_str: str) -> DatasetSpec | None:
|
| 888 |
+
path = Path(path_str)
|
| 889 |
+
if not path.exists():
|
| 890 |
+
return None
|
| 891 |
+
display = DISPLAY_LABELS.get(name, name.replace("_", " ").title())
|
| 892 |
+
return DatasetSpec(name=name, path=path, display=display)
|
| 893 |
+
|
| 894 |
+
eval_specs: list[DatasetSpec] = []
|
| 895 |
+
seen_paths: set[Path] = set()
|
| 896 |
+
for name, path_str in [
|
| 897 |
+
("jain", args.jain),
|
| 898 |
+
("shehata", args.shehata),
|
| 899 |
+
("shehata_curated", args.shehata_curated),
|
| 900 |
+
("harvey", args.harvey),
|
| 901 |
+
]:
|
| 902 |
+
spec = _make_spec(name, path_str)
|
| 903 |
+
if spec is not None:
|
| 904 |
+
resolved = spec.path.resolve()
|
| 905 |
+
if resolved in seen_paths:
|
| 906 |
+
continue
|
| 907 |
+
seen_paths.add(resolved)
|
| 908 |
+
eval_specs.append(spec)
|
| 909 |
+
|
| 910 |
+
base_config = load_config(args.config)
|
| 911 |
+
base_config_dict = _config_to_dict(base_config)
|
| 912 |
+
|
| 913 |
+
main_output = output_dir / "main"
|
| 914 |
+
main_output.mkdir(parents=True, exist_ok=True)
|
| 915 |
+
model_path = main_output / "model.joblib"
|
| 916 |
+
|
| 917 |
+
main_include_species = ["human"] if args.human_only else None
|
| 918 |
+
main_include_families = ["hiv", "influenza"] if args.human_only else None
|
| 919 |
+
|
| 920 |
+
run_train(
|
| 921 |
+
train_path=train_path,
|
| 922 |
+
eval_specs=eval_specs,
|
| 923 |
+
output_dir=main_output,
|
| 924 |
+
model_path=model_path,
|
| 925 |
+
config=args.config,
|
| 926 |
+
batch_size=args.batch_size,
|
| 927 |
+
include_species=main_include_species,
|
| 928 |
+
include_families=main_include_families,
|
| 929 |
+
keep_duplicates=True,
|
| 930 |
+
group_column="lineage",
|
| 931 |
+
train_loader="boughter",
|
| 932 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 933 |
+
bootstrap_alpha=args.bootstrap_alpha,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
metrics = pd.read_csv(main_output / "metrics.csv")
|
| 937 |
+
preds = pd.read_csv(main_output / "preds.csv")
|
| 938 |
+
|
| 939 |
+
plot_accuracy(metrics, main_output / "accuracy_overview.png", eval_specs)
|
| 940 |
+
plot_rocs(preds, main_output / "roc_overview.png", eval_specs)
|
| 941 |
+
|
| 942 |
+
if not args.skip_flag_regression:
|
| 943 |
+
run_flag_regression(
|
| 944 |
+
train_path=train_path,
|
| 945 |
+
output_dir=main_output,
|
| 946 |
+
config_path=args.config,
|
| 947 |
+
batch_size=args.batch_size,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
split_summary = _summarise_predictions(preds)
|
| 951 |
+
split_summary.to_csv(main_output / "dataset_split_summary.csv", index=False)
|
| 952 |
+
|
| 953 |
+
spearman_stat, spearman_p, corr_df = compute_spearman(
|
| 954 |
+
model_path=model_path,
|
| 955 |
+
dataset_path=Path(args.full_data),
|
| 956 |
+
batch_size=args.batch_size,
|
| 957 |
+
)
|
| 958 |
+
plot_flags_scatter(corr_df, spearman_stat, main_output / "prob_vs_flags.png")
|
| 959 |
+
(main_output / "spearman_flags.json").write_text(
|
| 960 |
+
json.dumps({"spearman": spearman_stat, "p_value": spearman_p}, indent=2)
|
| 961 |
+
)
|
| 962 |
+
corr_df.to_csv(main_output / "prob_vs_flags.csv", index=False)
|
| 963 |
+
|
| 964 |
+
if not args.skip_lofo:
|
| 965 |
+
full_df = pd.read_csv(args.train_data)
|
| 966 |
+
lofo_dir = output_dir / "lofo_runs"
|
| 967 |
+
lofo_dir.mkdir(parents=True, exist_ok=True)
|
| 968 |
+
lofo_df = run_lofo(
|
| 969 |
+
full_df,
|
| 970 |
+
families=["influenza", "hiv", "mouse_iga"],
|
| 971 |
+
config=args.config,
|
| 972 |
+
batch_size=args.batch_size,
|
| 973 |
+
output_dir=lofo_dir,
|
| 974 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 975 |
+
bootstrap_alpha=args.bootstrap_alpha,
|
| 976 |
+
)
|
| 977 |
+
lofo_df.to_csv(output_dir / "lofo_metrics.csv", index=False)
|
| 978 |
+
|
| 979 |
+
if not args.skip_descriptor_variants:
|
| 980 |
+
descriptor_dir = output_dir / "descriptor_variants"
|
| 981 |
+
descriptor_dir.mkdir(parents=True, exist_ok=True)
|
| 982 |
+
run_descriptor_variants(
|
| 983 |
+
base_config_dict,
|
| 984 |
+
train_path=train_path,
|
| 985 |
+
eval_specs=eval_specs,
|
| 986 |
+
output_dir=descriptor_dir,
|
| 987 |
+
batch_size=args.batch_size,
|
| 988 |
+
include_species=main_include_species,
|
| 989 |
+
include_families=main_include_families,
|
| 990 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 991 |
+
bootstrap_alpha=args.bootstrap_alpha,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
if not args.skip_fragment_variants:
|
| 995 |
+
fragment_dir = output_dir / "fragment_variants"
|
| 996 |
+
fragment_dir.mkdir(parents=True, exist_ok=True)
|
| 997 |
+
run_fragment_variants(
|
| 998 |
+
args.config,
|
| 999 |
+
train_path=train_path,
|
| 1000 |
+
eval_specs=eval_specs,
|
| 1001 |
+
output_dir=fragment_dir,
|
| 1002 |
+
batch_size=args.batch_size,
|
| 1003 |
+
include_species=main_include_species,
|
| 1004 |
+
include_families=main_include_families,
|
| 1005 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 1006 |
+
bootstrap_alpha=args.bootstrap_alpha,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
raw_summaries = []
|
| 1010 |
+
raw_summaries.append(_summarise_raw_dataset(train_path, "boughter_rebuilt"))
|
| 1011 |
+
for spec in eval_specs:
|
| 1012 |
+
summary_name = RAW_LABELS.get(spec.name, spec.name)
|
| 1013 |
+
raw_summaries.append(_summarise_raw_dataset(spec.path, summary_name))
|
| 1014 |
+
pd.DataFrame(raw_summaries).to_csv(output_dir / "raw_dataset_summary.csv", index=False)
|
| 1015 |
+
|
| 1016 |
+
return 0
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
if __name__ == "__main__":
|
| 1020 |
+
raise SystemExit(main())
|
polyreact/benchmarks/run_benchmarks.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run end-to-end benchmarks for the polyreactivity model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
from .. import train as train_cli
|
| 10 |
+
|
| 11 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 12 |
+
DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv"
|
| 13 |
+
DEFAULT_EVAL = [
|
| 14 |
+
PROJECT_ROOT / "tests" / "fixtures" / "jain.csv",
|
| 15 |
+
PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv",
|
| 16 |
+
PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 21 |
+
parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks")
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--config",
|
| 24 |
+
default="configs/default.yaml",
|
| 25 |
+
help="Path to configuration YAML file.",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--train",
|
| 29 |
+
default=str(DEFAULT_TRAIN),
|
| 30 |
+
help="Training dataset CSV (defaults to bundled fixture).",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--eval",
|
| 34 |
+
nargs="+",
|
| 35 |
+
default=[str(path) for path in DEFAULT_EVAL],
|
| 36 |
+
help="Evaluation dataset CSV paths (>=1).",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--report-dir",
|
| 40 |
+
default="artifacts",
|
| 41 |
+
help="Directory to write metrics, predictions, and plots.",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--model-path",
|
| 45 |
+
default="artifacts/model.joblib",
|
| 46 |
+
help="Destination for the trained model artifact.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--backend",
|
| 50 |
+
choices=["descriptors", "plm", "concat"],
|
| 51 |
+
help="Override feature backend during training.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--plm-model", help="Optional PLM model override.")
|
| 54 |
+
parser.add_argument("--cache-dir", help="Embedding cache directory override.")
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--device",
|
| 57 |
+
choices=["auto", "cpu", "cuda"],
|
| 58 |
+
help="Device override for embeddings.",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--paired",
|
| 62 |
+
action="store_true",
|
| 63 |
+
help="Use paired heavy/light chains when available.",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--batch-size",
|
| 67 |
+
type=int,
|
| 68 |
+
default=8,
|
| 69 |
+
help="Batch size for PLM embedding batches.",
|
| 70 |
+
)
|
| 71 |
+
return parser
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main(argv: List[str] | None = None) -> int:
|
| 75 |
+
parser = build_parser()
|
| 76 |
+
args = parser.parse_args(argv)
|
| 77 |
+
|
| 78 |
+
if len(args.eval) < 1:
|
| 79 |
+
parser.error("Provide at least one evaluation dataset via --eval.")
|
| 80 |
+
|
| 81 |
+
report_dir = Path(args.report_dir)
|
| 82 |
+
report_dir.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
train_args: list[str] = [
|
| 85 |
+
"--config",
|
| 86 |
+
args.config,
|
| 87 |
+
"--train",
|
| 88 |
+
args.train,
|
| 89 |
+
"--save-to",
|
| 90 |
+
str(Path(args.model_path)),
|
| 91 |
+
"--report-to",
|
| 92 |
+
str(report_dir),
|
| 93 |
+
"--batch-size",
|
| 94 |
+
str(args.batch_size),
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
train_args.extend(["--eval", *args.eval])
|
| 98 |
+
|
| 99 |
+
if args.backend:
|
| 100 |
+
train_args.extend(["--backend", args.backend])
|
| 101 |
+
if args.plm_model:
|
| 102 |
+
train_args.extend(["--plm-model", args.plm_model])
|
| 103 |
+
if args.cache_dir:
|
| 104 |
+
train_args.extend(["--cache-dir", args.cache_dir])
|
| 105 |
+
if args.device:
|
| 106 |
+
train_args.extend(["--device", args.device])
|
| 107 |
+
if args.paired:
|
| 108 |
+
train_args.append("--paired")
|
| 109 |
+
|
| 110 |
+
return train_cli.main(train_args)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
raise SystemExit(main())
|
polyreact/config.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration helpers for the polyreactivity project."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import asdict, dataclass, field
|
| 6 |
+
import importlib.resources as pkg_resources
|
| 7 |
+
from importlib.resources.abc import Traversable
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Sequence
|
| 10 |
+
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(slots=True)
|
| 15 |
+
class FeatureBackendSettings:
|
| 16 |
+
type: str = "plm"
|
| 17 |
+
plm_model_name: str = "facebook/esm2_t12_35M_UR50D"
|
| 18 |
+
layer_pool: str = "mean"
|
| 19 |
+
cache_dir: str = ".cache/embeddings"
|
| 20 |
+
standardize: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(slots=True)
|
| 24 |
+
class DescriptorSettings:
|
| 25 |
+
use_anarci: bool = True
|
| 26 |
+
regions: Sequence[str] = field(default_factory=lambda: ["CDRH1", "CDRH2", "CDRH3"])
|
| 27 |
+
features: Sequence[str] = field(
|
| 28 |
+
default_factory=lambda: [
|
| 29 |
+
"length",
|
| 30 |
+
"charge",
|
| 31 |
+
"hydropathy",
|
| 32 |
+
"aromaticity",
|
| 33 |
+
"pI",
|
| 34 |
+
"net_charge",
|
| 35 |
+
]
|
| 36 |
+
)
|
| 37 |
+
ph: float = 7.4
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass(slots=True)
|
| 41 |
+
class ModelSettings:
|
| 42 |
+
head: str = "logreg"
|
| 43 |
+
C: float = 1.0
|
| 44 |
+
class_weight: Any = "balanced"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass(slots=True)
|
| 48 |
+
class CalibrationSettings:
|
| 49 |
+
method: str | None = "isotonic"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass(slots=True)
|
| 53 |
+
class TrainingSettings:
|
| 54 |
+
cv_folds: int = 10
|
| 55 |
+
scoring: str = "roc_auc"
|
| 56 |
+
n_jobs: int = -1
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(slots=True)
|
| 60 |
+
class IOSettings:
|
| 61 |
+
outputs_dir: str = "artifacts"
|
| 62 |
+
preds_filename: str = "preds.csv"
|
| 63 |
+
metrics_filename: str = "metrics.csv"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass(slots=True)
|
| 67 |
+
class Config:
|
| 68 |
+
seed: int = 42
|
| 69 |
+
device: str = "auto"
|
| 70 |
+
feature_backend: FeatureBackendSettings = field(default_factory=FeatureBackendSettings)
|
| 71 |
+
descriptors: DescriptorSettings = field(default_factory=DescriptorSettings)
|
| 72 |
+
model: ModelSettings = field(default_factory=ModelSettings)
|
| 73 |
+
calibration: CalibrationSettings = field(default_factory=CalibrationSettings)
|
| 74 |
+
training: TrainingSettings = field(default_factory=TrainingSettings)
|
| 75 |
+
io: IOSettings = field(default_factory=IOSettings)
|
| 76 |
+
|
| 77 |
+
raw: dict[str, Any] = field(default_factory=dict)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _merge_section(default: Any, data: dict[str, Any] | None) -> Any:
|
| 81 |
+
if data is None:
|
| 82 |
+
return default
|
| 83 |
+
merged = asdict(default) | data
|
| 84 |
+
return type(default)(**merged)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_config(path: str | Path | None = None) -> Config:
|
| 88 |
+
"""Load a YAML configuration file into a strongly-typed ``Config`` object."""
|
| 89 |
+
|
| 90 |
+
data = _read_config_data(path)
|
| 91 |
+
|
| 92 |
+
feature_backend = _merge_section(FeatureBackendSettings(), data.get("feature_backend"))
|
| 93 |
+
descriptors = _merge_section(DescriptorSettings(), data.get("descriptors"))
|
| 94 |
+
model = _merge_section(ModelSettings(), data.get("model"))
|
| 95 |
+
calibration = _merge_section(CalibrationSettings(), data.get("calibration"))
|
| 96 |
+
training = _merge_section(TrainingSettings(), data.get("training"))
|
| 97 |
+
io_settings = _merge_section(IOSettings(), data.get("io"))
|
| 98 |
+
|
| 99 |
+
config = Config(
|
| 100 |
+
seed=int(data.get("seed", 42)),
|
| 101 |
+
device=str(data.get("device", "auto")),
|
| 102 |
+
feature_backend=feature_backend,
|
| 103 |
+
descriptors=descriptors,
|
| 104 |
+
model=model,
|
| 105 |
+
calibration=calibration,
|
| 106 |
+
training=training,
|
| 107 |
+
io=io_settings,
|
| 108 |
+
raw=data,
|
| 109 |
+
)
|
| 110 |
+
return config
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _read_config_data(path: str | Path | None) -> dict[str, Any]:
|
| 114 |
+
"""Return mapping data from YAML or the bundled default."""
|
| 115 |
+
|
| 116 |
+
if path is None:
|
| 117 |
+
resource = pkg_resources.files("polyreact.configs") / "default.yaml"
|
| 118 |
+
return _load_yaml_resource(resource)
|
| 119 |
+
|
| 120 |
+
resolved = _resolve_config_path(Path(path))
|
| 121 |
+
if resolved is not None:
|
| 122 |
+
return _load_yaml_file(resolved)
|
| 123 |
+
|
| 124 |
+
resource_root = pkg_resources.files("polyreact")
|
| 125 |
+
resource = resource_root / Path(path).as_posix()
|
| 126 |
+
if resource.is_file():
|
| 127 |
+
return _load_yaml_resource(resource)
|
| 128 |
+
|
| 129 |
+
msg = f"Configuration file not found: {path}"
|
| 130 |
+
raise FileNotFoundError(msg)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _resolve_config_path(path: Path) -> Path | None:
|
| 134 |
+
if path.exists():
|
| 135 |
+
return path
|
| 136 |
+
|
| 137 |
+
if not path.is_absolute():
|
| 138 |
+
candidate = Path(__file__).resolve().parent / path
|
| 139 |
+
if candidate.exists():
|
| 140 |
+
return candidate
|
| 141 |
+
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _load_yaml_file(path: Path) -> dict[str, Any]:
|
| 146 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 147 |
+
return _parse_yaml(handle.read())
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _load_yaml_resource(resource: Traversable) -> dict[str, Any]:
|
| 151 |
+
with resource.open("r", encoding="utf-8") as handle:
|
| 152 |
+
return _parse_yaml(handle.read())
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _parse_yaml(text: str) -> dict[str, Any]:
|
| 156 |
+
parsed = yaml.safe_load(text) or {}
|
| 157 |
+
if not isinstance(parsed, dict): # pragma: no cover - safeguard
|
| 158 |
+
msg = "Configuration must be a mapping at the top level"
|
| 159 |
+
raise ValueError(msg)
|
| 160 |
+
return parsed
|
polyreact/configs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Configuration package data."""
|
polyreact/configs/default.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 42
|
| 2 |
+
device: "auto"
|
| 3 |
+
feature_backend:
|
| 4 |
+
type: "plm"
|
| 5 |
+
plm_model_name: "facebook/esm1v_t33_650M_UR90S_1"
|
| 6 |
+
layer_pool: "mean"
|
| 7 |
+
cache_dir: ".cache/embeddings"
|
| 8 |
+
descriptors:
|
| 9 |
+
use_anarci: true
|
| 10 |
+
regions:
|
| 11 |
+
- "CDRH1"
|
| 12 |
+
- "CDRH2"
|
| 13 |
+
- "CDRH3"
|
| 14 |
+
features:
|
| 15 |
+
- "length"
|
| 16 |
+
- "charge"
|
| 17 |
+
- "hydropathy"
|
| 18 |
+
- "aromaticity"
|
| 19 |
+
- "pI"
|
| 20 |
+
- "net_charge"
|
| 21 |
+
model:
|
| 22 |
+
head: "logreg"
|
| 23 |
+
C: 0.1
|
| 24 |
+
class_weight: "balanced"
|
| 25 |
+
calibration:
|
| 26 |
+
method: "isotonic"
|
| 27 |
+
training:
|
| 28 |
+
cv_folds: 10
|
| 29 |
+
scoring: "roc_auc"
|
| 30 |
+
n_jobs: -1
|
| 31 |
+
io:
|
| 32 |
+
outputs_dir: "artifacts"
|
| 33 |
+
preds_filename: "preds.csv"
|
| 34 |
+
metrics_filename: "metrics.csv"
|
polyreact/data_loaders/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loaders for polyreactivity benchmarks."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["boughter", "jain", "shehata", "harvey", "utils"]
|
polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (297 Bytes). View file
|
|
|
polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
polyreact/data_loaders/__pycache__/jain.cpython-311.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc
ADDED
|
Binary file (8.16 kB). View file
|
|
|
polyreact/data_loaders/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (8.51 kB). View file
|
|
|
polyreact/data_loaders/boughter.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader for the Boughter et al. 2020 dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Iterable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from .utils import LOGGER, standardize_frame
|
| 11 |
+
|
| 12 |
+
_COLUMN_ALIASES = {
|
| 13 |
+
"id": ("sequence_id",),
|
| 14 |
+
"heavy_seq": ("heavy", "heavy_chain"),
|
| 15 |
+
"light_seq": ("light", "light_chain"),
|
| 16 |
+
"label": ("polyreactive",),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _find_flag_columns(columns: Iterable[str]) -> list[str]:
|
| 21 |
+
flag_cols: list[str] = []
|
| 22 |
+
for column in columns:
|
| 23 |
+
normalized = column.lower().replace(" ", "")
|
| 24 |
+
if "flag" in normalized:
|
| 25 |
+
flag_cols.append(column)
|
| 26 |
+
return flag_cols
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame:
|
| 30 |
+
if not flag_columns:
|
| 31 |
+
return frame
|
| 32 |
+
|
| 33 |
+
flag_values = (
|
| 34 |
+
frame[flag_columns]
|
| 35 |
+
.apply(pd.to_numeric, errors="coerce")
|
| 36 |
+
.fillna(0.0)
|
| 37 |
+
)
|
| 38 |
+
flag_binary = (flag_values > 0).astype(int)
|
| 39 |
+
flags_total = flag_binary.sum(axis=1)
|
| 40 |
+
|
| 41 |
+
specific_mask = flags_total == 0
|
| 42 |
+
nonspecific_mask = flags_total >= 4
|
| 43 |
+
keep_mask = specific_mask | nonspecific_mask
|
| 44 |
+
|
| 45 |
+
dropped = int((~keep_mask).sum())
|
| 46 |
+
if dropped:
|
| 47 |
+
LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped)
|
| 48 |
+
|
| 49 |
+
filtered = frame.loc[keep_mask].copy()
|
| 50 |
+
filtered["flags_total"] = flags_total.loc[keep_mask].astype(int)
|
| 51 |
+
filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0)
|
| 52 |
+
filtered["polyreactive"] = filtered["label"]
|
| 53 |
+
return filtered
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
|
| 57 |
+
"""Load the Boughter dataset into the canonical format."""
|
| 58 |
+
|
| 59 |
+
frame = pd.read_csv(path_or_url)
|
| 60 |
+
flag_columns = _find_flag_columns(frame.columns)
|
| 61 |
+
frame = _apply_flag_policy(frame, flag_columns)
|
| 62 |
+
|
| 63 |
+
label_series = frame.get("label")
|
| 64 |
+
if label_series is not None:
|
| 65 |
+
frame = frame[label_series.isin({0, 1})].copy()
|
| 66 |
+
|
| 67 |
+
standardized = standardize_frame(
|
| 68 |
+
frame,
|
| 69 |
+
source="boughter2020",
|
| 70 |
+
heavy_only=heavy_only,
|
| 71 |
+
column_aliases=_COLUMN_ALIASES,
|
| 72 |
+
is_test=False,
|
| 73 |
+
)
|
| 74 |
+
if "flags_total" in frame.columns and "flags_total" not in standardized.columns:
|
| 75 |
+
standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int)
|
| 76 |
+
return standardized
|
polyreact/data_loaders/harvey.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader for the Harvey et al. 2022 dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from .utils import standardize_frame
|
| 8 |
+
|
| 9 |
+
_COLUMN_ALIASES = {
|
| 10 |
+
"id": ("id", "clone_id"),
|
| 11 |
+
"heavy_seq": ("heavy", "heavy_chain", "sequence"),
|
| 12 |
+
"light_seq": ("light", "light_chain"),
|
| 13 |
+
"label": ("polyreactive", "is_polyreactive"),
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
_LABEL_MAP = {
|
| 17 |
+
"polyreactive": 1,
|
| 18 |
+
"non-polyreactive": 0,
|
| 19 |
+
"positive": 1,
|
| 20 |
+
"negative": 0,
|
| 21 |
+
1: 1,
|
| 22 |
+
0: 0,
|
| 23 |
+
"1": 1,
|
| 24 |
+
"0": 0,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
|
| 29 |
+
"""Load the Harvey dataset into the canonical format."""
|
| 30 |
+
|
| 31 |
+
frame = pd.read_csv(path_or_url)
|
| 32 |
+
return standardize_frame(
|
| 33 |
+
frame,
|
| 34 |
+
source="harvey2022",
|
| 35 |
+
heavy_only=heavy_only,
|
| 36 |
+
column_aliases=_COLUMN_ALIASES,
|
| 37 |
+
label_map=_LABEL_MAP,
|
| 38 |
+
is_test=True,
|
| 39 |
+
)
|
polyreact/data_loaders/jain.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader for the Jain et al. 2017 dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from .utils import standardize_frame
|
| 8 |
+
|
| 9 |
+
_COLUMN_ALIASES = {
|
| 10 |
+
"id": ("id", "antibody_id"),
|
| 11 |
+
"heavy_seq": ("heavy", "heavy_sequence", "H_chain"),
|
| 12 |
+
"light_seq": ("light", "light_sequence", "L_chain"),
|
| 13 |
+
"label": ("class", "polyreactive"),
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
_LABEL_MAP = {
|
| 17 |
+
"polyreactive": 1,
|
| 18 |
+
"non-polyreactive": 0,
|
| 19 |
+
"reactive": 1,
|
| 20 |
+
"non-reactive": 0,
|
| 21 |
+
1: 1,
|
| 22 |
+
0: 0,
|
| 23 |
+
1.0: 1,
|
| 24 |
+
0.0: 0,
|
| 25 |
+
"1": 1,
|
| 26 |
+
"0": 0,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
|
| 31 |
+
"""Load the Jain dataset into the canonical format."""
|
| 32 |
+
|
| 33 |
+
frame = pd.read_csv(path_or_url)
|
| 34 |
+
return standardize_frame(
|
| 35 |
+
frame,
|
| 36 |
+
source="jain2017",
|
| 37 |
+
heavy_only=heavy_only,
|
| 38 |
+
column_aliases=_COLUMN_ALIASES,
|
| 39 |
+
label_map=_LABEL_MAP,
|
| 40 |
+
is_test=True,
|
| 41 |
+
)
|
polyreact/data_loaders/shehata.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader for the Shehata et al. (2019) PSR dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from .utils import standardize_frame
|
| 11 |
+
|
| 12 |
+
SHEHATA_SOURCE = "shehata2019"
|
| 13 |
+
|
| 14 |
+
_COLUMN_ALIASES = {
|
| 15 |
+
"id": (
|
| 16 |
+
"antibody_id",
|
| 17 |
+
"antibody",
|
| 18 |
+
"antibody name",
|
| 19 |
+
"antibody_name",
|
| 20 |
+
"sequence_name",
|
| 21 |
+
"Antibody Name",
|
| 22 |
+
),
|
| 23 |
+
"heavy_seq": (
|
| 24 |
+
"heavy",
|
| 25 |
+
"heavy_chain",
|
| 26 |
+
"heavy aa",
|
| 27 |
+
"heavy_sequence",
|
| 28 |
+
"vh",
|
| 29 |
+
"vh_sequence",
|
| 30 |
+
"heavy chain aa",
|
| 31 |
+
"Heavy Chain AA",
|
| 32 |
+
),
|
| 33 |
+
"light_seq": (
|
| 34 |
+
"light",
|
| 35 |
+
"light_chain",
|
| 36 |
+
"light aa",
|
| 37 |
+
"light_sequence",
|
| 38 |
+
"vl",
|
| 39 |
+
"vl_sequence",
|
| 40 |
+
"light chain aa",
|
| 41 |
+
"Light Chain AA",
|
| 42 |
+
),
|
| 43 |
+
"label": (
|
| 44 |
+
"polyreactive",
|
| 45 |
+
"binding_class",
|
| 46 |
+
"binding class",
|
| 47 |
+
"psr_class",
|
| 48 |
+
"psr binding",
|
| 49 |
+
"psr classification",
|
| 50 |
+
"Binding class",
|
| 51 |
+
"Binding Class",
|
| 52 |
+
),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
_LABEL_MAP = {
|
| 56 |
+
"polyreactive": 1,
|
| 57 |
+
"non-polyreactive": 0,
|
| 58 |
+
"positive": 1,
|
| 59 |
+
"negative": 0,
|
| 60 |
+
"high": 1,
|
| 61 |
+
"low": 0,
|
| 62 |
+
"pos": 1,
|
| 63 |
+
"neg": 0,
|
| 64 |
+
1: 1,
|
| 65 |
+
0: 0,
|
| 66 |
+
1.0: 1,
|
| 67 |
+
0.0: 0,
|
| 68 |
+
"1": 1,
|
| 69 |
+
"0": 0,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
_PSR_SCORE_ALIASES: tuple[str, ...] = (
|
| 73 |
+
"psr score",
|
| 74 |
+
"psr_score",
|
| 75 |
+
"psr overall score",
|
| 76 |
+
"overall score",
|
| 77 |
+
"psr z",
|
| 78 |
+
"psr_z",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _clean_sequence(sequence: object) -> str:
|
| 83 |
+
if isinstance(sequence, str):
|
| 84 |
+
return "".join(sequence.split()).upper()
|
| 85 |
+
return ""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _maybe_extract_psr_scores(frame: pd.DataFrame) -> pd.DataFrame:
|
| 89 |
+
scores: dict[str, pd.Series] = {}
|
| 90 |
+
for column in frame.columns:
|
| 91 |
+
lowered = column.strip().lower()
|
| 92 |
+
if any(alias in lowered for alias in _PSR_SCORE_ALIASES):
|
| 93 |
+
key = lowered.replace(" ", "_")
|
| 94 |
+
scores[key] = frame[column]
|
| 95 |
+
if not scores:
|
| 96 |
+
return pd.DataFrame(index=frame.index)
|
| 97 |
+
renamed = {}
|
| 98 |
+
for name, series in scores.items():
|
| 99 |
+
cleaned_name = name
|
| 100 |
+
for prefix in ("psr_", "overall_"):
|
| 101 |
+
if cleaned_name.startswith(prefix):
|
| 102 |
+
cleaned_name = cleaned_name[len(prefix) :]
|
| 103 |
+
break
|
| 104 |
+
cleaned_name = cleaned_name.replace("__", "_")
|
| 105 |
+
cleaned_name = cleaned_name.replace("(", "").replace(")", "")
|
| 106 |
+
cleaned_name = cleaned_name.replace("-", "_")
|
| 107 |
+
renamed[f"psr_{cleaned_name}"] = pd.to_numeric(series, errors="coerce")
|
| 108 |
+
return pd.DataFrame(renamed)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _pick_source_label(path: Path | None) -> str:
|
| 112 |
+
if path is None:
|
| 113 |
+
return SHEHATA_SOURCE
|
| 114 |
+
stem = path.stem.lower()
|
| 115 |
+
if "curated" in stem or "subset" in stem:
|
| 116 |
+
return f"{SHEHATA_SOURCE}_curated"
|
| 117 |
+
return SHEHATA_SOURCE
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _standardize(
|
| 121 |
+
frame: pd.DataFrame,
|
| 122 |
+
*,
|
| 123 |
+
heavy_only: bool,
|
| 124 |
+
source: str,
|
| 125 |
+
) -> pd.DataFrame:
|
| 126 |
+
standardized = standardize_frame(
|
| 127 |
+
frame,
|
| 128 |
+
source=source,
|
| 129 |
+
heavy_only=heavy_only,
|
| 130 |
+
column_aliases=_COLUMN_ALIASES,
|
| 131 |
+
label_map=_LABEL_MAP,
|
| 132 |
+
is_test=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
psr_scores = _maybe_extract_psr_scores(frame)
|
| 136 |
+
|
| 137 |
+
mask = standardized["heavy_seq"].map(_clean_sequence) != ""
|
| 138 |
+
standardized = standardized.loc[mask].copy()
|
| 139 |
+
standardized.reset_index(drop=True, inplace=True)
|
| 140 |
+
standardized["heavy_seq"] = standardized["heavy_seq"].map(_clean_sequence)
|
| 141 |
+
standardized["light_seq"] = standardized["light_seq"].map(_clean_sequence)
|
| 142 |
+
|
| 143 |
+
if not psr_scores.empty:
|
| 144 |
+
psr_scores = psr_scores.loc[mask]
|
| 145 |
+
psr_scores = psr_scores.reset_index(drop=True)
|
| 146 |
+
for column in psr_scores.columns:
|
| 147 |
+
standardized[column] = psr_scores[column].reset_index(drop=True)
|
| 148 |
+
|
| 149 |
+
return standardized
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _read_excel(path: Path, *, heavy_only: bool) -> pd.DataFrame:
|
| 153 |
+
excel = pd.ExcelFile(path, engine="openpyxl")
|
| 154 |
+
sheet_candidates: Iterable[str] = excel.sheet_names
|
| 155 |
+
|
| 156 |
+
def _score(name: str) -> tuple[int, str]:
|
| 157 |
+
lowered = name.lower()
|
| 158 |
+
priority = 0
|
| 159 |
+
if "psr" in lowered or "polyreactivity" in lowered:
|
| 160 |
+
priority = 2
|
| 161 |
+
elif "sheet" not in lowered:
|
| 162 |
+
priority = 1
|
| 163 |
+
return (-priority, name)
|
| 164 |
+
|
| 165 |
+
sheet_name = sorted(sheet_candidates, key=_score)[0]
|
| 166 |
+
raw = excel.parse(sheet_name)
|
| 167 |
+
raw = raw.dropna(how="all")
|
| 168 |
+
return _standardize(raw, heavy_only=heavy_only, source=_pick_source_label(path))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
|
| 172 |
+
"""Load the Shehata dataset into the canonical format.
|
| 173 |
+
|
| 174 |
+
Supports both pre-processed CSV exports and the original Excel supplement
|
| 175 |
+
(*.xls/*.xlsx). Additional PSR score columns are preserved when available.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
lower = path_or_url.lower()
|
| 179 |
+
source_override: str | None = None
|
| 180 |
+
if lower.startswith("http://") or lower.startswith("https://"):
|
| 181 |
+
if lower.endswith((".xls", ".xlsx")):
|
| 182 |
+
raw = pd.read_excel(path_or_url, engine="openpyxl")
|
| 183 |
+
return _standardize(raw, heavy_only=heavy_only, source=SHEHATA_SOURCE)
|
| 184 |
+
frame = pd.read_csv(path_or_url)
|
| 185 |
+
return _standardize(frame, heavy_only=heavy_only, source=SHEHATA_SOURCE)
|
| 186 |
+
|
| 187 |
+
path = Path(path_or_url)
|
| 188 |
+
source_override = _pick_source_label(path)
|
| 189 |
+
if path.suffix.lower() in {".xls", ".xlsx"}:
|
| 190 |
+
engine = "openpyxl" if path.suffix.lower() == ".xlsx" else None
|
| 191 |
+
if engine:
|
| 192 |
+
frame = _read_excel(path, heavy_only=heavy_only)
|
| 193 |
+
else:
|
| 194 |
+
frame = pd.read_excel(path, engine=None)
|
| 195 |
+
frame = _standardize(frame, heavy_only=heavy_only, source=source_override)
|
| 196 |
+
frame["source"] = source_override
|
| 197 |
+
return frame
|
| 198 |
+
|
| 199 |
+
frame = pd.read_csv(path)
|
| 200 |
+
standardized = _standardize(frame, heavy_only=heavy_only, source=source_override)
|
| 201 |
+
standardized["source"] = source_override
|
| 202 |
+
return standardized
|
polyreact/data_loaders/utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility helpers for dataset loading."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Iterable, Sequence
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
EXPECTED_COLUMNS = ("id", "heavy_seq", "light_seq", "label")
|
| 11 |
+
OPTIONAL_COLUMNS = ("source", "is_test")
|
| 12 |
+
|
| 13 |
+
LOGGER = logging.getLogger("polyreact.data")
|
| 14 |
+
|
| 15 |
+
_DEFAULT_ALIASES: dict[str, Sequence[str]] = {
|
| 16 |
+
"id": ("id", "sequence_id", "antibody_id", "uid"),
|
| 17 |
+
"heavy_seq": ("heavy_seq", "heavy", "heavy_chain", "H", "H_chain"),
|
| 18 |
+
"light_seq": ("light_seq", "light", "light_chain", "L", "L_chain"),
|
| 19 |
+
"label": ("label", "polyreactive", "is_polyreactive", "class", "target"),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
DEFAULT_LABEL_MAP: dict[str | int | float | bool, int] = {
|
| 23 |
+
1: 1,
|
| 24 |
+
0: 0,
|
| 25 |
+
"1": 1,
|
| 26 |
+
"0": 0,
|
| 27 |
+
True: 1,
|
| 28 |
+
False: 0,
|
| 29 |
+
"true": 1,
|
| 30 |
+
"false": 0,
|
| 31 |
+
"polyreactive": 1,
|
| 32 |
+
"non-polyreactive": 0,
|
| 33 |
+
"poly": 1,
|
| 34 |
+
"non": 0,
|
| 35 |
+
"positive": 1,
|
| 36 |
+
"negative": 0,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _normalize_label_key(value: object) -> object:
|
| 41 |
+
if isinstance(value, str):
|
| 42 |
+
trimmed = value.strip().lower()
|
| 43 |
+
if trimmed in {
|
| 44 |
+
"polyreactive",
|
| 45 |
+
"non-polyreactive",
|
| 46 |
+
"poly",
|
| 47 |
+
"non",
|
| 48 |
+
"positive",
|
| 49 |
+
"negative",
|
| 50 |
+
"high",
|
| 51 |
+
"low",
|
| 52 |
+
"pos",
|
| 53 |
+
"neg",
|
| 54 |
+
"1",
|
| 55 |
+
"0",
|
| 56 |
+
"true",
|
| 57 |
+
"false",
|
| 58 |
+
}:
|
| 59 |
+
return trimmed
|
| 60 |
+
if trimmed.isdigit():
|
| 61 |
+
return trimmed
|
| 62 |
+
return value
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def ensure_columns(frame: pd.DataFrame, *, heavy_only: bool = True) -> pd.DataFrame:
|
| 66 |
+
"""Validate and coerce dataframe columns to the canonical format."""
|
| 67 |
+
|
| 68 |
+
frame = frame.copy()
|
| 69 |
+
for column in ("id", "heavy_seq", "label"):
|
| 70 |
+
if column not in frame.columns:
|
| 71 |
+
msg = f"Required column '{column}' missing from dataframe"
|
| 72 |
+
raise KeyError(msg)
|
| 73 |
+
|
| 74 |
+
if "light_seq" not in frame.columns:
|
| 75 |
+
frame["light_seq"] = ""
|
| 76 |
+
|
| 77 |
+
if heavy_only:
|
| 78 |
+
frame["light_seq"] = ""
|
| 79 |
+
|
| 80 |
+
frame["id"] = frame["id"].astype(str)
|
| 81 |
+
frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
|
| 82 |
+
frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
|
| 83 |
+
frame["label"] = frame["label"].astype(int)
|
| 84 |
+
|
| 85 |
+
ordered = list(EXPECTED_COLUMNS) + [
|
| 86 |
+
col for col in frame.columns if col not in EXPECTED_COLUMNS
|
| 87 |
+
]
|
| 88 |
+
return frame[ordered]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def standardize_frame(
|
| 92 |
+
frame: pd.DataFrame,
|
| 93 |
+
*,
|
| 94 |
+
source: str,
|
| 95 |
+
heavy_only: bool = True,
|
| 96 |
+
column_aliases: dict[str, Sequence[str]] | None = None,
|
| 97 |
+
label_map: dict[str | int | float | bool, int] | None = None,
|
| 98 |
+
is_test: bool | None = None,
|
| 99 |
+
) -> pd.DataFrame:
|
| 100 |
+
"""Rename columns using aliases and coerce labels to integers."""
|
| 101 |
+
|
| 102 |
+
aliases = {**_DEFAULT_ALIASES}
|
| 103 |
+
if column_aliases:
|
| 104 |
+
for key, values in column_aliases.items():
|
| 105 |
+
aliases[key] = tuple(values) + tuple(aliases.get(key, ()))
|
| 106 |
+
|
| 107 |
+
rename_map: dict[str, str] = {}
|
| 108 |
+
for target, candidates in aliases.items():
|
| 109 |
+
if target in frame.columns:
|
| 110 |
+
continue
|
| 111 |
+
for candidate in candidates:
|
| 112 |
+
if candidate in frame.columns and candidate not in rename_map:
|
| 113 |
+
rename_map[candidate] = target
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
normalized = frame.rename(columns=rename_map).copy()
|
| 117 |
+
|
| 118 |
+
if "light_seq" not in normalized.columns:
|
| 119 |
+
normalized["light_seq"] = ""
|
| 120 |
+
|
| 121 |
+
label_lookup = label_map or DEFAULT_LABEL_MAP
|
| 122 |
+
normalized["label"] = normalized["label"].map(lambda x: label_lookup.get(_normalize_label_key(x)))
|
| 123 |
+
|
| 124 |
+
if normalized["label"].isnull().any():
|
| 125 |
+
msg = "Label column contains unmapped or missing values"
|
| 126 |
+
raise ValueError(msg)
|
| 127 |
+
|
| 128 |
+
normalized["source"] = source
|
| 129 |
+
if is_test is not None:
|
| 130 |
+
normalized["is_test"] = bool(is_test)
|
| 131 |
+
|
| 132 |
+
normalized = ensure_columns(normalized, heavy_only=heavy_only)
|
| 133 |
+
return normalized
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def deduplicate_sequences(
|
| 137 |
+
frames: Iterable[pd.DataFrame],
|
| 138 |
+
*,
|
| 139 |
+
heavy_only: bool = True,
|
| 140 |
+
key_columns: Sequence[str] | None = None,
|
| 141 |
+
keep_intra_frames: set[int] | None = None,
|
| 142 |
+
) -> list[pd.DataFrame]:
|
| 143 |
+
"""Remove duplicate entries across multiple dataframes with configurable keys."""
|
| 144 |
+
|
| 145 |
+
if key_columns is None:
|
| 146 |
+
key_columns = ["heavy_seq"] if heavy_only else ["heavy_seq", "light_seq"]
|
| 147 |
+
keep_intra_frames = keep_intra_frames or set()
|
| 148 |
+
|
| 149 |
+
seen: set[tuple[str, ...]] = set()
|
| 150 |
+
cleaned: list[pd.DataFrame] = []
|
| 151 |
+
|
| 152 |
+
for frame_idx, frame in enumerate(frames):
|
| 153 |
+
valid_columns = [col for col in key_columns if col in frame.columns]
|
| 154 |
+
if not valid_columns:
|
| 155 |
+
valid_columns = ["heavy_seq"]
|
| 156 |
+
|
| 157 |
+
mask: list[bool] = []
|
| 158 |
+
frame_seen: set[tuple[str, ...]] = set()
|
| 159 |
+
allow_intra = frame_idx in keep_intra_frames
|
| 160 |
+
|
| 161 |
+
for values in frame[valid_columns].itertuples(index=False, name=None):
|
| 162 |
+
key = tuple(_normalise_key_value(value) for value in values)
|
| 163 |
+
if key in seen:
|
| 164 |
+
mask.append(False)
|
| 165 |
+
continue
|
| 166 |
+
if not allow_intra and key in frame_seen:
|
| 167 |
+
mask.append(False)
|
| 168 |
+
continue
|
| 169 |
+
mask.append(True)
|
| 170 |
+
frame_seen.add(key)
|
| 171 |
+
seen.update(frame_seen)
|
| 172 |
+
filtered = frame.loc[mask].reset_index(drop=True)
|
| 173 |
+
removed = len(frame) - len(filtered)
|
| 174 |
+
if removed:
|
| 175 |
+
dataset = "<unknown>"
|
| 176 |
+
if "source" in frame.columns and not frame["source"].empty:
|
| 177 |
+
dataset = str(frame["source"].iloc[0])
|
| 178 |
+
LOGGER.info("Removed %s duplicate sequences from %s", removed, dataset)
|
| 179 |
+
cleaned.append(filtered)
|
| 180 |
+
return cleaned
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _normalise_key_value(value: object) -> str:
|
| 184 |
+
if value is None or (isinstance(value, float) and pd.isna(value)):
|
| 185 |
+
return ""
|
| 186 |
+
return str(value).strip()
|
polyreact/features/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Feature backends for polyreactivity prediction."""
|
| 2 |
+
|
| 3 |
+
from . import anarsi, descriptors, plm
|
| 4 |
+
from .pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"anarsi",
|
| 8 |
+
"descriptors",
|
| 9 |
+
"plm",
|
| 10 |
+
"FeaturePipeline",
|
| 11 |
+
"FeaturePipelineState",
|
| 12 |
+
"build_feature_pipeline",
|
| 13 |
+
]
|
polyreact/features/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (530 Bytes). View file
|
|
|
polyreact/features/__pycache__/anarsi.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
polyreact/features/__pycache__/descriptors.cpython-311.pyc
ADDED
|
Binary file (9.22 kB). View file
|
|
|
polyreact/features/__pycache__/pipeline.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
polyreact/features/__pycache__/plm.cpython-311.pyc
ADDED
|
Binary file (23 kB). View file
|
|
|
polyreact/features/anarsi.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ANARCI/ANARCII numbering helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from typing import Dict, List, Sequence, Tuple
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from anarcii.pipeline import Anarcii # type: ignore
|
| 11 |
+
except ImportError: # pragma: no cover - optional dependency
|
| 12 |
+
Anarcii = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(slots=True)
|
| 16 |
+
class NumberedResidue:
|
| 17 |
+
"""Single residue with IMGT numbering metadata."""
|
| 18 |
+
|
| 19 |
+
position: int
|
| 20 |
+
insertion: str
|
| 21 |
+
amino_acid: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(slots=True)
|
| 25 |
+
class NumberedSequence:
|
| 26 |
+
"""Container for numbering results and derived regions."""
|
| 27 |
+
|
| 28 |
+
sequence: str
|
| 29 |
+
scheme: str
|
| 30 |
+
chain_type: str
|
| 31 |
+
residues: list[NumberedResidue]
|
| 32 |
+
regions: dict[str, str]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
_IMGT_HEAVY_REGIONS: Sequence[Tuple[str, int, int]] = (
|
| 36 |
+
("FR1", 1, 26),
|
| 37 |
+
("CDRH1", 27, 38),
|
| 38 |
+
("FR2", 39, 55),
|
| 39 |
+
("CDRH2", 56, 65),
|
| 40 |
+
("FR3", 66, 104),
|
| 41 |
+
("CDRH3", 105, 117),
|
| 42 |
+
("FR4", 118, 128),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
_IMGT_LIGHT_REGIONS: Sequence[Tuple[str, int, int]] = (
|
| 46 |
+
("FR1", 1, 26),
|
| 47 |
+
("CDRL1", 27, 38),
|
| 48 |
+
("FR2", 39, 55),
|
| 49 |
+
("CDRL2", 56, 65),
|
| 50 |
+
("FR3", 66, 104),
|
| 51 |
+
("CDRL3", 105, 117),
|
| 52 |
+
("FR4", 118, 128),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
_REGION_MAP: dict[Tuple[str, str], Sequence[Tuple[str, int, int]]] = {
|
| 56 |
+
("imgt", "H"): _IMGT_HEAVY_REGIONS,
|
| 57 |
+
("imgt", "L"): _IMGT_LIGHT_REGIONS,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
_VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
|
| 61 |
+
_DEFAULT_SCHEME = "imgt"
|
| 62 |
+
_DEFAULT_CHAIN = "H"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
_DEFAULT_NUMBERER: AnarciNumberer | None = None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _sanitize_sequence(sequence: str) -> str:
|
| 69 |
+
return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_default_numberer() -> AnarciNumberer:
|
| 73 |
+
global _DEFAULT_NUMBERER
|
| 74 |
+
if _DEFAULT_NUMBERER is None:
|
| 75 |
+
_DEFAULT_NUMBERER = AnarciNumberer(chain_type=_DEFAULT_CHAIN, cpu=True, ncpu=1, verbose=False)
|
| 76 |
+
return _DEFAULT_NUMBERER
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def trim_variable_domain(
|
| 80 |
+
sequence: str,
|
| 81 |
+
*,
|
| 82 |
+
numberer: AnarciNumberer | None = None,
|
| 83 |
+
scheme: str = _DEFAULT_SCHEME,
|
| 84 |
+
chain_type: str = _DEFAULT_CHAIN,
|
| 85 |
+
fallback_length: int = 130,
|
| 86 |
+
) -> str:
|
| 87 |
+
"""Return the FR1–FR4 variable domain for a heavy/light chain sequence."""
|
| 88 |
+
|
| 89 |
+
cleaned = _sanitize_sequence(sequence)
|
| 90 |
+
if not cleaned:
|
| 91 |
+
return ""
|
| 92 |
+
|
| 93 |
+
active_numberer = numberer or get_default_numberer()
|
| 94 |
+
try:
|
| 95 |
+
numbered = active_numberer.number_sequence(cleaned)
|
| 96 |
+
except Exception: # pragma: no cover - best effort safeguard
|
| 97 |
+
return cleaned[:fallback_length]
|
| 98 |
+
|
| 99 |
+
region_sets = _region_boundaries(scheme, chain_type)
|
| 100 |
+
pieces: list[str] = []
|
| 101 |
+
for region_name, _start, _end in region_sets:
|
| 102 |
+
segment = numbered.regions.get(region_name, "")
|
| 103 |
+
if segment:
|
| 104 |
+
pieces.append(segment)
|
| 105 |
+
trimmed = "".join(pieces)
|
| 106 |
+
if not trimmed:
|
| 107 |
+
trimmed = numbered.regions.get("full", "")
|
| 108 |
+
if not trimmed:
|
| 109 |
+
trimmed = cleaned[:fallback_length]
|
| 110 |
+
return trimmed
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _normalise_chain_type(chain_type: str) -> str:
|
| 114 |
+
upper = chain_type.upper()
|
| 115 |
+
if upper in {"H", "HV"}:
|
| 116 |
+
return "H"
|
| 117 |
+
if upper in {"L", "K", "LV", "KV"}:
|
| 118 |
+
return "L"
|
| 119 |
+
return upper
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class AnarciNumberer:
|
| 123 |
+
"""Thin wrapper around the ANARCII pipeline to obtain IMGT regions."""
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
*,
|
| 128 |
+
scheme: str = "imgt",
|
| 129 |
+
chain_type: str = "H",
|
| 130 |
+
cpu: bool = True,
|
| 131 |
+
ncpu: int = 1,
|
| 132 |
+
verbose: bool = False,
|
| 133 |
+
) -> None:
|
| 134 |
+
if Anarcii is None: # pragma: no cover - optional dependency guard
|
| 135 |
+
msg = (
|
| 136 |
+
"anarcii is required for numbering but is not installed."
|
| 137 |
+
" Install 'anarcii' to enable ANARCI-based features."
|
| 138 |
+
)
|
| 139 |
+
raise ImportError(msg)
|
| 140 |
+
self.scheme = scheme
|
| 141 |
+
self.expected_chain_type = _normalise_chain_type(chain_type)
|
| 142 |
+
self.cpu = cpu
|
| 143 |
+
self.ncpu = ncpu
|
| 144 |
+
self.verbose = verbose
|
| 145 |
+
self._runner = None
|
| 146 |
+
|
| 147 |
+
def _ensure_runner(self) -> Anarcii:
|
| 148 |
+
if self._runner is None:
|
| 149 |
+
self._runner = Anarcii(
|
| 150 |
+
seq_type="antibody",
|
| 151 |
+
mode="accuracy",
|
| 152 |
+
batch_size=1,
|
| 153 |
+
cpu=self.cpu,
|
| 154 |
+
ncpu=self.ncpu,
|
| 155 |
+
verbose=self.verbose,
|
| 156 |
+
)
|
| 157 |
+
return self._runner
|
| 158 |
+
|
| 159 |
+
def number_sequence(self, sequence: str) -> NumberedSequence:
|
| 160 |
+
"""Return numbering metadata for a single amino-acid sequence."""
|
| 161 |
+
|
| 162 |
+
runner = self._ensure_runner()
|
| 163 |
+
output = runner.number([sequence])
|
| 164 |
+
record = next(iter(output.values()))
|
| 165 |
+
if record.get("error"):
|
| 166 |
+
raise RuntimeError(f"ANARCI failed: {record['error']}")
|
| 167 |
+
|
| 168 |
+
scheme = record.get("scheme", self.scheme)
|
| 169 |
+
detected_chain = record.get("chain_type", self.expected_chain_type)
|
| 170 |
+
normalised_chain = _normalise_chain_type(detected_chain)
|
| 171 |
+
if self.expected_chain_type and normalised_chain != self.expected_chain_type:
|
| 172 |
+
msg = (
|
| 173 |
+
f"Expected chain type {self.expected_chain_type!r} but got"
|
| 174 |
+
f" {normalised_chain!r}"
|
| 175 |
+
)
|
| 176 |
+
raise ValueError(msg)
|
| 177 |
+
|
| 178 |
+
residues = [
|
| 179 |
+
NumberedResidue(position=pos, insertion=ins, amino_acid=aa)
|
| 180 |
+
for (pos, ins), aa in record["numbering"]
|
| 181 |
+
]
|
| 182 |
+
regions = _extract_regions(
|
| 183 |
+
residues=residues,
|
| 184 |
+
scheme=scheme,
|
| 185 |
+
chain_type=normalised_chain,
|
| 186 |
+
)
|
| 187 |
+
return NumberedSequence(
|
| 188 |
+
sequence=sequence,
|
| 189 |
+
scheme=scheme,
|
| 190 |
+
chain_type=normalised_chain,
|
| 191 |
+
residues=residues,
|
| 192 |
+
regions=regions,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@lru_cache(maxsize=32)
|
| 197 |
+
def _region_boundaries(scheme: str, chain_type: str) -> Sequence[Tuple[str, int, int]]:
|
| 198 |
+
key = (scheme.lower(), chain_type.upper())
|
| 199 |
+
return _REGION_MAP.get(key, ())
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _extract_regions(
|
| 203 |
+
*,
|
| 204 |
+
residues: Sequence[NumberedResidue],
|
| 205 |
+
scheme: str,
|
| 206 |
+
chain_type: str,
|
| 207 |
+
) -> dict[str, str]:
|
| 208 |
+
boundaries = _region_boundaries(scheme, chain_type)
|
| 209 |
+
slots: Dict[str, List[str]] = {name: [] for name, _, _ in boundaries}
|
| 210 |
+
slots["full"] = []
|
| 211 |
+
|
| 212 |
+
for residue in residues:
|
| 213 |
+
aa = residue.amino_acid
|
| 214 |
+
if aa == "-":
|
| 215 |
+
continue
|
| 216 |
+
slots["full"].append(aa)
|
| 217 |
+
for name, start, end in boundaries:
|
| 218 |
+
if start <= residue.position <= end:
|
| 219 |
+
slots.setdefault(name, []).append(aa)
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
return {key: "".join(value) for key, value in slots.items()}
|
polyreact/features/descriptors.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sequence descriptor features for polyreactivity prediction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Iterable, Sequence
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from Bio.SeqUtils.ProtParam import ProteinAnalysis
|
| 11 |
+
from sklearn.preprocessing import StandardScaler
|
| 12 |
+
|
| 13 |
+
from .anarsi import AnarciNumberer, NumberedSequence
|
| 14 |
+
|
| 15 |
+
_VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(slots=True)
|
| 19 |
+
class DescriptorConfig:
|
| 20 |
+
"""Configuration for descriptor-based features."""
|
| 21 |
+
|
| 22 |
+
use_anarci: bool = True
|
| 23 |
+
regions: Sequence[str] = ("CDRH1", "CDRH2", "CDRH3")
|
| 24 |
+
features: Sequence[str] = (
|
| 25 |
+
"length",
|
| 26 |
+
"charge",
|
| 27 |
+
"hydropathy",
|
| 28 |
+
"aromaticity",
|
| 29 |
+
"pI",
|
| 30 |
+
"net_charge",
|
| 31 |
+
)
|
| 32 |
+
ph: float = 7.4
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DescriptorFeaturizer:
|
| 36 |
+
"""Compute descriptor features with optional ANARCI-based regions."""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
*,
|
| 41 |
+
config: DescriptorConfig,
|
| 42 |
+
numberer: AnarciNumberer | None = None,
|
| 43 |
+
standardize: bool = True,
|
| 44 |
+
) -> None:
|
| 45 |
+
self.config = config
|
| 46 |
+
self.numberer = numberer if not config.use_anarci else numberer or AnarciNumberer()
|
| 47 |
+
self.standardize = standardize
|
| 48 |
+
self.scaler = StandardScaler() if standardize else None
|
| 49 |
+
self.feature_names_: list[str] | None = None
|
| 50 |
+
|
| 51 |
+
def fit(self, sequences: Iterable[str]) -> "DescriptorFeaturizer":
|
| 52 |
+
table = self.compute_feature_table(sequences)
|
| 53 |
+
values = table.to_numpy(dtype=float)
|
| 54 |
+
if self.standardize and self.scaler is not None:
|
| 55 |
+
self.scaler.fit(values)
|
| 56 |
+
self.feature_names_ = list(table.columns)
|
| 57 |
+
return self
|
| 58 |
+
|
| 59 |
+
def transform(self, sequences: Iterable[str]) -> np.ndarray:
|
| 60 |
+
if self.feature_names_ is None:
|
| 61 |
+
msg = "DescriptorFeaturizer must be fitted before calling transform."
|
| 62 |
+
raise RuntimeError(msg)
|
| 63 |
+
table = self.compute_feature_table(sequences)
|
| 64 |
+
values = table.to_numpy(dtype=float)
|
| 65 |
+
if self.standardize and self.scaler is not None:
|
| 66 |
+
values = self.scaler.transform(values)
|
| 67 |
+
return values
|
| 68 |
+
|
| 69 |
+
def fit_transform(self, sequences: Iterable[str]) -> np.ndarray:
|
| 70 |
+
table = self.compute_feature_table(sequences)
|
| 71 |
+
values = table.to_numpy(dtype=float)
|
| 72 |
+
if self.standardize and self.scaler is not None:
|
| 73 |
+
self.scaler.fit(values)
|
| 74 |
+
values = self.scaler.transform(values)
|
| 75 |
+
self.feature_names_ = list(table.columns)
|
| 76 |
+
return values
|
| 77 |
+
|
| 78 |
+
def compute_feature_table(self, sequences: Iterable[str]) -> pd.DataFrame:
|
| 79 |
+
rows: list[dict[str, float]] = []
|
| 80 |
+
for sequence in sequences:
|
| 81 |
+
regions = self._prepare_regions(sequence)
|
| 82 |
+
if not self.config.use_anarci:
|
| 83 |
+
region_names = ["FULL"]
|
| 84 |
+
else:
|
| 85 |
+
region_names = [region.upper() for region in self.config.regions]
|
| 86 |
+
row: dict[str, float] = {}
|
| 87 |
+
for region_name in region_names:
|
| 88 |
+
normalized_name = region_name.upper()
|
| 89 |
+
region_sequence = regions.get(normalized_name, "")
|
| 90 |
+
for feature_name in self.config.features:
|
| 91 |
+
column = f"{normalized_name}_{feature_name}"
|
| 92 |
+
row[column] = _compute_feature(
|
| 93 |
+
region_sequence,
|
| 94 |
+
feature_name,
|
| 95 |
+
ph=self.config.ph,
|
| 96 |
+
)
|
| 97 |
+
rows.append(row)
|
| 98 |
+
|
| 99 |
+
if not self.config.use_anarci:
|
| 100 |
+
region_names = ["FULL"]
|
| 101 |
+
else:
|
| 102 |
+
region_names = [region.upper() for region in self.config.regions]
|
| 103 |
+
columns = [
|
| 104 |
+
f"{region}_{feature}"
|
| 105 |
+
for region in region_names
|
| 106 |
+
for feature in self.config.features
|
| 107 |
+
]
|
| 108 |
+
frame = pd.DataFrame(rows, columns=columns)
|
| 109 |
+
return frame.fillna(0.0)
|
| 110 |
+
|
| 111 |
+
def _prepare_regions(self, sequence: str) -> dict[str, str]:
|
| 112 |
+
if not self.config.use_anarci:
|
| 113 |
+
return {"FULL": sequence}
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
numbered: NumberedSequence = self.numberer.number_sequence(sequence)
|
| 117 |
+
except (RuntimeError, ValueError):
|
| 118 |
+
return {}
|
| 119 |
+
return {key.upper(): value for key, value in numbered.regions.items()}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _sanitize_sequence(sequence: str) -> str:
|
| 123 |
+
return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _compute_feature(sequence: str, feature_name: str, *, ph: float) -> float:
|
| 127 |
+
sanitized = _sanitize_sequence(sequence)
|
| 128 |
+
if not sanitized:
|
| 129 |
+
return 0.0
|
| 130 |
+
|
| 131 |
+
analysis = ProteinAnalysis(sanitized)
|
| 132 |
+
if feature_name == "length":
|
| 133 |
+
return float(len(sanitized))
|
| 134 |
+
if feature_name == "hydropathy":
|
| 135 |
+
return float(analysis.gravy())
|
| 136 |
+
if feature_name == "aromaticity":
|
| 137 |
+
return float(analysis.aromaticity())
|
| 138 |
+
if feature_name == "pI":
|
| 139 |
+
return float(analysis.isoelectric_point())
|
| 140 |
+
if feature_name == "net_charge":
|
| 141 |
+
return float(analysis.charge_at_pH(ph))
|
| 142 |
+
if feature_name == "charge":
|
| 143 |
+
net = analysis.charge_at_pH(ph)
|
| 144 |
+
return float(net / len(sanitized))
|
| 145 |
+
msg = f"Unsupported feature: {feature_name}"
|
| 146 |
+
raise ValueError(msg)
|
polyreact/features/pipeline.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Feature pipeline construction utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import asdict, dataclass, field
|
| 6 |
+
from typing import Iterable, Sequence
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.preprocessing import StandardScaler
|
| 10 |
+
|
| 11 |
+
from ..config import Config, DescriptorSettings, FeatureBackendSettings
|
| 12 |
+
from .descriptors import DescriptorConfig, DescriptorFeaturizer
|
| 13 |
+
from .plm import PLMEmbedder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(slots=True)
|
| 17 |
+
class FeaturePipelineState:
|
| 18 |
+
backend_type: str
|
| 19 |
+
descriptor_featurizer: DescriptorFeaturizer | None
|
| 20 |
+
plm_scaler: StandardScaler | None
|
| 21 |
+
descriptor_config: DescriptorConfig | None
|
| 22 |
+
plm_model_name: str | None
|
| 23 |
+
plm_layer_pool: str | None
|
| 24 |
+
cache_dir: str | None
|
| 25 |
+
device: str
|
| 26 |
+
feature_names: list[str] = field(default_factory=list)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FeaturePipeline:
|
| 30 |
+
"""Fit/transform feature matrices according to configuration."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
*,
|
| 35 |
+
backend: FeatureBackendSettings,
|
| 36 |
+
descriptors: DescriptorSettings,
|
| 37 |
+
device: str,
|
| 38 |
+
cache_dir_override: str | None = None,
|
| 39 |
+
plm_model_override: str | None = None,
|
| 40 |
+
layer_pool_override: str | None = None,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.backend = backend
|
| 43 |
+
self.descriptor_settings = descriptors
|
| 44 |
+
self.device = device
|
| 45 |
+
self.cache_dir_override = cache_dir_override
|
| 46 |
+
self.plm_model_override = plm_model_override
|
| 47 |
+
self.layer_pool_override = layer_pool_override
|
| 48 |
+
|
| 49 |
+
self._descriptor: DescriptorFeaturizer | None = None
|
| 50 |
+
self._plm: PLMEmbedder | None = None
|
| 51 |
+
self._plm_scaler: StandardScaler | None = None
|
| 52 |
+
self._feature_names: list[str] = []
|
| 53 |
+
|
| 54 |
+
def fit_transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001
|
| 55 |
+
backend_type = self.backend.type if self.backend.type else "descriptors"
|
| 56 |
+
self._validate_heavy_support(backend_type, heavy_only)
|
| 57 |
+
sequences = _extract_sequences(df, heavy_only=heavy_only)
|
| 58 |
+
|
| 59 |
+
if backend_type == "descriptors":
|
| 60 |
+
self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
|
| 61 |
+
features = self._descriptor.fit_transform(sequences)
|
| 62 |
+
self._feature_names = list(self._descriptor.feature_names_ or [])
|
| 63 |
+
self._plm = None
|
| 64 |
+
self._plm_scaler = None
|
| 65 |
+
return features.astype(np.float32)
|
| 66 |
+
|
| 67 |
+
if backend_type == "plm":
|
| 68 |
+
self._descriptor = None
|
| 69 |
+
self._plm = _build_plm_embedder(
|
| 70 |
+
self.backend,
|
| 71 |
+
device=self.device,
|
| 72 |
+
cache_dir_override=self.cache_dir_override,
|
| 73 |
+
plm_model_override=self.plm_model_override,
|
| 74 |
+
layer_pool_override=self.layer_pool_override,
|
| 75 |
+
)
|
| 76 |
+
embeddings = self._plm.embed(sequences, batch_size=batch_size)
|
| 77 |
+
if self.backend.standardize:
|
| 78 |
+
self._plm_scaler = StandardScaler()
|
| 79 |
+
embeddings = self._plm_scaler.fit_transform(embeddings)
|
| 80 |
+
else:
|
| 81 |
+
self._plm_scaler = None
|
| 82 |
+
self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
|
| 83 |
+
return embeddings.astype(np.float32)
|
| 84 |
+
|
| 85 |
+
if backend_type == "concat":
|
| 86 |
+
descriptor = _build_descriptor_featurizer(self.descriptor_settings)
|
| 87 |
+
desc_features = descriptor.fit_transform(sequences)
|
| 88 |
+
plm = _build_plm_embedder(
|
| 89 |
+
self.backend,
|
| 90 |
+
device=self.device,
|
| 91 |
+
cache_dir_override=self.cache_dir_override,
|
| 92 |
+
plm_model_override=self.plm_model_override,
|
| 93 |
+
layer_pool_override=self.layer_pool_override,
|
| 94 |
+
)
|
| 95 |
+
embeddings = plm.embed(sequences, batch_size=batch_size)
|
| 96 |
+
if self.backend.standardize:
|
| 97 |
+
plm_scaler = StandardScaler()
|
| 98 |
+
embeddings = plm_scaler.fit_transform(embeddings)
|
| 99 |
+
else:
|
| 100 |
+
plm_scaler = None
|
| 101 |
+
self._descriptor = descriptor
|
| 102 |
+
self._plm = plm
|
| 103 |
+
self._plm_scaler = plm_scaler
|
| 104 |
+
self._feature_names = list(descriptor.feature_names_ or []) + [
|
| 105 |
+
f"plm_{i}" for i in range(embeddings.shape[1])
|
| 106 |
+
]
|
| 107 |
+
return np.concatenate([desc_features, embeddings], axis=1).astype(np.float32)
|
| 108 |
+
|
| 109 |
+
msg = f"Unsupported feature backend: {backend_type}"
|
| 110 |
+
raise ValueError(msg)
|
| 111 |
+
|
| 112 |
+
def fit(self, df, *, heavy_only: bool, batch_size: int = 8) -> "FeaturePipeline": # noqa: ANN001
|
| 113 |
+
backend_type = self.backend.type if self.backend.type else "descriptors"
|
| 114 |
+
self._validate_heavy_support(backend_type, heavy_only)
|
| 115 |
+
sequences = _extract_sequences(df, heavy_only=heavy_only)
|
| 116 |
+
|
| 117 |
+
if backend_type == "descriptors":
|
| 118 |
+
self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
|
| 119 |
+
self._descriptor.fit(sequences)
|
| 120 |
+
self._feature_names = list(self._descriptor.feature_names_ or [])
|
| 121 |
+
self._plm = None
|
| 122 |
+
self._plm_scaler = None
|
| 123 |
+
elif backend_type == "plm":
|
| 124 |
+
self._descriptor = None
|
| 125 |
+
self._plm = _build_plm_embedder(
|
| 126 |
+
self.backend,
|
| 127 |
+
device=self.device,
|
| 128 |
+
cache_dir_override=self.cache_dir_override,
|
| 129 |
+
plm_model_override=self.plm_model_override,
|
| 130 |
+
layer_pool_override=self.layer_pool_override,
|
| 131 |
+
)
|
| 132 |
+
embeddings = self._plm.embed(sequences, batch_size=batch_size)
|
| 133 |
+
if self.backend.standardize:
|
| 134 |
+
self._plm_scaler = StandardScaler()
|
| 135 |
+
embeddings = self._plm_scaler.fit_transform(embeddings)
|
| 136 |
+
else:
|
| 137 |
+
self._plm_scaler = None
|
| 138 |
+
self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
|
| 139 |
+
elif backend_type == "concat":
|
| 140 |
+
descriptor = _build_descriptor_featurizer(self.descriptor_settings)
|
| 141 |
+
desc_features = descriptor.fit_transform(sequences)
|
| 142 |
+
plm = _build_plm_embedder(
|
| 143 |
+
self.backend,
|
| 144 |
+
device=self.device,
|
| 145 |
+
cache_dir_override=self.cache_dir_override,
|
| 146 |
+
plm_model_override=self.plm_model_override,
|
| 147 |
+
layer_pool_override=self.layer_pool_override,
|
| 148 |
+
)
|
| 149 |
+
embeddings = plm.embed(sequences, batch_size=batch_size)
|
| 150 |
+
if self.backend.standardize:
|
| 151 |
+
plm_scaler = StandardScaler()
|
| 152 |
+
embeddings = plm_scaler.fit_transform(embeddings)
|
| 153 |
+
else:
|
| 154 |
+
plm_scaler = None
|
| 155 |
+
self._descriptor = descriptor
|
| 156 |
+
self._plm = plm
|
| 157 |
+
self._plm_scaler = plm_scaler
|
| 158 |
+
self._feature_names = list(descriptor.feature_names_ or []) + [
|
| 159 |
+
f"plm_{i}" for i in range(embeddings.shape[1])
|
| 160 |
+
]
|
| 161 |
+
else: # pragma: no cover - defensive branch
|
| 162 |
+
msg = f"Unsupported feature backend: {backend_type}"
|
| 163 |
+
raise ValueError(msg)
|
| 164 |
+
return self
|
| 165 |
+
|
| 166 |
+
def transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001
|
| 167 |
+
backend_type = self.backend.type if self.backend.type else "descriptors"
|
| 168 |
+
self._validate_heavy_support(backend_type, heavy_only)
|
| 169 |
+
sequences = _extract_sequences(df, heavy_only=heavy_only)
|
| 170 |
+
|
| 171 |
+
if backend_type == "descriptors":
|
| 172 |
+
if self._descriptor is None:
|
| 173 |
+
msg = "Descriptor featurizer is not fitted"
|
| 174 |
+
raise RuntimeError(msg)
|
| 175 |
+
features = self._descriptor.transform(sequences)
|
| 176 |
+
elif backend_type == "plm":
|
| 177 |
+
if self._plm is None:
|
| 178 |
+
msg = "PLM embedder is not initialised"
|
| 179 |
+
raise RuntimeError(msg)
|
| 180 |
+
embeddings = self._plm.embed(sequences, batch_size=batch_size)
|
| 181 |
+
if self.backend.standardize and self._plm_scaler is not None:
|
| 182 |
+
embeddings = self._plm_scaler.transform(embeddings)
|
| 183 |
+
features = embeddings
|
| 184 |
+
elif backend_type == "concat":
|
| 185 |
+
if self._descriptor is None or self._plm is None:
|
| 186 |
+
msg = "Feature pipeline not fitted"
|
| 187 |
+
raise RuntimeError(msg)
|
| 188 |
+
desc_features = self._descriptor.transform(sequences)
|
| 189 |
+
embeddings = self._plm.embed(sequences, batch_size=batch_size)
|
| 190 |
+
if self.backend.standardize and self._plm_scaler is not None:
|
| 191 |
+
embeddings = self._plm_scaler.transform(embeddings)
|
| 192 |
+
features = np.concatenate([desc_features, embeddings], axis=1)
|
| 193 |
+
else: # pragma: no cover - defensive branch
|
| 194 |
+
msg = f"Unsupported feature backend: {backend_type}"
|
| 195 |
+
raise ValueError(msg)
|
| 196 |
+
|
| 197 |
+
return features.astype(np.float32)
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def feature_names(self) -> list[str]:
|
| 201 |
+
return self._feature_names
|
| 202 |
+
|
| 203 |
+
def get_state(self) -> FeaturePipelineState:
|
| 204 |
+
descriptor = self._descriptor
|
| 205 |
+
if descriptor is not None and descriptor.numberer is not None:
|
| 206 |
+
if hasattr(descriptor.numberer, "_runner"):
|
| 207 |
+
descriptor.numberer._runner = None # type: ignore[attr-defined]
|
| 208 |
+
return FeaturePipelineState(
|
| 209 |
+
backend_type=self.backend.type,
|
| 210 |
+
descriptor_featurizer=descriptor,
|
| 211 |
+
plm_scaler=self._plm_scaler,
|
| 212 |
+
descriptor_config=_build_descriptor_config(self.descriptor_settings),
|
| 213 |
+
plm_model_name=self._effective_plm_model_name,
|
| 214 |
+
plm_layer_pool=self._effective_layer_pool,
|
| 215 |
+
cache_dir=self._effective_cache_dir,
|
| 216 |
+
device=self.device,
|
| 217 |
+
feature_names=self._feature_names,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def load_state(self, state: FeaturePipelineState) -> None:
|
| 221 |
+
self.backend.type = state.backend_type
|
| 222 |
+
if state.plm_model_name:
|
| 223 |
+
self.backend.plm_model_name = state.plm_model_name
|
| 224 |
+
self.plm_model_override = state.plm_model_name
|
| 225 |
+
if state.plm_layer_pool:
|
| 226 |
+
self.backend.layer_pool = state.plm_layer_pool
|
| 227 |
+
self.layer_pool_override = state.plm_layer_pool
|
| 228 |
+
if state.cache_dir:
|
| 229 |
+
self.backend.cache_dir = state.cache_dir
|
| 230 |
+
self.cache_dir_override = state.cache_dir
|
| 231 |
+
if state.descriptor_config:
|
| 232 |
+
self.descriptor_settings = DescriptorSettings(
|
| 233 |
+
use_anarci=state.descriptor_config.use_anarci,
|
| 234 |
+
regions=tuple(state.descriptor_config.regions),
|
| 235 |
+
features=tuple(state.descriptor_config.features),
|
| 236 |
+
ph=state.descriptor_config.ph,
|
| 237 |
+
)
|
| 238 |
+
self._descriptor = state.descriptor_featurizer
|
| 239 |
+
self._plm_scaler = state.plm_scaler
|
| 240 |
+
self._feature_names = state.feature_names
|
| 241 |
+
if self.backend.type in {"plm", "concat"}:
|
| 242 |
+
self._plm = _build_plm_embedder(
|
| 243 |
+
self.backend,
|
| 244 |
+
device=self.device,
|
| 245 |
+
cache_dir_override=self.backend.cache_dir,
|
| 246 |
+
plm_model_override=self.backend.plm_model_name,
|
| 247 |
+
layer_pool_override=self.backend.layer_pool,
|
| 248 |
+
)
|
| 249 |
+
else:
|
| 250 |
+
self._plm = None
|
| 251 |
+
|
| 252 |
+
@property
|
| 253 |
+
def _effective_plm_model_name(self) -> str | None:
|
| 254 |
+
if self.backend.type not in {"plm", "concat"}:
|
| 255 |
+
return None
|
| 256 |
+
return self.plm_model_override or self.backend.plm_model_name
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def _effective_layer_pool(self) -> str | None:
|
| 260 |
+
if self.backend.type not in {"plm", "concat"}:
|
| 261 |
+
return None
|
| 262 |
+
return self.layer_pool_override or self.backend.layer_pool
|
| 263 |
+
|
| 264 |
+
@property
|
| 265 |
+
def _effective_cache_dir(self) -> str | None:
|
| 266 |
+
if self.backend.type not in {"plm", "concat"}:
|
| 267 |
+
return None
|
| 268 |
+
if self.cache_dir_override is not None:
|
| 269 |
+
return self.cache_dir_override
|
| 270 |
+
return self.backend.cache_dir
|
| 271 |
+
|
| 272 |
+
def _validate_heavy_support(self, backend_type: str, heavy_only: bool) -> None:
|
| 273 |
+
if heavy_only:
|
| 274 |
+
return
|
| 275 |
+
if backend_type == "descriptors" and self.descriptor_settings.use_anarci:
|
| 276 |
+
msg = "Descriptor backend with ANARCI currently supports heavy-chain only inference."
|
| 277 |
+
raise ValueError(msg)
|
| 278 |
+
if backend_type == "concat" and self.descriptor_settings.use_anarci:
|
| 279 |
+
msg = "Concat backend with descriptors requires heavy-chain only data."
|
| 280 |
+
raise ValueError(msg)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def build_feature_pipeline(
|
| 284 |
+
config: Config,
|
| 285 |
+
*,
|
| 286 |
+
backend_override: str | None = None,
|
| 287 |
+
plm_model_override: str | None = None,
|
| 288 |
+
cache_dir_override: str | None = None,
|
| 289 |
+
layer_pool_override: str | None = None,
|
| 290 |
+
) -> FeaturePipeline:
|
| 291 |
+
backend = FeatureBackendSettings(**asdict(config.feature_backend))
|
| 292 |
+
if backend_override:
|
| 293 |
+
backend.type = backend_override
|
| 294 |
+
pipeline = FeaturePipeline(
|
| 295 |
+
backend=backend,
|
| 296 |
+
descriptors=config.descriptors,
|
| 297 |
+
device=config.device,
|
| 298 |
+
cache_dir_override=cache_dir_override,
|
| 299 |
+
plm_model_override=plm_model_override,
|
| 300 |
+
layer_pool_override=layer_pool_override,
|
| 301 |
+
)
|
| 302 |
+
return pipeline
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _build_descriptor_featurizer(settings: DescriptorSettings) -> DescriptorFeaturizer:
|
| 306 |
+
descriptor_config = _build_descriptor_config(settings)
|
| 307 |
+
return DescriptorFeaturizer(config=descriptor_config, standardize=True)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _build_descriptor_config(settings: DescriptorSettings) -> DescriptorConfig:
|
| 311 |
+
return DescriptorConfig(
|
| 312 |
+
use_anarci=settings.use_anarci,
|
| 313 |
+
regions=tuple(settings.regions),
|
| 314 |
+
features=tuple(settings.features),
|
| 315 |
+
ph=settings.ph,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _build_plm_embedder(
|
| 320 |
+
backend: FeatureBackendSettings,
|
| 321 |
+
*,
|
| 322 |
+
device: str,
|
| 323 |
+
cache_dir_override: str | None,
|
| 324 |
+
plm_model_override: str | None,
|
| 325 |
+
layer_pool_override: str | None,
|
| 326 |
+
) -> PLMEmbedder:
|
| 327 |
+
model_name = plm_model_override or backend.plm_model_name
|
| 328 |
+
cache_dir = cache_dir_override or backend.cache_dir
|
| 329 |
+
layer_pool = layer_pool_override or backend.layer_pool
|
| 330 |
+
return PLMEmbedder(
|
| 331 |
+
model_name=model_name,
|
| 332 |
+
layer_pool=layer_pool,
|
| 333 |
+
device=device,
|
| 334 |
+
cache_dir=cache_dir,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _extract_sequences(df, heavy_only: bool) -> Sequence[str]: # noqa: ANN001
|
| 339 |
+
if heavy_only or "light_seq" not in df.columns:
|
| 340 |
+
return df["heavy_seq"].fillna("").astype(str).tolist()
|
| 341 |
+
heavy = df["heavy_seq"].fillna("").astype(str)
|
| 342 |
+
light = df["light_seq"].fillna("").astype(str)
|
| 343 |
+
return (heavy + "|" + light).tolist()
|
polyreact/features/plm.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Protein language model embeddings backend with caching support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
from typing import Callable, Iterable, List, Sequence, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
try: # pragma: no cover - optional dependency
|
| 17 |
+
from transformers import AutoModel, AutoTokenizer
|
| 18 |
+
except ImportError: # pragma: no cover - optional dependency
|
| 19 |
+
AutoModel = None
|
| 20 |
+
AutoTokenizer = None
|
| 21 |
+
|
| 22 |
+
try: # pragma: no cover - optional dependency
|
| 23 |
+
import esm
|
| 24 |
+
except ImportError: # pragma: no cover - optional dependency
|
| 25 |
+
esm = None
|
| 26 |
+
|
| 27 |
+
from .anarsi import AnarciNumberer
|
| 28 |
+
|
| 29 |
+
ModelLoader = Callable[[str, str], Tuple[object, nn.Module]]
|
| 30 |
+
|
| 31 |
+
if esm is not None: # pragma: no cover - optional dependency
|
| 32 |
+
_ESM1V_LOADERS = {
|
| 33 |
+
"esm1v_t33_650m_ur90s_1": esm.pretrained.esm1v_t33_650M_UR90S_1,
|
| 34 |
+
"esm1v_t33_650m_ur90s_2": esm.pretrained.esm1v_t33_650M_UR90S_2,
|
| 35 |
+
"esm1v_t33_650m_ur90s_3": esm.pretrained.esm1v_t33_650M_UR90S_3,
|
| 36 |
+
"esm1v_t33_650m_ur90s_4": esm.pretrained.esm1v_t33_650M_UR90S_4,
|
| 37 |
+
"esm1v_t33_650m_ur90s_5": esm.pretrained.esm1v_t33_650M_UR90S_5,
|
| 38 |
+
}
|
| 39 |
+
else: # pragma: no cover - optional dependency
|
| 40 |
+
_ESM1V_LOADERS: dict[str, Callable[[], tuple[nn.Module, object]]] = {}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class _ESMTokenizer:
|
| 44 |
+
"""Callable wrapper that mimics Hugging Face tokenizers for ESM models."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, alphabet) -> None: # noqa: ANN001
|
| 47 |
+
self.alphabet = alphabet
|
| 48 |
+
self._batch_converter = alphabet.get_batch_converter()
|
| 49 |
+
|
| 50 |
+
def __call__(
|
| 51 |
+
self,
|
| 52 |
+
sequences: Sequence[str],
|
| 53 |
+
*,
|
| 54 |
+
return_tensors: str = "pt",
|
| 55 |
+
padding: bool = True, # noqa: FBT002
|
| 56 |
+
truncation: bool = True, # noqa: FBT002
|
| 57 |
+
add_special_tokens: bool = True, # noqa: FBT002
|
| 58 |
+
return_special_tokens_mask: bool = True, # noqa: FBT002
|
| 59 |
+
) -> dict[str, torch.Tensor]:
|
| 60 |
+
if return_tensors != "pt": # pragma: no cover - defensive branch
|
| 61 |
+
msg = "ESM tokenizer only supports return_tensors='pt'"
|
| 62 |
+
raise ValueError(msg)
|
| 63 |
+
data = [(str(idx), (seq or "").upper()) for idx, seq in enumerate(sequences)]
|
| 64 |
+
_labels, _strings, tokens = self._batch_converter(data)
|
| 65 |
+
attention_mask = (tokens != self.alphabet.padding_idx).long()
|
| 66 |
+
special_tokens = torch.zeros_like(tokens)
|
| 67 |
+
specials = {
|
| 68 |
+
self.alphabet.padding_idx,
|
| 69 |
+
self.alphabet.cls_idx,
|
| 70 |
+
self.alphabet.eos_idx,
|
| 71 |
+
}
|
| 72 |
+
for special in specials:
|
| 73 |
+
special_tokens |= tokens == special
|
| 74 |
+
output: dict[str, torch.Tensor] = {
|
| 75 |
+
"input_ids": tokens,
|
| 76 |
+
"attention_mask": attention_mask,
|
| 77 |
+
}
|
| 78 |
+
if return_special_tokens_mask:
|
| 79 |
+
output["special_tokens_mask"] = special_tokens.long()
|
| 80 |
+
return output
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class _ESMModelWrapper(nn.Module):
|
| 84 |
+
"""Adapter providing a Hugging Face style interface for ESM models."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, model: nn.Module) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.model = model
|
| 89 |
+
self.layer_index = getattr(model, "num_layers", None)
|
| 90 |
+
if self.layer_index is None:
|
| 91 |
+
msg = "Unable to determine final layer for ESM model"
|
| 92 |
+
raise AttributeError(msg)
|
| 93 |
+
|
| 94 |
+
def eval(self) -> "_ESMModelWrapper": # pragma: no cover - trivial
|
| 95 |
+
self.model.eval()
|
| 96 |
+
return self
|
| 97 |
+
|
| 98 |
+
def to(self, device: str) -> "_ESMModelWrapper": # pragma: no cover - trivial
|
| 99 |
+
self.model.to(device)
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def forward(self, input_ids: torch.Tensor, **_): # noqa: ANN003
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
outputs = self.model(
|
| 105 |
+
input_ids,
|
| 106 |
+
repr_layers=[self.layer_index],
|
| 107 |
+
return_contacts=False,
|
| 108 |
+
)
|
| 109 |
+
hidden = outputs["representations"][self.layer_index]
|
| 110 |
+
return SimpleNamespace(last_hidden_state=hidden)
|
| 111 |
+
|
| 112 |
+
__call__ = forward
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass(slots=True)
|
| 116 |
+
class PLMConfig:
|
| 117 |
+
model_name: str = "facebook/esm1v_t33_650M_UR90S_1"
|
| 118 |
+
layer_pool: str = "mean"
|
| 119 |
+
cache_dir: Path = Path(".cache/embeddings")
|
| 120 |
+
device: str = "auto"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class PLMEmbedder:
|
| 124 |
+
"""Embed amino-acid sequences using a transformer model with caching."""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
model_name: str = "facebook/esm1v_t33_650M_UR90S_1",
|
| 129 |
+
*,
|
| 130 |
+
layer_pool: str = "mean",
|
| 131 |
+
device: str = "auto",
|
| 132 |
+
cache_dir: str | Path | None = None,
|
| 133 |
+
numberer: AnarciNumberer | None = None,
|
| 134 |
+
model_loader: ModelLoader | None = None,
|
| 135 |
+
) -> None:
|
| 136 |
+
self.model_name = model_name
|
| 137 |
+
self.layer_pool = layer_pool
|
| 138 |
+
self.device = self._resolve_device(device)
|
| 139 |
+
self.cache_dir = Path(cache_dir or ".cache/embeddings")
|
| 140 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 141 |
+
self.numberer = numberer
|
| 142 |
+
self.model_loader = model_loader
|
| 143 |
+
self._tokenizer: object | None = None
|
| 144 |
+
self._model: nn.Module | None = None
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _resolve_device(device: str) -> str:
|
| 148 |
+
if device == "auto":
|
| 149 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 150 |
+
return device
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def tokenizer(self): # noqa: D401
|
| 154 |
+
if self._tokenizer is None:
|
| 155 |
+
tokenizer, model = self._load_model_components()
|
| 156 |
+
self._tokenizer = tokenizer
|
| 157 |
+
self._model = model
|
| 158 |
+
return self._tokenizer
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def model(self) -> nn.Module:
|
| 162 |
+
if self._model is None:
|
| 163 |
+
tokenizer, model = self._load_model_components()
|
| 164 |
+
self._tokenizer = tokenizer
|
| 165 |
+
self._model = model
|
| 166 |
+
return self._model
|
| 167 |
+
|
| 168 |
+
def _load_model_components(self) -> Tuple[object, nn.Module]:
|
| 169 |
+
if self.model_loader is not None:
|
| 170 |
+
tokenizer, model = self.model_loader(self.model_name, self.device)
|
| 171 |
+
return tokenizer, model
|
| 172 |
+
|
| 173 |
+
if self._is_esm1v_model(self.model_name):
|
| 174 |
+
return self._load_esm_model()
|
| 175 |
+
|
| 176 |
+
if AutoModel is None or AutoTokenizer is None: # pragma: no cover - optional dependency
|
| 177 |
+
msg = "transformers must be installed to use PLMEmbedder"
|
| 178 |
+
raise ImportError(msg)
|
| 179 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
| 180 |
+
model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
|
| 181 |
+
model.eval()
|
| 182 |
+
model.to(self.device)
|
| 183 |
+
return tokenizer, model
|
| 184 |
+
|
| 185 |
+
def _load_esm_model(self) -> Tuple[object, nn.Module]:
|
| 186 |
+
if esm is None: # pragma: no cover - optional dependency
|
| 187 |
+
msg = (
|
| 188 |
+
"The 'esm' package is required to use ESM-1v models."
|
| 189 |
+
)
|
| 190 |
+
raise ImportError(msg)
|
| 191 |
+
|
| 192 |
+
normalized = self._canonical_esm_name(self.model_name)
|
| 193 |
+
loader = _ESM1V_LOADERS.get(normalized)
|
| 194 |
+
if loader is None: # pragma: no cover - guard branch
|
| 195 |
+
msg = f"Unsupported ESM-1v model: {self.model_name}"
|
| 196 |
+
raise ValueError(msg)
|
| 197 |
+
|
| 198 |
+
model, alphabet = loader()
|
| 199 |
+
model.eval()
|
| 200 |
+
model.to(self.device)
|
| 201 |
+
tokenizer = _ESMTokenizer(alphabet)
|
| 202 |
+
wrapper = _ESMModelWrapper(model)
|
| 203 |
+
return tokenizer, wrapper
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def _canonical_esm_name(model_name: str) -> str:
|
| 207 |
+
name = model_name.lower()
|
| 208 |
+
if "/" in name:
|
| 209 |
+
name = name.split("/")[-1]
|
| 210 |
+
return name
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def _is_esm1v_model(cls, model_name: str) -> bool:
|
| 214 |
+
return cls._canonical_esm_name(model_name).startswith("esm1v")
|
| 215 |
+
|
| 216 |
+
def embed(self, sequences: Iterable[str], *, batch_size: int = 8) -> np.ndarray:
|
| 217 |
+
batch_sequences = list(sequences)
|
| 218 |
+
if not batch_sequences:
|
| 219 |
+
return np.empty((0, 0), dtype=np.float32)
|
| 220 |
+
|
| 221 |
+
outputs: List[np.ndarray | None] = [None] * len(batch_sequences)
|
| 222 |
+
unique_to_compute: dict[str, List[Tuple[int, Path]]] = {}
|
| 223 |
+
model_dir = self.cache_dir / self._normalized_model_name()
|
| 224 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 225 |
+
|
| 226 |
+
cache_hits: list[tuple[int, Path]] = []
|
| 227 |
+
for idx, sequence in enumerate(batch_sequences):
|
| 228 |
+
cache_path = self._sequence_cache_path(model_dir, sequence)
|
| 229 |
+
if cache_path.exists():
|
| 230 |
+
cache_hits.append((idx, cache_path))
|
| 231 |
+
else:
|
| 232 |
+
unique_to_compute.setdefault(sequence, []).append((idx, cache_path))
|
| 233 |
+
|
| 234 |
+
if cache_hits:
|
| 235 |
+
loaders = [path for _, path in cache_hits]
|
| 236 |
+
max_workers = min(len(loaders), 32)
|
| 237 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 238 |
+
for (idx, _), embedding in zip(cache_hits, executor.map(np.load, loaders), strict=True):
|
| 239 |
+
outputs[idx] = embedding
|
| 240 |
+
|
| 241 |
+
if unique_to_compute:
|
| 242 |
+
embeddings = self._compute_embeddings(list(unique_to_compute.keys()), batch_size=batch_size)
|
| 243 |
+
for sequence, embedding in zip(unique_to_compute.keys(), embeddings, strict=True):
|
| 244 |
+
targets = unique_to_compute[sequence]
|
| 245 |
+
for idx, cache_path in targets:
|
| 246 |
+
outputs[idx] = embedding
|
| 247 |
+
np.save(cache_path, embedding)
|
| 248 |
+
if any(item is None for item in outputs): # pragma: no cover - safety
|
| 249 |
+
msg = "Failed to compute embeddings for all sequences"
|
| 250 |
+
raise RuntimeError(msg)
|
| 251 |
+
array_outputs = [np.asarray(item, dtype=np.float32) for item in outputs] # type: ignore[arg-type]
|
| 252 |
+
return np.stack(array_outputs, axis=0)
|
| 253 |
+
|
| 254 |
+
def _compute_embeddings(self, sequences: Sequence[str], *, batch_size: int) -> List[np.ndarray]:
|
| 255 |
+
tokenizer = self.tokenizer
|
| 256 |
+
model = self.model
|
| 257 |
+
model.eval()
|
| 258 |
+
embeddings: List[np.ndarray] = []
|
| 259 |
+
for start in range(0, len(sequences), batch_size):
|
| 260 |
+
chunk = list(sequences[start : start + batch_size])
|
| 261 |
+
tokenized = self._tokenize(tokenizer, chunk)
|
| 262 |
+
model_inputs: dict[str, torch.Tensor] = {}
|
| 263 |
+
aux_inputs: dict[str, torch.Tensor] = {}
|
| 264 |
+
for key, value in tokenized.items():
|
| 265 |
+
if isinstance(value, torch.Tensor):
|
| 266 |
+
tensor_value = value.to(self.device)
|
| 267 |
+
else:
|
| 268 |
+
tensor_value = value
|
| 269 |
+
if key == "special_tokens_mask":
|
| 270 |
+
aux_inputs[key] = tensor_value
|
| 271 |
+
else:
|
| 272 |
+
model_inputs[key] = tensor_value
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
outputs = model(**model_inputs)
|
| 275 |
+
hidden_states = outputs.last_hidden_state.detach().cpu()
|
| 276 |
+
attention_mask = model_inputs.get("attention_mask")
|
| 277 |
+
special_tokens_mask = aux_inputs.get("special_tokens_mask")
|
| 278 |
+
if isinstance(attention_mask, torch.Tensor):
|
| 279 |
+
attention_mask = attention_mask.detach().cpu()
|
| 280 |
+
if isinstance(special_tokens_mask, torch.Tensor):
|
| 281 |
+
special_tokens_mask = special_tokens_mask.detach().cpu()
|
| 282 |
+
|
| 283 |
+
for idx, sequence in enumerate(chunk):
|
| 284 |
+
hidden = hidden_states[idx]
|
| 285 |
+
mask = attention_mask[idx] if isinstance(attention_mask, torch.Tensor) else None
|
| 286 |
+
special_mask = (
|
| 287 |
+
special_tokens_mask[idx]
|
| 288 |
+
if isinstance(special_tokens_mask, torch.Tensor)
|
| 289 |
+
else None
|
| 290 |
+
)
|
| 291 |
+
embedding = self._pool_hidden(hidden, mask, special_mask, sequence)
|
| 292 |
+
embeddings.append(embedding)
|
| 293 |
+
return embeddings
|
| 294 |
+
|
| 295 |
+
def _tokenize(self, tokenizer, sequences: Sequence[str]):
|
| 296 |
+
if hasattr(tokenizer, "__call__"):
|
| 297 |
+
return tokenizer(
|
| 298 |
+
list(sequences),
|
| 299 |
+
return_tensors="pt",
|
| 300 |
+
padding=True,
|
| 301 |
+
truncation=True,
|
| 302 |
+
add_special_tokens=True,
|
| 303 |
+
return_special_tokens_mask=True,
|
| 304 |
+
)
|
| 305 |
+
msg = "Tokenizer does not implement __call__"
|
| 306 |
+
raise TypeError(msg)
|
| 307 |
+
|
| 308 |
+
def _pool_hidden(
|
| 309 |
+
self,
|
| 310 |
+
hidden: torch.Tensor,
|
| 311 |
+
attention_mask: torch.Tensor | None,
|
| 312 |
+
special_mask: torch.Tensor | None,
|
| 313 |
+
sequence: str,
|
| 314 |
+
) -> np.ndarray:
|
| 315 |
+
if attention_mask is None:
|
| 316 |
+
attention = torch.ones(hidden.size(0), dtype=torch.float32)
|
| 317 |
+
else:
|
| 318 |
+
attention = attention_mask.to(dtype=torch.float32)
|
| 319 |
+
if special_mask is not None:
|
| 320 |
+
attention = attention * (1.0 - special_mask.to(dtype=torch.float32))
|
| 321 |
+
if attention.sum() == 0:
|
| 322 |
+
attention = torch.ones_like(attention)
|
| 323 |
+
|
| 324 |
+
if self.layer_pool == "mean":
|
| 325 |
+
return self._masked_mean(hidden, attention)
|
| 326 |
+
if self.layer_pool == "cls":
|
| 327 |
+
return hidden[0].detach().cpu().numpy()
|
| 328 |
+
if self.layer_pool == "per_token_mean_cdrh3":
|
| 329 |
+
return self._pool_cdrh3(hidden, attention, sequence)
|
| 330 |
+
msg = f"Unsupported layer pool: {self.layer_pool}"
|
| 331 |
+
raise ValueError(msg)
|
| 332 |
+
|
| 333 |
+
@staticmethod
|
| 334 |
+
def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> np.ndarray:
|
| 335 |
+
weights = mask.unsqueeze(-1)
|
| 336 |
+
weighted = hidden * weights
|
| 337 |
+
denom = weights.sum()
|
| 338 |
+
if denom == 0:
|
| 339 |
+
pooled = hidden.mean(dim=0)
|
| 340 |
+
else:
|
| 341 |
+
pooled = weighted.sum(dim=0) / denom
|
| 342 |
+
return pooled.detach().cpu().numpy()
|
| 343 |
+
|
| 344 |
+
def _pool_cdrh3(self, hidden: torch.Tensor, mask: torch.Tensor, sequence: str) -> np.ndarray:
|
| 345 |
+
numberer = self.numberer
|
| 346 |
+
if numberer is None:
|
| 347 |
+
numberer = AnarciNumberer()
|
| 348 |
+
self.numberer = numberer
|
| 349 |
+
numbered = numberer.number_sequence(sequence)
|
| 350 |
+
cdr = numbered.regions.get("CDRH3", "")
|
| 351 |
+
if not cdr:
|
| 352 |
+
return self._masked_mean(hidden, mask)
|
| 353 |
+
sequence_upper = sequence.upper()
|
| 354 |
+
start = sequence_upper.find(cdr.upper())
|
| 355 |
+
if start == -1:
|
| 356 |
+
return self._masked_mean(hidden, mask)
|
| 357 |
+
residues_idx = mask.nonzero(as_tuple=False).squeeze(-1).tolist()
|
| 358 |
+
if not residues_idx:
|
| 359 |
+
return self._masked_mean(hidden, mask)
|
| 360 |
+
end = start + len(cdr)
|
| 361 |
+
if end > len(residues_idx):
|
| 362 |
+
return self._masked_mean(hidden, mask)
|
| 363 |
+
cdr_token_positions = residues_idx[start:end]
|
| 364 |
+
if not cdr_token_positions:
|
| 365 |
+
return self._masked_mean(hidden, mask)
|
| 366 |
+
cdr_mask = torch.zeros_like(mask)
|
| 367 |
+
for pos in cdr_token_positions:
|
| 368 |
+
cdr_mask[pos] = 1.0
|
| 369 |
+
return self._masked_mean(hidden, cdr_mask)
|
| 370 |
+
|
| 371 |
+
def _sequence_cache_path(self, model_dir: Path, sequence: str) -> Path:
|
| 372 |
+
digest = hashlib.sha1(sequence.encode("utf-8")).hexdigest()
|
| 373 |
+
return model_dir / f"{digest}.npy"
|
| 374 |
+
|
| 375 |
+
def _normalized_model_name(self) -> str:
|
| 376 |
+
if self._is_esm1v_model(self.model_name):
|
| 377 |
+
return self._canonical_esm_name(self.model_name)
|
| 378 |
+
return self.model_name.replace("/", "_")
|
polyreact/models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Models for polyreactivity classification."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["linear", "calibrate", "ordinal"]
|
polyreact/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (274 Bytes). View file
|
|
|
polyreact/models/__pycache__/calibrate.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
polyreact/models/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
polyreact/models/__pycache__/ordinal.cpython-311.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
polyreact/models/calibrate.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Probability calibration helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def fit_calibrator(
|
| 12 |
+
estimator: Any,
|
| 13 |
+
X: np.ndarray,
|
| 14 |
+
y: np.ndarray,
|
| 15 |
+
*,
|
| 16 |
+
method: str = "isotonic",
|
| 17 |
+
cv: int | str | None = "prefit",
|
| 18 |
+
) -> CalibratedClassifierCV:
|
| 19 |
+
"""Fit a ``CalibratedClassifierCV`` on top of a pre-trained estimator."""
|
| 20 |
+
|
| 21 |
+
calibrator = CalibratedClassifierCV(estimator, method=method, cv=cv)
|
| 22 |
+
calibrator.fit(X, y)
|
| 23 |
+
return calibrator
|
| 24 |
+
|
polyreact/models/linear.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Linear classification heads for polyreactivity prediction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.svm import LinearSVC
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass(slots=True)
|
| 14 |
+
class LinearModelConfig:
|
| 15 |
+
"""Configuration options for linear heads."""
|
| 16 |
+
|
| 17 |
+
head: str = "logreg"
|
| 18 |
+
C: float = 1.0
|
| 19 |
+
class_weight: Any = "balanced"
|
| 20 |
+
max_iter: int = 1000
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(slots=True)
|
| 24 |
+
class TrainedModel:
|
| 25 |
+
"""Container for trained estimators and optional calibration."""
|
| 26 |
+
|
| 27 |
+
estimator: Any
|
| 28 |
+
calibrator: Any | None = None
|
| 29 |
+
vectorizer_name: str = ""
|
| 30 |
+
feature_meta: dict[str, Any] = field(default_factory=dict)
|
| 31 |
+
metrics_cv: dict[str, float] = field(default_factory=dict)
|
| 32 |
+
|
| 33 |
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
| 34 |
+
if self.calibrator is not None and hasattr(self.calibrator, "predict"):
|
| 35 |
+
return self.calibrator.predict(X)
|
| 36 |
+
return self.estimator.predict(X)
|
| 37 |
+
|
| 38 |
+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
| 39 |
+
if self.calibrator is not None and hasattr(self.calibrator, "predict_proba"):
|
| 40 |
+
probs = self.calibrator.predict_proba(X)
|
| 41 |
+
return probs[:, 1]
|
| 42 |
+
if hasattr(self.estimator, "predict_proba"):
|
| 43 |
+
probs = self.estimator.predict_proba(X)
|
| 44 |
+
return probs[:, 1]
|
| 45 |
+
if hasattr(self.estimator, "decision_function"):
|
| 46 |
+
scores = self.estimator.decision_function(X)
|
| 47 |
+
return 1.0 / (1.0 + np.exp(-scores))
|
| 48 |
+
msg = "Estimator does not support probability prediction"
|
| 49 |
+
raise AttributeError(msg)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_estimator(
|
| 53 |
+
*, config: LinearModelConfig, random_state: int | None = 42
|
| 54 |
+
) -> Any:
|
| 55 |
+
"""Construct an unfitted linear estimator based on configuration."""
|
| 56 |
+
|
| 57 |
+
if config.head == "logreg":
|
| 58 |
+
estimator = LogisticRegression(
|
| 59 |
+
C=config.C,
|
| 60 |
+
max_iter=config.max_iter,
|
| 61 |
+
class_weight=config.class_weight,
|
| 62 |
+
solver="liblinear",
|
| 63 |
+
random_state=random_state,
|
| 64 |
+
)
|
| 65 |
+
elif config.head == "linear_svm":
|
| 66 |
+
estimator = LinearSVC(
|
| 67 |
+
C=config.C,
|
| 68 |
+
class_weight=config.class_weight,
|
| 69 |
+
max_iter=config.max_iter,
|
| 70 |
+
random_state=random_state,
|
| 71 |
+
)
|
| 72 |
+
else: # pragma: no cover - defensive branch
|
| 73 |
+
msg = f"Unsupported head type: {config.head}"
|
| 74 |
+
raise ValueError(msg)
|
| 75 |
+
return estimator
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train_linear_model(
|
| 79 |
+
X: np.ndarray,
|
| 80 |
+
y: np.ndarray,
|
| 81 |
+
*,
|
| 82 |
+
config: LinearModelConfig,
|
| 83 |
+
random_state: int | None = 42,
|
| 84 |
+
) -> TrainedModel:
|
| 85 |
+
"""Fit a linear classifier on the provided feature matrix."""
|
| 86 |
+
|
| 87 |
+
estimator = build_estimator(config=config, random_state=random_state)
|
| 88 |
+
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
|
| 89 |
+
estimator.set_params(solver="lbfgs")
|
| 90 |
+
estimator.fit(X, y)
|
| 91 |
+
return TrainedModel(estimator=estimator)
|
polyreact/models/ordinal.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ordinal/count modeling utilities for flag regression."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import statsmodels.api as sm
|
| 9 |
+
from statsmodels.discrete.discrete_model import NegativeBinomialResults
|
| 10 |
+
from sklearn.linear_model import PoissonRegressor
|
| 11 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from sklearn.metrics import root_mean_squared_error
|
| 15 |
+
except ImportError: # pragma: no cover - fallback for older sklearn
|
| 16 |
+
def root_mean_squared_error(y_true, y_pred):
|
| 17 |
+
return mean_squared_error(y_true, y_pred, squared=False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(slots=True)
|
| 21 |
+
class PoissonModel:
|
| 22 |
+
"""Wrapper storing a fitted Poisson regression model."""
|
| 23 |
+
|
| 24 |
+
estimator: PoissonRegressor
|
| 25 |
+
|
| 26 |
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
| 27 |
+
return self.estimator.predict(X)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fit_poisson_model(
|
| 31 |
+
X: np.ndarray,
|
| 32 |
+
y: np.ndarray,
|
| 33 |
+
*,
|
| 34 |
+
alpha: float = 1e-6,
|
| 35 |
+
max_iter: int = 1000,
|
| 36 |
+
) -> PoissonModel:
|
| 37 |
+
"""Train a Poisson regression model on count targets."""
|
| 38 |
+
|
| 39 |
+
model = PoissonRegressor(alpha=alpha, max_iter=max_iter)
|
| 40 |
+
model.fit(X, y)
|
| 41 |
+
return PoissonModel(estimator=model)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(slots=True)
|
| 45 |
+
class NegativeBinomialModel:
|
| 46 |
+
"""Wrapper storing a fitted negative-binomial regression model."""
|
| 47 |
+
|
| 48 |
+
result: NegativeBinomialResults
|
| 49 |
+
|
| 50 |
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
| 51 |
+
X_const = sm.add_constant(X, has_constant="add")
|
| 52 |
+
return self.result.predict(X_const)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def alpha(self) -> float:
|
| 56 |
+
params = np.asarray(self.result.params, dtype=float)
|
| 57 |
+
exog_dim = self.result.model.exog.shape[1]
|
| 58 |
+
if params.size > exog_dim:
|
| 59 |
+
# statsmodels stores log(alpha) as the final coefficient
|
| 60 |
+
return float(np.exp(params[-1]))
|
| 61 |
+
model_alpha = getattr(self.result.model, "alpha", None)
|
| 62 |
+
if model_alpha is not None:
|
| 63 |
+
return float(model_alpha)
|
| 64 |
+
return float("nan")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def fit_negative_binomial_model(
|
| 68 |
+
X: np.ndarray,
|
| 69 |
+
y: np.ndarray,
|
| 70 |
+
*,
|
| 71 |
+
max_iter: int = 200,
|
| 72 |
+
) -> NegativeBinomialModel:
|
| 73 |
+
"""Train a negative binomial regression model (NB2)."""
|
| 74 |
+
|
| 75 |
+
X_const = sm.add_constant(X, has_constant="add")
|
| 76 |
+
model = sm.NegativeBinomial(y, X_const, loglike_method="nb2")
|
| 77 |
+
result = model.fit(maxiter=max_iter, disp=False)
|
| 78 |
+
return NegativeBinomialModel(result=result)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
|
| 82 |
+
"""Return standard regression metrics for count predictions."""
|
| 83 |
+
|
| 84 |
+
mae = mean_absolute_error(y_true, y_pred)
|
| 85 |
+
rmse = root_mean_squared_error(y_true, y_pred)
|
| 86 |
+
r2 = r2_score(y_true, y_pred)
|
| 87 |
+
return {
|
| 88 |
+
"mae": float(mae),
|
| 89 |
+
"rmse": float(rmse),
|
| 90 |
+
"r2": float(r2),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def pearson_dispersion(
|
| 95 |
+
y_true: np.ndarray,
|
| 96 |
+
y_pred: np.ndarray,
|
| 97 |
+
*,
|
| 98 |
+
dof: int,
|
| 99 |
+
) -> float:
|
| 100 |
+
"""Compute Pearson dispersion (chi-square / dof)."""
|
| 101 |
+
|
| 102 |
+
eps = 1e-8
|
| 103 |
+
adjusted = np.maximum(y_pred, eps)
|
| 104 |
+
resid = (y_true - y_pred) / np.sqrt(adjusted)
|
| 105 |
+
denom = max(dof, 1)
|
| 106 |
+
return float(np.sum(resid**2) / denom)
|
polyreact/predict.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Command-line interface for polyreactivity predictions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from .api import predict_batch
|
| 11 |
+
from .config import load_config
|
| 12 |
+
from .utils.io import read_table, write_table
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Polyreactivity prediction CLI")
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--input",
|
| 19 |
+
required=True,
|
| 20 |
+
help="Path to input CSV or JSONL file with sequences.",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--output",
|
| 24 |
+
required=True,
|
| 25 |
+
help="Path to write predictions CSV.",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--config",
|
| 29 |
+
default="configs/default.yaml",
|
| 30 |
+
help="Path to configuration YAML file.",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--backend",
|
| 34 |
+
choices=["plm", "descriptors", "concat"],
|
| 35 |
+
help="Override feature backend from config.",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--plm-model",
|
| 39 |
+
help="Override PLM model name.",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--weights",
|
| 43 |
+
required=True,
|
| 44 |
+
help="Path to trained model artifact (joblib).",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--heavy-only",
|
| 48 |
+
dest="heavy_only",
|
| 49 |
+
action="store_true",
|
| 50 |
+
default=True,
|
| 51 |
+
help="Use only heavy chains (default).",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--paired",
|
| 55 |
+
dest="heavy_only",
|
| 56 |
+
action="store_false",
|
| 57 |
+
help="Use paired heavy/light chains if available.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--batch-size",
|
| 61 |
+
type=int,
|
| 62 |
+
default=8,
|
| 63 |
+
help="Batch size for model inference (PLM backend).",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--device",
|
| 67 |
+
choices=["auto", "cpu", "cuda"],
|
| 68 |
+
help="Computation device override.",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--cache-dir",
|
| 72 |
+
help="Cache directory for embeddings.",
|
| 73 |
+
)
|
| 74 |
+
return parser
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main(argv: list[str] | None = None) -> int:
|
| 78 |
+
parser = build_parser()
|
| 79 |
+
args = parser.parse_args(argv)
|
| 80 |
+
|
| 81 |
+
config = load_config(args.config)
|
| 82 |
+
df = read_table(args.input)
|
| 83 |
+
|
| 84 |
+
if "heavy_seq" not in df.columns and "heavy" not in df.columns:
|
| 85 |
+
parser.error("Input file must contain a 'heavy_seq' column (or 'heavy').")
|
| 86 |
+
if df.get("heavy_seq", df.get("heavy", "")).fillna("").str.len().eq(0).all():
|
| 87 |
+
parser.error("At least one non-empty heavy sequence is required.")
|
| 88 |
+
|
| 89 |
+
predictions = predict_batch(
|
| 90 |
+
df.to_dict("records"),
|
| 91 |
+
config=config,
|
| 92 |
+
backend=args.backend,
|
| 93 |
+
plm_model=args.plm_model,
|
| 94 |
+
weights=args.weights,
|
| 95 |
+
heavy_only=args.heavy_only,
|
| 96 |
+
batch_size=args.batch_size,
|
| 97 |
+
device=args.device,
|
| 98 |
+
cache_dir=args.cache_dir,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
write_table(predictions, args.output)
|
| 102 |
+
return 0
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
raise SystemExit(main())
|
polyreact/train.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training entrypoint for the polyreactivity model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import subprocess
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Sequence
|
| 10 |
+
|
| 11 |
+
import joblib
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from sklearn.metrics import roc_auc_score
|
| 15 |
+
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
|
| 16 |
+
from sklearn.linear_model import LogisticRegression
|
| 17 |
+
|
| 18 |
+
from .config import Config, load_config
|
| 19 |
+
from .data_loaders import boughter, harvey, jain, shehata
|
| 20 |
+
from .data_loaders.utils import deduplicate_sequences
|
| 21 |
+
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
|
| 22 |
+
from .models.calibrate import fit_calibrator
|
| 23 |
+
from .models.linear import LinearModelConfig, TrainedModel, build_estimator, train_linear_model
|
| 24 |
+
from .utils.io import write_table
|
| 25 |
+
from .utils.logging import configure_logging
|
| 26 |
+
from .utils.metrics import bootstrap_metric_intervals, compute_metrics
|
| 27 |
+
from .utils.plots import plot_precision_recall, plot_reliability_curve, plot_roc_curve
|
| 28 |
+
from .utils.seeds import set_global_seeds
|
| 29 |
+
|
| 30 |
+
DATASET_LOADERS = {
|
| 31 |
+
"boughter": boughter.load_dataframe,
|
| 32 |
+
"jain": jain.load_dataframe,
|
| 33 |
+
"shehata": shehata.load_dataframe,
|
| 34 |
+
"shehata_curated": shehata.load_dataframe,
|
| 35 |
+
"harvey": harvey.load_dataframe,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 40 |
+
parser = argparse.ArgumentParser(description="Train polyreactivity model")
|
| 41 |
+
parser.add_argument("--config", default="configs/default.yaml", help="Config file")
|
| 42 |
+
parser.add_argument("--train", required=True, help="Training dataset path")
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--eval",
|
| 45 |
+
nargs="*",
|
| 46 |
+
default=[],
|
| 47 |
+
help="Evaluation dataset paths",
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--save-to",
|
| 51 |
+
default="artifacts/model.joblib",
|
| 52 |
+
help="Path to save trained model artifact",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--report-to",
|
| 56 |
+
default="artifacts",
|
| 57 |
+
help="Directory for metrics, predictions, and plots",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--train-loader",
|
| 61 |
+
choices=list(DATASET_LOADERS.keys()),
|
| 62 |
+
help="Optional explicit loader for training dataset",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--eval-loaders",
|
| 66 |
+
nargs="*",
|
| 67 |
+
help="Optional explicit loaders for evaluation datasets (aligned with --eval order)",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--backend",
|
| 71 |
+
choices=["plm", "descriptors", "concat"],
|
| 72 |
+
help="Override feature backend",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument("--plm-model", help="Override PLM model name")
|
| 75 |
+
parser.add_argument("--cache-dir", help="Override embedding cache directory")
|
| 76 |
+
parser.add_argument("--device", choices=["auto", "cpu", "cuda"], help="Device override")
|
| 77 |
+
parser.add_argument("--batch-size", type=int, default=8, help="Batch size for embeddings")
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--heavy-only",
|
| 80 |
+
action="store_true",
|
| 81 |
+
default=True,
|
| 82 |
+
help="Use heavy chains only (default true)",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--paired",
|
| 86 |
+
dest="heavy_only",
|
| 87 |
+
action="store_false",
|
| 88 |
+
help="Use paired heavy/light chains when available.",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--include-families",
|
| 92 |
+
nargs="*",
|
| 93 |
+
help="Optional list of family names to retain in the training dataset",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--exclude-families",
|
| 97 |
+
nargs="*",
|
| 98 |
+
help="Optional list of family names to drop from the training dataset",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--include-species",
|
| 102 |
+
nargs="*",
|
| 103 |
+
help="Optional list of species (e.g. human, mouse) to retain",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--cv-group-column",
|
| 107 |
+
default="lineage",
|
| 108 |
+
help="Column name used to group samples during cross-validation (default: lineage)",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--no-group-cv",
|
| 112 |
+
action="store_true",
|
| 113 |
+
help="Disable group-aware cross-validation even if group column is present",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--keep-train-duplicates",
|
| 117 |
+
action="store_true",
|
| 118 |
+
help="Keep duplicate keys within the training dataset when deduplicating across splits",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--dedupe-key-columns",
|
| 122 |
+
nargs="*",
|
| 123 |
+
help="Columns used to detect duplicates across datasets (defaults to heavy/light sequences)",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--bootstrap-samples",
|
| 127 |
+
type=int,
|
| 128 |
+
default=200,
|
| 129 |
+
help="Number of bootstrap resamples for confidence intervals (0 to disable).",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--bootstrap-alpha",
|
| 133 |
+
type=float,
|
| 134 |
+
default=0.05,
|
| 135 |
+
help="Alpha for two-sided bootstrap confidence intervals (default 0.05 → 95% CI).",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--write-train-in-sample",
|
| 139 |
+
action="store_true",
|
| 140 |
+
help=(
|
| 141 |
+
"Persist in-sample metrics on the full training set; disabled by default to avoid"
|
| 142 |
+
" over-optimistic reporting."
|
| 143 |
+
),
|
| 144 |
+
)
|
| 145 |
+
return parser
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _infer_loader(path: str, explicit: str | None) -> tuple[str, callable]:
|
| 149 |
+
if explicit:
|
| 150 |
+
return explicit, DATASET_LOADERS[explicit]
|
| 151 |
+
lower = Path(path).stem.lower()
|
| 152 |
+
for name, loader in DATASET_LOADERS.items():
|
| 153 |
+
if name in lower:
|
| 154 |
+
return name, loader
|
| 155 |
+
msg = f"Could not infer loader for dataset: {path}. Provide --train-loader/--eval-loaders."
|
| 156 |
+
raise ValueError(msg)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _load_dataset(path: str, loader_name: str, loader_fn, *, heavy_only: bool) -> pd.DataFrame:
|
| 160 |
+
frame = loader_fn(path, heavy_only=heavy_only)
|
| 161 |
+
frame["source"] = loader_name
|
| 162 |
+
return frame
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _apply_dataset_filters(
|
| 166 |
+
frame: pd.DataFrame,
|
| 167 |
+
*,
|
| 168 |
+
include_families: Sequence[str] | None,
|
| 169 |
+
exclude_families: Sequence[str] | None,
|
| 170 |
+
include_species: Sequence[str] | None,
|
| 171 |
+
) -> pd.DataFrame:
|
| 172 |
+
filtered = frame.copy()
|
| 173 |
+
if include_families:
|
| 174 |
+
families = {fam.lower() for fam in include_families}
|
| 175 |
+
if "family" in filtered.columns:
|
| 176 |
+
filtered = filtered[
|
| 177 |
+
filtered["family"].astype(str).str.lower().isin(families)
|
| 178 |
+
]
|
| 179 |
+
if exclude_families:
|
| 180 |
+
families_ex = {fam.lower() for fam in exclude_families}
|
| 181 |
+
if "family" in filtered.columns:
|
| 182 |
+
filtered = filtered[
|
| 183 |
+
~filtered["family"].astype(str).str.lower().isin(families_ex)
|
| 184 |
+
]
|
| 185 |
+
if include_species:
|
| 186 |
+
species_set = {spec.lower() for spec in include_species}
|
| 187 |
+
if "species" in filtered.columns:
|
| 188 |
+
filtered = filtered[
|
| 189 |
+
filtered["species"].astype(str).str.lower().isin(species_set)
|
| 190 |
+
]
|
| 191 |
+
return filtered.reset_index(drop=True)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main(argv: Sequence[str] | None = None) -> int:
|
| 195 |
+
parser = build_parser()
|
| 196 |
+
args = parser.parse_args(argv)
|
| 197 |
+
|
| 198 |
+
config = load_config(args.config)
|
| 199 |
+
if args.device:
|
| 200 |
+
config.device = args.device
|
| 201 |
+
if args.backend:
|
| 202 |
+
config.feature_backend.type = args.backend
|
| 203 |
+
if args.cache_dir:
|
| 204 |
+
config.feature_backend.cache_dir = args.cache_dir
|
| 205 |
+
if args.plm_model:
|
| 206 |
+
config.feature_backend.plm_model_name = args.plm_model
|
| 207 |
+
|
| 208 |
+
logger = configure_logging()
|
| 209 |
+
set_global_seeds(config.seed)
|
| 210 |
+
_log_environment(logger)
|
| 211 |
+
|
| 212 |
+
heavy_only = args.heavy_only
|
| 213 |
+
|
| 214 |
+
train_name, train_loader = _infer_loader(args.train, args.train_loader)
|
| 215 |
+
train_df = _load_dataset(args.train, train_name, train_loader, heavy_only=heavy_only)
|
| 216 |
+
train_df = _apply_dataset_filters(
|
| 217 |
+
train_df,
|
| 218 |
+
include_families=args.include_families,
|
| 219 |
+
exclude_families=args.exclude_families,
|
| 220 |
+
include_species=args.include_species,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
eval_frames: list[pd.DataFrame] = []
|
| 224 |
+
if args.eval:
|
| 225 |
+
loaders_iter = args.eval_loaders or []
|
| 226 |
+
for idx, eval_path in enumerate(args.eval):
|
| 227 |
+
explicit = loaders_iter[idx] if idx < len(loaders_iter) else None
|
| 228 |
+
eval_name, eval_loader = _infer_loader(eval_path, explicit)
|
| 229 |
+
eval_df = _load_dataset(eval_path, eval_name, eval_loader, heavy_only=heavy_only)
|
| 230 |
+
eval_frames.append(eval_df)
|
| 231 |
+
|
| 232 |
+
all_frames = [train_df, *eval_frames]
|
| 233 |
+
dedup_keep = {0} if args.keep_train_duplicates else set()
|
| 234 |
+
deduped_frames = deduplicate_sequences(
|
| 235 |
+
all_frames,
|
| 236 |
+
heavy_only=heavy_only,
|
| 237 |
+
key_columns=args.dedupe_key_columns,
|
| 238 |
+
keep_intra_frames=dedup_keep,
|
| 239 |
+
)
|
| 240 |
+
train_df = deduped_frames[0]
|
| 241 |
+
eval_frames = deduped_frames[1:]
|
| 242 |
+
|
| 243 |
+
pipeline_factory = lambda: build_feature_pipeline( # noqa: E731
|
| 244 |
+
config,
|
| 245 |
+
backend_override=args.backend,
|
| 246 |
+
plm_model_override=args.plm_model,
|
| 247 |
+
cache_dir_override=args.cache_dir,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
model_config = LinearModelConfig(
|
| 251 |
+
head=config.model.head,
|
| 252 |
+
C=config.model.C,
|
| 253 |
+
class_weight=config.model.class_weight,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
groups = None
|
| 257 |
+
if not args.no_group_cv and args.cv_group_column:
|
| 258 |
+
if args.cv_group_column in train_df.columns:
|
| 259 |
+
groups = train_df[args.cv_group_column].fillna("").astype(str).to_numpy()
|
| 260 |
+
else:
|
| 261 |
+
logger.warning(
|
| 262 |
+
"Group column '%s' not found in training dataframe; falling back to standard CV",
|
| 263 |
+
args.cv_group_column,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
cv_results = _cross_validate(
|
| 267 |
+
train_df,
|
| 268 |
+
pipeline_factory,
|
| 269 |
+
model_config,
|
| 270 |
+
config,
|
| 271 |
+
heavy_only=heavy_only,
|
| 272 |
+
batch_size=args.batch_size,
|
| 273 |
+
groups=groups,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
trained_model, feature_pipeline = _fit_full_model(
|
| 277 |
+
train_df,
|
| 278 |
+
pipeline_factory,
|
| 279 |
+
model_config,
|
| 280 |
+
config,
|
| 281 |
+
heavy_only=heavy_only,
|
| 282 |
+
batch_size=args.batch_size,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
outputs_dir = Path(args.report_to)
|
| 286 |
+
outputs_dir.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
metrics_df, preds_rows = _evaluate_datasets(
|
| 289 |
+
train_df,
|
| 290 |
+
eval_frames,
|
| 291 |
+
trained_model,
|
| 292 |
+
feature_pipeline,
|
| 293 |
+
config,
|
| 294 |
+
cv_results,
|
| 295 |
+
outputs_dir,
|
| 296 |
+
batch_size=args.batch_size,
|
| 297 |
+
heavy_only=heavy_only,
|
| 298 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 299 |
+
bootstrap_alpha=args.bootstrap_alpha,
|
| 300 |
+
write_train_in_sample=args.write_train_in_sample,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
write_table(metrics_df, outputs_dir / config.io.metrics_filename)
|
| 304 |
+
preds_df = pd.DataFrame(preds_rows)
|
| 305 |
+
write_table(preds_df, outputs_dir / config.io.preds_filename)
|
| 306 |
+
|
| 307 |
+
artifact = {
|
| 308 |
+
"config": config,
|
| 309 |
+
"feature_state": feature_pipeline.get_state(),
|
| 310 |
+
"model": trained_model,
|
| 311 |
+
}
|
| 312 |
+
Path(args.save_to).parent.mkdir(parents=True, exist_ok=True)
|
| 313 |
+
joblib.dump(artifact, args.save_to)
|
| 314 |
+
|
| 315 |
+
logger.info("Training complete. Metrics written to %s", outputs_dir)
|
| 316 |
+
return 0
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _cross_validate(
|
| 320 |
+
train_df: pd.DataFrame,
|
| 321 |
+
pipeline_factory,
|
| 322 |
+
model_config: LinearModelConfig,
|
| 323 |
+
config: Config,
|
| 324 |
+
*,
|
| 325 |
+
heavy_only: bool,
|
| 326 |
+
batch_size: int,
|
| 327 |
+
groups: np.ndarray | None = None,
|
| 328 |
+
):
|
| 329 |
+
y = train_df["label"].to_numpy(dtype=int)
|
| 330 |
+
n_samples = len(y)
|
| 331 |
+
# Determine a safe number of folds for tiny fixtures; prefer the configured value
|
| 332 |
+
# but never exceed the number of samples. Fall back to non-stratified KFold when
|
| 333 |
+
# per-class counts are too small for stratification (e.g., 1 positive/1 negative).
|
| 334 |
+
n_splits = max(2, min(config.training.cv_folds, n_samples))
|
| 335 |
+
|
| 336 |
+
use_stratified = True
|
| 337 |
+
class_counts = np.bincount(y) if y.size else np.array([])
|
| 338 |
+
if class_counts.size > 0 and (class_counts.min(initial=0) < n_splits):
|
| 339 |
+
use_stratified = False
|
| 340 |
+
|
| 341 |
+
if groups is not None and use_stratified:
|
| 342 |
+
splitter = StratifiedGroupKFold(
|
| 343 |
+
n_splits=n_splits,
|
| 344 |
+
shuffle=True,
|
| 345 |
+
random_state=config.seed,
|
| 346 |
+
)
|
| 347 |
+
split_iter = splitter.split(train_df, y, groups)
|
| 348 |
+
elif use_stratified:
|
| 349 |
+
splitter = StratifiedKFold(
|
| 350 |
+
n_splits=n_splits,
|
| 351 |
+
shuffle=True,
|
| 352 |
+
random_state=config.seed,
|
| 353 |
+
)
|
| 354 |
+
split_iter = splitter.split(train_df, y)
|
| 355 |
+
else:
|
| 356 |
+
# Non-stratified fallback for extreme class imbalance / tiny datasets
|
| 357 |
+
from sklearn.model_selection import KFold # local import to limit surface
|
| 358 |
+
|
| 359 |
+
splitter = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed)
|
| 360 |
+
split_iter = splitter.split(train_df)
|
| 361 |
+
oof_scores = np.zeros(len(train_df), dtype=float)
|
| 362 |
+
metrics_per_fold: list[dict[str, float]] = []
|
| 363 |
+
|
| 364 |
+
for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
|
| 365 |
+
train_slice = train_df.iloc[train_idx].reset_index(drop=True)
|
| 366 |
+
val_slice = train_df.iloc[val_idx].reset_index(drop=True)
|
| 367 |
+
|
| 368 |
+
pipeline: FeaturePipeline = pipeline_factory()
|
| 369 |
+
X_train = pipeline.fit_transform(train_slice, heavy_only=heavy_only, batch_size=batch_size)
|
| 370 |
+
X_val = pipeline.transform(val_slice, heavy_only=heavy_only, batch_size=batch_size)
|
| 371 |
+
|
| 372 |
+
y_train = y[train_idx]
|
| 373 |
+
y_val = y[val_idx]
|
| 374 |
+
|
| 375 |
+
# Handle degenerate folds where training data contains a single class
|
| 376 |
+
if np.unique(y_train).size < 2:
|
| 377 |
+
fallback_prob = float(y.mean()) if y.size else 0.5
|
| 378 |
+
y_scores = np.full(X_val.shape[0], fallback_prob, dtype=float)
|
| 379 |
+
else:
|
| 380 |
+
trained = train_linear_model(
|
| 381 |
+
X_train, y_train, config=model_config, random_state=config.seed
|
| 382 |
+
)
|
| 383 |
+
calibrator = _fit_model_calibrator(
|
| 384 |
+
model_config,
|
| 385 |
+
config,
|
| 386 |
+
X_train,
|
| 387 |
+
y_train,
|
| 388 |
+
base_estimator=trained.estimator,
|
| 389 |
+
)
|
| 390 |
+
trained.calibrator = calibrator
|
| 391 |
+
if calibrator is not None:
|
| 392 |
+
y_scores = calibrator.predict_proba(X_val)[:, 1]
|
| 393 |
+
else:
|
| 394 |
+
y_scores = trained.predict_proba(X_val)
|
| 395 |
+
oof_scores[val_idx] = y_scores
|
| 396 |
+
|
| 397 |
+
fold_metrics = compute_metrics(y_val, y_scores)
|
| 398 |
+
try:
|
| 399 |
+
fold_metrics["roc_auc"] = float(roc_auc_score(y_val, y_scores))
|
| 400 |
+
except ValueError:
|
| 401 |
+
# For tiny validation folds with a single class, ROC-AUC is undefined
|
| 402 |
+
pass
|
| 403 |
+
metrics_per_fold.append(fold_metrics)
|
| 404 |
+
|
| 405 |
+
metrics_mean: dict[str, float] = {}
|
| 406 |
+
metrics_std: dict[str, float] = {}
|
| 407 |
+
metric_names = list(metrics_per_fold[0].keys()) if metrics_per_fold else []
|
| 408 |
+
for metric in metric_names:
|
| 409 |
+
values = [fold[metric] for fold in metrics_per_fold]
|
| 410 |
+
metrics_mean[metric] = float(np.mean(values))
|
| 411 |
+
metrics_std[metric] = float(np.std(values, ddof=1))
|
| 412 |
+
|
| 413 |
+
return {
|
| 414 |
+
"oof_scores": oof_scores,
|
| 415 |
+
"metrics_per_fold": metrics_per_fold,
|
| 416 |
+
"metrics_mean": metrics_mean,
|
| 417 |
+
"metrics_std": metrics_std,
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _fit_full_model(
|
| 422 |
+
train_df: pd.DataFrame,
|
| 423 |
+
pipeline_factory,
|
| 424 |
+
model_config: LinearModelConfig,
|
| 425 |
+
config: Config,
|
| 426 |
+
*,
|
| 427 |
+
heavy_only: bool,
|
| 428 |
+
batch_size: int,
|
| 429 |
+
) -> tuple[TrainedModel, FeaturePipeline]:
|
| 430 |
+
pipeline: FeaturePipeline = pipeline_factory()
|
| 431 |
+
X_train = pipeline.fit_transform(train_df, heavy_only=heavy_only, batch_size=batch_size)
|
| 432 |
+
y_train = train_df["label"].to_numpy(dtype=int)
|
| 433 |
+
|
| 434 |
+
trained = train_linear_model(X_train, y_train, config=model_config, random_state=config.seed)
|
| 435 |
+
calibrator = _fit_model_calibrator(
|
| 436 |
+
model_config,
|
| 437 |
+
config,
|
| 438 |
+
X_train,
|
| 439 |
+
y_train,
|
| 440 |
+
base_estimator=trained.estimator,
|
| 441 |
+
)
|
| 442 |
+
trained.calibrator = calibrator
|
| 443 |
+
|
| 444 |
+
return trained, pipeline
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _evaluate_datasets(
|
| 448 |
+
train_df: pd.DataFrame,
|
| 449 |
+
eval_frames: list[pd.DataFrame],
|
| 450 |
+
trained_model: TrainedModel,
|
| 451 |
+
pipeline: FeaturePipeline,
|
| 452 |
+
config: Config,
|
| 453 |
+
cv_results: dict,
|
| 454 |
+
outputs_dir: Path,
|
| 455 |
+
*,
|
| 456 |
+
batch_size: int,
|
| 457 |
+
heavy_only: bool,
|
| 458 |
+
bootstrap_samples: int,
|
| 459 |
+
bootstrap_alpha: float,
|
| 460 |
+
write_train_in_sample: bool,
|
| 461 |
+
):
|
| 462 |
+
metrics_lookup: dict[str, dict[str, float]] = {}
|
| 463 |
+
preds_rows: list[dict[str, float]] = []
|
| 464 |
+
|
| 465 |
+
metrics_mean: dict[str, float] = cv_results["metrics_mean"]
|
| 466 |
+
metrics_std: dict[str, float] = cv_results["metrics_std"]
|
| 467 |
+
|
| 468 |
+
for metric_name, value in metrics_mean.items():
|
| 469 |
+
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
|
| 470 |
+
"train_cv_mean"
|
| 471 |
+
] = value
|
| 472 |
+
for metric_name, value in metrics_std.items():
|
| 473 |
+
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
|
| 474 |
+
"train_cv_std"
|
| 475 |
+
] = value
|
| 476 |
+
|
| 477 |
+
train_scores = cv_results["oof_scores"]
|
| 478 |
+
train_preds = train_df[["id", "source", "label"]].copy()
|
| 479 |
+
train_preds["y_true"] = train_preds["label"]
|
| 480 |
+
train_preds["y_score"] = train_scores
|
| 481 |
+
train_preds["y_pred"] = (train_scores >= 0.5).astype(int)
|
| 482 |
+
train_preds["split"] = "train_cv_oof"
|
| 483 |
+
preds_rows.extend(
|
| 484 |
+
train_preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
plot_reliability_curve(
|
| 488 |
+
train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "reliability_train.png"
|
| 489 |
+
)
|
| 490 |
+
plot_precision_recall(
|
| 491 |
+
train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "pr_train.png"
|
| 492 |
+
)
|
| 493 |
+
plot_roc_curve(train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "roc_train.png")
|
| 494 |
+
|
| 495 |
+
if bootstrap_samples > 0:
|
| 496 |
+
ci_map = bootstrap_metric_intervals(
|
| 497 |
+
train_preds["y_true"],
|
| 498 |
+
train_preds["y_score"],
|
| 499 |
+
n_bootstrap=bootstrap_samples,
|
| 500 |
+
alpha=bootstrap_alpha,
|
| 501 |
+
random_state=config.seed,
|
| 502 |
+
)
|
| 503 |
+
for metric_name, stats in ci_map.items():
|
| 504 |
+
row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
|
| 505 |
+
row["train_cv_ci_lower"] = stats.get("ci_lower")
|
| 506 |
+
row["train_cv_ci_upper"] = stats.get("ci_upper")
|
| 507 |
+
row["train_cv_ci_median"] = stats.get("ci_median")
|
| 508 |
+
|
| 509 |
+
if write_train_in_sample:
|
| 510 |
+
train_features_full = pipeline.transform(
|
| 511 |
+
train_df, heavy_only=heavy_only, batch_size=batch_size
|
| 512 |
+
)
|
| 513 |
+
train_full_scores = trained_model.predict_proba(train_features_full)
|
| 514 |
+
train_full_metrics = compute_metrics(
|
| 515 |
+
train_df["label"].to_numpy(dtype=int), train_full_scores
|
| 516 |
+
)
|
| 517 |
+
(outputs_dir / "train_in_sample.json").write_text(
|
| 518 |
+
json.dumps(train_full_metrics, indent=2),
|
| 519 |
+
encoding="utf-8",
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
for frame in eval_frames:
|
| 523 |
+
if frame.empty:
|
| 524 |
+
continue
|
| 525 |
+
features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
|
| 526 |
+
scores = trained_model.predict_proba(features)
|
| 527 |
+
y_true = frame["label"].to_numpy(dtype=int)
|
| 528 |
+
metrics = compute_metrics(y_true, scores)
|
| 529 |
+
dataset_name = frame["source"].iloc[0]
|
| 530 |
+
for metric_name, value in metrics.items():
|
| 531 |
+
metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
|
| 532 |
+
dataset_name
|
| 533 |
+
] = value
|
| 534 |
+
|
| 535 |
+
preds = frame[["id", "source", "label"]].copy()
|
| 536 |
+
preds["y_true"] = preds["label"]
|
| 537 |
+
preds["y_score"] = scores
|
| 538 |
+
preds["y_pred"] = (scores >= 0.5).astype(int)
|
| 539 |
+
preds["split"] = dataset_name
|
| 540 |
+
preds_rows.extend(
|
| 541 |
+
preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
plot_reliability_curve(
|
| 545 |
+
preds["y_true"],
|
| 546 |
+
preds["y_score"],
|
| 547 |
+
path=outputs_dir / f"reliability_{dataset_name}.png",
|
| 548 |
+
)
|
| 549 |
+
plot_precision_recall(
|
| 550 |
+
preds["y_true"],
|
| 551 |
+
preds["y_score"],
|
| 552 |
+
path=outputs_dir / f"pr_{dataset_name}.png",
|
| 553 |
+
)
|
| 554 |
+
plot_roc_curve(
|
| 555 |
+
preds["y_true"], preds["y_score"], path=outputs_dir / f"roc_{dataset_name}.png"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if bootstrap_samples > 0:
|
| 559 |
+
ci_map = bootstrap_metric_intervals(
|
| 560 |
+
preds["y_true"],
|
| 561 |
+
preds["y_score"],
|
| 562 |
+
n_bootstrap=bootstrap_samples,
|
| 563 |
+
alpha=bootstrap_alpha,
|
| 564 |
+
random_state=config.seed,
|
| 565 |
+
)
|
| 566 |
+
for metric_name, stats in ci_map.items():
|
| 567 |
+
row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
|
| 568 |
+
row[f"{dataset_name}_ci_lower"] = stats.get("ci_lower")
|
| 569 |
+
row[f"{dataset_name}_ci_upper"] = stats.get("ci_upper")
|
| 570 |
+
row[f"{dataset_name}_ci_median"] = stats.get("ci_median")
|
| 571 |
+
|
| 572 |
+
metrics_df = pd.DataFrame(sorted(metrics_lookup.values(), key=lambda row: row["metric"]))
|
| 573 |
+
return metrics_df, preds_rows
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _fit_model_calibrator(
|
| 577 |
+
model_config: LinearModelConfig,
|
| 578 |
+
config: Config,
|
| 579 |
+
X: np.ndarray,
|
| 580 |
+
y: np.ndarray,
|
| 581 |
+
*,
|
| 582 |
+
base_estimator: Any | None = None,
|
| 583 |
+
):
|
| 584 |
+
method = config.calibration.method
|
| 585 |
+
if not method:
|
| 586 |
+
return None
|
| 587 |
+
if len(np.unique(y)) < 2:
|
| 588 |
+
return None
|
| 589 |
+
|
| 590 |
+
if len(y) >= 4:
|
| 591 |
+
cv_cal = min(config.training.cv_folds, max(2, len(y) // 2))
|
| 592 |
+
estimator = build_estimator(config=model_config, random_state=config.seed)
|
| 593 |
+
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
|
| 594 |
+
estimator.set_params(solver="lbfgs")
|
| 595 |
+
calibrator = fit_calibrator(estimator, X, y, method=method, cv=cv_cal)
|
| 596 |
+
else:
|
| 597 |
+
estimator = base_estimator or build_estimator(config=model_config, random_state=config.seed)
|
| 598 |
+
if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
|
| 599 |
+
estimator.set_params(solver="lbfgs")
|
| 600 |
+
estimator.fit(X, y)
|
| 601 |
+
calibrator = fit_calibrator(estimator, X, y, method=method, cv="prefit")
|
| 602 |
+
return calibrator
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def _log_environment(logger) -> None:
|
| 606 |
+
try:
|
| 607 |
+
git_head = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
|
| 608 |
+
except Exception: # pragma: no cover - best effort
|
| 609 |
+
git_head = "unknown"
|
| 610 |
+
try:
|
| 611 |
+
pip_freeze = subprocess.check_output(["pip", "freeze"], text=True)
|
| 612 |
+
except Exception: # pragma: no cover
|
| 613 |
+
pip_freeze = ""
|
| 614 |
+
logger.info("git_head=%s", git_head)
|
| 615 |
+
logger.info("pip_freeze=%s", pip_freeze)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
if __name__ == "__main__":
|
| 619 |
+
raise SystemExit(main())
|
polyreact/utils/__pycache__/io.cpython-311.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
polyreact/utils/__pycache__/logging.cpython-311.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
polyreact/utils/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (7.67 kB). View file
|
|
|