|
|
"""Gradio Space for polyreactivity prediction.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from scipy.stats import spearmanr |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, |
|
|
average_precision_score, |
|
|
brier_score_loss, |
|
|
f1_score, |
|
|
roc_auc_score, |
|
|
) |
|
|
|
|
|
from polyreact.api import predict_batch |
|
|
|
|
|
DEFAULT_MODEL_PATH = Path(os.environ.get("POLYREACT_MODEL_PATH", "artifacts/model.joblib")).resolve() |
|
|
DEFAULT_CONFIG_PATH = Path(os.environ.get("POLYREACT_CONFIG_PATH", "configs/default.yaml")).resolve() |
|
|
|
|
|
|
|
|
def _resolve_model_path(upload: Optional[gr.File]) -> Path: |
|
|
if upload is not None: |
|
|
return Path(upload.name) |
|
|
if DEFAULT_MODEL_PATH.exists(): |
|
|
return DEFAULT_MODEL_PATH |
|
|
raise FileNotFoundError( |
|
|
"Model artifact not found. Upload a trained model (.joblib) to run predictions." |
|
|
) |
|
|
|
|
|
|
|
|
def _predict_single( |
|
|
heavy_seq: str, |
|
|
light_seq: str, |
|
|
use_paired: bool, |
|
|
backend: str, |
|
|
model_file: Optional[gr.File], |
|
|
) -> tuple[str, float, int]: |
|
|
model_path = _resolve_model_path(model_file) |
|
|
heavy_seq = (heavy_seq or "").strip() |
|
|
light_seq = (light_seq or "").strip() |
|
|
if not heavy_seq: |
|
|
raise gr.Error("Please provide a heavy-chain amino acid sequence.") |
|
|
|
|
|
record = { |
|
|
"id": "sample", |
|
|
"heavy_seq": heavy_seq, |
|
|
"light_seq": light_seq, |
|
|
} |
|
|
progress = gr.Progress(track_tqdm=True) |
|
|
progress(0.02, "📦 Downloading ESM-1v weights (first run can take a few minutes)…", total=None) |
|
|
preds = predict_batch( |
|
|
[record], |
|
|
weights=model_path, |
|
|
heavy_only=not use_paired, |
|
|
backend=backend or None, |
|
|
config=DEFAULT_CONFIG_PATH if DEFAULT_CONFIG_PATH.exists() else None, |
|
|
) |
|
|
progress(1.0, "✅ Prediction complete") |
|
|
score = float(preds.iloc[0]["score"]) |
|
|
pred = int(preds.iloc[0]["pred"]) |
|
|
label = "Polyreactive" if pred == 1 else "Non-polyreactive" |
|
|
return label, score, pred |
|
|
|
|
|
|
|
|
def _format_metric(value: float) -> float: |
|
|
return float(f"{value:.4f}") |
|
|
|
|
|
|
|
|
def _compute_metrics(results: pd.DataFrame) -> tuple[pd.DataFrame, list[str], Optional[str]]: |
|
|
metrics_rows: list[dict[str, float]] = [] |
|
|
warnings: list[str] = [] |
|
|
spearman_text: Optional[str] = None |
|
|
|
|
|
if "label" in results.columns: |
|
|
label_series = results["label"].dropna() |
|
|
valid_labels = label_series.isin({0, 1}).all() |
|
|
if valid_labels and label_series.nunique() > 1: |
|
|
y_true = results.loc[label_series.index, "label"].astype(int) |
|
|
y_score = results.loc[label_series.index, "score"].astype(float) |
|
|
y_pred = results.loc[label_series.index, "pred"].astype(int) |
|
|
|
|
|
metrics_rows.append({"metric": "Accuracy", "value": _format_metric(accuracy_score(y_true, y_pred))}) |
|
|
metrics_rows.append({"metric": "F1", "value": _format_metric(f1_score(y_true, y_pred))}) |
|
|
try: |
|
|
roc = roc_auc_score(y_true, y_score) |
|
|
metrics_rows.append({"metric": "ROC-AUC", "value": _format_metric(roc)}) |
|
|
except ValueError: |
|
|
warnings.append("ROC-AUC skipped (requires both positive and negative labels).") |
|
|
try: |
|
|
pr_auc = average_precision_score(y_true, y_score) |
|
|
metrics_rows.append({"metric": "PR-AUC", "value": _format_metric(pr_auc)}) |
|
|
except ValueError: |
|
|
warnings.append("PR-AUC skipped (requires both positive and negative labels).") |
|
|
try: |
|
|
brier = brier_score_loss(y_true, y_score) |
|
|
metrics_rows.append({"metric": "Brier", "value": _format_metric(brier)}) |
|
|
except ValueError: |
|
|
warnings.append("Brier score skipped (invalid probability values).") |
|
|
else: |
|
|
warnings.append("Label column found but must contain binary 0/1 values with both classes present.") |
|
|
|
|
|
if "reactivity_count" in results.columns: |
|
|
valid = results[["reactivity_count", "score"]].dropna() |
|
|
if len(valid) > 2 and valid["reactivity_count"].nunique() > 1: |
|
|
stat, pval = spearmanr(valid["reactivity_count"], valid["score"]) |
|
|
if stat == stat: |
|
|
spearman_text = f"Spearman ρ = {stat:.4f} (p = {pval:.3g})" |
|
|
else: |
|
|
warnings.append("Flag-count Spearman skipped (need ≥3 non-identical counts).") |
|
|
|
|
|
metrics_df = pd.DataFrame(metrics_rows) |
|
|
return metrics_df, warnings, spearman_text |
|
|
|
|
|
|
|
|
def _predict_batch( |
|
|
input_file: gr.File, |
|
|
use_paired: bool, |
|
|
backend: str, |
|
|
model_file: Optional[gr.File], |
|
|
) -> tuple[gr.File, gr.DataFrame, gr.Textbox, gr.Markdown]: |
|
|
if input_file is None: |
|
|
raise gr.Error("Upload a CSV file with columns id, heavy_seq[, light_seq].") |
|
|
model_path = _resolve_model_path(model_file) |
|
|
input_path = Path(input_file.name) |
|
|
frame = pd.read_csv(input_path) |
|
|
required_cols = {"id", "heavy_seq"} |
|
|
if not required_cols.issubset(frame.columns): |
|
|
raise gr.Error("CSV must include at least 'id' and 'heavy_seq' columns.") |
|
|
|
|
|
records = frame.to_dict("records") |
|
|
progress = gr.Progress(track_tqdm=True) |
|
|
progress(0.02, "📦 Downloading ESM-1v weights (first run can take a few minutes)…", total=None) |
|
|
preds = predict_batch( |
|
|
records, |
|
|
weights=model_path, |
|
|
heavy_only=not use_paired, |
|
|
backend=backend or None, |
|
|
config=DEFAULT_CONFIG_PATH if DEFAULT_CONFIG_PATH.exists() else None, |
|
|
) |
|
|
progress(1.0, "✅ Batch prediction complete") |
|
|
merged = frame.merge(preds, on="id", how="left") |
|
|
output_path = input_path.parent / "polyreact_predictions.csv" |
|
|
merged.to_csv(output_path, index=False) |
|
|
|
|
|
metrics_df, warnings, spearman_text = _compute_metrics(merged) |
|
|
metrics_update = gr.update(value=metrics_df, visible=not metrics_df.empty) |
|
|
spearman_update = gr.update(value=spearman_text or "", visible=spearman_text is not None) |
|
|
notes_update = gr.update( |
|
|
value="\n".join(f"- {msg}" for msg in warnings) if warnings else "", |
|
|
visible=bool(warnings), |
|
|
) |
|
|
|
|
|
return ( |
|
|
gr.update(value=str(output_path), visible=True), |
|
|
metrics_update, |
|
|
spearman_update, |
|
|
notes_update, |
|
|
) |
|
|
|
|
|
|
|
|
def make_interface() -> gr.Blocks: |
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Polyreactivity Predictor |
|
|
|
|
|
Provide an antibody heavy chain (and optional light chain) to estimate |
|
|
polyreactivity probability. Upload a trained model artifact or place it |
|
|
at `artifacts/model.joblib`. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Tab("Single Sequence"): |
|
|
with gr.Row(): |
|
|
heavy_input = gr.Textbox( |
|
|
label="Heavy chain sequence", |
|
|
lines=6, |
|
|
placeholder="Enter amino acid sequence", |
|
|
) |
|
|
light_input = gr.Textbox( |
|
|
label="Light chain sequence (optional)", |
|
|
lines=6, |
|
|
placeholder="Enter amino acid sequence", |
|
|
) |
|
|
with gr.Row(): |
|
|
use_paired = gr.Checkbox(label="Use paired evaluation", value=False) |
|
|
backend_input = gr.Dropdown( |
|
|
label="Feature backend override", |
|
|
choices=["", "descriptors", "plm", "concat"], |
|
|
value="", |
|
|
) |
|
|
model_upload = gr.File(label="Model artifact (.joblib)", file_types=[".joblib"], file_count="single") |
|
|
|
|
|
run_button = gr.Button("Predict", variant="primary") |
|
|
result_label = gr.Textbox(label="Prediction", interactive=False) |
|
|
result_score = gr.Number(label="Probability", precision=4) |
|
|
result_class = gr.Number(label="Binary call (1=polyreactive)") |
|
|
|
|
|
run_button.click( |
|
|
_predict_single, |
|
|
inputs=[heavy_input, light_input, use_paired, backend_input, model_upload], |
|
|
outputs=[result_label, result_score, result_class], |
|
|
) |
|
|
|
|
|
with gr.Tab("Batch CSV"): |
|
|
batch_file = gr.File(label="Upload CSV", file_types=[".csv"], file_count="single") |
|
|
batch_paired = gr.Checkbox(label="Use paired evaluation", value=False) |
|
|
batch_backend = gr.Dropdown( |
|
|
label="Feature backend override", |
|
|
choices=["", "descriptors", "plm", "concat"], |
|
|
value="", |
|
|
) |
|
|
batch_model = gr.File(label="Model artifact (.joblib)", file_types=[".joblib"], file_count="single") |
|
|
batch_button = gr.Button("Run batch predictions", variant="primary") |
|
|
batch_output = gr.File(label="Download predictions", visible=False) |
|
|
batch_metrics = gr.Dataframe(label="Benchmark metrics", visible=False) |
|
|
batch_spearman = gr.Textbox(label="Flag-count Spearman", interactive=False, visible=False) |
|
|
batch_notes = gr.Markdown(visible=False) |
|
|
|
|
|
batch_button.click( |
|
|
_predict_batch, |
|
|
inputs=[batch_file, batch_paired, batch_backend, batch_model], |
|
|
outputs=[batch_output, batch_metrics, batch_spearman, batch_notes], |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### Notes |
|
|
- Default configuration expects heavy-chain only evaluation. |
|
|
- Backend overrides should match how the model was trained to avoid feature mismatch. |
|
|
- CSV inputs should include `id`, `heavy_seq`, and optionally `light_seq`. |
|
|
- Add a binary `label` column to compute accuracy/F1/ROC-AUC/PR-AUC/Brier. |
|
|
- Include `reactivity_count` to report Spearman correlation with predicted probabilities. |
|
|
- **First run downloads the 650M-parameter ESM-1v model; the progress bar will display a download message until it finishes (can take several minutes).** |
|
|
""" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
demo = make_interface() |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|