makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
2.87 kB
"""Command-line interface for polyreactivity predictions."""
from __future__ import annotations
import argparse
from pathlib import Path
import pandas as pd
from .api import predict_batch
from .config import load_config
from .utils.io import read_table, write_table
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Polyreactivity prediction CLI")
parser.add_argument(
"--input",
required=True,
help="Path to input CSV or JSONL file with sequences.",
)
parser.add_argument(
"--output",
required=True,
help="Path to write predictions CSV.",
)
parser.add_argument(
"--config",
default="configs/default.yaml",
help="Path to configuration YAML file.",
)
parser.add_argument(
"--backend",
choices=["plm", "descriptors", "concat"],
help="Override feature backend from config.",
)
parser.add_argument(
"--plm-model",
help="Override PLM model name.",
)
parser.add_argument(
"--weights",
required=True,
help="Path to trained model artifact (joblib).",
)
parser.add_argument(
"--heavy-only",
dest="heavy_only",
action="store_true",
default=True,
help="Use only heavy chains (default).",
)
parser.add_argument(
"--paired",
dest="heavy_only",
action="store_false",
help="Use paired heavy/light chains if available.",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size for model inference (PLM backend).",
)
parser.add_argument(
"--device",
choices=["auto", "cpu", "cuda"],
help="Computation device override.",
)
parser.add_argument(
"--cache-dir",
help="Cache directory for embeddings.",
)
return parser
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
config = load_config(args.config)
df = read_table(args.input)
if "heavy_seq" not in df.columns and "heavy" not in df.columns:
parser.error("Input file must contain a 'heavy_seq' column (or 'heavy').")
if df.get("heavy_seq", df.get("heavy", "")).fillna("").str.len().eq(0).all():
parser.error("At least one non-empty heavy sequence is required.")
predictions = predict_batch(
df.to_dict("records"),
config=config,
backend=args.backend,
plm_model=args.plm_model,
weights=args.weights,
heavy_only=args.heavy_only,
batch_size=args.batch_size,
device=args.device,
cache_dir=args.cache_dir,
)
write_table(predictions, args.output)
return 0
if __name__ == "__main__":
raise SystemExit(main())