File size: 3,910 Bytes
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Public Python API for polyreactivity prediction."""

from __future__ import annotations

from pathlib import Path
from typing import Iterable

import copy

import joblib
import pandas as pd
from sklearn.preprocessing import StandardScaler

from .config import Config, load_config
from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline


def predict_batch(  # noqa: ANN003
    records: Iterable[dict],
    *,
    config: Config | str | Path | None = None,
    backend: str | None = None,
    plm_model: str | None = None,
    weights: str | Path | None = None,
    heavy_only: bool = True,
    batch_size: int = 8,
    device: str | None = None,
    cache_dir: str | None = None,
) -> pd.DataFrame:
    """Predict polyreactivity scores for a batch of sequences."""

    records_list = list(records)
    if not records_list:
        return pd.DataFrame(columns=["id", "score", "pred"])

    artifact = _load_artifact(weights)

    if config is None:
        artifact_config = artifact.get("config")
        if isinstance(artifact_config, Config):
            config = copy.deepcopy(artifact_config)
        else:
            config = load_config("configs/default.yaml")
    elif isinstance(config, (str, Path)):
        config = load_config(config)
    else:
        config = copy.deepcopy(config)

    if backend:
        config.feature_backend.type = backend
    if plm_model:
        config.feature_backend.plm_model_name = plm_model
    if device:
        config.device = device
    if cache_dir:
        config.feature_backend.cache_dir = cache_dir

    pipeline = _restore_pipeline(config, artifact)
    trained_model = artifact["model"]

    frame = pd.DataFrame(records_list)
    if frame.empty:
        raise ValueError("Prediction requires at least one record.")
    if "id" not in frame.columns:
        frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str)
    if "heavy_seq" in frame.columns:
        frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str)
    else:
        heavy_series = frame.get("heavy")
        if heavy_series is None:
            heavy_series = pd.Series([""] * len(frame))
        frame["heavy_seq"] = heavy_series.fillna("").astype(str)

    if "light_seq" in frame.columns:
        frame["light_seq"] = frame["light_seq"].fillna("").astype(str)
    else:
        light_series = frame.get("light")
        if light_series is None:
            light_series = pd.Series([""] * len(frame))
        frame["light_seq"] = light_series.fillna("").astype(str)

    if heavy_only:
        frame["light_seq"] = ""
    if frame["heavy_seq"].str.len().eq(0).all():
        raise ValueError("No heavy chain sequences provided for prediction.")

    features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size)
    scores = trained_model.predict_proba(features)
    preds = (scores >= 0.5).astype(int)

    return pd.DataFrame(
        {
            "id": frame["id"].astype(str),
            "score": scores,
            "pred": preds,
        }
    )


def _load_artifact(weights: str | Path | None) -> dict:
    if weights is None:
        msg = "Prediction requires a path to model weights"
        raise ValueError(msg)
    artifact = joblib.load(weights)
    if not isinstance(artifact, dict):
        msg = "Model artifact must be a dictionary"
        raise ValueError(msg)
    return artifact


def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline:
    pipeline = build_feature_pipeline(config)
    state = artifact.get("feature_state")
    if isinstance(state, FeaturePipelineState):
        pipeline.load_state(state)
        if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None:
            pipeline._plm_scaler = StandardScaler()
        return pipeline

    msg = "Model artifact is missing feature pipeline state"
    raise ValueError(msg)