Spaces:
Sleeping
Sleeping
| """Gradio Space for polyreactivity prediction.""" | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| import pandas as pd | |
| from scipy.stats import spearmanr | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| average_precision_score, | |
| brier_score_loss, | |
| f1_score, | |
| roc_auc_score, | |
| ) | |
| from polyreact.api import predict_batch | |
| DEFAULT_MODEL_PATH = Path(os.environ.get("POLYREACT_MODEL_PATH", "artifacts/model.joblib")).resolve() | |
| DEFAULT_CONFIG_PATH = Path(os.environ.get("POLYREACT_CONFIG_PATH", "configs/default.yaml")).resolve() | |
| def _resolve_model_path(upload: Optional[gr.File]) -> Path: | |
| if upload is not None: | |
| return Path(upload.name) | |
| if DEFAULT_MODEL_PATH.exists(): | |
| return DEFAULT_MODEL_PATH | |
| raise FileNotFoundError( | |
| "Model artifact not found. Upload a trained model (.joblib) to run predictions." | |
| ) | |
| def _predict_single( | |
| heavy_seq: str, | |
| light_seq: str, | |
| use_paired: bool, | |
| backend: str, | |
| model_file: Optional[gr.File], | |
| ) -> tuple[str, float, int]: | |
| model_path = _resolve_model_path(model_file) | |
| heavy_seq = (heavy_seq or "").strip() | |
| light_seq = (light_seq or "").strip() | |
| if not heavy_seq: | |
| raise gr.Error("Please provide a heavy-chain amino acid sequence.") | |
| record = { | |
| "id": "sample", | |
| "heavy_seq": heavy_seq, | |
| "light_seq": light_seq, | |
| } | |
| progress = gr.Progress(track_tqdm=True) | |
| progress(0.02, "📦 Downloading ESM-1v weights (first run can take a few minutes)…", total=None) | |
| preds = predict_batch( | |
| [record], | |
| weights=model_path, | |
| heavy_only=not use_paired, | |
| backend=backend or None, | |
| config=DEFAULT_CONFIG_PATH if DEFAULT_CONFIG_PATH.exists() else None, | |
| ) | |
| progress(1.0, "✅ Prediction complete") | |
| score = float(preds.iloc[0]["score"]) | |
| pred = int(preds.iloc[0]["pred"]) | |
| label = "Polyreactive" if pred == 1 else "Non-polyreactive" | |
| return label, score, pred | |
| def _format_metric(value: float) -> float: | |
| return float(f"{value:.4f}") | |
| def _compute_metrics(results: pd.DataFrame) -> tuple[pd.DataFrame, list[str], Optional[str]]: | |
| metrics_rows: list[dict[str, float]] = [] | |
| warnings: list[str] = [] | |
| spearman_text: Optional[str] = None | |
| if "label" in results.columns: | |
| label_series = results["label"].dropna() | |
| valid_labels = label_series.isin({0, 1}).all() | |
| if valid_labels and label_series.nunique() > 1: | |
| y_true = results.loc[label_series.index, "label"].astype(int) | |
| y_score = results.loc[label_series.index, "score"].astype(float) | |
| y_pred = results.loc[label_series.index, "pred"].astype(int) | |
| metrics_rows.append({"metric": "Accuracy", "value": _format_metric(accuracy_score(y_true, y_pred))}) | |
| metrics_rows.append({"metric": "F1", "value": _format_metric(f1_score(y_true, y_pred))}) | |
| try: | |
| roc = roc_auc_score(y_true, y_score) | |
| metrics_rows.append({"metric": "ROC-AUC", "value": _format_metric(roc)}) | |
| except ValueError: | |
| warnings.append("ROC-AUC skipped (requires both positive and negative labels).") | |
| try: | |
| pr_auc = average_precision_score(y_true, y_score) | |
| metrics_rows.append({"metric": "PR-AUC", "value": _format_metric(pr_auc)}) | |
| except ValueError: | |
| warnings.append("PR-AUC skipped (requires both positive and negative labels).") | |
| try: | |
| brier = brier_score_loss(y_true, y_score) | |
| metrics_rows.append({"metric": "Brier", "value": _format_metric(brier)}) | |
| except ValueError: | |
| warnings.append("Brier score skipped (invalid probability values).") | |
| else: | |
| warnings.append("Label column found but must contain binary 0/1 values with both classes present.") | |
| if "reactivity_count" in results.columns: | |
| valid = results[["reactivity_count", "score"]].dropna() | |
| if len(valid) > 2 and valid["reactivity_count"].nunique() > 1: | |
| stat, pval = spearmanr(valid["reactivity_count"], valid["score"]) | |
| if stat == stat: # NaN check | |
| spearman_text = f"Spearman ρ = {stat:.4f} (p = {pval:.3g})" | |
| else: | |
| warnings.append("Flag-count Spearman skipped (need ≥3 non-identical counts).") | |
| metrics_df = pd.DataFrame(metrics_rows) | |
| return metrics_df, warnings, spearman_text | |
| def _predict_batch( | |
| input_file: gr.File, | |
| use_paired: bool, | |
| backend: str, | |
| model_file: Optional[gr.File], | |
| ) -> tuple[gr.File, gr.DataFrame, gr.Textbox, gr.Markdown]: | |
| if input_file is None: | |
| raise gr.Error("Upload a CSV file with columns id, heavy_seq[, light_seq].") | |
| model_path = _resolve_model_path(model_file) | |
| input_path = Path(input_file.name) | |
| frame = pd.read_csv(input_path) | |
| required_cols = {"id", "heavy_seq"} | |
| if not required_cols.issubset(frame.columns): | |
| raise gr.Error("CSV must include at least 'id' and 'heavy_seq' columns.") | |
| records = frame.to_dict("records") | |
| progress = gr.Progress(track_tqdm=True) | |
| progress(0.02, "📦 Downloading ESM-1v weights (first run can take a few minutes)…", total=None) | |
| preds = predict_batch( | |
| records, | |
| weights=model_path, | |
| heavy_only=not use_paired, | |
| backend=backend or None, | |
| config=DEFAULT_CONFIG_PATH if DEFAULT_CONFIG_PATH.exists() else None, | |
| ) | |
| progress(1.0, "✅ Batch prediction complete") | |
| merged = frame.merge(preds, on="id", how="left") | |
| output_path = input_path.parent / "polyreact_predictions.csv" | |
| merged.to_csv(output_path, index=False) | |
| metrics_df, warnings, spearman_text = _compute_metrics(merged) | |
| metrics_update = gr.update(value=metrics_df, visible=not metrics_df.empty) | |
| spearman_update = gr.update(value=spearman_text or "", visible=spearman_text is not None) | |
| notes_update = gr.update( | |
| value="\n".join(f"- {msg}" for msg in warnings) if warnings else "", | |
| visible=bool(warnings), | |
| ) | |
| return ( | |
| gr.update(value=str(output_path), visible=True), | |
| metrics_update, | |
| spearman_update, | |
| notes_update, | |
| ) | |
| def make_interface() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Polyreactivity Predictor | |
| Provide an antibody heavy chain (and optional light chain) to estimate | |
| polyreactivity probability. Upload a trained model artifact or place it | |
| at `artifacts/model.joblib`. | |
| """ | |
| ) | |
| with gr.Tab("Single Sequence"): | |
| with gr.Row(): | |
| heavy_input = gr.Textbox( | |
| label="Heavy chain sequence", | |
| lines=6, | |
| placeholder="Enter amino acid sequence", | |
| ) | |
| light_input = gr.Textbox( | |
| label="Light chain sequence (optional)", | |
| lines=6, | |
| placeholder="Enter amino acid sequence", | |
| ) | |
| with gr.Row(): | |
| use_paired = gr.Checkbox(label="Use paired evaluation", value=False) | |
| backend_input = gr.Dropdown( | |
| label="Feature backend override", | |
| choices=["", "descriptors", "plm", "concat"], | |
| value="", | |
| ) | |
| model_upload = gr.File(label="Model artifact (.joblib)", file_types=[".joblib"], file_count="single") | |
| run_button = gr.Button("Predict", variant="primary") | |
| result_label = gr.Textbox(label="Prediction", interactive=False) | |
| result_score = gr.Number(label="Probability", precision=4) | |
| result_class = gr.Number(label="Binary call (1=polyreactive)") | |
| run_button.click( | |
| _predict_single, | |
| inputs=[heavy_input, light_input, use_paired, backend_input, model_upload], | |
| outputs=[result_label, result_score, result_class], | |
| ) | |
| with gr.Tab("Batch CSV"): | |
| batch_file = gr.File(label="Upload CSV", file_types=[".csv"], file_count="single") | |
| batch_paired = gr.Checkbox(label="Use paired evaluation", value=False) | |
| batch_backend = gr.Dropdown( | |
| label="Feature backend override", | |
| choices=["", "descriptors", "plm", "concat"], | |
| value="", | |
| ) | |
| batch_model = gr.File(label="Model artifact (.joblib)", file_types=[".joblib"], file_count="single") | |
| batch_button = gr.Button("Run batch predictions", variant="primary") | |
| batch_output = gr.File(label="Download predictions", visible=False) | |
| batch_metrics = gr.Dataframe(label="Benchmark metrics", visible=False) | |
| batch_spearman = gr.Textbox(label="Flag-count Spearman", interactive=False, visible=False) | |
| batch_notes = gr.Markdown(visible=False) | |
| batch_button.click( | |
| _predict_batch, | |
| inputs=[batch_file, batch_paired, batch_backend, batch_model], | |
| outputs=[batch_output, batch_metrics, batch_spearman, batch_notes], | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Notes | |
| - Default configuration expects heavy-chain only evaluation. | |
| - Backend overrides should match how the model was trained to avoid feature mismatch. | |
| - CSV inputs should include `id`, `heavy_seq`, and optionally `light_seq`. | |
| - Add a binary `label` column to compute accuracy/F1/ROC-AUC/PR-AUC/Brier. | |
| - Include `reactivity_count` to report Spearman correlation with predicted probabilities. | |
| - **First run downloads the 650M-parameter ESM-1v model; the progress bar will display a download message until it finishes (can take several minutes).** | |
| """ | |
| ) | |
| return demo | |
| def main() -> None: | |
| demo = make_interface() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |