"""Command-line interface for polyreactivity predictions.""" from __future__ import annotations import argparse from pathlib import Path import pandas as pd from .api import predict_batch from .config import load_config from .utils.io import read_table, write_table def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Polyreactivity prediction CLI") parser.add_argument( "--input", required=True, help="Path to input CSV or JSONL file with sequences.", ) parser.add_argument( "--output", required=True, help="Path to write predictions CSV.", ) parser.add_argument( "--config", default="configs/default.yaml", help="Path to configuration YAML file.", ) parser.add_argument( "--backend", choices=["plm", "descriptors", "concat"], help="Override feature backend from config.", ) parser.add_argument( "--plm-model", help="Override PLM model name.", ) parser.add_argument( "--weights", required=True, help="Path to trained model artifact (joblib).", ) parser.add_argument( "--heavy-only", dest="heavy_only", action="store_true", default=True, help="Use only heavy chains (default).", ) parser.add_argument( "--paired", dest="heavy_only", action="store_false", help="Use paired heavy/light chains if available.", ) parser.add_argument( "--batch-size", type=int, default=8, help="Batch size for model inference (PLM backend).", ) parser.add_argument( "--device", choices=["auto", "cpu", "cuda"], help="Computation device override.", ) parser.add_argument( "--cache-dir", help="Cache directory for embeddings.", ) return parser def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) config = load_config(args.config) df = read_table(args.input) if "heavy_seq" not in df.columns and "heavy" not in df.columns: parser.error("Input file must contain a 'heavy_seq' column (or 'heavy').") if df.get("heavy_seq", df.get("heavy", "")).fillna("").str.len().eq(0).all(): parser.error("At least one non-empty heavy sequence is required.") predictions = predict_batch( df.to_dict("records"), config=config, backend=args.backend, plm_model=args.plm_model, weights=args.weights, heavy_only=args.heavy_only, batch_size=args.batch_size, device=args.device, cache_dir=args.cache_dir, ) write_table(predictions, args.output) return 0 if __name__ == "__main__": raise SystemExit(main())