Spaces:
Sleeping
Sleeping
| """Public Python API for polyreactivity prediction.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Iterable | |
| import copy | |
| import joblib | |
| import pandas as pd | |
| from sklearn.preprocessing import StandardScaler | |
| from .config import Config, load_config | |
| from .features.pipeline import FeaturePipeline, FeaturePipelineState, build_feature_pipeline | |
| def predict_batch( # noqa: ANN003 | |
| records: Iterable[dict], | |
| *, | |
| config: Config | str | Path | None = None, | |
| backend: str | None = None, | |
| plm_model: str | None = None, | |
| weights: str | Path | None = None, | |
| heavy_only: bool = True, | |
| batch_size: int = 8, | |
| device: str | None = None, | |
| cache_dir: str | None = None, | |
| ) -> pd.DataFrame: | |
| """Predict polyreactivity scores for a batch of sequences.""" | |
| records_list = list(records) | |
| if not records_list: | |
| return pd.DataFrame(columns=["id", "score", "pred"]) | |
| artifact = _load_artifact(weights) | |
| if config is None: | |
| artifact_config = artifact.get("config") | |
| if isinstance(artifact_config, Config): | |
| config = copy.deepcopy(artifact_config) | |
| else: | |
| config = load_config("configs/default.yaml") | |
| elif isinstance(config, (str, Path)): | |
| config = load_config(config) | |
| else: | |
| config = copy.deepcopy(config) | |
| if backend: | |
| config.feature_backend.type = backend | |
| if plm_model: | |
| config.feature_backend.plm_model_name = plm_model | |
| if device: | |
| config.device = device | |
| if cache_dir: | |
| config.feature_backend.cache_dir = cache_dir | |
| pipeline = _restore_pipeline(config, artifact) | |
| trained_model = artifact["model"] | |
| frame = pd.DataFrame(records_list) | |
| if frame.empty: | |
| raise ValueError("Prediction requires at least one record.") | |
| if "id" not in frame.columns: | |
| frame["id"] = frame.get("sequence_id", range(len(frame))).astype(str) | |
| if "heavy_seq" in frame.columns: | |
| frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str) | |
| else: | |
| heavy_series = frame.get("heavy") | |
| if heavy_series is None: | |
| heavy_series = pd.Series([""] * len(frame)) | |
| frame["heavy_seq"] = heavy_series.fillna("").astype(str) | |
| if "light_seq" in frame.columns: | |
| frame["light_seq"] = frame["light_seq"].fillna("").astype(str) | |
| else: | |
| light_series = frame.get("light") | |
| if light_series is None: | |
| light_series = pd.Series([""] * len(frame)) | |
| frame["light_seq"] = light_series.fillna("").astype(str) | |
| if heavy_only: | |
| frame["light_seq"] = "" | |
| if frame["heavy_seq"].str.len().eq(0).all(): | |
| raise ValueError("No heavy chain sequences provided for prediction.") | |
| features = pipeline.transform(frame, heavy_only=heavy_only, batch_size=batch_size) | |
| scores = trained_model.predict_proba(features) | |
| preds = (scores >= 0.5).astype(int) | |
| return pd.DataFrame( | |
| { | |
| "id": frame["id"].astype(str), | |
| "score": scores, | |
| "pred": preds, | |
| } | |
| ) | |
| def _load_artifact(weights: str | Path | None) -> dict: | |
| if weights is None: | |
| msg = "Prediction requires a path to model weights" | |
| raise ValueError(msg) | |
| artifact = joblib.load(weights) | |
| if not isinstance(artifact, dict): | |
| msg = "Model artifact must be a dictionary" | |
| raise ValueError(msg) | |
| return artifact | |
| def _restore_pipeline(config: Config, artifact: dict) -> FeaturePipeline: | |
| pipeline = build_feature_pipeline(config) | |
| state = artifact.get("feature_state") | |
| if isinstance(state, FeaturePipelineState): | |
| pipeline.load_state(state) | |
| if pipeline.backend.type in {"plm", "concat"} and pipeline._plm_scaler is None: | |
| pipeline._plm_scaler = StandardScaler() | |
| return pipeline | |
| msg = "Model artifact is missing feature pipeline state" | |
| raise ValueError(msg) | |