makiling commited on
Commit
5f58699
·
verified ·
1 Parent(s): 7724405

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. polyreact/__init__.py +10 -0
  2. polyreact/__pycache__/__init__.cpython-311.pyc +0 -0
  3. polyreact/__pycache__/api.cpython-311.pyc +0 -0
  4. polyreact/__pycache__/config.cpython-311.pyc +0 -0
  5. polyreact/__pycache__/predict.cpython-311.pyc +0 -0
  6. polyreact/__pycache__/train.cpython-311.pyc +0 -0
  7. polyreact/api.py +121 -0
  8. polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc +0 -0
  9. polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc +0 -0
  10. polyreact/benchmarks/reproduce_paper.ipynb +25 -0
  11. polyreact/benchmarks/reproduce_paper.py +1020 -0
  12. polyreact/benchmarks/run_benchmarks.py +114 -0
  13. polyreact/config.py +160 -0
  14. polyreact/configs/__init__.py +1 -0
  15. polyreact/configs/default.yaml +34 -0
  16. polyreact/data_loaders/__init__.py +3 -0
  17. polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  18. polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc +0 -0
  19. polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc +0 -0
  20. polyreact/data_loaders/__pycache__/jain.cpython-311.pyc +0 -0
  21. polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc +0 -0
  22. polyreact/data_loaders/__pycache__/utils.cpython-311.pyc +0 -0
  23. polyreact/data_loaders/boughter.py +76 -0
  24. polyreact/data_loaders/harvey.py +39 -0
  25. polyreact/data_loaders/jain.py +41 -0
  26. polyreact/data_loaders/shehata.py +202 -0
  27. polyreact/data_loaders/utils.py +186 -0
  28. polyreact/features/__init__.py +13 -0
  29. polyreact/features/__pycache__/__init__.cpython-311.pyc +0 -0
  30. polyreact/features/__pycache__/anarsi.cpython-311.pyc +0 -0
  31. polyreact/features/__pycache__/descriptors.cpython-311.pyc +0 -0
  32. polyreact/features/__pycache__/pipeline.cpython-311.pyc +0 -0
  33. polyreact/features/__pycache__/plm.cpython-311.pyc +0 -0
  34. polyreact/features/anarsi.py +222 -0
  35. polyreact/features/descriptors.py +146 -0
  36. polyreact/features/pipeline.py +343 -0
  37. polyreact/features/plm.py +378 -0
  38. polyreact/models/__init__.py +3 -0
  39. polyreact/models/__pycache__/__init__.cpython-311.pyc +0 -0
  40. polyreact/models/__pycache__/calibrate.cpython-311.pyc +0 -0
  41. polyreact/models/__pycache__/linear.cpython-311.pyc +0 -0
  42. polyreact/models/__pycache__/ordinal.cpython-311.pyc +0 -0
  43. polyreact/models/calibrate.py +24 -0
  44. polyreact/models/linear.py +91 -0
  45. polyreact/models/ordinal.py +106 -0
  46. polyreact/predict.py +106 -0
  47. polyreact/train.py +619 -0
  48. polyreact/utils/__pycache__/io.cpython-311.pyc +0 -0
  49. polyreact/utils/__pycache__/logging.cpython-311.pyc +0 -0
  50. polyreact/utils/__pycache__/metrics.cpython-311.pyc +0 -0
polyreact/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Polyreactivity prediction package."""
2
+
3
+ from importlib import metadata
4
+
5
+ __all__ = ["__version__"]
6
+
7
+ try:
8
+ __version__ = metadata.version("polyreact")
9
+ except metadata.PackageNotFoundError: # pragma: no cover
10
+ __version__ = "0.0.0"
polyreact/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (502 Bytes). View file
 
polyreact/__pycache__/api.cpython-311.pyc ADDED
Binary file (6.1 kB). View file
 
polyreact/__pycache__/config.cpython-311.pyc ADDED
Binary file (9.9 kB). View file
 
polyreact/__pycache__/predict.cpython-311.pyc ADDED
Binary file (4.43 kB). View file
 
polyreact/__pycache__/train.cpython-311.pyc ADDED
Binary file (28.1 kB). View file
 
polyreact/api.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public Python API for polyreactivity prediction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Iterable
7
+
8
+ import copy
9
+
10
+ import joblib
11
+ import pandas as pd
12
+ from sklearn.preprocessing import StandardScaler
13
+
14
+ from .config import Config, load_config
15
+ from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
16
+
17
+
18
+ def predict_batch( # noqa: ANN003
19
+ records: Iterable[dict],
20
+ *,
21
+ config: Config | str | Path | None = None,
22
+ backend: str | None = None,
23
+ plm_model: str | None = None,
24
+ weights: str | Path | None = None,
25
+ heavy_only: bool = True,
26
+ batch_size: int = 8,
27
+ device: str | None = None,
28
+ cache_dir: str | None = None,
29
+ ) -> pd.DataFrame:
30
+ """Predict polyreactivity scores for a batch of sequences."""
31
+
32
+ records_list = list(records)
33
+ if not records_list:
34
+ return pd.DataFrame(columns=["id", "score", "pred"])
35
+
36
+ artifact = _load_artifact(weights)
37
+
38
+ if config is None:
39
+ artifact_config = artifact.get("config")
40
+ if isinstance(artifact_config, Config):
41
+ config = copy.deepcopy(artifact_config)
42
+ else:
43
+ config = load_config("configs/default.yaml")
44
+ elif isinstance(config, (str, Path)):
45
+ config = load_config(config)
46
+ else:
47
+ config = copy.deepcopy(config)
48
+
49
+ if backend:
50
+ config.feature_backend.type = backend
51
+ if plm_model:
52
+ config.feature_backend.plm_model_name = plm_model
53
+ if device:
54
+ config.device = device
55
+ if cache_dir:
56
+ config.feature_backend.cache_dir = cache_dir
57
+
58
+ pipeline = _restore_pipeline(config, artifact)
59
+ trained_model = artifact["model"]
60
+
61
+ frame = pd.DataFrame(records_list)
62
+ if frame.empty:
63
+ raise ValueError("Prediction requires at least one record.")
64
+ if "id" not in frame.columns:
65
+ frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str)
66
+ if "heavy_seq" in frame.columns:
67
+ frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
68
+ else:
69
+ heavy_series = frame.get("heavy")
70
+ if heavy_series is None:
71
+ heavy_series = pd.Series([""] * len(frame))
72
+ frame["heavy_seq"] = heavy_series.fillna("").astype(str)
73
+
74
+ if "light_seq" in frame.columns:
75
+ frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
76
+ else:
77
+ light_series = frame.get("light")
78
+ if light_series is None:
79
+ light_series = pd.Series([""] * len(frame))
80
+ frame["light_seq"] = light_series.fillna("").astype(str)
81
+
82
+ if heavy_only:
83
+ frame["light_seq"] = ""
84
+ if frame["heavy_seq"].str.len().eq(0).all():
85
+ raise ValueError("No heavy chain sequences provided for prediction.")
86
+
87
+ features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
88
+ scores = trained_model.predict_proba(features)
89
+ preds = (scores >= 0.5).astype(int)
90
+
91
+ return pd.DataFrame(
92
+ {
93
+ "id": frame["id"].astype(str),
94
+ "score": scores,
95
+ "pred": preds,
96
+ }
97
+ )
98
+
99
+
100
+ def _load_artifact(weights: str | Path | None) -> dict:
101
+ if weights is None:
102
+ msg = "Prediction requires a path to model weights"
103
+ raise ValueError(msg)
104
+ artifact = joblib.load(weights)
105
+ if not isinstance(artifact, dict):
106
+ msg = "Model artifact must be a dictionary"
107
+ raise ValueError(msg)
108
+ return artifact
109
+
110
+
111
+ def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline:
112
+ pipeline = build_feature_pipeline(config)
113
+ state = artifact.get("feature_state")
114
+ if isinstance(state, FeaturePipelineState):
115
+ pipeline.load_state(state)
116
+ if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None:
117
+ pipeline._plm_scaler = StandardScaler()
118
+ return pipeline
119
+
120
+ msg = "Model artifact is missing feature pipeline state"
121
+ raise ValueError(msg)
polyreact/benchmarks/__pycache__/reproduce_paper.cpython-311.pyc ADDED
Binary file (46.4 kB). View file
 
polyreact/benchmarks/__pycache__/run_benchmarks.cpython-311.pyc ADDED
Binary file (5.21 kB). View file
 
polyreact/benchmarks/reproduce_paper.ipynb ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Polyreactivity Benchmark Notebook\n",
8
+ "\n",
9
+ "This notebook will reproduce paper results once the pipeline is implemented." ]
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3",
15
+ "language": "python",
16
+ "name": "python3"
17
+ },
18
+ "language_info": {
19
+ "name": "python",
20
+ "version": "3.11"
21
+ }
22
+ },
23
+ "nbformat": 4,
24
+ "nbformat_minor": 5
25
+ }
polyreact/benchmarks/reproduce_paper.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reproduce key metrics and visualisations for the polyreactivity model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import copy
7
+ import json
8
+ import subprocess
9
+ from dataclasses import asdict, dataclass
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Iterable, List, Sequence
12
+
13
+ import joblib
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import pandas as pd
17
+ import yaml
18
+ from scipy.stats import pearsonr, spearmanr
19
+ from sklearn.linear_model import LogisticRegression
20
+ from sklearn.metrics import roc_curve
21
+ from sklearn.model_selection import KFold
22
+ from sklearn.preprocessing import StandardScaler
23
+
24
+ from polyreact import train as train_module
25
+ from polyreact.config import load_config
26
+ from polyreact.features.anarsi import AnarciNumberer
27
+ from polyreact.features.pipeline import FeaturePipeline
28
+ from polyreact.models.ordinal import (
29
+ fit_negative_binomial_model,
30
+ fit_poisson_model,
31
+ pearson_dispersion,
32
+ regression_metrics,
33
+ )
34
+
35
+
36
+ @dataclass(slots=True)
37
+ class DatasetSpec:
38
+ name: str
39
+ path: Path
40
+ display: str
41
+
42
+
43
+ DISPLAY_LABELS = {
44
+ "jain": "Jain (2017)",
45
+ "shehata": "Shehata PSR (398)",
46
+ "shehata_curated": "Shehata curated (88)",
47
+ "harvey": "Harvey (2022)",
48
+ }
49
+
50
+ RAW_LABELS = {
51
+ "jain": "jain2017",
52
+ "shehata": "shehata2019",
53
+ "shehata_curated": "shehata2019_curated",
54
+ "harvey": "harvey2022",
55
+ }
56
+
57
+
58
+ def build_parser() -> argparse.ArgumentParser:
59
+ parser = argparse.ArgumentParser(description="Reproduce paper-style metrics and plots")
60
+ parser.add_argument(
61
+ "--train-data",
62
+ default="data/processed/boughter_counts_rebuilt.csv",
63
+ help="Reconstructed Boughter dataset path.",
64
+ )
65
+ parser.add_argument(
66
+ "--full-data",
67
+ default="data/processed/boughter_counts_rebuilt.csv",
68
+ help="Dataset (including mild flags) for correlation analysis.",
69
+ )
70
+ parser.add_argument("--jain", default="data/processed/jain.csv")
71
+ parser.add_argument(
72
+ "--shehata",
73
+ default="data/processed/shehata_full.csv",
74
+ help="Full Shehata PSR panel (398 sequences) in processed CSV form.",
75
+ )
76
+ parser.add_argument(
77
+ "--shehata-curated",
78
+ default="data/processed/shehata_curated.csv",
79
+ help="Optional curated subset of Shehata et al. (88 sequences).",
80
+ )
81
+ parser.add_argument("--harvey", default="data/processed/harvey.csv")
82
+ parser.add_argument("--output-dir", default="artifacts/paper")
83
+ parser.add_argument("--config", default="configs/default.yaml")
84
+ parser.add_argument("--batch-size", type=int, default=16)
85
+ parser.add_argument("--rebuild", action="store_true")
86
+ parser.add_argument(
87
+ "--bootstrap-samples",
88
+ type=int,
89
+ default=1000,
90
+ help="Bootstrap resamples for metrics confidence intervals.",
91
+ )
92
+ parser.add_argument(
93
+ "--bootstrap-alpha",
94
+ type=float,
95
+ default=0.05,
96
+ help="Alpha for bootstrap confidence intervals (default 0.05 → 95%).",
97
+ )
98
+ parser.add_argument(
99
+ "--human-only",
100
+ action="store_true",
101
+ help=(
102
+ "Restrict the main cross-validation run to human HIV and influenza families"
103
+ " (legacy behaviour). By default all Boughter families, including mouse IgA,"
104
+ " participate in CV as in Sakhnini et al."
105
+ ),
106
+ )
107
+ parser.add_argument(
108
+ "--skip-flag-regression",
109
+ action="store_true",
110
+ help="Skip ELISA flag regression diagnostics (Poisson/NB).",
111
+ )
112
+ parser.add_argument(
113
+ "--skip-lofo",
114
+ action="store_true",
115
+ help="Skip leave-one-family-out experiments.",
116
+ )
117
+ parser.add_argument(
118
+ "--skip-descriptor-variants",
119
+ action="store_true",
120
+ help="Skip descriptor-only benchmark variants.",
121
+ )
122
+ parser.add_argument(
123
+ "--skip-fragment-variants",
124
+ action="store_true",
125
+ help="Skip CDR fragment ablation benchmarks.",
126
+ )
127
+ return parser
128
+
129
+
130
+ def _config_to_dict(config) -> Dict[str, Any]:
131
+ data = asdict(config)
132
+ data.pop("raw", None)
133
+ return data
134
+
135
+
136
+ def _deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
137
+ result = copy.deepcopy(base)
138
+ for key, value in overrides.items():
139
+ if isinstance(value, dict) and isinstance(result.get(key), dict):
140
+ result[key] = _deep_merge(result.get(key, {}), value)
141
+ else:
142
+ result[key] = value
143
+ return result
144
+
145
+
146
+ def _write_variant_config(
147
+ base_config: Dict[str, Any],
148
+ overrides: Dict[str, Any],
149
+ target_path: Path,
150
+ ) -> Path:
151
+ merged = _deep_merge(base_config, overrides)
152
+ target_path.parent.mkdir(parents=True, exist_ok=True)
153
+ with target_path.open("w", encoding="utf-8") as handle:
154
+ yaml.safe_dump(merged, handle, sort_keys=False)
155
+ return target_path
156
+
157
+
158
+ def _collect_metric_records(variant: str, metrics: pd.DataFrame) -> list[dict[str, Any]]:
159
+ tracked = {
160
+ "roc_auc",
161
+ "pr_auc",
162
+ "accuracy",
163
+ "f1",
164
+ "f1_positive",
165
+ "f1_negative",
166
+ "precision",
167
+ "sensitivity",
168
+ "specificity",
169
+ "brier",
170
+ "ece",
171
+ "mce",
172
+ }
173
+ records: list[dict[str, Any]] = []
174
+ for _, row in metrics.iterrows():
175
+ metric_name = row["metric"]
176
+ if metric_name not in tracked:
177
+ continue
178
+ record = {"variant": variant, "metric": metric_name}
179
+ for column in metrics.columns:
180
+ if column == "metric":
181
+ continue
182
+ record[column] = float(row[column]) if pd.notna(row[column]) else np.nan
183
+ records.append(record)
184
+ return records
185
+
186
+
187
+ def _dump_coefficients(model_path: Path, output_path: Path) -> None:
188
+ artifact = joblib.load(model_path)
189
+ trained = artifact["model"]
190
+ estimator = getattr(trained, "estimator", None)
191
+ if estimator is None or not hasattr(estimator, "coef_"):
192
+ return
193
+ coefs = estimator.coef_[0]
194
+ feature_state = artifact.get("feature_state")
195
+ feature_names: list[str]
196
+ if feature_state is not None and getattr(feature_state, "feature_names", None):
197
+ feature_names = list(feature_state.feature_names)
198
+ else:
199
+ feature_names = [f"f{i}" for i in range(len(coefs))]
200
+ coeff_df = pd.DataFrame(
201
+ {
202
+ "feature": feature_names,
203
+ "coef": coefs,
204
+ "abs_coef": np.abs(coefs),
205
+ }
206
+ ).sort_values("abs_coef", ascending=False)
207
+ coeff_df.to_csv(output_path, index=False)
208
+
209
+
210
+ def _summarise_predictions(preds: pd.DataFrame) -> pd.DataFrame:
211
+ records: list[dict[str, Any]] = []
212
+ for split, group in preds.groupby("split"):
213
+ stats = {
214
+ "split": split,
215
+ "n_samples": int(len(group)),
216
+ "positives": int(group["y_true"].sum()),
217
+ "positive_rate": float(group["y_true"].mean()) if len(group) else np.nan,
218
+ "score_mean": float(group["y_score"].mean()) if len(group) else np.nan,
219
+ "score_std": float(group["y_score"].std(ddof=1)) if len(group) > 1 else np.nan,
220
+ }
221
+ records.append(stats)
222
+ return pd.DataFrame(records)
223
+
224
+
225
+ def _summarise_raw_dataset(path: Path, name: str) -> dict[str, Any]:
226
+ df = pd.read_csv(path)
227
+ summary: dict[str, Any] = {
228
+ "dataset": name,
229
+ "path": str(path),
230
+ "rows": int(len(df)),
231
+ }
232
+ if "label" in df.columns:
233
+ positives = int(df["label"].sum())
234
+ summary["positives"] = positives
235
+ summary["positive_rate"] = float(df["label"].mean()) if len(df) else np.nan
236
+ if "reactivity_count" in df.columns:
237
+ summary["reactivity_count_mean"] = float(df["reactivity_count"].mean())
238
+ summary["reactivity_count_median"] = float(df["reactivity_count"].median())
239
+ summary["reactivity_count_max"] = int(df["reactivity_count"].max())
240
+ if "smp" in df.columns:
241
+ summary["smp_mean"] = float(df["smp"].mean())
242
+ summary["smp_median"] = float(df["smp"].median())
243
+ summary["smp_max"] = float(df["smp"].max())
244
+ summary["smp_min"] = float(df["smp"].min())
245
+ summary["unique_heavy"] = int(df["heavy_seq"].nunique()) if "heavy_seq" in df.columns else np.nan
246
+ return summary
247
+
248
+
249
+ def _extract_region_sequence(sequence: str, regions: List[str], numberer: AnarciNumberer) -> str:
250
+ if not sequence:
251
+ return ""
252
+ upper_regions = [region.upper() for region in regions]
253
+ if upper_regions == ["VH"]:
254
+ return sequence
255
+ try:
256
+ numbered = numberer.number_sequence(sequence)
257
+ except Exception:
258
+ return ""
259
+ fragments: list[str] = []
260
+ for region in upper_regions:
261
+ if region == "VH":
262
+ return sequence
263
+ fragment = numbered.regions.get(region)
264
+ if not fragment:
265
+ return ""
266
+ fragments.append(fragment)
267
+ return "".join(fragments)
268
+
269
+
270
+ def _make_region_dataset(
271
+ frame: pd.DataFrame, regions: List[str], numberer: AnarciNumberer
272
+ ) -> tuple[pd.DataFrame, dict[str, Any]]:
273
+ records: list[dict[str, Any]] = []
274
+ dropped = 0
275
+ for record in frame.to_dict(orient="records"):
276
+ new_seq = _extract_region_sequence(record.get("heavy_seq", ""), regions, numberer)
277
+ if not new_seq:
278
+ dropped += 1
279
+ continue
280
+ updated = record.copy()
281
+ updated["heavy_seq"] = new_seq
282
+ updated["light_seq"] = ""
283
+ records.append(updated)
284
+ result = pd.DataFrame(records, columns=frame.columns)
285
+ summary = {
286
+ "regions": "+".join(regions),
287
+ "input_rows": int(len(frame)),
288
+ "retained_rows": int(len(result)),
289
+ "dropped_rows": int(dropped),
290
+ }
291
+ return result, summary
292
+
293
+
294
+ def run_train(
295
+ *,
296
+ train_path: Path,
297
+ eval_specs: Sequence[DatasetSpec],
298
+ output_dir: Path,
299
+ model_path: Path,
300
+ config: str,
301
+ batch_size: int,
302
+ include_species: list[str] | None = None,
303
+ include_families: list[str] | None = None,
304
+ exclude_families: list[str] | None = None,
305
+ keep_duplicates: bool = False,
306
+ group_column: str | None = "lineage",
307
+ train_loader: str | None = None,
308
+ bootstrap_samples: int = 200,
309
+ bootstrap_alpha: float = 0.05,
310
+ ) -> None:
311
+ args: list[str] = [
312
+ "--config",
313
+ str(config),
314
+ "--train",
315
+ str(train_path),
316
+ "--report-to",
317
+ str(output_dir),
318
+ "--save-to",
319
+ str(model_path),
320
+ "--batch-size",
321
+ str(batch_size),
322
+ ]
323
+
324
+ if eval_specs:
325
+ args.append("--eval")
326
+ args.extend(str(spec.path) for spec in eval_specs)
327
+
328
+ if train_loader:
329
+ args.extend(["--train-loader", train_loader])
330
+ if eval_specs:
331
+ args.append("--eval-loaders")
332
+ args.extend(spec.name for spec in eval_specs)
333
+ if include_species:
334
+ args.append("--include-species")
335
+ args.extend(include_species)
336
+ if include_families:
337
+ args.append("--include-families")
338
+ args.extend(include_families)
339
+ if exclude_families:
340
+ args.append("--exclude-families")
341
+ args.extend(exclude_families)
342
+ if keep_duplicates:
343
+ args.append("--keep-train-duplicates")
344
+ if group_column:
345
+ args.extend(["--cv-group-column", group_column])
346
+ else:
347
+ args.append("--no-group-cv")
348
+ args.extend(["--bootstrap-samples", str(bootstrap_samples)])
349
+ args.extend(["--bootstrap-alpha", str(bootstrap_alpha)])
350
+
351
+ exit_code = train_module.main(args)
352
+ if exit_code != 0:
353
+ raise RuntimeError(f"Training command failed with exit code {exit_code}")
354
+
355
+
356
+ def compute_spearman(model_path: Path, dataset_path: Path, batch_size: int) -> tuple[float, float, pd.DataFrame]:
357
+ artifact = joblib.load(model_path)
358
+ config = artifact["config"]
359
+ pipeline_state = artifact["feature_state"]
360
+ trained_model = artifact["model"]
361
+
362
+ pipeline = FeaturePipeline(backend=config.feature_backend, descriptors=config.descriptors, device=config.device)
363
+ pipeline.load_state(pipeline_state)
364
+
365
+ dataset = pd.read_csv(dataset_path)
366
+ features = pipeline.transform(dataset, heavy_only=True, batch_size=batch_size)
367
+ scores = trained_model.predict_proba(features)
368
+ dataset = dataset.copy()
369
+ dataset["score"] = scores
370
+
371
+ stat, pvalue = spearmanr(dataset["reactivity_count"], dataset["score"])
372
+ return float(stat), float(pvalue), dataset
373
+
374
+
375
+ def plot_accuracy(
376
+ metrics: pd.DataFrame,
377
+ output_path: Path,
378
+ eval_specs: Sequence[DatasetSpec],
379
+ ) -> None:
380
+ row = metrics.loc[metrics["metric"] == "accuracy"].iloc[0]
381
+ labels = ["Train CV"] + [spec.display for spec in eval_specs]
382
+ values = [row.get("train_cv_mean", np.nan)] + [row.get(spec.name, np.nan) for spec in eval_specs]
383
+
384
+ fig, ax = plt.subplots(figsize=(6, 4))
385
+ xs = np.arange(len(labels))
386
+ ax.bar(xs, values, color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"])
387
+ ax.set_xticks(xs, labels)
388
+ ax.set_ylim(0.0, 1.05)
389
+ ax.set_ylabel("Accuracy")
390
+ ax.set_title("Polyreactivity accuracy overview")
391
+ for x, val in zip(xs, values, strict=False):
392
+ if np.isnan(val):
393
+ continue
394
+ ax.text(x, val + 0.02, f"{val:.3f}", ha="center", va="bottom")
395
+ fig.tight_layout()
396
+ fig.savefig(output_path, dpi=300)
397
+ plt.close(fig)
398
+
399
+
400
+ def plot_rocs(
401
+ preds: pd.DataFrame,
402
+ output_path: Path,
403
+ eval_specs: Sequence[DatasetSpec],
404
+ ) -> None:
405
+ mapping = {"train_cv_oof": "Train CV"}
406
+ for spec in eval_specs:
407
+ mapping[spec.name] = spec.display
408
+ fig, ax = plt.subplots(figsize=(6, 6))
409
+ for split, label in mapping.items():
410
+ subset = preds[preds["split"] == split]
411
+ if subset.empty:
412
+ continue
413
+ fpr, tpr, _ = roc_curve(subset["y_true"], subset["y_score"])
414
+ ax.plot(fpr, tpr, label=label)
415
+ ax.plot([0, 1], [0, 1], linestyle="--", color="gray")
416
+ ax.set_xlabel("False positive rate")
417
+ ax.set_ylabel("True positive rate")
418
+ ax.set_title("ROC curves")
419
+ ax.legend()
420
+ fig.tight_layout()
421
+ fig.savefig(output_path, dpi=300)
422
+ plt.close(fig)
423
+
424
+
425
+ def plot_flags_scatter(data: pd.DataFrame, spearman_stat: float, output_path: Path) -> None:
426
+ rng = np.random.default_rng(42)
427
+ jitter = rng.uniform(-0.1, 0.1, size=len(data))
428
+ x = data["reactivity_count"].to_numpy(dtype=float) + jitter
429
+ y = data["score"].to_numpy(dtype=float)
430
+
431
+ fig, ax = plt.subplots(figsize=(6, 4))
432
+ ax.scatter(x, y, alpha=0.5, s=10)
433
+ ax.set_xlabel("ELISA flag count")
434
+ ax.set_ylabel("Predicted probability")
435
+ ax.set_title(f"Prediction vs flag count (Spearman={spearman_stat:.2f})")
436
+ fig.tight_layout()
437
+ fig.savefig(output_path, dpi=300)
438
+ plt.close(fig)
439
+
440
+
441
+ def run_lofo(
442
+ full_df: pd.DataFrame,
443
+ *,
444
+ families: list[str],
445
+ config: str,
446
+ batch_size: int,
447
+ output_dir: Path,
448
+ bootstrap_samples: int,
449
+ bootstrap_alpha: float,
450
+ ) -> pd.DataFrame:
451
+ results: list[dict[str, float]] = []
452
+ for family in families:
453
+ family_lower = family.lower()
454
+ holdout = full_df[full_df["family"].str.lower() == family_lower].copy()
455
+ train = full_df[full_df["family"].str.lower() != family_lower].copy()
456
+ if holdout.empty or train.empty:
457
+ continue
458
+
459
+ train_path = output_dir / f"train_lofo_{family_lower}.csv"
460
+ holdout_path = output_dir / f"eval_lofo_{family_lower}.csv"
461
+ train.to_csv(train_path, index=False)
462
+ holdout.to_csv(holdout_path, index=False)
463
+
464
+ run_dir = output_dir / f"lofo_{family_lower}"
465
+ run_dir.mkdir(parents=True, exist_ok=True)
466
+ model_path = run_dir / "model.joblib"
467
+
468
+ run_train(
469
+ train_path=train_path,
470
+ eval_specs=[
471
+ DatasetSpec(
472
+ name="boughter",
473
+ path=holdout_path,
474
+ display=f"{family.title()} holdout",
475
+ )
476
+ ],
477
+ output_dir=run_dir,
478
+ model_path=model_path,
479
+ config=config,
480
+ batch_size=batch_size,
481
+ keep_duplicates=True,
482
+ include_species=None,
483
+ include_families=None,
484
+ exclude_families=None,
485
+ group_column="lineage",
486
+ train_loader="boughter",
487
+ bootstrap_samples=bootstrap_samples,
488
+ bootstrap_alpha=bootstrap_alpha,
489
+ )
490
+
491
+ metrics = pd.read_csv(run_dir / "metrics.csv")
492
+ evaluation_cols = [
493
+ col
494
+ for col in metrics.columns
495
+ if col not in {"metric", "train_cv_mean", "train_cv_std"}
496
+ ]
497
+ if not evaluation_cols:
498
+ continue
499
+ eval_col = evaluation_cols[0]
500
+ def _metric_value(name: str) -> float:
501
+ series = metrics.loc[metrics["metric"] == name, eval_col]
502
+ return float(series.values[0]) if not series.empty else float("nan")
503
+
504
+ results.append(
505
+ {
506
+ "family": family,
507
+ "accuracy": _metric_value("accuracy"),
508
+ "roc_auc": _metric_value("roc_auc"),
509
+ "pr_auc": _metric_value("pr_auc"),
510
+ "sensitivity": _metric_value("sensitivity"),
511
+ "specificity": _metric_value("specificity"),
512
+ }
513
+ )
514
+
515
+ return pd.DataFrame(results)
516
+
517
+
518
+
519
+ def run_flag_regression(
520
+ train_path: Path,
521
+ *,
522
+ output_dir: Path,
523
+ config_path: str,
524
+ batch_size: int,
525
+ n_splits: int = 5,
526
+ ) -> None:
527
+ df = pd.read_csv(train_path)
528
+ if "reactivity_count" not in df.columns:
529
+ return
530
+
531
+ config = load_config(config_path)
532
+ kfold = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed)
533
+
534
+ metrics_rows: list[dict[str, float]] = []
535
+ preds_rows: list[dict[str, float]] = []
536
+
537
+ for fold_idx, (train_idx, val_idx) in enumerate(kfold.split(df), start=1):
538
+ train_split = df.iloc[train_idx].reset_index(drop=True)
539
+ val_split = df.iloc[val_idx].reset_index(drop=True)
540
+
541
+ pipeline = FeaturePipeline(
542
+ backend=config.feature_backend,
543
+ descriptors=config.descriptors,
544
+ device=config.device,
545
+ )
546
+ X_train = pipeline.fit_transform(train_split, heavy_only=True, batch_size=batch_size)
547
+ scaler = StandardScaler()
548
+ X_train_scaled = scaler.fit_transform(X_train)
549
+ y_train = train_split["reactivity_count"].to_numpy(dtype=float)
550
+ # Train a logistic head to obtain probabilities as a 1-D feature
551
+ clf = LogisticRegression(
552
+ C=config.model.C,
553
+ class_weight=config.model.class_weight,
554
+ max_iter=2000,
555
+ solver="lbfgs",
556
+ )
557
+ clf.fit(X_train_scaled, train_split["label"].to_numpy(dtype=int))
558
+ prob_train = clf.predict_proba(X_train_scaled)[:, 1]
559
+
560
+ X_val = pipeline.transform(val_split, heavy_only=True, batch_size=batch_size)
561
+ X_val_scaled = scaler.transform(X_val)
562
+ y_val = val_split["reactivity_count"].to_numpy(dtype=float)
563
+ prob_val = clf.predict_proba(X_val_scaled)[:, 1]
564
+
565
+ poisson_X_train = prob_train.reshape(-1, 1)
566
+ poisson_X_val = prob_val.reshape(-1, 1)
567
+ model = fit_poisson_model(poisson_X_train, y_train)
568
+ poisson_preds = model.predict(poisson_X_val)
569
+
570
+ n_params = poisson_X_train.shape[1] + 1 # include intercept
571
+ dof = max(len(y_val) - n_params, 1)
572
+ variance_to_mean = float(np.var(y_val, ddof=1) / np.mean(y_val)) if np.mean(y_val) else float("nan")
573
+
574
+ spearman_val = float(spearmanr(y_val, poisson_preds).statistic)
575
+ try:
576
+ pearson_val = float(pearsonr(y_val, poisson_preds)[0])
577
+ except Exception: # pragma: no cover - fallback if correlation fails
578
+ pearson_val = float("nan")
579
+
580
+ poisson_metrics = regression_metrics(y_val, poisson_preds)
581
+ poisson_metrics.update(
582
+ {
583
+ "spearman": spearman_val,
584
+ "pearson": pearson_val,
585
+ "pearson_dispersion": pearson_dispersion(y_val, poisson_preds, dof=dof),
586
+ "variance_to_mean": variance_to_mean,
587
+ "fold": fold_idx,
588
+ "model": "poisson",
589
+ "status": "ok",
590
+ }
591
+ )
592
+ metrics_rows.append(poisson_metrics)
593
+
594
+ nb_preds: np.ndarray | None = None
595
+ nb_model = None
596
+ try:
597
+ nb_model = fit_negative_binomial_model(poisson_X_train, y_train)
598
+ nb_preds = nb_model.predict(poisson_X_val)
599
+ if not np.all(np.isfinite(nb_preds)):
600
+ raise ValueError("negative binomial produced non-finite predictions")
601
+ except Exception:
602
+ nb_metrics = {
603
+ "spearman": float("nan"),
604
+ "pearson": float("nan"),
605
+ "pearson_dispersion": float("nan"),
606
+ "variance_to_mean": variance_to_mean,
607
+ "alpha": float("nan"),
608
+ "fold": fold_idx,
609
+ "model": "negative_binomial",
610
+ "status": "failed",
611
+ }
612
+ metrics_rows.append(nb_metrics)
613
+ else:
614
+ spearman_nb = float(spearmanr(y_val, nb_preds).statistic)
615
+ try:
616
+ pearson_nb = float(pearsonr(y_val, nb_preds)[0])
617
+ except Exception: # pragma: no cover
618
+ pearson_nb = float("nan")
619
+
620
+ nb_metrics = regression_metrics(y_val, nb_preds)
621
+ nb_metrics.update(
622
+ {
623
+ "spearman": spearman_nb,
624
+ "pearson": pearson_nb,
625
+ "pearson_dispersion": pearson_dispersion(y_val, nb_preds, dof=dof),
626
+ "variance_to_mean": variance_to_mean,
627
+ "alpha": nb_model.alpha,
628
+ "fold": fold_idx,
629
+ "model": "negative_binomial",
630
+ "status": "ok",
631
+ }
632
+ )
633
+ metrics_rows.append(nb_metrics)
634
+
635
+ records = list(val_split.itertuples(index=False))
636
+ for idx, row in enumerate(records):
637
+ row_id = getattr(row, "id", idx)
638
+ y_true_val = float(getattr(row, "reactivity_count"))
639
+ preds_rows.append(
640
+ {
641
+ "fold": fold_idx,
642
+ "model": "poisson",
643
+ "id": row_id,
644
+ "y_true": y_true_val,
645
+ "y_pred": float(poisson_preds[idx]),
646
+ }
647
+ )
648
+ if nb_preds is not None:
649
+ preds_rows.append(
650
+ {
651
+ "fold": fold_idx,
652
+ "model": "negative_binomial",
653
+ "id": row_id,
654
+ "y_true": y_true_val,
655
+ "y_pred": float(nb_preds[idx]),
656
+ }
657
+ )
658
+
659
+ metrics_df = pd.DataFrame(metrics_rows)
660
+ metrics_df.to_csv(output_dir / "flag_regression_folds.csv", index=False)
661
+
662
+ summary_records: list[dict[str, float]] = []
663
+ for model_name, group in metrics_df.groupby("model"):
664
+ for column in group.columns:
665
+ if column in {"fold", "model", "status"}:
666
+ continue
667
+ values = group[column].dropna()
668
+ if values.empty:
669
+ continue
670
+ summary_records.append(
671
+ {
672
+ "model": model_name,
673
+ "metric": column,
674
+ "mean": float(values.mean()),
675
+ "std": float(values.std(ddof=1)) if len(values) > 1 else float("nan"),
676
+ }
677
+ )
678
+ if summary_records:
679
+ pd.DataFrame(summary_records).to_csv(
680
+ output_dir / "flag_regression_metrics.csv", index=False
681
+ )
682
+
683
+ if preds_rows:
684
+ pd.DataFrame(preds_rows).to_csv(output_dir / "flag_regression_preds.csv", index=False)
685
+
686
+ def run_descriptor_variants(
687
+ base_config: Dict[str, Any],
688
+ *,
689
+ train_path: Path,
690
+ eval_specs: Sequence[DatasetSpec],
691
+ output_dir: Path,
692
+ batch_size: int,
693
+ include_species: List[str] | None,
694
+ include_families: List[str] | None,
695
+ bootstrap_samples: int,
696
+ bootstrap_alpha: float,
697
+ ) -> None:
698
+ variants = [
699
+ (
700
+ "descriptors_full_vh",
701
+ {
702
+ "feature_backend": {"type": "descriptors"},
703
+ "descriptors": {
704
+ "use_anarci": True,
705
+ "regions": ["CDRH1", "CDRH2", "CDRH3"],
706
+ "features": [
707
+ "length",
708
+ "charge",
709
+ "hydropathy",
710
+ "aromaticity",
711
+ "pI",
712
+ "net_charge",
713
+ ],
714
+ },
715
+ },
716
+ ),
717
+ (
718
+ "descriptors_cdrh3_pi",
719
+ {
720
+ "feature_backend": {"type": "descriptors"},
721
+ "descriptors": {
722
+ "use_anarci": True,
723
+ "regions": ["CDRH3"],
724
+ "features": ["pI"],
725
+ },
726
+ },
727
+ ),
728
+ (
729
+ "descriptors_cdrh3_top5",
730
+ {
731
+ "feature_backend": {"type": "descriptors"},
732
+ "descriptors": {
733
+ "use_anarci": True,
734
+ "regions": ["CDRH3"],
735
+ "features": [
736
+ "pI",
737
+ "net_charge",
738
+ "charge",
739
+ "hydropathy",
740
+ "length",
741
+ ],
742
+ },
743
+ },
744
+ ),
745
+ ]
746
+
747
+ configs_dir = output_dir / "configs"
748
+ configs_dir.mkdir(parents=True, exist_ok=True)
749
+ summary_records: list[dict[str, Any]] = []
750
+
751
+ for name, overrides in variants:
752
+ variant_config_path = _write_variant_config(
753
+ base_config,
754
+ overrides,
755
+ configs_dir / f"{name}.yaml",
756
+ )
757
+ variant_output = output_dir / name
758
+ variant_output.mkdir(parents=True, exist_ok=True)
759
+ model_path = variant_output / "model.joblib"
760
+
761
+ run_train(
762
+ train_path=train_path,
763
+ eval_specs=eval_specs,
764
+ output_dir=variant_output,
765
+ model_path=model_path,
766
+ config=str(variant_config_path),
767
+ batch_size=batch_size,
768
+ include_species=include_species,
769
+ include_families=include_families,
770
+ keep_duplicates=True,
771
+ group_column="lineage",
772
+ train_loader="boughter",
773
+ bootstrap_samples=bootstrap_samples,
774
+ bootstrap_alpha=bootstrap_alpha,
775
+ )
776
+
777
+ metrics_path = variant_output / "metrics.csv"
778
+ if metrics_path.exists():
779
+ metrics_df = pd.read_csv(metrics_path)
780
+ summary_records.extend(_collect_metric_records(name, metrics_df))
781
+
782
+ _dump_coefficients(model_path, variant_output / "coefficients.csv")
783
+
784
+ if summary_records:
785
+ pd.DataFrame(summary_records).to_csv(output_dir / "summary.csv", index=False)
786
+
787
+
788
+ def run_fragment_variants(
789
+ config_path: str,
790
+ *,
791
+ train_path: Path,
792
+ eval_specs: Sequence[DatasetSpec],
793
+ output_dir: Path,
794
+ batch_size: int,
795
+ include_species: List[str] | None,
796
+ include_families: List[str] | None,
797
+ bootstrap_samples: int,
798
+ bootstrap_alpha: float,
799
+ ) -> None:
800
+ numberer = AnarciNumberer()
801
+ specs = [
802
+ ("vh_full", ["VH"]),
803
+ ("cdrh1", ["CDRH1"]),
804
+ ("cdrh2", ["CDRH2"]),
805
+ ("cdrh3", ["CDRH3"]),
806
+ ("cdrh123", ["CDRH1", "CDRH2", "CDRH3"]),
807
+ ]
808
+
809
+ summary_rows: list[dict[str, Any]] = []
810
+ metric_summary_rows: list[dict[str, Any]] = []
811
+
812
+ for name, regions in specs:
813
+ variant_dir = output_dir / name
814
+ variant_dir.mkdir(parents=True, exist_ok=True)
815
+ dataset_dir = variant_dir / "datasets"
816
+ dataset_dir.mkdir(parents=True, exist_ok=True)
817
+
818
+ train_df = pd.read_csv(train_path)
819
+ train_variant, train_summary = _make_region_dataset(train_df, regions, numberer)
820
+ train_variant_path = dataset_dir / "train.csv"
821
+ train_variant.to_csv(train_variant_path, index=False)
822
+
823
+ eval_variant_specs: list[DatasetSpec] = []
824
+ for spec in eval_specs:
825
+ eval_df = pd.read_csv(spec.path)
826
+ transformed, eval_summary = _make_region_dataset(eval_df, regions, numberer)
827
+ eval_path = dataset_dir / f"{spec.name}.csv"
828
+ transformed.to_csv(eval_path, index=False)
829
+ eval_variant_specs.append(
830
+ DatasetSpec(name=spec.name, path=eval_path, display=spec.display)
831
+ )
832
+ eval_summary.update({"variant": name, "dataset": spec.name})
833
+ summary_rows.append(eval_summary)
834
+
835
+ train_summary.update({"variant": name, "dataset": "train"})
836
+ summary_rows.append(train_summary)
837
+
838
+ run_train(
839
+ train_path=train_variant_path,
840
+ eval_specs=eval_variant_specs,
841
+ output_dir=variant_dir,
842
+ model_path=variant_dir / "model.joblib",
843
+ config=config_path,
844
+ batch_size=batch_size,
845
+ include_species=include_species,
846
+ include_families=include_families,
847
+ keep_duplicates=True,
848
+ group_column="lineage",
849
+ train_loader="boughter",
850
+ bootstrap_samples=bootstrap_samples,
851
+ bootstrap_alpha=bootstrap_alpha,
852
+ )
853
+
854
+ metrics_path = variant_dir / "metrics.csv"
855
+ if metrics_path.exists():
856
+ metrics_df = pd.read_csv(metrics_path)
857
+ metric_records = _collect_metric_records(name, metrics_df)
858
+ for record in metric_records:
859
+ record["variant_type"] = "fragment"
860
+ metric_summary_rows.extend(metric_records)
861
+
862
+ if summary_rows:
863
+ pd.DataFrame(summary_rows).to_csv(output_dir / "fragment_dataset_summary.csv", index=False)
864
+ if metric_summary_rows:
865
+ pd.DataFrame(metric_summary_rows).to_csv(output_dir / "fragment_metrics_summary.csv", index=False)
866
+
867
+
868
+ def main(argv: list[str] | None = None) -> int:
869
+ parser = build_parser()
870
+ args = parser.parse_args(argv)
871
+
872
+ output_dir = Path(args.output_dir)
873
+ output_dir.mkdir(parents=True, exist_ok=True)
874
+
875
+ if args.rebuild:
876
+ rebuild_cmd = [
877
+ "python",
878
+ "scripts/rebuild_boughter_from_counts.py",
879
+ "--output",
880
+ str(args.train_data),
881
+ ]
882
+ if subprocess.run(rebuild_cmd, check=False).returncode != 0:
883
+ raise RuntimeError("Dataset rebuild failed")
884
+
885
+ train_path = Path(args.train_data)
886
+
887
+ def _make_spec(name: str, path_str: str) -> DatasetSpec | None:
888
+ path = Path(path_str)
889
+ if not path.exists():
890
+ return None
891
+ display = DISPLAY_LABELS.get(name, name.replace("_", " ").title())
892
+ return DatasetSpec(name=name, path=path, display=display)
893
+
894
+ eval_specs: list[DatasetSpec] = []
895
+ seen_paths: set[Path] = set()
896
+ for name, path_str in [
897
+ ("jain", args.jain),
898
+ ("shehata", args.shehata),
899
+ ("shehata_curated", args.shehata_curated),
900
+ ("harvey", args.harvey),
901
+ ]:
902
+ spec = _make_spec(name, path_str)
903
+ if spec is not None:
904
+ resolved = spec.path.resolve()
905
+ if resolved in seen_paths:
906
+ continue
907
+ seen_paths.add(resolved)
908
+ eval_specs.append(spec)
909
+
910
+ base_config = load_config(args.config)
911
+ base_config_dict = _config_to_dict(base_config)
912
+
913
+ main_output = output_dir / "main"
914
+ main_output.mkdir(parents=True, exist_ok=True)
915
+ model_path = main_output / "model.joblib"
916
+
917
+ main_include_species = ["human"] if args.human_only else None
918
+ main_include_families = ["hiv", "influenza"] if args.human_only else None
919
+
920
+ run_train(
921
+ train_path=train_path,
922
+ eval_specs=eval_specs,
923
+ output_dir=main_output,
924
+ model_path=model_path,
925
+ config=args.config,
926
+ batch_size=args.batch_size,
927
+ include_species=main_include_species,
928
+ include_families=main_include_families,
929
+ keep_duplicates=True,
930
+ group_column="lineage",
931
+ train_loader="boughter",
932
+ bootstrap_samples=args.bootstrap_samples,
933
+ bootstrap_alpha=args.bootstrap_alpha,
934
+ )
935
+
936
+ metrics = pd.read_csv(main_output / "metrics.csv")
937
+ preds = pd.read_csv(main_output / "preds.csv")
938
+
939
+ plot_accuracy(metrics, main_output / "accuracy_overview.png", eval_specs)
940
+ plot_rocs(preds, main_output / "roc_overview.png", eval_specs)
941
+
942
+ if not args.skip_flag_regression:
943
+ run_flag_regression(
944
+ train_path=train_path,
945
+ output_dir=main_output,
946
+ config_path=args.config,
947
+ batch_size=args.batch_size,
948
+ )
949
+
950
+ split_summary = _summarise_predictions(preds)
951
+ split_summary.to_csv(main_output / "dataset_split_summary.csv", index=False)
952
+
953
+ spearman_stat, spearman_p, corr_df = compute_spearman(
954
+ model_path=model_path,
955
+ dataset_path=Path(args.full_data),
956
+ batch_size=args.batch_size,
957
+ )
958
+ plot_flags_scatter(corr_df, spearman_stat, main_output / "prob_vs_flags.png")
959
+ (main_output / "spearman_flags.json").write_text(
960
+ json.dumps({"spearman": spearman_stat, "p_value": spearman_p}, indent=2)
961
+ )
962
+ corr_df.to_csv(main_output / "prob_vs_flags.csv", index=False)
963
+
964
+ if not args.skip_lofo:
965
+ full_df = pd.read_csv(args.train_data)
966
+ lofo_dir = output_dir / "lofo_runs"
967
+ lofo_dir.mkdir(parents=True, exist_ok=True)
968
+ lofo_df = run_lofo(
969
+ full_df,
970
+ families=["influenza", "hiv", "mouse_iga"],
971
+ config=args.config,
972
+ batch_size=args.batch_size,
973
+ output_dir=lofo_dir,
974
+ bootstrap_samples=args.bootstrap_samples,
975
+ bootstrap_alpha=args.bootstrap_alpha,
976
+ )
977
+ lofo_df.to_csv(output_dir / "lofo_metrics.csv", index=False)
978
+
979
+ if not args.skip_descriptor_variants:
980
+ descriptor_dir = output_dir / "descriptor_variants"
981
+ descriptor_dir.mkdir(parents=True, exist_ok=True)
982
+ run_descriptor_variants(
983
+ base_config_dict,
984
+ train_path=train_path,
985
+ eval_specs=eval_specs,
986
+ output_dir=descriptor_dir,
987
+ batch_size=args.batch_size,
988
+ include_species=main_include_species,
989
+ include_families=main_include_families,
990
+ bootstrap_samples=args.bootstrap_samples,
991
+ bootstrap_alpha=args.bootstrap_alpha,
992
+ )
993
+
994
+ if not args.skip_fragment_variants:
995
+ fragment_dir = output_dir / "fragment_variants"
996
+ fragment_dir.mkdir(parents=True, exist_ok=True)
997
+ run_fragment_variants(
998
+ args.config,
999
+ train_path=train_path,
1000
+ eval_specs=eval_specs,
1001
+ output_dir=fragment_dir,
1002
+ batch_size=args.batch_size,
1003
+ include_species=main_include_species,
1004
+ include_families=main_include_families,
1005
+ bootstrap_samples=args.bootstrap_samples,
1006
+ bootstrap_alpha=args.bootstrap_alpha,
1007
+ )
1008
+
1009
+ raw_summaries = []
1010
+ raw_summaries.append(_summarise_raw_dataset(train_path, "boughter_rebuilt"))
1011
+ for spec in eval_specs:
1012
+ summary_name = RAW_LABELS.get(spec.name, spec.name)
1013
+ raw_summaries.append(_summarise_raw_dataset(spec.path, summary_name))
1014
+ pd.DataFrame(raw_summaries).to_csv(output_dir / "raw_dataset_summary.csv", index=False)
1015
+
1016
+ return 0
1017
+
1018
+
1019
+ if __name__ == "__main__":
1020
+ raise SystemExit(main())
polyreact/benchmarks/run_benchmarks.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run end-to-end benchmarks for the polyreactivity model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ from .. import train as train_cli
10
+
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[2]
12
+ DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv"
13
+ DEFAULT_EVAL = [
14
+ PROJECT_ROOT / "tests" / "fixtures" / "jain.csv",
15
+ PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv",
16
+ PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv",
17
+ ]
18
+
19
+
20
+ def build_parser() -> argparse.ArgumentParser:
21
+ parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks")
22
+ parser.add_argument(
23
+ "--config",
24
+ default="configs/default.yaml",
25
+ help="Path to configuration YAML file.",
26
+ )
27
+ parser.add_argument(
28
+ "--train",
29
+ default=str(DEFAULT_TRAIN),
30
+ help="Training dataset CSV (defaults to bundled fixture).",
31
+ )
32
+ parser.add_argument(
33
+ "--eval",
34
+ nargs="+",
35
+ default=[str(path) for path in DEFAULT_EVAL],
36
+ help="Evaluation dataset CSV paths (>=1).",
37
+ )
38
+ parser.add_argument(
39
+ "--report-dir",
40
+ default="artifacts",
41
+ help="Directory to write metrics, predictions, and plots.",
42
+ )
43
+ parser.add_argument(
44
+ "--model-path",
45
+ default="artifacts/model.joblib",
46
+ help="Destination for the trained model artifact.",
47
+ )
48
+ parser.add_argument(
49
+ "--backend",
50
+ choices=["descriptors", "plm", "concat"],
51
+ help="Override feature backend during training.",
52
+ )
53
+ parser.add_argument("--plm-model", help="Optional PLM model override.")
54
+ parser.add_argument("--cache-dir", help="Embedding cache directory override.")
55
+ parser.add_argument(
56
+ "--device",
57
+ choices=["auto", "cpu", "cuda"],
58
+ help="Device override for embeddings.",
59
+ )
60
+ parser.add_argument(
61
+ "--paired",
62
+ action="store_true",
63
+ help="Use paired heavy/light chains when available.",
64
+ )
65
+ parser.add_argument(
66
+ "--batch-size",
67
+ type=int,
68
+ default=8,
69
+ help="Batch size for PLM embedding batches.",
70
+ )
71
+ return parser
72
+
73
+
74
+ def main(argv: List[str] | None = None) -> int:
75
+ parser = build_parser()
76
+ args = parser.parse_args(argv)
77
+
78
+ if len(args.eval) < 1:
79
+ parser.error("Provide at least one evaluation dataset via --eval.")
80
+
81
+ report_dir = Path(args.report_dir)
82
+ report_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ train_args: list[str] = [
85
+ "--config",
86
+ args.config,
87
+ "--train",
88
+ args.train,
89
+ "--save-to",
90
+ str(Path(args.model_path)),
91
+ "--report-to",
92
+ str(report_dir),
93
+ "--batch-size",
94
+ str(args.batch_size),
95
+ ]
96
+
97
+ train_args.extend(["--eval", *args.eval])
98
+
99
+ if args.backend:
100
+ train_args.extend(["--backend", args.backend])
101
+ if args.plm_model:
102
+ train_args.extend(["--plm-model", args.plm_model])
103
+ if args.cache_dir:
104
+ train_args.extend(["--cache-dir", args.cache_dir])
105
+ if args.device:
106
+ train_args.extend(["--device", args.device])
107
+ if args.paired:
108
+ train_args.append("--paired")
109
+
110
+ return train_cli.main(train_args)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ raise SystemExit(main())
polyreact/config.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration helpers for the polyreactivity project."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict, dataclass, field
6
+ import importlib.resources as pkg_resources
7
+ from importlib.resources.abc import Traversable
8
+ from pathlib import Path
9
+ from typing import Any, Sequence
10
+
11
+ import yaml
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class FeatureBackendSettings:
16
+ type: str = "plm"
17
+ plm_model_name: str = "facebook/esm2_t12_35M_UR50D"
18
+ layer_pool: str = "mean"
19
+ cache_dir: str = ".cache/embeddings"
20
+ standardize: bool = True
21
+
22
+
23
+ @dataclass(slots=True)
24
+ class DescriptorSettings:
25
+ use_anarci: bool = True
26
+ regions: Sequence[str] = field(default_factory=lambda: ["CDRH1", "CDRH2", "CDRH3"])
27
+ features: Sequence[str] = field(
28
+ default_factory=lambda: [
29
+ "length",
30
+ "charge",
31
+ "hydropathy",
32
+ "aromaticity",
33
+ "pI",
34
+ "net_charge",
35
+ ]
36
+ )
37
+ ph: float = 7.4
38
+
39
+
40
+ @dataclass(slots=True)
41
+ class ModelSettings:
42
+ head: str = "logreg"
43
+ C: float = 1.0
44
+ class_weight: Any = "balanced"
45
+
46
+
47
+ @dataclass(slots=True)
48
+ class CalibrationSettings:
49
+ method: str | None = "isotonic"
50
+
51
+
52
+ @dataclass(slots=True)
53
+ class TrainingSettings:
54
+ cv_folds: int = 10
55
+ scoring: str = "roc_auc"
56
+ n_jobs: int = -1
57
+
58
+
59
+ @dataclass(slots=True)
60
+ class IOSettings:
61
+ outputs_dir: str = "artifacts"
62
+ preds_filename: str = "preds.csv"
63
+ metrics_filename: str = "metrics.csv"
64
+
65
+
66
+ @dataclass(slots=True)
67
+ class Config:
68
+ seed: int = 42
69
+ device: str = "auto"
70
+ feature_backend: FeatureBackendSettings = field(default_factory=FeatureBackendSettings)
71
+ descriptors: DescriptorSettings = field(default_factory=DescriptorSettings)
72
+ model: ModelSettings = field(default_factory=ModelSettings)
73
+ calibration: CalibrationSettings = field(default_factory=CalibrationSettings)
74
+ training: TrainingSettings = field(default_factory=TrainingSettings)
75
+ io: IOSettings = field(default_factory=IOSettings)
76
+
77
+ raw: dict[str, Any] = field(default_factory=dict)
78
+
79
+
80
+ def _merge_section(default: Any, data: dict[str, Any] | None) -> Any:
81
+ if data is None:
82
+ return default
83
+ merged = asdict(default) | data
84
+ return type(default)(**merged)
85
+
86
+
87
+ def load_config(path: str | Path | None = None) -> Config:
88
+ """Load a YAML configuration file into a strongly-typed ``Config`` object."""
89
+
90
+ data = _read_config_data(path)
91
+
92
+ feature_backend = _merge_section(FeatureBackendSettings(), data.get("feature_backend"))
93
+ descriptors = _merge_section(DescriptorSettings(), data.get("descriptors"))
94
+ model = _merge_section(ModelSettings(), data.get("model"))
95
+ calibration = _merge_section(CalibrationSettings(), data.get("calibration"))
96
+ training = _merge_section(TrainingSettings(), data.get("training"))
97
+ io_settings = _merge_section(IOSettings(), data.get("io"))
98
+
99
+ config = Config(
100
+ seed=int(data.get("seed", 42)),
101
+ device=str(data.get("device", "auto")),
102
+ feature_backend=feature_backend,
103
+ descriptors=descriptors,
104
+ model=model,
105
+ calibration=calibration,
106
+ training=training,
107
+ io=io_settings,
108
+ raw=data,
109
+ )
110
+ return config
111
+
112
+
113
+ def _read_config_data(path: str | Path | None) -> dict[str, Any]:
114
+ """Return mapping data from YAML or the bundled default."""
115
+
116
+ if path is None:
117
+ resource = pkg_resources.files("polyreact.configs") / "default.yaml"
118
+ return _load_yaml_resource(resource)
119
+
120
+ resolved = _resolve_config_path(Path(path))
121
+ if resolved is not None:
122
+ return _load_yaml_file(resolved)
123
+
124
+ resource_root = pkg_resources.files("polyreact")
125
+ resource = resource_root / Path(path).as_posix()
126
+ if resource.is_file():
127
+ return _load_yaml_resource(resource)
128
+
129
+ msg = f"Configuration file not found: {path}"
130
+ raise FileNotFoundError(msg)
131
+
132
+
133
+ def _resolve_config_path(path: Path) -> Path | None:
134
+ if path.exists():
135
+ return path
136
+
137
+ if not path.is_absolute():
138
+ candidate = Path(__file__).resolve().parent / path
139
+ if candidate.exists():
140
+ return candidate
141
+
142
+ return None
143
+
144
+
145
+ def _load_yaml_file(path: Path) -> dict[str, Any]:
146
+ with path.open("r", encoding="utf-8") as handle:
147
+ return _parse_yaml(handle.read())
148
+
149
+
150
+ def _load_yaml_resource(resource: Traversable) -> dict[str, Any]:
151
+ with resource.open("r", encoding="utf-8") as handle:
152
+ return _parse_yaml(handle.read())
153
+
154
+
155
+ def _parse_yaml(text: str) -> dict[str, Any]:
156
+ parsed = yaml.safe_load(text) or {}
157
+ if not isinstance(parsed, dict): # pragma: no cover - safeguard
158
+ msg = "Configuration must be a mapping at the top level"
159
+ raise ValueError(msg)
160
+ return parsed
polyreact/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Configuration package data."""
polyreact/configs/default.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 42
2
+ device: "auto"
3
+ feature_backend:
4
+ type: "plm"
5
+ plm_model_name: "facebook/esm1v_t33_650M_UR90S_1"
6
+ layer_pool: "mean"
7
+ cache_dir: ".cache/embeddings"
8
+ descriptors:
9
+ use_anarci: true
10
+ regions:
11
+ - "CDRH1"
12
+ - "CDRH2"
13
+ - "CDRH3"
14
+ features:
15
+ - "length"
16
+ - "charge"
17
+ - "hydropathy"
18
+ - "aromaticity"
19
+ - "pI"
20
+ - "net_charge"
21
+ model:
22
+ head: "logreg"
23
+ C: 0.1
24
+ class_weight: "balanced"
25
+ calibration:
26
+ method: "isotonic"
27
+ training:
28
+ cv_folds: 10
29
+ scoring: "roc_auc"
30
+ n_jobs: -1
31
+ io:
32
+ outputs_dir: "artifacts"
33
+ preds_filename: "preds.csv"
34
+ metrics_filename: "metrics.csv"
polyreact/data_loaders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Dataset loaders for polyreactivity benchmarks."""
2
+
3
+ __all__ = ["boughter", "jain", "shehata", "harvey", "utils"]
polyreact/data_loaders/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (297 Bytes). View file
 
polyreact/data_loaders/__pycache__/boughter.cpython-311.pyc ADDED
Binary file (3.68 kB). View file
 
polyreact/data_loaders/__pycache__/harvey.cpython-311.pyc ADDED
Binary file (1.26 kB). View file
 
polyreact/data_loaders/__pycache__/jain.cpython-311.pyc ADDED
Binary file (1.3 kB). View file
 
polyreact/data_loaders/__pycache__/shehata.cpython-311.pyc ADDED
Binary file (8.16 kB). View file
 
polyreact/data_loaders/__pycache__/utils.cpython-311.pyc ADDED
Binary file (8.51 kB). View file
 
polyreact/data_loaders/boughter.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader for the Boughter et al. 2020 dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from .utils import LOGGER, standardize_frame
11
+
12
+ _COLUMN_ALIASES = {
13
+ "id": ("sequence_id",),
14
+ "heavy_seq": ("heavy", "heavy_chain"),
15
+ "light_seq": ("light", "light_chain"),
16
+ "label": ("polyreactive",),
17
+ }
18
+
19
+
20
+ def _find_flag_columns(columns: Iterable[str]) -> list[str]:
21
+ flag_cols: list[str] = []
22
+ for column in columns:
23
+ normalized = column.lower().replace(" ", "")
24
+ if "flag" in normalized:
25
+ flag_cols.append(column)
26
+ return flag_cols
27
+
28
+
29
+ def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame:
30
+ if not flag_columns:
31
+ return frame
32
+
33
+ flag_values = (
34
+ frame[flag_columns]
35
+ .apply(pd.to_numeric, errors="coerce")
36
+ .fillna(0.0)
37
+ )
38
+ flag_binary = (flag_values > 0).astype(int)
39
+ flags_total = flag_binary.sum(axis=1)
40
+
41
+ specific_mask = flags_total == 0
42
+ nonspecific_mask = flags_total >= 4
43
+ keep_mask = specific_mask | nonspecific_mask
44
+
45
+ dropped = int((~keep_mask).sum())
46
+ if dropped:
47
+ LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped)
48
+
49
+ filtered = frame.loc[keep_mask].copy()
50
+ filtered["flags_total"] = flags_total.loc[keep_mask].astype(int)
51
+ filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0)
52
+ filtered["polyreactive"] = filtered["label"]
53
+ return filtered
54
+
55
+
56
+ def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
57
+ """Load the Boughter dataset into the canonical format."""
58
+
59
+ frame = pd.read_csv(path_or_url)
60
+ flag_columns = _find_flag_columns(frame.columns)
61
+ frame = _apply_flag_policy(frame, flag_columns)
62
+
63
+ label_series = frame.get("label")
64
+ if label_series is not None:
65
+ frame = frame[label_series.isin({0, 1})].copy()
66
+
67
+ standardized = standardize_frame(
68
+ frame,
69
+ source="boughter2020",
70
+ heavy_only=heavy_only,
71
+ column_aliases=_COLUMN_ALIASES,
72
+ is_test=False,
73
+ )
74
+ if "flags_total" in frame.columns and "flags_total" not in standardized.columns:
75
+ standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int)
76
+ return standardized
polyreact/data_loaders/harvey.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader for the Harvey et al. 2022 dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pandas as pd
6
+
7
+ from .utils import standardize_frame
8
+
9
+ _COLUMN_ALIASES = {
10
+ "id": ("id", "clone_id"),
11
+ "heavy_seq": ("heavy", "heavy_chain", "sequence"),
12
+ "light_seq": ("light", "light_chain"),
13
+ "label": ("polyreactive", "is_polyreactive"),
14
+ }
15
+
16
+ _LABEL_MAP = {
17
+ "polyreactive": 1,
18
+ "non-polyreactive": 0,
19
+ "positive": 1,
20
+ "negative": 0,
21
+ 1: 1,
22
+ 0: 0,
23
+ "1": 1,
24
+ "0": 0,
25
+ }
26
+
27
+
28
+ def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
29
+ """Load the Harvey dataset into the canonical format."""
30
+
31
+ frame = pd.read_csv(path_or_url)
32
+ return standardize_frame(
33
+ frame,
34
+ source="harvey2022",
35
+ heavy_only=heavy_only,
36
+ column_aliases=_COLUMN_ALIASES,
37
+ label_map=_LABEL_MAP,
38
+ is_test=True,
39
+ )
polyreact/data_loaders/jain.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader for the Jain et al. 2017 dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pandas as pd
6
+
7
+ from .utils import standardize_frame
8
+
9
+ _COLUMN_ALIASES = {
10
+ "id": ("id", "antibody_id"),
11
+ "heavy_seq": ("heavy", "heavy_sequence", "H_chain"),
12
+ "light_seq": ("light", "light_sequence", "L_chain"),
13
+ "label": ("class", "polyreactive"),
14
+ }
15
+
16
+ _LABEL_MAP = {
17
+ "polyreactive": 1,
18
+ "non-polyreactive": 0,
19
+ "reactive": 1,
20
+ "non-reactive": 0,
21
+ 1: 1,
22
+ 0: 0,
23
+ 1.0: 1,
24
+ 0.0: 0,
25
+ "1": 1,
26
+ "0": 0,
27
+ }
28
+
29
+
30
+ def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
31
+ """Load the Jain dataset into the canonical format."""
32
+
33
+ frame = pd.read_csv(path_or_url)
34
+ return standardize_frame(
35
+ frame,
36
+ source="jain2017",
37
+ heavy_only=heavy_only,
38
+ column_aliases=_COLUMN_ALIASES,
39
+ label_map=_LABEL_MAP,
40
+ is_test=True,
41
+ )
polyreact/data_loaders/shehata.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loader for the Shehata et al. (2019) PSR dataset."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Iterable
7
+
8
+ import pandas as pd
9
+
10
+ from .utils import standardize_frame
11
+
12
+ SHEHATA_SOURCE = "shehata2019"
13
+
14
+ _COLUMN_ALIASES = {
15
+ "id": (
16
+ "antibody_id",
17
+ "antibody",
18
+ "antibody name",
19
+ "antibody_name",
20
+ "sequence_name",
21
+ "Antibody Name",
22
+ ),
23
+ "heavy_seq": (
24
+ "heavy",
25
+ "heavy_chain",
26
+ "heavy aa",
27
+ "heavy_sequence",
28
+ "vh",
29
+ "vh_sequence",
30
+ "heavy chain aa",
31
+ "Heavy Chain AA",
32
+ ),
33
+ "light_seq": (
34
+ "light",
35
+ "light_chain",
36
+ "light aa",
37
+ "light_sequence",
38
+ "vl",
39
+ "vl_sequence",
40
+ "light chain aa",
41
+ "Light Chain AA",
42
+ ),
43
+ "label": (
44
+ "polyreactive",
45
+ "binding_class",
46
+ "binding class",
47
+ "psr_class",
48
+ "psr binding",
49
+ "psr classification",
50
+ "Binding class",
51
+ "Binding Class",
52
+ ),
53
+ }
54
+
55
+ _LABEL_MAP = {
56
+ "polyreactive": 1,
57
+ "non-polyreactive": 0,
58
+ "positive": 1,
59
+ "negative": 0,
60
+ "high": 1,
61
+ "low": 0,
62
+ "pos": 1,
63
+ "neg": 0,
64
+ 1: 1,
65
+ 0: 0,
66
+ 1.0: 1,
67
+ 0.0: 0,
68
+ "1": 1,
69
+ "0": 0,
70
+ }
71
+
72
+ _PSR_SCORE_ALIASES: tuple[str, ...] = (
73
+ "psr score",
74
+ "psr_score",
75
+ "psr overall score",
76
+ "overall score",
77
+ "psr z",
78
+ "psr_z",
79
+ )
80
+
81
+
82
+ def _clean_sequence(sequence: object) -> str:
83
+ if isinstance(sequence, str):
84
+ return "".join(sequence.split()).upper()
85
+ return ""
86
+
87
+
88
+ def _maybe_extract_psr_scores(frame: pd.DataFrame) -> pd.DataFrame:
89
+ scores: dict[str, pd.Series] = {}
90
+ for column in frame.columns:
91
+ lowered = column.strip().lower()
92
+ if any(alias in lowered for alias in _PSR_SCORE_ALIASES):
93
+ key = lowered.replace(" ", "_")
94
+ scores[key] = frame[column]
95
+ if not scores:
96
+ return pd.DataFrame(index=frame.index)
97
+ renamed = {}
98
+ for name, series in scores.items():
99
+ cleaned_name = name
100
+ for prefix in ("psr_", "overall_"):
101
+ if cleaned_name.startswith(prefix):
102
+ cleaned_name = cleaned_name[len(prefix) :]
103
+ break
104
+ cleaned_name = cleaned_name.replace("__", "_")
105
+ cleaned_name = cleaned_name.replace("(", "").replace(")", "")
106
+ cleaned_name = cleaned_name.replace("-", "_")
107
+ renamed[f"psr_{cleaned_name}"] = pd.to_numeric(series, errors="coerce")
108
+ return pd.DataFrame(renamed)
109
+
110
+
111
+ def _pick_source_label(path: Path | None) -> str:
112
+ if path is None:
113
+ return SHEHATA_SOURCE
114
+ stem = path.stem.lower()
115
+ if "curated" in stem or "subset" in stem:
116
+ return f"{SHEHATA_SOURCE}_curated"
117
+ return SHEHATA_SOURCE
118
+
119
+
120
+ def _standardize(
121
+ frame: pd.DataFrame,
122
+ *,
123
+ heavy_only: bool,
124
+ source: str,
125
+ ) -> pd.DataFrame:
126
+ standardized = standardize_frame(
127
+ frame,
128
+ source=source,
129
+ heavy_only=heavy_only,
130
+ column_aliases=_COLUMN_ALIASES,
131
+ label_map=_LABEL_MAP,
132
+ is_test=True,
133
+ )
134
+
135
+ psr_scores = _maybe_extract_psr_scores(frame)
136
+
137
+ mask = standardized["heavy_seq"].map(_clean_sequence) != ""
138
+ standardized = standardized.loc[mask].copy()
139
+ standardized.reset_index(drop=True, inplace=True)
140
+ standardized["heavy_seq"] = standardized["heavy_seq"].map(_clean_sequence)
141
+ standardized["light_seq"] = standardized["light_seq"].map(_clean_sequence)
142
+
143
+ if not psr_scores.empty:
144
+ psr_scores = psr_scores.loc[mask]
145
+ psr_scores = psr_scores.reset_index(drop=True)
146
+ for column in psr_scores.columns:
147
+ standardized[column] = psr_scores[column].reset_index(drop=True)
148
+
149
+ return standardized
150
+
151
+
152
+ def _read_excel(path: Path, *, heavy_only: bool) -> pd.DataFrame:
153
+ excel = pd.ExcelFile(path, engine="openpyxl")
154
+ sheet_candidates: Iterable[str] = excel.sheet_names
155
+
156
+ def _score(name: str) -> tuple[int, str]:
157
+ lowered = name.lower()
158
+ priority = 0
159
+ if "psr" in lowered or "polyreactivity" in lowered:
160
+ priority = 2
161
+ elif "sheet" not in lowered:
162
+ priority = 1
163
+ return (-priority, name)
164
+
165
+ sheet_name = sorted(sheet_candidates, key=_score)[0]
166
+ raw = excel.parse(sheet_name)
167
+ raw = raw.dropna(how="all")
168
+ return _standardize(raw, heavy_only=heavy_only, source=_pick_source_label(path))
169
+
170
+
171
+ def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
172
+ """Load the Shehata dataset into the canonical format.
173
+
174
+ Supports both pre-processed CSV exports and the original Excel supplement
175
+ (*.xls/*.xlsx). Additional PSR score columns are preserved when available.
176
+ """
177
+
178
+ lower = path_or_url.lower()
179
+ source_override: str | None = None
180
+ if lower.startswith("http://") or lower.startswith("https://"):
181
+ if lower.endswith((".xls", ".xlsx")):
182
+ raw = pd.read_excel(path_or_url, engine="openpyxl")
183
+ return _standardize(raw, heavy_only=heavy_only, source=SHEHATA_SOURCE)
184
+ frame = pd.read_csv(path_or_url)
185
+ return _standardize(frame, heavy_only=heavy_only, source=SHEHATA_SOURCE)
186
+
187
+ path = Path(path_or_url)
188
+ source_override = _pick_source_label(path)
189
+ if path.suffix.lower() in {".xls", ".xlsx"}:
190
+ engine = "openpyxl" if path.suffix.lower() == ".xlsx" else None
191
+ if engine:
192
+ frame = _read_excel(path, heavy_only=heavy_only)
193
+ else:
194
+ frame = pd.read_excel(path, engine=None)
195
+ frame = _standardize(frame, heavy_only=heavy_only, source=source_override)
196
+ frame["source"] = source_override
197
+ return frame
198
+
199
+ frame = pd.read_csv(path)
200
+ standardized = _standardize(frame, heavy_only=heavy_only, source=source_override)
201
+ standardized["source"] = source_override
202
+ return standardized
polyreact/data_loaders/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility helpers for dataset loading."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Iterable, Sequence
7
+
8
+ import pandas as pd
9
+
10
+ EXPECTED_COLUMNS = ("id", "heavy_seq", "light_seq", "label")
11
+ OPTIONAL_COLUMNS = ("source", "is_test")
12
+
13
+ LOGGER = logging.getLogger("polyreact.data")
14
+
15
+ _DEFAULT_ALIASES: dict[str, Sequence[str]] = {
16
+ "id": ("id", "sequence_id", "antibody_id", "uid"),
17
+ "heavy_seq": ("heavy_seq", "heavy", "heavy_chain", "H", "H_chain"),
18
+ "light_seq": ("light_seq", "light", "light_chain", "L", "L_chain"),
19
+ "label": ("label", "polyreactive", "is_polyreactive", "class", "target"),
20
+ }
21
+
22
+ DEFAULT_LABEL_MAP: dict[str | int | float | bool, int] = {
23
+ 1: 1,
24
+ 0: 0,
25
+ "1": 1,
26
+ "0": 0,
27
+ True: 1,
28
+ False: 0,
29
+ "true": 1,
30
+ "false": 0,
31
+ "polyreactive": 1,
32
+ "non-polyreactive": 0,
33
+ "poly": 1,
34
+ "non": 0,
35
+ "positive": 1,
36
+ "negative": 0,
37
+ }
38
+
39
+
40
+ def _normalize_label_key(value: object) -> object:
41
+ if isinstance(value, str):
42
+ trimmed = value.strip().lower()
43
+ if trimmed in {
44
+ "polyreactive",
45
+ "non-polyreactive",
46
+ "poly",
47
+ "non",
48
+ "positive",
49
+ "negative",
50
+ "high",
51
+ "low",
52
+ "pos",
53
+ "neg",
54
+ "1",
55
+ "0",
56
+ "true",
57
+ "false",
58
+ }:
59
+ return trimmed
60
+ if trimmed.isdigit():
61
+ return trimmed
62
+ return value
63
+
64
+
65
+ def ensure_columns(frame: pd.DataFrame, *, heavy_only: bool = True) -> pd.DataFrame:
66
+ """Validate and coerce dataframe columns to the canonical format."""
67
+
68
+ frame = frame.copy()
69
+ for column in ("id", "heavy_seq", "label"):
70
+ if column not in frame.columns:
71
+ msg = f"Required column '{column}' missing from dataframe"
72
+ raise KeyError(msg)
73
+
74
+ if "light_seq" not in frame.columns:
75
+ frame["light_seq"] = ""
76
+
77
+ if heavy_only:
78
+ frame["light_seq"] = ""
79
+
80
+ frame["id"] = frame["id"].astype(str)
81
+ frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
82
+ frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
83
+ frame["label"] = frame["label"].astype(int)
84
+
85
+ ordered = list(EXPECTED_COLUMNS) + [
86
+ col for col in frame.columns if col not in EXPECTED_COLUMNS
87
+ ]
88
+ return frame[ordered]
89
+
90
+
91
+ def standardize_frame(
92
+ frame: pd.DataFrame,
93
+ *,
94
+ source: str,
95
+ heavy_only: bool = True,
96
+ column_aliases: dict[str, Sequence[str]] | None = None,
97
+ label_map: dict[str | int | float | bool, int] | None = None,
98
+ is_test: bool | None = None,
99
+ ) -> pd.DataFrame:
100
+ """Rename columns using aliases and coerce labels to integers."""
101
+
102
+ aliases = {**_DEFAULT_ALIASES}
103
+ if column_aliases:
104
+ for key, values in column_aliases.items():
105
+ aliases[key] = tuple(values) + tuple(aliases.get(key, ()))
106
+
107
+ rename_map: dict[str, str] = {}
108
+ for target, candidates in aliases.items():
109
+ if target in frame.columns:
110
+ continue
111
+ for candidate in candidates:
112
+ if candidate in frame.columns and candidate not in rename_map:
113
+ rename_map[candidate] = target
114
+ break
115
+
116
+ normalized = frame.rename(columns=rename_map).copy()
117
+
118
+ if "light_seq" not in normalized.columns:
119
+ normalized["light_seq"] = ""
120
+
121
+ label_lookup = label_map or DEFAULT_LABEL_MAP
122
+ normalized["label"] = normalized["label"].map(lambda x: label_lookup.get(_normalize_label_key(x)))
123
+
124
+ if normalized["label"].isnull().any():
125
+ msg = "Label column contains unmapped or missing values"
126
+ raise ValueError(msg)
127
+
128
+ normalized["source"] = source
129
+ if is_test is not None:
130
+ normalized["is_test"] = bool(is_test)
131
+
132
+ normalized = ensure_columns(normalized, heavy_only=heavy_only)
133
+ return normalized
134
+
135
+
136
+ def deduplicate_sequences(
137
+ frames: Iterable[pd.DataFrame],
138
+ *,
139
+ heavy_only: bool = True,
140
+ key_columns: Sequence[str] | None = None,
141
+ keep_intra_frames: set[int] | None = None,
142
+ ) -> list[pd.DataFrame]:
143
+ """Remove duplicate entries across multiple dataframes with configurable keys."""
144
+
145
+ if key_columns is None:
146
+ key_columns = ["heavy_seq"] if heavy_only else ["heavy_seq", "light_seq"]
147
+ keep_intra_frames = keep_intra_frames or set()
148
+
149
+ seen: set[tuple[str, ...]] = set()
150
+ cleaned: list[pd.DataFrame] = []
151
+
152
+ for frame_idx, frame in enumerate(frames):
153
+ valid_columns = [col for col in key_columns if col in frame.columns]
154
+ if not valid_columns:
155
+ valid_columns = ["heavy_seq"]
156
+
157
+ mask: list[bool] = []
158
+ frame_seen: set[tuple[str, ...]] = set()
159
+ allow_intra = frame_idx in keep_intra_frames
160
+
161
+ for values in frame[valid_columns].itertuples(index=False, name=None):
162
+ key = tuple(_normalise_key_value(value) for value in values)
163
+ if key in seen:
164
+ mask.append(False)
165
+ continue
166
+ if not allow_intra and key in frame_seen:
167
+ mask.append(False)
168
+ continue
169
+ mask.append(True)
170
+ frame_seen.add(key)
171
+ seen.update(frame_seen)
172
+ filtered = frame.loc[mask].reset_index(drop=True)
173
+ removed = len(frame) - len(filtered)
174
+ if removed:
175
+ dataset = "<unknown>"
176
+ if "source" in frame.columns and not frame["source"].empty:
177
+ dataset = str(frame["source"].iloc[0])
178
+ LOGGER.info("Removed %s duplicate sequences from %s", removed, dataset)
179
+ cleaned.append(filtered)
180
+ return cleaned
181
+
182
+
183
+ def _normalise_key_value(value: object) -> str:
184
+ if value is None or (isinstance(value, float) and pd.isna(value)):
185
+ return ""
186
+ return str(value).strip()
polyreact/features/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature backends for polyreactivity prediction."""
2
+
3
+ from . import anarsi, descriptors, plm
4
+ from .pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
5
+
6
+ __all__ = [
7
+ "anarsi",
8
+ "descriptors",
9
+ "plm",
10
+ "FeaturePipeline",
11
+ "FeaturePipelineState",
12
+ "build_feature_pipeline",
13
+ ]
polyreact/features/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (530 Bytes). View file
 
polyreact/features/__pycache__/anarsi.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
polyreact/features/__pycache__/descriptors.cpython-311.pyc ADDED
Binary file (9.22 kB). View file
 
polyreact/features/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (17.1 kB). View file
 
polyreact/features/__pycache__/plm.cpython-311.pyc ADDED
Binary file (23 kB). View file
 
polyreact/features/anarsi.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ANARCI/ANARCII numbering helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import Dict, List, Sequence, Tuple
8
+
9
+ try:
10
+ from anarcii.pipeline import Anarcii # type: ignore
11
+ except ImportError: # pragma: no cover - optional dependency
12
+ Anarcii = None
13
+
14
+
15
+ @dataclass(slots=True)
16
+ class NumberedResidue:
17
+ """Single residue with IMGT numbering metadata."""
18
+
19
+ position: int
20
+ insertion: str
21
+ amino_acid: str
22
+
23
+
24
+ @dataclass(slots=True)
25
+ class NumberedSequence:
26
+ """Container for numbering results and derived regions."""
27
+
28
+ sequence: str
29
+ scheme: str
30
+ chain_type: str
31
+ residues: list[NumberedResidue]
32
+ regions: dict[str, str]
33
+
34
+
35
+ _IMGT_HEAVY_REGIONS: Sequence[Tuple[str, int, int]] = (
36
+ ("FR1", 1, 26),
37
+ ("CDRH1", 27, 38),
38
+ ("FR2", 39, 55),
39
+ ("CDRH2", 56, 65),
40
+ ("FR3", 66, 104),
41
+ ("CDRH3", 105, 117),
42
+ ("FR4", 118, 128),
43
+ )
44
+
45
+ _IMGT_LIGHT_REGIONS: Sequence[Tuple[str, int, int]] = (
46
+ ("FR1", 1, 26),
47
+ ("CDRL1", 27, 38),
48
+ ("FR2", 39, 55),
49
+ ("CDRL2", 56, 65),
50
+ ("FR3", 66, 104),
51
+ ("CDRL3", 105, 117),
52
+ ("FR4", 118, 128),
53
+ )
54
+
55
+ _REGION_MAP: dict[Tuple[str, str], Sequence[Tuple[str, int, int]]] = {
56
+ ("imgt", "H"): _IMGT_HEAVY_REGIONS,
57
+ ("imgt", "L"): _IMGT_LIGHT_REGIONS,
58
+ }
59
+
60
+ _VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
61
+ _DEFAULT_SCHEME = "imgt"
62
+ _DEFAULT_CHAIN = "H"
63
+
64
+
65
+ _DEFAULT_NUMBERER: AnarciNumberer | None = None
66
+
67
+
68
+ def _sanitize_sequence(sequence: str) -> str:
69
+ return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS)
70
+
71
+
72
+ def get_default_numberer() -> AnarciNumberer:
73
+ global _DEFAULT_NUMBERER
74
+ if _DEFAULT_NUMBERER is None:
75
+ _DEFAULT_NUMBERER = AnarciNumberer(chain_type=_DEFAULT_CHAIN, cpu=True, ncpu=1, verbose=False)
76
+ return _DEFAULT_NUMBERER
77
+
78
+
79
+ def trim_variable_domain(
80
+ sequence: str,
81
+ *,
82
+ numberer: AnarciNumberer | None = None,
83
+ scheme: str = _DEFAULT_SCHEME,
84
+ chain_type: str = _DEFAULT_CHAIN,
85
+ fallback_length: int = 130,
86
+ ) -> str:
87
+ """Return the FR1–FR4 variable domain for a heavy/light chain sequence."""
88
+
89
+ cleaned = _sanitize_sequence(sequence)
90
+ if not cleaned:
91
+ return ""
92
+
93
+ active_numberer = numberer or get_default_numberer()
94
+ try:
95
+ numbered = active_numberer.number_sequence(cleaned)
96
+ except Exception: # pragma: no cover - best effort safeguard
97
+ return cleaned[:fallback_length]
98
+
99
+ region_sets = _region_boundaries(scheme, chain_type)
100
+ pieces: list[str] = []
101
+ for region_name, _start, _end in region_sets:
102
+ segment = numbered.regions.get(region_name, "")
103
+ if segment:
104
+ pieces.append(segment)
105
+ trimmed = "".join(pieces)
106
+ if not trimmed:
107
+ trimmed = numbered.regions.get("full", "")
108
+ if not trimmed:
109
+ trimmed = cleaned[:fallback_length]
110
+ return trimmed
111
+
112
+
113
+ def _normalise_chain_type(chain_type: str) -> str:
114
+ upper = chain_type.upper()
115
+ if upper in {"H", "HV"}:
116
+ return "H"
117
+ if upper in {"L", "K", "LV", "KV"}:
118
+ return "L"
119
+ return upper
120
+
121
+
122
+ class AnarciNumberer:
123
+ """Thin wrapper around the ANARCII pipeline to obtain IMGT regions."""
124
+
125
+ def __init__(
126
+ self,
127
+ *,
128
+ scheme: str = "imgt",
129
+ chain_type: str = "H",
130
+ cpu: bool = True,
131
+ ncpu: int = 1,
132
+ verbose: bool = False,
133
+ ) -> None:
134
+ if Anarcii is None: # pragma: no cover - optional dependency guard
135
+ msg = (
136
+ "anarcii is required for numbering but is not installed."
137
+ " Install 'anarcii' to enable ANARCI-based features."
138
+ )
139
+ raise ImportError(msg)
140
+ self.scheme = scheme
141
+ self.expected_chain_type = _normalise_chain_type(chain_type)
142
+ self.cpu = cpu
143
+ self.ncpu = ncpu
144
+ self.verbose = verbose
145
+ self._runner = None
146
+
147
+ def _ensure_runner(self) -> Anarcii:
148
+ if self._runner is None:
149
+ self._runner = Anarcii(
150
+ seq_type="antibody",
151
+ mode="accuracy",
152
+ batch_size=1,
153
+ cpu=self.cpu,
154
+ ncpu=self.ncpu,
155
+ verbose=self.verbose,
156
+ )
157
+ return self._runner
158
+
159
+ def number_sequence(self, sequence: str) -> NumberedSequence:
160
+ """Return numbering metadata for a single amino-acid sequence."""
161
+
162
+ runner = self._ensure_runner()
163
+ output = runner.number([sequence])
164
+ record = next(iter(output.values()))
165
+ if record.get("error"):
166
+ raise RuntimeError(f"ANARCI failed: {record['error']}")
167
+
168
+ scheme = record.get("scheme", self.scheme)
169
+ detected_chain = record.get("chain_type", self.expected_chain_type)
170
+ normalised_chain = _normalise_chain_type(detected_chain)
171
+ if self.expected_chain_type and normalised_chain != self.expected_chain_type:
172
+ msg = (
173
+ f"Expected chain type {self.expected_chain_type!r} but got"
174
+ f" {normalised_chain!r}"
175
+ )
176
+ raise ValueError(msg)
177
+
178
+ residues = [
179
+ NumberedResidue(position=pos, insertion=ins, amino_acid=aa)
180
+ for (pos, ins), aa in record["numbering"]
181
+ ]
182
+ regions = _extract_regions(
183
+ residues=residues,
184
+ scheme=scheme,
185
+ chain_type=normalised_chain,
186
+ )
187
+ return NumberedSequence(
188
+ sequence=sequence,
189
+ scheme=scheme,
190
+ chain_type=normalised_chain,
191
+ residues=residues,
192
+ regions=regions,
193
+ )
194
+
195
+
196
+ @lru_cache(maxsize=32)
197
+ def _region_boundaries(scheme: str, chain_type: str) -> Sequence[Tuple[str, int, int]]:
198
+ key = (scheme.lower(), chain_type.upper())
199
+ return _REGION_MAP.get(key, ())
200
+
201
+
202
+ def _extract_regions(
203
+ *,
204
+ residues: Sequence[NumberedResidue],
205
+ scheme: str,
206
+ chain_type: str,
207
+ ) -> dict[str, str]:
208
+ boundaries = _region_boundaries(scheme, chain_type)
209
+ slots: Dict[str, List[str]] = {name: [] for name, _, _ in boundaries}
210
+ slots["full"] = []
211
+
212
+ for residue in residues:
213
+ aa = residue.amino_acid
214
+ if aa == "-":
215
+ continue
216
+ slots["full"].append(aa)
217
+ for name, start, end in boundaries:
218
+ if start <= residue.position <= end:
219
+ slots.setdefault(name, []).append(aa)
220
+ break
221
+
222
+ return {key: "".join(value) for key, value in slots.items()}
polyreact/features/descriptors.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sequence descriptor features for polyreactivity prediction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Iterable, Sequence
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from Bio.SeqUtils.ProtParam import ProteinAnalysis
11
+ from sklearn.preprocessing import StandardScaler
12
+
13
+ from .anarsi import AnarciNumberer, NumberedSequence
14
+
15
+ _VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
16
+
17
+
18
+ @dataclass(slots=True)
19
+ class DescriptorConfig:
20
+ """Configuration for descriptor-based features."""
21
+
22
+ use_anarci: bool = True
23
+ regions: Sequence[str] = ("CDRH1", "CDRH2", "CDRH3")
24
+ features: Sequence[str] = (
25
+ "length",
26
+ "charge",
27
+ "hydropathy",
28
+ "aromaticity",
29
+ "pI",
30
+ "net_charge",
31
+ )
32
+ ph: float = 7.4
33
+
34
+
35
+ class DescriptorFeaturizer:
36
+ """Compute descriptor features with optional ANARCI-based regions."""
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ config: DescriptorConfig,
42
+ numberer: AnarciNumberer | None = None,
43
+ standardize: bool = True,
44
+ ) -> None:
45
+ self.config = config
46
+ self.numberer = numberer if not config.use_anarci else numberer or AnarciNumberer()
47
+ self.standardize = standardize
48
+ self.scaler = StandardScaler() if standardize else None
49
+ self.feature_names_: list[str] | None = None
50
+
51
+ def fit(self, sequences: Iterable[str]) -> "DescriptorFeaturizer":
52
+ table = self.compute_feature_table(sequences)
53
+ values = table.to_numpy(dtype=float)
54
+ if self.standardize and self.scaler is not None:
55
+ self.scaler.fit(values)
56
+ self.feature_names_ = list(table.columns)
57
+ return self
58
+
59
+ def transform(self, sequences: Iterable[str]) -> np.ndarray:
60
+ if self.feature_names_ is None:
61
+ msg = "DescriptorFeaturizer must be fitted before calling transform."
62
+ raise RuntimeError(msg)
63
+ table = self.compute_feature_table(sequences)
64
+ values = table.to_numpy(dtype=float)
65
+ if self.standardize and self.scaler is not None:
66
+ values = self.scaler.transform(values)
67
+ return values
68
+
69
+ def fit_transform(self, sequences: Iterable[str]) -> np.ndarray:
70
+ table = self.compute_feature_table(sequences)
71
+ values = table.to_numpy(dtype=float)
72
+ if self.standardize and self.scaler is not None:
73
+ self.scaler.fit(values)
74
+ values = self.scaler.transform(values)
75
+ self.feature_names_ = list(table.columns)
76
+ return values
77
+
78
+ def compute_feature_table(self, sequences: Iterable[str]) -> pd.DataFrame:
79
+ rows: list[dict[str, float]] = []
80
+ for sequence in sequences:
81
+ regions = self._prepare_regions(sequence)
82
+ if not self.config.use_anarci:
83
+ region_names = ["FULL"]
84
+ else:
85
+ region_names = [region.upper() for region in self.config.regions]
86
+ row: dict[str, float] = {}
87
+ for region_name in region_names:
88
+ normalized_name = region_name.upper()
89
+ region_sequence = regions.get(normalized_name, "")
90
+ for feature_name in self.config.features:
91
+ column = f"{normalized_name}_{feature_name}"
92
+ row[column] = _compute_feature(
93
+ region_sequence,
94
+ feature_name,
95
+ ph=self.config.ph,
96
+ )
97
+ rows.append(row)
98
+
99
+ if not self.config.use_anarci:
100
+ region_names = ["FULL"]
101
+ else:
102
+ region_names = [region.upper() for region in self.config.regions]
103
+ columns = [
104
+ f"{region}_{feature}"
105
+ for region in region_names
106
+ for feature in self.config.features
107
+ ]
108
+ frame = pd.DataFrame(rows, columns=columns)
109
+ return frame.fillna(0.0)
110
+
111
+ def _prepare_regions(self, sequence: str) -> dict[str, str]:
112
+ if not self.config.use_anarci:
113
+ return {"FULL": sequence}
114
+
115
+ try:
116
+ numbered: NumberedSequence = self.numberer.number_sequence(sequence)
117
+ except (RuntimeError, ValueError):
118
+ return {}
119
+ return {key.upper(): value for key, value in numbered.regions.items()}
120
+
121
+
122
+ def _sanitize_sequence(sequence: str) -> str:
123
+ return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS)
124
+
125
+
126
+ def _compute_feature(sequence: str, feature_name: str, *, ph: float) -> float:
127
+ sanitized = _sanitize_sequence(sequence)
128
+ if not sanitized:
129
+ return 0.0
130
+
131
+ analysis = ProteinAnalysis(sanitized)
132
+ if feature_name == "length":
133
+ return float(len(sanitized))
134
+ if feature_name == "hydropathy":
135
+ return float(analysis.gravy())
136
+ if feature_name == "aromaticity":
137
+ return float(analysis.aromaticity())
138
+ if feature_name == "pI":
139
+ return float(analysis.isoelectric_point())
140
+ if feature_name == "net_charge":
141
+ return float(analysis.charge_at_pH(ph))
142
+ if feature_name == "charge":
143
+ net = analysis.charge_at_pH(ph)
144
+ return float(net / len(sanitized))
145
+ msg = f"Unsupported feature: {feature_name}"
146
+ raise ValueError(msg)
polyreact/features/pipeline.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Feature pipeline construction utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict, dataclass, field
6
+ from typing import Iterable, Sequence
7
+
8
+ import numpy as np
9
+ from sklearn.preprocessing import StandardScaler
10
+
11
+ from ..config import Config, DescriptorSettings, FeatureBackendSettings
12
+ from .descriptors import DescriptorConfig, DescriptorFeaturizer
13
+ from .plm import PLMEmbedder
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class FeaturePipelineState:
18
+ backend_type: str
19
+ descriptor_featurizer: DescriptorFeaturizer | None
20
+ plm_scaler: StandardScaler | None
21
+ descriptor_config: DescriptorConfig | None
22
+ plm_model_name: str | None
23
+ plm_layer_pool: str | None
24
+ cache_dir: str | None
25
+ device: str
26
+ feature_names: list[str] = field(default_factory=list)
27
+
28
+
29
+ class FeaturePipeline:
30
+ """Fit/transform feature matrices according to configuration."""
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ backend: FeatureBackendSettings,
36
+ descriptors: DescriptorSettings,
37
+ device: str,
38
+ cache_dir_override: str | None = None,
39
+ plm_model_override: str | None = None,
40
+ layer_pool_override: str | None = None,
41
+ ) -> None:
42
+ self.backend = backend
43
+ self.descriptor_settings = descriptors
44
+ self.device = device
45
+ self.cache_dir_override = cache_dir_override
46
+ self.plm_model_override = plm_model_override
47
+ self.layer_pool_override = layer_pool_override
48
+
49
+ self._descriptor: DescriptorFeaturizer | None = None
50
+ self._plm: PLMEmbedder | None = None
51
+ self._plm_scaler: StandardScaler | None = None
52
+ self._feature_names: list[str] = []
53
+
54
+ def fit_transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001
55
+ backend_type = self.backend.type if self.backend.type else "descriptors"
56
+ self._validate_heavy_support(backend_type, heavy_only)
57
+ sequences = _extract_sequences(df, heavy_only=heavy_only)
58
+
59
+ if backend_type == "descriptors":
60
+ self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
61
+ features = self._descriptor.fit_transform(sequences)
62
+ self._feature_names = list(self._descriptor.feature_names_ or [])
63
+ self._plm = None
64
+ self._plm_scaler = None
65
+ return features.astype(np.float32)
66
+
67
+ if backend_type == "plm":
68
+ self._descriptor = None
69
+ self._plm = _build_plm_embedder(
70
+ self.backend,
71
+ device=self.device,
72
+ cache_dir_override=self.cache_dir_override,
73
+ plm_model_override=self.plm_model_override,
74
+ layer_pool_override=self.layer_pool_override,
75
+ )
76
+ embeddings = self._plm.embed(sequences, batch_size=batch_size)
77
+ if self.backend.standardize:
78
+ self._plm_scaler = StandardScaler()
79
+ embeddings = self._plm_scaler.fit_transform(embeddings)
80
+ else:
81
+ self._plm_scaler = None
82
+ self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
83
+ return embeddings.astype(np.float32)
84
+
85
+ if backend_type == "concat":
86
+ descriptor = _build_descriptor_featurizer(self.descriptor_settings)
87
+ desc_features = descriptor.fit_transform(sequences)
88
+ plm = _build_plm_embedder(
89
+ self.backend,
90
+ device=self.device,
91
+ cache_dir_override=self.cache_dir_override,
92
+ plm_model_override=self.plm_model_override,
93
+ layer_pool_override=self.layer_pool_override,
94
+ )
95
+ embeddings = plm.embed(sequences, batch_size=batch_size)
96
+ if self.backend.standardize:
97
+ plm_scaler = StandardScaler()
98
+ embeddings = plm_scaler.fit_transform(embeddings)
99
+ else:
100
+ plm_scaler = None
101
+ self._descriptor = descriptor
102
+ self._plm = plm
103
+ self._plm_scaler = plm_scaler
104
+ self._feature_names = list(descriptor.feature_names_ or []) + [
105
+ f"plm_{i}" for i in range(embeddings.shape[1])
106
+ ]
107
+ return np.concatenate([desc_features, embeddings], axis=1).astype(np.float32)
108
+
109
+ msg = f"Unsupported feature backend: {backend_type}"
110
+ raise ValueError(msg)
111
+
112
+ def fit(self, df, *, heavy_only: bool, batch_size: int = 8) -> "FeaturePipeline": # noqa: ANN001
113
+ backend_type = self.backend.type if self.backend.type else "descriptors"
114
+ self._validate_heavy_support(backend_type, heavy_only)
115
+ sequences = _extract_sequences(df, heavy_only=heavy_only)
116
+
117
+ if backend_type == "descriptors":
118
+ self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
119
+ self._descriptor.fit(sequences)
120
+ self._feature_names = list(self._descriptor.feature_names_ or [])
121
+ self._plm = None
122
+ self._plm_scaler = None
123
+ elif backend_type == "plm":
124
+ self._descriptor = None
125
+ self._plm = _build_plm_embedder(
126
+ self.backend,
127
+ device=self.device,
128
+ cache_dir_override=self.cache_dir_override,
129
+ plm_model_override=self.plm_model_override,
130
+ layer_pool_override=self.layer_pool_override,
131
+ )
132
+ embeddings = self._plm.embed(sequences, batch_size=batch_size)
133
+ if self.backend.standardize:
134
+ self._plm_scaler = StandardScaler()
135
+ embeddings = self._plm_scaler.fit_transform(embeddings)
136
+ else:
137
+ self._plm_scaler = None
138
+ self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
139
+ elif backend_type == "concat":
140
+ descriptor = _build_descriptor_featurizer(self.descriptor_settings)
141
+ desc_features = descriptor.fit_transform(sequences)
142
+ plm = _build_plm_embedder(
143
+ self.backend,
144
+ device=self.device,
145
+ cache_dir_override=self.cache_dir_override,
146
+ plm_model_override=self.plm_model_override,
147
+ layer_pool_override=self.layer_pool_override,
148
+ )
149
+ embeddings = plm.embed(sequences, batch_size=batch_size)
150
+ if self.backend.standardize:
151
+ plm_scaler = StandardScaler()
152
+ embeddings = plm_scaler.fit_transform(embeddings)
153
+ else:
154
+ plm_scaler = None
155
+ self._descriptor = descriptor
156
+ self._plm = plm
157
+ self._plm_scaler = plm_scaler
158
+ self._feature_names = list(descriptor.feature_names_ or []) + [
159
+ f"plm_{i}" for i in range(embeddings.shape[1])
160
+ ]
161
+ else: # pragma: no cover - defensive branch
162
+ msg = f"Unsupported feature backend: {backend_type}"
163
+ raise ValueError(msg)
164
+ return self
165
+
166
+ def transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray: # noqa: ANN001
167
+ backend_type = self.backend.type if self.backend.type else "descriptors"
168
+ self._validate_heavy_support(backend_type, heavy_only)
169
+ sequences = _extract_sequences(df, heavy_only=heavy_only)
170
+
171
+ if backend_type == "descriptors":
172
+ if self._descriptor is None:
173
+ msg = "Descriptor featurizer is not fitted"
174
+ raise RuntimeError(msg)
175
+ features = self._descriptor.transform(sequences)
176
+ elif backend_type == "plm":
177
+ if self._plm is None:
178
+ msg = "PLM embedder is not initialised"
179
+ raise RuntimeError(msg)
180
+ embeddings = self._plm.embed(sequences, batch_size=batch_size)
181
+ if self.backend.standardize and self._plm_scaler is not None:
182
+ embeddings = self._plm_scaler.transform(embeddings)
183
+ features = embeddings
184
+ elif backend_type == "concat":
185
+ if self._descriptor is None or self._plm is None:
186
+ msg = "Feature pipeline not fitted"
187
+ raise RuntimeError(msg)
188
+ desc_features = self._descriptor.transform(sequences)
189
+ embeddings = self._plm.embed(sequences, batch_size=batch_size)
190
+ if self.backend.standardize and self._plm_scaler is not None:
191
+ embeddings = self._plm_scaler.transform(embeddings)
192
+ features = np.concatenate([desc_features, embeddings], axis=1)
193
+ else: # pragma: no cover - defensive branch
194
+ msg = f"Unsupported feature backend: {backend_type}"
195
+ raise ValueError(msg)
196
+
197
+ return features.astype(np.float32)
198
+
199
+ @property
200
+ def feature_names(self) -> list[str]:
201
+ return self._feature_names
202
+
203
+ def get_state(self) -> FeaturePipelineState:
204
+ descriptor = self._descriptor
205
+ if descriptor is not None and descriptor.numberer is not None:
206
+ if hasattr(descriptor.numberer, "_runner"):
207
+ descriptor.numberer._runner = None # type: ignore[attr-defined]
208
+ return FeaturePipelineState(
209
+ backend_type=self.backend.type,
210
+ descriptor_featurizer=descriptor,
211
+ plm_scaler=self._plm_scaler,
212
+ descriptor_config=_build_descriptor_config(self.descriptor_settings),
213
+ plm_model_name=self._effective_plm_model_name,
214
+ plm_layer_pool=self._effective_layer_pool,
215
+ cache_dir=self._effective_cache_dir,
216
+ device=self.device,
217
+ feature_names=self._feature_names,
218
+ )
219
+
220
+ def load_state(self, state: FeaturePipelineState) -> None:
221
+ self.backend.type = state.backend_type
222
+ if state.plm_model_name:
223
+ self.backend.plm_model_name = state.plm_model_name
224
+ self.plm_model_override = state.plm_model_name
225
+ if state.plm_layer_pool:
226
+ self.backend.layer_pool = state.plm_layer_pool
227
+ self.layer_pool_override = state.plm_layer_pool
228
+ if state.cache_dir:
229
+ self.backend.cache_dir = state.cache_dir
230
+ self.cache_dir_override = state.cache_dir
231
+ if state.descriptor_config:
232
+ self.descriptor_settings = DescriptorSettings(
233
+ use_anarci=state.descriptor_config.use_anarci,
234
+ regions=tuple(state.descriptor_config.regions),
235
+ features=tuple(state.descriptor_config.features),
236
+ ph=state.descriptor_config.ph,
237
+ )
238
+ self._descriptor = state.descriptor_featurizer
239
+ self._plm_scaler = state.plm_scaler
240
+ self._feature_names = state.feature_names
241
+ if self.backend.type in {"plm", "concat"}:
242
+ self._plm = _build_plm_embedder(
243
+ self.backend,
244
+ device=self.device,
245
+ cache_dir_override=self.backend.cache_dir,
246
+ plm_model_override=self.backend.plm_model_name,
247
+ layer_pool_override=self.backend.layer_pool,
248
+ )
249
+ else:
250
+ self._plm = None
251
+
252
+ @property
253
+ def _effective_plm_model_name(self) -> str | None:
254
+ if self.backend.type not in {"plm", "concat"}:
255
+ return None
256
+ return self.plm_model_override or self.backend.plm_model_name
257
+
258
+ @property
259
+ def _effective_layer_pool(self) -> str | None:
260
+ if self.backend.type not in {"plm", "concat"}:
261
+ return None
262
+ return self.layer_pool_override or self.backend.layer_pool
263
+
264
+ @property
265
+ def _effective_cache_dir(self) -> str | None:
266
+ if self.backend.type not in {"plm", "concat"}:
267
+ return None
268
+ if self.cache_dir_override is not None:
269
+ return self.cache_dir_override
270
+ return self.backend.cache_dir
271
+
272
+ def _validate_heavy_support(self, backend_type: str, heavy_only: bool) -> None:
273
+ if heavy_only:
274
+ return
275
+ if backend_type == "descriptors" and self.descriptor_settings.use_anarci:
276
+ msg = "Descriptor backend with ANARCI currently supports heavy-chain only inference."
277
+ raise ValueError(msg)
278
+ if backend_type == "concat" and self.descriptor_settings.use_anarci:
279
+ msg = "Concat backend with descriptors requires heavy-chain only data."
280
+ raise ValueError(msg)
281
+
282
+
283
+ def build_feature_pipeline(
284
+ config: Config,
285
+ *,
286
+ backend_override: str | None = None,
287
+ plm_model_override: str | None = None,
288
+ cache_dir_override: str | None = None,
289
+ layer_pool_override: str | None = None,
290
+ ) -> FeaturePipeline:
291
+ backend = FeatureBackendSettings(**asdict(config.feature_backend))
292
+ if backend_override:
293
+ backend.type = backend_override
294
+ pipeline = FeaturePipeline(
295
+ backend=backend,
296
+ descriptors=config.descriptors,
297
+ device=config.device,
298
+ cache_dir_override=cache_dir_override,
299
+ plm_model_override=plm_model_override,
300
+ layer_pool_override=layer_pool_override,
301
+ )
302
+ return pipeline
303
+
304
+
305
+ def _build_descriptor_featurizer(settings: DescriptorSettings) -> DescriptorFeaturizer:
306
+ descriptor_config = _build_descriptor_config(settings)
307
+ return DescriptorFeaturizer(config=descriptor_config, standardize=True)
308
+
309
+
310
+ def _build_descriptor_config(settings: DescriptorSettings) -> DescriptorConfig:
311
+ return DescriptorConfig(
312
+ use_anarci=settings.use_anarci,
313
+ regions=tuple(settings.regions),
314
+ features=tuple(settings.features),
315
+ ph=settings.ph,
316
+ )
317
+
318
+
319
+ def _build_plm_embedder(
320
+ backend: FeatureBackendSettings,
321
+ *,
322
+ device: str,
323
+ cache_dir_override: str | None,
324
+ plm_model_override: str | None,
325
+ layer_pool_override: str | None,
326
+ ) -> PLMEmbedder:
327
+ model_name = plm_model_override or backend.plm_model_name
328
+ cache_dir = cache_dir_override or backend.cache_dir
329
+ layer_pool = layer_pool_override or backend.layer_pool
330
+ return PLMEmbedder(
331
+ model_name=model_name,
332
+ layer_pool=layer_pool,
333
+ device=device,
334
+ cache_dir=cache_dir,
335
+ )
336
+
337
+
338
+ def _extract_sequences(df, heavy_only: bool) -> Sequence[str]: # noqa: ANN001
339
+ if heavy_only or "light_seq" not in df.columns:
340
+ return df["heavy_seq"].fillna("").astype(str).tolist()
341
+ heavy = df["heavy_seq"].fillna("").astype(str)
342
+ light = df["light_seq"].fillna("").astype(str)
343
+ return (heavy + "|" + light).tolist()
polyreact/features/plm.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Protein language model embeddings backend with caching support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from types import SimpleNamespace
10
+ from typing import Callable, Iterable, List, Sequence, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+ try: # pragma: no cover - optional dependency
17
+ from transformers import AutoModel, AutoTokenizer
18
+ except ImportError: # pragma: no cover - optional dependency
19
+ AutoModel = None
20
+ AutoTokenizer = None
21
+
22
+ try: # pragma: no cover - optional dependency
23
+ import esm
24
+ except ImportError: # pragma: no cover - optional dependency
25
+ esm = None
26
+
27
+ from .anarsi import AnarciNumberer
28
+
29
+ ModelLoader = Callable[[str, str], Tuple[object, nn.Module]]
30
+
31
+ if esm is not None: # pragma: no cover - optional dependency
32
+ _ESM1V_LOADERS = {
33
+ "esm1v_t33_650m_ur90s_1": esm.pretrained.esm1v_t33_650M_UR90S_1,
34
+ "esm1v_t33_650m_ur90s_2": esm.pretrained.esm1v_t33_650M_UR90S_2,
35
+ "esm1v_t33_650m_ur90s_3": esm.pretrained.esm1v_t33_650M_UR90S_3,
36
+ "esm1v_t33_650m_ur90s_4": esm.pretrained.esm1v_t33_650M_UR90S_4,
37
+ "esm1v_t33_650m_ur90s_5": esm.pretrained.esm1v_t33_650M_UR90S_5,
38
+ }
39
+ else: # pragma: no cover - optional dependency
40
+ _ESM1V_LOADERS: dict[str, Callable[[], tuple[nn.Module, object]]] = {}
41
+
42
+
43
+ class _ESMTokenizer:
44
+ """Callable wrapper that mimics Hugging Face tokenizers for ESM models."""
45
+
46
+ def __init__(self, alphabet) -> None: # noqa: ANN001
47
+ self.alphabet = alphabet
48
+ self._batch_converter = alphabet.get_batch_converter()
49
+
50
+ def __call__(
51
+ self,
52
+ sequences: Sequence[str],
53
+ *,
54
+ return_tensors: str = "pt",
55
+ padding: bool = True, # noqa: FBT002
56
+ truncation: bool = True, # noqa: FBT002
57
+ add_special_tokens: bool = True, # noqa: FBT002
58
+ return_special_tokens_mask: bool = True, # noqa: FBT002
59
+ ) -> dict[str, torch.Tensor]:
60
+ if return_tensors != "pt": # pragma: no cover - defensive branch
61
+ msg = "ESM tokenizer only supports return_tensors='pt'"
62
+ raise ValueError(msg)
63
+ data = [(str(idx), (seq or "").upper()) for idx, seq in enumerate(sequences)]
64
+ _labels, _strings, tokens = self._batch_converter(data)
65
+ attention_mask = (tokens != self.alphabet.padding_idx).long()
66
+ special_tokens = torch.zeros_like(tokens)
67
+ specials = {
68
+ self.alphabet.padding_idx,
69
+ self.alphabet.cls_idx,
70
+ self.alphabet.eos_idx,
71
+ }
72
+ for special in specials:
73
+ special_tokens |= tokens == special
74
+ output: dict[str, torch.Tensor] = {
75
+ "input_ids": tokens,
76
+ "attention_mask": attention_mask,
77
+ }
78
+ if return_special_tokens_mask:
79
+ output["special_tokens_mask"] = special_tokens.long()
80
+ return output
81
+
82
+
83
+ class _ESMModelWrapper(nn.Module):
84
+ """Adapter providing a Hugging Face style interface for ESM models."""
85
+
86
+ def __init__(self, model: nn.Module) -> None:
87
+ super().__init__()
88
+ self.model = model
89
+ self.layer_index = getattr(model, "num_layers", None)
90
+ if self.layer_index is None:
91
+ msg = "Unable to determine final layer for ESM model"
92
+ raise AttributeError(msg)
93
+
94
+ def eval(self) -> "_ESMModelWrapper": # pragma: no cover - trivial
95
+ self.model.eval()
96
+ return self
97
+
98
+ def to(self, device: str) -> "_ESMModelWrapper": # pragma: no cover - trivial
99
+ self.model.to(device)
100
+ return self
101
+
102
+ def forward(self, input_ids: torch.Tensor, **_): # noqa: ANN003
103
+ with torch.no_grad():
104
+ outputs = self.model(
105
+ input_ids,
106
+ repr_layers=[self.layer_index],
107
+ return_contacts=False,
108
+ )
109
+ hidden = outputs["representations"][self.layer_index]
110
+ return SimpleNamespace(last_hidden_state=hidden)
111
+
112
+ __call__ = forward
113
+
114
+
115
+ @dataclass(slots=True)
116
+ class PLMConfig:
117
+ model_name: str = "facebook/esm1v_t33_650M_UR90S_1"
118
+ layer_pool: str = "mean"
119
+ cache_dir: Path = Path(".cache/embeddings")
120
+ device: str = "auto"
121
+
122
+
123
+ class PLMEmbedder:
124
+ """Embed amino-acid sequences using a transformer model with caching."""
125
+
126
+ def __init__(
127
+ self,
128
+ model_name: str = "facebook/esm1v_t33_650M_UR90S_1",
129
+ *,
130
+ layer_pool: str = "mean",
131
+ device: str = "auto",
132
+ cache_dir: str | Path | None = None,
133
+ numberer: AnarciNumberer | None = None,
134
+ model_loader: ModelLoader | None = None,
135
+ ) -> None:
136
+ self.model_name = model_name
137
+ self.layer_pool = layer_pool
138
+ self.device = self._resolve_device(device)
139
+ self.cache_dir = Path(cache_dir or ".cache/embeddings")
140
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
141
+ self.numberer = numberer
142
+ self.model_loader = model_loader
143
+ self._tokenizer: object | None = None
144
+ self._model: nn.Module | None = None
145
+
146
+ @staticmethod
147
+ def _resolve_device(device: str) -> str:
148
+ if device == "auto":
149
+ return "cuda" if torch.cuda.is_available() else "cpu"
150
+ return device
151
+
152
+ @property
153
+ def tokenizer(self): # noqa: D401
154
+ if self._tokenizer is None:
155
+ tokenizer, model = self._load_model_components()
156
+ self._tokenizer = tokenizer
157
+ self._model = model
158
+ return self._tokenizer
159
+
160
+ @property
161
+ def model(self) -> nn.Module:
162
+ if self._model is None:
163
+ tokenizer, model = self._load_model_components()
164
+ self._tokenizer = tokenizer
165
+ self._model = model
166
+ return self._model
167
+
168
+ def _load_model_components(self) -> Tuple[object, nn.Module]:
169
+ if self.model_loader is not None:
170
+ tokenizer, model = self.model_loader(self.model_name, self.device)
171
+ return tokenizer, model
172
+
173
+ if self._is_esm1v_model(self.model_name):
174
+ return self._load_esm_model()
175
+
176
+ if AutoModel is None or AutoTokenizer is None: # pragma: no cover - optional dependency
177
+ msg = "transformers must be installed to use PLMEmbedder"
178
+ raise ImportError(msg)
179
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
180
+ model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
181
+ model.eval()
182
+ model.to(self.device)
183
+ return tokenizer, model
184
+
185
+ def _load_esm_model(self) -> Tuple[object, nn.Module]:
186
+ if esm is None: # pragma: no cover - optional dependency
187
+ msg = (
188
+ "The 'esm' package is required to use ESM-1v models."
189
+ )
190
+ raise ImportError(msg)
191
+
192
+ normalized = self._canonical_esm_name(self.model_name)
193
+ loader = _ESM1V_LOADERS.get(normalized)
194
+ if loader is None: # pragma: no cover - guard branch
195
+ msg = f"Unsupported ESM-1v model: {self.model_name}"
196
+ raise ValueError(msg)
197
+
198
+ model, alphabet = loader()
199
+ model.eval()
200
+ model.to(self.device)
201
+ tokenizer = _ESMTokenizer(alphabet)
202
+ wrapper = _ESMModelWrapper(model)
203
+ return tokenizer, wrapper
204
+
205
+ @staticmethod
206
+ def _canonical_esm_name(model_name: str) -> str:
207
+ name = model_name.lower()
208
+ if "/" in name:
209
+ name = name.split("/")[-1]
210
+ return name
211
+
212
+ @classmethod
213
+ def _is_esm1v_model(cls, model_name: str) -> bool:
214
+ return cls._canonical_esm_name(model_name).startswith("esm1v")
215
+
216
+ def embed(self, sequences: Iterable[str], *, batch_size: int = 8) -> np.ndarray:
217
+ batch_sequences = list(sequences)
218
+ if not batch_sequences:
219
+ return np.empty((0, 0), dtype=np.float32)
220
+
221
+ outputs: List[np.ndarray | None] = [None] * len(batch_sequences)
222
+ unique_to_compute: dict[str, List[Tuple[int, Path]]] = {}
223
+ model_dir = self.cache_dir / self._normalized_model_name()
224
+ model_dir.mkdir(parents=True, exist_ok=True)
225
+
226
+ cache_hits: list[tuple[int, Path]] = []
227
+ for idx, sequence in enumerate(batch_sequences):
228
+ cache_path = self._sequence_cache_path(model_dir, sequence)
229
+ if cache_path.exists():
230
+ cache_hits.append((idx, cache_path))
231
+ else:
232
+ unique_to_compute.setdefault(sequence, []).append((idx, cache_path))
233
+
234
+ if cache_hits:
235
+ loaders = [path for _, path in cache_hits]
236
+ max_workers = min(len(loaders), 32)
237
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
238
+ for (idx, _), embedding in zip(cache_hits, executor.map(np.load, loaders), strict=True):
239
+ outputs[idx] = embedding
240
+
241
+ if unique_to_compute:
242
+ embeddings = self._compute_embeddings(list(unique_to_compute.keys()), batch_size=batch_size)
243
+ for sequence, embedding in zip(unique_to_compute.keys(), embeddings, strict=True):
244
+ targets = unique_to_compute[sequence]
245
+ for idx, cache_path in targets:
246
+ outputs[idx] = embedding
247
+ np.save(cache_path, embedding)
248
+ if any(item is None for item in outputs): # pragma: no cover - safety
249
+ msg = "Failed to compute embeddings for all sequences"
250
+ raise RuntimeError(msg)
251
+ array_outputs = [np.asarray(item, dtype=np.float32) for item in outputs] # type: ignore[arg-type]
252
+ return np.stack(array_outputs, axis=0)
253
+
254
+ def _compute_embeddings(self, sequences: Sequence[str], *, batch_size: int) -> List[np.ndarray]:
255
+ tokenizer = self.tokenizer
256
+ model = self.model
257
+ model.eval()
258
+ embeddings: List[np.ndarray] = []
259
+ for start in range(0, len(sequences), batch_size):
260
+ chunk = list(sequences[start : start + batch_size])
261
+ tokenized = self._tokenize(tokenizer, chunk)
262
+ model_inputs: dict[str, torch.Tensor] = {}
263
+ aux_inputs: dict[str, torch.Tensor] = {}
264
+ for key, value in tokenized.items():
265
+ if isinstance(value, torch.Tensor):
266
+ tensor_value = value.to(self.device)
267
+ else:
268
+ tensor_value = value
269
+ if key == "special_tokens_mask":
270
+ aux_inputs[key] = tensor_value
271
+ else:
272
+ model_inputs[key] = tensor_value
273
+ with torch.no_grad():
274
+ outputs = model(**model_inputs)
275
+ hidden_states = outputs.last_hidden_state.detach().cpu()
276
+ attention_mask = model_inputs.get("attention_mask")
277
+ special_tokens_mask = aux_inputs.get("special_tokens_mask")
278
+ if isinstance(attention_mask, torch.Tensor):
279
+ attention_mask = attention_mask.detach().cpu()
280
+ if isinstance(special_tokens_mask, torch.Tensor):
281
+ special_tokens_mask = special_tokens_mask.detach().cpu()
282
+
283
+ for idx, sequence in enumerate(chunk):
284
+ hidden = hidden_states[idx]
285
+ mask = attention_mask[idx] if isinstance(attention_mask, torch.Tensor) else None
286
+ special_mask = (
287
+ special_tokens_mask[idx]
288
+ if isinstance(special_tokens_mask, torch.Tensor)
289
+ else None
290
+ )
291
+ embedding = self._pool_hidden(hidden, mask, special_mask, sequence)
292
+ embeddings.append(embedding)
293
+ return embeddings
294
+
295
+ def _tokenize(self, tokenizer, sequences: Sequence[str]):
296
+ if hasattr(tokenizer, "__call__"):
297
+ return tokenizer(
298
+ list(sequences),
299
+ return_tensors="pt",
300
+ padding=True,
301
+ truncation=True,
302
+ add_special_tokens=True,
303
+ return_special_tokens_mask=True,
304
+ )
305
+ msg = "Tokenizer does not implement __call__"
306
+ raise TypeError(msg)
307
+
308
+ def _pool_hidden(
309
+ self,
310
+ hidden: torch.Tensor,
311
+ attention_mask: torch.Tensor | None,
312
+ special_mask: torch.Tensor | None,
313
+ sequence: str,
314
+ ) -> np.ndarray:
315
+ if attention_mask is None:
316
+ attention = torch.ones(hidden.size(0), dtype=torch.float32)
317
+ else:
318
+ attention = attention_mask.to(dtype=torch.float32)
319
+ if special_mask is not None:
320
+ attention = attention * (1.0 - special_mask.to(dtype=torch.float32))
321
+ if attention.sum() == 0:
322
+ attention = torch.ones_like(attention)
323
+
324
+ if self.layer_pool == "mean":
325
+ return self._masked_mean(hidden, attention)
326
+ if self.layer_pool == "cls":
327
+ return hidden[0].detach().cpu().numpy()
328
+ if self.layer_pool == "per_token_mean_cdrh3":
329
+ return self._pool_cdrh3(hidden, attention, sequence)
330
+ msg = f"Unsupported layer pool: {self.layer_pool}"
331
+ raise ValueError(msg)
332
+
333
+ @staticmethod
334
+ def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> np.ndarray:
335
+ weights = mask.unsqueeze(-1)
336
+ weighted = hidden * weights
337
+ denom = weights.sum()
338
+ if denom == 0:
339
+ pooled = hidden.mean(dim=0)
340
+ else:
341
+ pooled = weighted.sum(dim=0) / denom
342
+ return pooled.detach().cpu().numpy()
343
+
344
+ def _pool_cdrh3(self, hidden: torch.Tensor, mask: torch.Tensor, sequence: str) -> np.ndarray:
345
+ numberer = self.numberer
346
+ if numberer is None:
347
+ numberer = AnarciNumberer()
348
+ self.numberer = numberer
349
+ numbered = numberer.number_sequence(sequence)
350
+ cdr = numbered.regions.get("CDRH3", "")
351
+ if not cdr:
352
+ return self._masked_mean(hidden, mask)
353
+ sequence_upper = sequence.upper()
354
+ start = sequence_upper.find(cdr.upper())
355
+ if start == -1:
356
+ return self._masked_mean(hidden, mask)
357
+ residues_idx = mask.nonzero(as_tuple=False).squeeze(-1).tolist()
358
+ if not residues_idx:
359
+ return self._masked_mean(hidden, mask)
360
+ end = start + len(cdr)
361
+ if end > len(residues_idx):
362
+ return self._masked_mean(hidden, mask)
363
+ cdr_token_positions = residues_idx[start:end]
364
+ if not cdr_token_positions:
365
+ return self._masked_mean(hidden, mask)
366
+ cdr_mask = torch.zeros_like(mask)
367
+ for pos in cdr_token_positions:
368
+ cdr_mask[pos] = 1.0
369
+ return self._masked_mean(hidden, cdr_mask)
370
+
371
+ def _sequence_cache_path(self, model_dir: Path, sequence: str) -> Path:
372
+ digest = hashlib.sha1(sequence.encode("utf-8")).hexdigest()
373
+ return model_dir / f"{digest}.npy"
374
+
375
+ def _normalized_model_name(self) -> str:
376
+ if self._is_esm1v_model(self.model_name):
377
+ return self._canonical_esm_name(self.model_name)
378
+ return self.model_name.replace("/", "_")
polyreact/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Models for polyreactivity classification."""
2
+
3
+ __all__ = ["linear", "calibrate", "ordinal"]
polyreact/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (274 Bytes). View file
 
polyreact/models/__pycache__/calibrate.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
polyreact/models/__pycache__/linear.cpython-311.pyc ADDED
Binary file (4.92 kB). View file
 
polyreact/models/__pycache__/ordinal.cpython-311.pyc ADDED
Binary file (5.76 kB). View file
 
polyreact/models/calibrate.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Probability calibration helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from sklearn.calibration import CalibratedClassifierCV
9
+
10
+
11
+ def fit_calibrator(
12
+ estimator: Any,
13
+ X: np.ndarray,
14
+ y: np.ndarray,
15
+ *,
16
+ method: str = "isotonic",
17
+ cv: int | str | None = "prefit",
18
+ ) -> CalibratedClassifierCV:
19
+ """Fit a ``CalibratedClassifierCV`` on top of a pre-trained estimator."""
20
+
21
+ calibrator = CalibratedClassifierCV(estimator, method=method, cv=cv)
22
+ calibrator.fit(X, y)
23
+ return calibrator
24
+
polyreact/models/linear.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Linear classification heads for polyreactivity prediction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.svm import LinearSVC
11
+
12
+
13
+ @dataclass(slots=True)
14
+ class LinearModelConfig:
15
+ """Configuration options for linear heads."""
16
+
17
+ head: str = "logreg"
18
+ C: float = 1.0
19
+ class_weight: Any = "balanced"
20
+ max_iter: int = 1000
21
+
22
+
23
+ @dataclass(slots=True)
24
+ class TrainedModel:
25
+ """Container for trained estimators and optional calibration."""
26
+
27
+ estimator: Any
28
+ calibrator: Any | None = None
29
+ vectorizer_name: str = ""
30
+ feature_meta: dict[str, Any] = field(default_factory=dict)
31
+ metrics_cv: dict[str, float] = field(default_factory=dict)
32
+
33
+ def predict(self, X: np.ndarray) -> np.ndarray:
34
+ if self.calibrator is not None and hasattr(self.calibrator, "predict"):
35
+ return self.calibrator.predict(X)
36
+ return self.estimator.predict(X)
37
+
38
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
39
+ if self.calibrator is not None and hasattr(self.calibrator, "predict_proba"):
40
+ probs = self.calibrator.predict_proba(X)
41
+ return probs[:, 1]
42
+ if hasattr(self.estimator, "predict_proba"):
43
+ probs = self.estimator.predict_proba(X)
44
+ return probs[:, 1]
45
+ if hasattr(self.estimator, "decision_function"):
46
+ scores = self.estimator.decision_function(X)
47
+ return 1.0 / (1.0 + np.exp(-scores))
48
+ msg = "Estimator does not support probability prediction"
49
+ raise AttributeError(msg)
50
+
51
+
52
+ def build_estimator(
53
+ *, config: LinearModelConfig, random_state: int | None = 42
54
+ ) -> Any:
55
+ """Construct an unfitted linear estimator based on configuration."""
56
+
57
+ if config.head == "logreg":
58
+ estimator = LogisticRegression(
59
+ C=config.C,
60
+ max_iter=config.max_iter,
61
+ class_weight=config.class_weight,
62
+ solver="liblinear",
63
+ random_state=random_state,
64
+ )
65
+ elif config.head == "linear_svm":
66
+ estimator = LinearSVC(
67
+ C=config.C,
68
+ class_weight=config.class_weight,
69
+ max_iter=config.max_iter,
70
+ random_state=random_state,
71
+ )
72
+ else: # pragma: no cover - defensive branch
73
+ msg = f"Unsupported head type: {config.head}"
74
+ raise ValueError(msg)
75
+ return estimator
76
+
77
+
78
+ def train_linear_model(
79
+ X: np.ndarray,
80
+ y: np.ndarray,
81
+ *,
82
+ config: LinearModelConfig,
83
+ random_state: int | None = 42,
84
+ ) -> TrainedModel:
85
+ """Fit a linear classifier on the provided feature matrix."""
86
+
87
+ estimator = build_estimator(config=config, random_state=random_state)
88
+ if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
89
+ estimator.set_params(solver="lbfgs")
90
+ estimator.fit(X, y)
91
+ return TrainedModel(estimator=estimator)
polyreact/models/ordinal.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ordinal/count modeling utilities for flag regression."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import statsmodels.api as sm
9
+ from statsmodels.discrete.discrete_model import NegativeBinomialResults
10
+ from sklearn.linear_model import PoissonRegressor
11
+ from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
12
+
13
+ try:
14
+ from sklearn.metrics import root_mean_squared_error
15
+ except ImportError: # pragma: no cover - fallback for older sklearn
16
+ def root_mean_squared_error(y_true, y_pred):
17
+ return mean_squared_error(y_true, y_pred, squared=False)
18
+
19
+
20
+ @dataclass(slots=True)
21
+ class PoissonModel:
22
+ """Wrapper storing a fitted Poisson regression model."""
23
+
24
+ estimator: PoissonRegressor
25
+
26
+ def predict(self, X: np.ndarray) -> np.ndarray:
27
+ return self.estimator.predict(X)
28
+
29
+
30
+ def fit_poisson_model(
31
+ X: np.ndarray,
32
+ y: np.ndarray,
33
+ *,
34
+ alpha: float = 1e-6,
35
+ max_iter: int = 1000,
36
+ ) -> PoissonModel:
37
+ """Train a Poisson regression model on count targets."""
38
+
39
+ model = PoissonRegressor(alpha=alpha, max_iter=max_iter)
40
+ model.fit(X, y)
41
+ return PoissonModel(estimator=model)
42
+
43
+
44
+ @dataclass(slots=True)
45
+ class NegativeBinomialModel:
46
+ """Wrapper storing a fitted negative-binomial regression model."""
47
+
48
+ result: NegativeBinomialResults
49
+
50
+ def predict(self, X: np.ndarray) -> np.ndarray:
51
+ X_const = sm.add_constant(X, has_constant="add")
52
+ return self.result.predict(X_const)
53
+
54
+ @property
55
+ def alpha(self) -> float:
56
+ params = np.asarray(self.result.params, dtype=float)
57
+ exog_dim = self.result.model.exog.shape[1]
58
+ if params.size > exog_dim:
59
+ # statsmodels stores log(alpha) as the final coefficient
60
+ return float(np.exp(params[-1]))
61
+ model_alpha = getattr(self.result.model, "alpha", None)
62
+ if model_alpha is not None:
63
+ return float(model_alpha)
64
+ return float("nan")
65
+
66
+
67
+ def fit_negative_binomial_model(
68
+ X: np.ndarray,
69
+ y: np.ndarray,
70
+ *,
71
+ max_iter: int = 200,
72
+ ) -> NegativeBinomialModel:
73
+ """Train a negative binomial regression model (NB2)."""
74
+
75
+ X_const = sm.add_constant(X, has_constant="add")
76
+ model = sm.NegativeBinomial(y, X_const, loglike_method="nb2")
77
+ result = model.fit(maxiter=max_iter, disp=False)
78
+ return NegativeBinomialModel(result=result)
79
+
80
+
81
+ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
82
+ """Return standard regression metrics for count predictions."""
83
+
84
+ mae = mean_absolute_error(y_true, y_pred)
85
+ rmse = root_mean_squared_error(y_true, y_pred)
86
+ r2 = r2_score(y_true, y_pred)
87
+ return {
88
+ "mae": float(mae),
89
+ "rmse": float(rmse),
90
+ "r2": float(r2),
91
+ }
92
+
93
+
94
+ def pearson_dispersion(
95
+ y_true: np.ndarray,
96
+ y_pred: np.ndarray,
97
+ *,
98
+ dof: int,
99
+ ) -> float:
100
+ """Compute Pearson dispersion (chi-square / dof)."""
101
+
102
+ eps = 1e-8
103
+ adjusted = np.maximum(y_pred, eps)
104
+ resid = (y_true - y_pred) / np.sqrt(adjusted)
105
+ denom = max(dof, 1)
106
+ return float(np.sum(resid**2) / denom)
polyreact/predict.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line interface for polyreactivity predictions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ import pandas as pd
9
+
10
+ from .api import predict_batch
11
+ from .config import load_config
12
+ from .utils.io import read_table, write_table
13
+
14
+
15
+ def build_parser() -> argparse.ArgumentParser:
16
+ parser = argparse.ArgumentParser(description="Polyreactivity prediction CLI")
17
+ parser.add_argument(
18
+ "--input",
19
+ required=True,
20
+ help="Path to input CSV or JSONL file with sequences.",
21
+ )
22
+ parser.add_argument(
23
+ "--output",
24
+ required=True,
25
+ help="Path to write predictions CSV.",
26
+ )
27
+ parser.add_argument(
28
+ "--config",
29
+ default="configs/default.yaml",
30
+ help="Path to configuration YAML file.",
31
+ )
32
+ parser.add_argument(
33
+ "--backend",
34
+ choices=["plm", "descriptors", "concat"],
35
+ help="Override feature backend from config.",
36
+ )
37
+ parser.add_argument(
38
+ "--plm-model",
39
+ help="Override PLM model name.",
40
+ )
41
+ parser.add_argument(
42
+ "--weights",
43
+ required=True,
44
+ help="Path to trained model artifact (joblib).",
45
+ )
46
+ parser.add_argument(
47
+ "--heavy-only",
48
+ dest="heavy_only",
49
+ action="store_true",
50
+ default=True,
51
+ help="Use only heavy chains (default).",
52
+ )
53
+ parser.add_argument(
54
+ "--paired",
55
+ dest="heavy_only",
56
+ action="store_false",
57
+ help="Use paired heavy/light chains if available.",
58
+ )
59
+ parser.add_argument(
60
+ "--batch-size",
61
+ type=int,
62
+ default=8,
63
+ help="Batch size for model inference (PLM backend).",
64
+ )
65
+ parser.add_argument(
66
+ "--device",
67
+ choices=["auto", "cpu", "cuda"],
68
+ help="Computation device override.",
69
+ )
70
+ parser.add_argument(
71
+ "--cache-dir",
72
+ help="Cache directory for embeddings.",
73
+ )
74
+ return parser
75
+
76
+
77
+ def main(argv: list[str] | None = None) -> int:
78
+ parser = build_parser()
79
+ args = parser.parse_args(argv)
80
+
81
+ config = load_config(args.config)
82
+ df = read_table(args.input)
83
+
84
+ if "heavy_seq" not in df.columns and "heavy" not in df.columns:
85
+ parser.error("Input file must contain a 'heavy_seq' column (or 'heavy').")
86
+ if df.get("heavy_seq", df.get("heavy", "")).fillna("").str.len().eq(0).all():
87
+ parser.error("At least one non-empty heavy sequence is required.")
88
+
89
+ predictions = predict_batch(
90
+ df.to_dict("records"),
91
+ config=config,
92
+ backend=args.backend,
93
+ plm_model=args.plm_model,
94
+ weights=args.weights,
95
+ heavy_only=args.heavy_only,
96
+ batch_size=args.batch_size,
97
+ device=args.device,
98
+ cache_dir=args.cache_dir,
99
+ )
100
+
101
+ write_table(predictions, args.output)
102
+ return 0
103
+
104
+
105
+ if __name__ == "__main__":
106
+ raise SystemExit(main())
polyreact/train.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training entrypoint for the polyreactivity model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import subprocess
8
+ from pathlib import Path
9
+ from typing import Any, Sequence
10
+
11
+ import joblib
12
+ import numpy as np
13
+ import pandas as pd
14
+ from sklearn.metrics import roc_auc_score
15
+ from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
16
+ from sklearn.linear_model import LogisticRegression
17
+
18
+ from .config import Config, load_config
19
+ from .data_loaders import boughter, harvey, jain, shehata
20
+ from .data_loaders.utils import deduplicate_sequences
21
+ from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline
22
+ from .models.calibrate import fit_calibrator
23
+ from .models.linear import LinearModelConfig, TrainedModel, build_estimator, train_linear_model
24
+ from .utils.io import write_table
25
+ from .utils.logging import configure_logging
26
+ from .utils.metrics import bootstrap_metric_intervals, compute_metrics
27
+ from .utils.plots import plot_precision_recall, plot_reliability_curve, plot_roc_curve
28
+ from .utils.seeds import set_global_seeds
29
+
30
+ DATASET_LOADERS = {
31
+ "boughter": boughter.load_dataframe,
32
+ "jain": jain.load_dataframe,
33
+ "shehata": shehata.load_dataframe,
34
+ "shehata_curated": shehata.load_dataframe,
35
+ "harvey": harvey.load_dataframe,
36
+ }
37
+
38
+
39
+ def build_parser() -> argparse.ArgumentParser:
40
+ parser = argparse.ArgumentParser(description="Train polyreactivity model")
41
+ parser.add_argument("--config", default="configs/default.yaml", help="Config file")
42
+ parser.add_argument("--train", required=True, help="Training dataset path")
43
+ parser.add_argument(
44
+ "--eval",
45
+ nargs="*",
46
+ default=[],
47
+ help="Evaluation dataset paths",
48
+ )
49
+ parser.add_argument(
50
+ "--save-to",
51
+ default="artifacts/model.joblib",
52
+ help="Path to save trained model artifact",
53
+ )
54
+ parser.add_argument(
55
+ "--report-to",
56
+ default="artifacts",
57
+ help="Directory for metrics, predictions, and plots",
58
+ )
59
+ parser.add_argument(
60
+ "--train-loader",
61
+ choices=list(DATASET_LOADERS.keys()),
62
+ help="Optional explicit loader for training dataset",
63
+ )
64
+ parser.add_argument(
65
+ "--eval-loaders",
66
+ nargs="*",
67
+ help="Optional explicit loaders for evaluation datasets (aligned with --eval order)",
68
+ )
69
+ parser.add_argument(
70
+ "--backend",
71
+ choices=["plm", "descriptors", "concat"],
72
+ help="Override feature backend",
73
+ )
74
+ parser.add_argument("--plm-model", help="Override PLM model name")
75
+ parser.add_argument("--cache-dir", help="Override embedding cache directory")
76
+ parser.add_argument("--device", choices=["auto", "cpu", "cuda"], help="Device override")
77
+ parser.add_argument("--batch-size", type=int, default=8, help="Batch size for embeddings")
78
+ parser.add_argument(
79
+ "--heavy-only",
80
+ action="store_true",
81
+ default=True,
82
+ help="Use heavy chains only (default true)",
83
+ )
84
+ parser.add_argument(
85
+ "--paired",
86
+ dest="heavy_only",
87
+ action="store_false",
88
+ help="Use paired heavy/light chains when available.",
89
+ )
90
+ parser.add_argument(
91
+ "--include-families",
92
+ nargs="*",
93
+ help="Optional list of family names to retain in the training dataset",
94
+ )
95
+ parser.add_argument(
96
+ "--exclude-families",
97
+ nargs="*",
98
+ help="Optional list of family names to drop from the training dataset",
99
+ )
100
+ parser.add_argument(
101
+ "--include-species",
102
+ nargs="*",
103
+ help="Optional list of species (e.g. human, mouse) to retain",
104
+ )
105
+ parser.add_argument(
106
+ "--cv-group-column",
107
+ default="lineage",
108
+ help="Column name used to group samples during cross-validation (default: lineage)",
109
+ )
110
+ parser.add_argument(
111
+ "--no-group-cv",
112
+ action="store_true",
113
+ help="Disable group-aware cross-validation even if group column is present",
114
+ )
115
+ parser.add_argument(
116
+ "--keep-train-duplicates",
117
+ action="store_true",
118
+ help="Keep duplicate keys within the training dataset when deduplicating across splits",
119
+ )
120
+ parser.add_argument(
121
+ "--dedupe-key-columns",
122
+ nargs="*",
123
+ help="Columns used to detect duplicates across datasets (defaults to heavy/light sequences)",
124
+ )
125
+ parser.add_argument(
126
+ "--bootstrap-samples",
127
+ type=int,
128
+ default=200,
129
+ help="Number of bootstrap resamples for confidence intervals (0 to disable).",
130
+ )
131
+ parser.add_argument(
132
+ "--bootstrap-alpha",
133
+ type=float,
134
+ default=0.05,
135
+ help="Alpha for two-sided bootstrap confidence intervals (default 0.05 → 95% CI).",
136
+ )
137
+ parser.add_argument(
138
+ "--write-train-in-sample",
139
+ action="store_true",
140
+ help=(
141
+ "Persist in-sample metrics on the full training set; disabled by default to avoid"
142
+ " over-optimistic reporting."
143
+ ),
144
+ )
145
+ return parser
146
+
147
+
148
+ def _infer_loader(path: str, explicit: str | None) -> tuple[str, callable]:
149
+ if explicit:
150
+ return explicit, DATASET_LOADERS[explicit]
151
+ lower = Path(path).stem.lower()
152
+ for name, loader in DATASET_LOADERS.items():
153
+ if name in lower:
154
+ return name, loader
155
+ msg = f"Could not infer loader for dataset: {path}. Provide --train-loader/--eval-loaders."
156
+ raise ValueError(msg)
157
+
158
+
159
+ def _load_dataset(path: str, loader_name: str, loader_fn, *, heavy_only: bool) -> pd.DataFrame:
160
+ frame = loader_fn(path, heavy_only=heavy_only)
161
+ frame["source"] = loader_name
162
+ return frame
163
+
164
+
165
+ def _apply_dataset_filters(
166
+ frame: pd.DataFrame,
167
+ *,
168
+ include_families: Sequence[str] | None,
169
+ exclude_families: Sequence[str] | None,
170
+ include_species: Sequence[str] | None,
171
+ ) -> pd.DataFrame:
172
+ filtered = frame.copy()
173
+ if include_families:
174
+ families = {fam.lower() for fam in include_families}
175
+ if "family" in filtered.columns:
176
+ filtered = filtered[
177
+ filtered["family"].astype(str).str.lower().isin(families)
178
+ ]
179
+ if exclude_families:
180
+ families_ex = {fam.lower() for fam in exclude_families}
181
+ if "family" in filtered.columns:
182
+ filtered = filtered[
183
+ ~filtered["family"].astype(str).str.lower().isin(families_ex)
184
+ ]
185
+ if include_species:
186
+ species_set = {spec.lower() for spec in include_species}
187
+ if "species" in filtered.columns:
188
+ filtered = filtered[
189
+ filtered["species"].astype(str).str.lower().isin(species_set)
190
+ ]
191
+ return filtered.reset_index(drop=True)
192
+
193
+
194
+ def main(argv: Sequence[str] | None = None) -> int:
195
+ parser = build_parser()
196
+ args = parser.parse_args(argv)
197
+
198
+ config = load_config(args.config)
199
+ if args.device:
200
+ config.device = args.device
201
+ if args.backend:
202
+ config.feature_backend.type = args.backend
203
+ if args.cache_dir:
204
+ config.feature_backend.cache_dir = args.cache_dir
205
+ if args.plm_model:
206
+ config.feature_backend.plm_model_name = args.plm_model
207
+
208
+ logger = configure_logging()
209
+ set_global_seeds(config.seed)
210
+ _log_environment(logger)
211
+
212
+ heavy_only = args.heavy_only
213
+
214
+ train_name, train_loader = _infer_loader(args.train, args.train_loader)
215
+ train_df = _load_dataset(args.train, train_name, train_loader, heavy_only=heavy_only)
216
+ train_df = _apply_dataset_filters(
217
+ train_df,
218
+ include_families=args.include_families,
219
+ exclude_families=args.exclude_families,
220
+ include_species=args.include_species,
221
+ )
222
+
223
+ eval_frames: list[pd.DataFrame] = []
224
+ if args.eval:
225
+ loaders_iter = args.eval_loaders or []
226
+ for idx, eval_path in enumerate(args.eval):
227
+ explicit = loaders_iter[idx] if idx < len(loaders_iter) else None
228
+ eval_name, eval_loader = _infer_loader(eval_path, explicit)
229
+ eval_df = _load_dataset(eval_path, eval_name, eval_loader, heavy_only=heavy_only)
230
+ eval_frames.append(eval_df)
231
+
232
+ all_frames = [train_df, *eval_frames]
233
+ dedup_keep = {0} if args.keep_train_duplicates else set()
234
+ deduped_frames = deduplicate_sequences(
235
+ all_frames,
236
+ heavy_only=heavy_only,
237
+ key_columns=args.dedupe_key_columns,
238
+ keep_intra_frames=dedup_keep,
239
+ )
240
+ train_df = deduped_frames[0]
241
+ eval_frames = deduped_frames[1:]
242
+
243
+ pipeline_factory = lambda: build_feature_pipeline( # noqa: E731
244
+ config,
245
+ backend_override=args.backend,
246
+ plm_model_override=args.plm_model,
247
+ cache_dir_override=args.cache_dir,
248
+ )
249
+
250
+ model_config = LinearModelConfig(
251
+ head=config.model.head,
252
+ C=config.model.C,
253
+ class_weight=config.model.class_weight,
254
+ )
255
+
256
+ groups = None
257
+ if not args.no_group_cv and args.cv_group_column:
258
+ if args.cv_group_column in train_df.columns:
259
+ groups = train_df[args.cv_group_column].fillna("").astype(str).to_numpy()
260
+ else:
261
+ logger.warning(
262
+ "Group column '%s' not found in training dataframe; falling back to standard CV",
263
+ args.cv_group_column,
264
+ )
265
+
266
+ cv_results = _cross_validate(
267
+ train_df,
268
+ pipeline_factory,
269
+ model_config,
270
+ config,
271
+ heavy_only=heavy_only,
272
+ batch_size=args.batch_size,
273
+ groups=groups,
274
+ )
275
+
276
+ trained_model, feature_pipeline = _fit_full_model(
277
+ train_df,
278
+ pipeline_factory,
279
+ model_config,
280
+ config,
281
+ heavy_only=heavy_only,
282
+ batch_size=args.batch_size,
283
+ )
284
+
285
+ outputs_dir = Path(args.report_to)
286
+ outputs_dir.mkdir(parents=True, exist_ok=True)
287
+
288
+ metrics_df, preds_rows = _evaluate_datasets(
289
+ train_df,
290
+ eval_frames,
291
+ trained_model,
292
+ feature_pipeline,
293
+ config,
294
+ cv_results,
295
+ outputs_dir,
296
+ batch_size=args.batch_size,
297
+ heavy_only=heavy_only,
298
+ bootstrap_samples=args.bootstrap_samples,
299
+ bootstrap_alpha=args.bootstrap_alpha,
300
+ write_train_in_sample=args.write_train_in_sample,
301
+ )
302
+
303
+ write_table(metrics_df, outputs_dir / config.io.metrics_filename)
304
+ preds_df = pd.DataFrame(preds_rows)
305
+ write_table(preds_df, outputs_dir / config.io.preds_filename)
306
+
307
+ artifact = {
308
+ "config": config,
309
+ "feature_state": feature_pipeline.get_state(),
310
+ "model": trained_model,
311
+ }
312
+ Path(args.save_to).parent.mkdir(parents=True, exist_ok=True)
313
+ joblib.dump(artifact, args.save_to)
314
+
315
+ logger.info("Training complete. Metrics written to %s", outputs_dir)
316
+ return 0
317
+
318
+
319
+ def _cross_validate(
320
+ train_df: pd.DataFrame,
321
+ pipeline_factory,
322
+ model_config: LinearModelConfig,
323
+ config: Config,
324
+ *,
325
+ heavy_only: bool,
326
+ batch_size: int,
327
+ groups: np.ndarray | None = None,
328
+ ):
329
+ y = train_df["label"].to_numpy(dtype=int)
330
+ n_samples = len(y)
331
+ # Determine a safe number of folds for tiny fixtures; prefer the configured value
332
+ # but never exceed the number of samples. Fall back to non-stratified KFold when
333
+ # per-class counts are too small for stratification (e.g., 1 positive/1 negative).
334
+ n_splits = max(2, min(config.training.cv_folds, n_samples))
335
+
336
+ use_stratified = True
337
+ class_counts = np.bincount(y) if y.size else np.array([])
338
+ if class_counts.size > 0 and (class_counts.min(initial=0) < n_splits):
339
+ use_stratified = False
340
+
341
+ if groups is not None and use_stratified:
342
+ splitter = StratifiedGroupKFold(
343
+ n_splits=n_splits,
344
+ shuffle=True,
345
+ random_state=config.seed,
346
+ )
347
+ split_iter = splitter.split(train_df, y, groups)
348
+ elif use_stratified:
349
+ splitter = StratifiedKFold(
350
+ n_splits=n_splits,
351
+ shuffle=True,
352
+ random_state=config.seed,
353
+ )
354
+ split_iter = splitter.split(train_df, y)
355
+ else:
356
+ # Non-stratified fallback for extreme class imbalance / tiny datasets
357
+ from sklearn.model_selection import KFold # local import to limit surface
358
+
359
+ splitter = KFold(n_splits=n_splits, shuffle=True, random_state=config.seed)
360
+ split_iter = splitter.split(train_df)
361
+ oof_scores = np.zeros(len(train_df), dtype=float)
362
+ metrics_per_fold: list[dict[str, float]] = []
363
+
364
+ for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
365
+ train_slice = train_df.iloc[train_idx].reset_index(drop=True)
366
+ val_slice = train_df.iloc[val_idx].reset_index(drop=True)
367
+
368
+ pipeline: FeaturePipeline = pipeline_factory()
369
+ X_train = pipeline.fit_transform(train_slice, heavy_only=heavy_only, batch_size=batch_size)
370
+ X_val = pipeline.transform(val_slice, heavy_only=heavy_only, batch_size=batch_size)
371
+
372
+ y_train = y[train_idx]
373
+ y_val = y[val_idx]
374
+
375
+ # Handle degenerate folds where training data contains a single class
376
+ if np.unique(y_train).size < 2:
377
+ fallback_prob = float(y.mean()) if y.size else 0.5
378
+ y_scores = np.full(X_val.shape[0], fallback_prob, dtype=float)
379
+ else:
380
+ trained = train_linear_model(
381
+ X_train, y_train, config=model_config, random_state=config.seed
382
+ )
383
+ calibrator = _fit_model_calibrator(
384
+ model_config,
385
+ config,
386
+ X_train,
387
+ y_train,
388
+ base_estimator=trained.estimator,
389
+ )
390
+ trained.calibrator = calibrator
391
+ if calibrator is not None:
392
+ y_scores = calibrator.predict_proba(X_val)[:, 1]
393
+ else:
394
+ y_scores = trained.predict_proba(X_val)
395
+ oof_scores[val_idx] = y_scores
396
+
397
+ fold_metrics = compute_metrics(y_val, y_scores)
398
+ try:
399
+ fold_metrics["roc_auc"] = float(roc_auc_score(y_val, y_scores))
400
+ except ValueError:
401
+ # For tiny validation folds with a single class, ROC-AUC is undefined
402
+ pass
403
+ metrics_per_fold.append(fold_metrics)
404
+
405
+ metrics_mean: dict[str, float] = {}
406
+ metrics_std: dict[str, float] = {}
407
+ metric_names = list(metrics_per_fold[0].keys()) if metrics_per_fold else []
408
+ for metric in metric_names:
409
+ values = [fold[metric] for fold in metrics_per_fold]
410
+ metrics_mean[metric] = float(np.mean(values))
411
+ metrics_std[metric] = float(np.std(values, ddof=1))
412
+
413
+ return {
414
+ "oof_scores": oof_scores,
415
+ "metrics_per_fold": metrics_per_fold,
416
+ "metrics_mean": metrics_mean,
417
+ "metrics_std": metrics_std,
418
+ }
419
+
420
+
421
+ def _fit_full_model(
422
+ train_df: pd.DataFrame,
423
+ pipeline_factory,
424
+ model_config: LinearModelConfig,
425
+ config: Config,
426
+ *,
427
+ heavy_only: bool,
428
+ batch_size: int,
429
+ ) -> tuple[TrainedModel, FeaturePipeline]:
430
+ pipeline: FeaturePipeline = pipeline_factory()
431
+ X_train = pipeline.fit_transform(train_df, heavy_only=heavy_only, batch_size=batch_size)
432
+ y_train = train_df["label"].to_numpy(dtype=int)
433
+
434
+ trained = train_linear_model(X_train, y_train, config=model_config, random_state=config.seed)
435
+ calibrator = _fit_model_calibrator(
436
+ model_config,
437
+ config,
438
+ X_train,
439
+ y_train,
440
+ base_estimator=trained.estimator,
441
+ )
442
+ trained.calibrator = calibrator
443
+
444
+ return trained, pipeline
445
+
446
+
447
+ def _evaluate_datasets(
448
+ train_df: pd.DataFrame,
449
+ eval_frames: list[pd.DataFrame],
450
+ trained_model: TrainedModel,
451
+ pipeline: FeaturePipeline,
452
+ config: Config,
453
+ cv_results: dict,
454
+ outputs_dir: Path,
455
+ *,
456
+ batch_size: int,
457
+ heavy_only: bool,
458
+ bootstrap_samples: int,
459
+ bootstrap_alpha: float,
460
+ write_train_in_sample: bool,
461
+ ):
462
+ metrics_lookup: dict[str, dict[str, float]] = {}
463
+ preds_rows: list[dict[str, float]] = []
464
+
465
+ metrics_mean: dict[str, float] = cv_results["metrics_mean"]
466
+ metrics_std: dict[str, float] = cv_results["metrics_std"]
467
+
468
+ for metric_name, value in metrics_mean.items():
469
+ metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
470
+ "train_cv_mean"
471
+ ] = value
472
+ for metric_name, value in metrics_std.items():
473
+ metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
474
+ "train_cv_std"
475
+ ] = value
476
+
477
+ train_scores = cv_results["oof_scores"]
478
+ train_preds = train_df[["id", "source", "label"]].copy()
479
+ train_preds["y_true"] = train_preds["label"]
480
+ train_preds["y_score"] = train_scores
481
+ train_preds["y_pred"] = (train_scores >= 0.5).astype(int)
482
+ train_preds["split"] = "train_cv_oof"
483
+ preds_rows.extend(
484
+ train_preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
485
+ )
486
+
487
+ plot_reliability_curve(
488
+ train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "reliability_train.png"
489
+ )
490
+ plot_precision_recall(
491
+ train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "pr_train.png"
492
+ )
493
+ plot_roc_curve(train_preds["y_true"], train_preds["y_score"], path=outputs_dir / "roc_train.png")
494
+
495
+ if bootstrap_samples > 0:
496
+ ci_map = bootstrap_metric_intervals(
497
+ train_preds["y_true"],
498
+ train_preds["y_score"],
499
+ n_bootstrap=bootstrap_samples,
500
+ alpha=bootstrap_alpha,
501
+ random_state=config.seed,
502
+ )
503
+ for metric_name, stats in ci_map.items():
504
+ row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
505
+ row["train_cv_ci_lower"] = stats.get("ci_lower")
506
+ row["train_cv_ci_upper"] = stats.get("ci_upper")
507
+ row["train_cv_ci_median"] = stats.get("ci_median")
508
+
509
+ if write_train_in_sample:
510
+ train_features_full = pipeline.transform(
511
+ train_df, heavy_only=heavy_only, batch_size=batch_size
512
+ )
513
+ train_full_scores = trained_model.predict_proba(train_features_full)
514
+ train_full_metrics = compute_metrics(
515
+ train_df["label"].to_numpy(dtype=int), train_full_scores
516
+ )
517
+ (outputs_dir / "train_in_sample.json").write_text(
518
+ json.dumps(train_full_metrics, indent=2),
519
+ encoding="utf-8",
520
+ )
521
+
522
+ for frame in eval_frames:
523
+ if frame.empty:
524
+ continue
525
+ features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
526
+ scores = trained_model.predict_proba(features)
527
+ y_true = frame["label"].to_numpy(dtype=int)
528
+ metrics = compute_metrics(y_true, scores)
529
+ dataset_name = frame["source"].iloc[0]
530
+ for metric_name, value in metrics.items():
531
+ metrics_lookup.setdefault(metric_name, {"metric": metric_name})[
532
+ dataset_name
533
+ ] = value
534
+
535
+ preds = frame[["id", "source", "label"]].copy()
536
+ preds["y_true"] = preds["label"]
537
+ preds["y_score"] = scores
538
+ preds["y_pred"] = (scores >= 0.5).astype(int)
539
+ preds["split"] = dataset_name
540
+ preds_rows.extend(
541
+ preds[["id", "source", "split", "y_true", "y_score", "y_pred"]].to_dict("records")
542
+ )
543
+
544
+ plot_reliability_curve(
545
+ preds["y_true"],
546
+ preds["y_score"],
547
+ path=outputs_dir / f"reliability_{dataset_name}.png",
548
+ )
549
+ plot_precision_recall(
550
+ preds["y_true"],
551
+ preds["y_score"],
552
+ path=outputs_dir / f"pr_{dataset_name}.png",
553
+ )
554
+ plot_roc_curve(
555
+ preds["y_true"], preds["y_score"], path=outputs_dir / f"roc_{dataset_name}.png"
556
+ )
557
+
558
+ if bootstrap_samples > 0:
559
+ ci_map = bootstrap_metric_intervals(
560
+ preds["y_true"],
561
+ preds["y_score"],
562
+ n_bootstrap=bootstrap_samples,
563
+ alpha=bootstrap_alpha,
564
+ random_state=config.seed,
565
+ )
566
+ for metric_name, stats in ci_map.items():
567
+ row = metrics_lookup.setdefault(metric_name, {"metric": metric_name})
568
+ row[f"{dataset_name}_ci_lower"] = stats.get("ci_lower")
569
+ row[f"{dataset_name}_ci_upper"] = stats.get("ci_upper")
570
+ row[f"{dataset_name}_ci_median"] = stats.get("ci_median")
571
+
572
+ metrics_df = pd.DataFrame(sorted(metrics_lookup.values(), key=lambda row: row["metric"]))
573
+ return metrics_df, preds_rows
574
+
575
+
576
+ def _fit_model_calibrator(
577
+ model_config: LinearModelConfig,
578
+ config: Config,
579
+ X: np.ndarray,
580
+ y: np.ndarray,
581
+ *,
582
+ base_estimator: Any | None = None,
583
+ ):
584
+ method = config.calibration.method
585
+ if not method:
586
+ return None
587
+ if len(np.unique(y)) < 2:
588
+ return None
589
+
590
+ if len(y) >= 4:
591
+ cv_cal = min(config.training.cv_folds, max(2, len(y) // 2))
592
+ estimator = build_estimator(config=model_config, random_state=config.seed)
593
+ if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
594
+ estimator.set_params(solver="lbfgs")
595
+ calibrator = fit_calibrator(estimator, X, y, method=method, cv=cv_cal)
596
+ else:
597
+ estimator = base_estimator or build_estimator(config=model_config, random_state=config.seed)
598
+ if isinstance(estimator, LogisticRegression) and X.shape[0] >= 1000:
599
+ estimator.set_params(solver="lbfgs")
600
+ estimator.fit(X, y)
601
+ calibrator = fit_calibrator(estimator, X, y, method=method, cv="prefit")
602
+ return calibrator
603
+
604
+
605
+ def _log_environment(logger) -> None:
606
+ try:
607
+ git_head = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
608
+ except Exception: # pragma: no cover - best effort
609
+ git_head = "unknown"
610
+ try:
611
+ pip_freeze = subprocess.check_output(["pip", "freeze"], text=True)
612
+ except Exception: # pragma: no cover
613
+ pip_freeze = ""
614
+ logger.info("git_head=%s", git_head)
615
+ logger.info("pip_freeze=%s", pip_freeze)
616
+
617
+
618
+ if __name__ == "__main__":
619
+ raise SystemExit(main())
polyreact/utils/__pycache__/io.cpython-311.pyc ADDED
Binary file (2.23 kB). View file
 
polyreact/utils/__pycache__/logging.cpython-311.pyc ADDED
Binary file (2.13 kB). View file
 
polyreact/utils/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (7.67 kB). View file