|
|
"""Public Python API for polyreactivity prediction.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Iterable |
|
|
|
|
|
import copy |
|
|
|
|
|
import joblib |
|
|
import pandas as pd |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
|
|
|
from .config import Config, load_config |
|
|
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline |
|
|
|
|
|
|
|
|
def predict_batch( |
|
|
records: Iterable[dict], |
|
|
*, |
|
|
config: Config | str | Path | None = None, |
|
|
backend: str | None = None, |
|
|
plm_model: str | None = None, |
|
|
weights: str | Path | None = None, |
|
|
heavy_only: bool = True, |
|
|
batch_size: int = 8, |
|
|
device: str | None = None, |
|
|
cache_dir: str | None = None, |
|
|
) -> pd.DataFrame: |
|
|
"""Predict polyreactivity scores for a batch of sequences.""" |
|
|
|
|
|
records_list = list(records) |
|
|
if not records_list: |
|
|
return pd.DataFrame(columns=["id", "score", "pred"]) |
|
|
|
|
|
artifact = _load_artifact(weights) |
|
|
|
|
|
if config is None: |
|
|
artifact_config = artifact.get("config") |
|
|
if isinstance(artifact_config, Config): |
|
|
config = copy.deepcopy(artifact_config) |
|
|
else: |
|
|
config = load_config("configs/default.yaml") |
|
|
elif isinstance(config, (str, Path)): |
|
|
config = load_config(config) |
|
|
else: |
|
|
config = copy.deepcopy(config) |
|
|
|
|
|
if backend: |
|
|
config.feature_backend.type = backend |
|
|
if plm_model: |
|
|
config.feature_backend.plm_model_name = plm_model |
|
|
if device: |
|
|
config.device = device |
|
|
if cache_dir: |
|
|
config.feature_backend.cache_dir = cache_dir |
|
|
|
|
|
pipeline = _restore_pipeline(config, artifact) |
|
|
trained_model = artifact["model"] |
|
|
|
|
|
frame = pd.DataFrame(records_list) |
|
|
if frame.empty: |
|
|
raise ValueError("Prediction requires at least one record.") |
|
|
if "id" not in frame.columns: |
|
|
frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str) |
|
|
if "heavy_seq" in frame.columns: |
|
|
frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str) |
|
|
else: |
|
|
heavy_series = frame.get("heavy") |
|
|
if heavy_series is None: |
|
|
heavy_series = pd.Series([""] * len(frame)) |
|
|
frame["heavy_seq"] = heavy_series.fillna("").astype(str) |
|
|
|
|
|
if "light_seq" in frame.columns: |
|
|
frame["light_seq"] = frame["light_seq"].fillna("").astype(str) |
|
|
else: |
|
|
light_series = frame.get("light") |
|
|
if light_series is None: |
|
|
light_series = pd.Series([""] * len(frame)) |
|
|
frame["light_seq"] = light_series.fillna("").astype(str) |
|
|
|
|
|
if heavy_only: |
|
|
frame["light_seq"] = "" |
|
|
if frame["heavy_seq"].str.len().eq(0).all(): |
|
|
raise ValueError("No heavy chain sequences provided for prediction.") |
|
|
|
|
|
features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size) |
|
|
scores = trained_model.predict_proba(features) |
|
|
preds = (scores >= 0.5).astype(int) |
|
|
|
|
|
return pd.DataFrame( |
|
|
{ |
|
|
"id": frame["id"].astype(str), |
|
|
"score": scores, |
|
|
"pred": preds, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
def _load_artifact(weights: str | Path | None) -> dict: |
|
|
if weights is None: |
|
|
msg = "Prediction requires a path to model weights" |
|
|
raise ValueError(msg) |
|
|
artifact = joblib.load(weights) |
|
|
if not isinstance(artifact, dict): |
|
|
msg = "Model artifact must be a dictionary" |
|
|
raise ValueError(msg) |
|
|
return artifact |
|
|
|
|
|
|
|
|
def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline: |
|
|
pipeline = build_feature_pipeline(config) |
|
|
state = artifact.get("feature_state") |
|
|
if isinstance(state, FeaturePipelineState): |
|
|
pipeline.load_state(state) |
|
|
if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None: |
|
|
pipeline._plm_scaler = StandardScaler() |
|
|
return pipeline |
|
|
|
|
|
msg = "Model artifact is missing feature pipeline state" |
|
|
raise ValueError(msg) |
|
|
|