makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
3.28 kB
"""Run end-to-end benchmarks for the polyreactivity model."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import List
from .. import train as train_cli
PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv"
DEFAULT_EVAL = [
PROJECT_ROOT / "tests" / "fixtures" / "jain.csv",
PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv",
PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv",
]
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks")
parser.add_argument(
"--config",
default="configs/default.yaml",
help="Path to configuration YAML file.",
)
parser.add_argument(
"--train",
default=str(DEFAULT_TRAIN),
help="Training dataset CSV (defaults to bundled fixture).",
)
parser.add_argument(
"--eval",
nargs="+",
default=[str(path) for path in DEFAULT_EVAL],
help="Evaluation dataset CSV paths (>=1).",
)
parser.add_argument(
"--report-dir",
default="artifacts",
help="Directory to write metrics, predictions, and plots.",
)
parser.add_argument(
"--model-path",
default="artifacts/model.joblib",
help="Destination for the trained model artifact.",
)
parser.add_argument(
"--backend",
choices=["descriptors", "plm", "concat"],
help="Override feature backend during training.",
)
parser.add_argument("--plm-model", help="Optional PLM model override.")
parser.add_argument("--cache-dir", help="Embedding cache directory override.")
parser.add_argument(
"--device",
choices=["auto", "cpu", "cuda"],
help="Device override for embeddings.",
)
parser.add_argument(
"--paired",
action="store_true",
help="Use paired heavy/light chains when available.",
)
parser.add_argument(
"--batch-size",
type=int,
default=8,
help="Batch size for PLM embedding batches.",
)
return parser
def main(argv: List[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
if len(args.eval) < 1:
parser.error("Provide at least one evaluation dataset via --eval.")
report_dir = Path(args.report_dir)
report_dir.mkdir(parents=True, exist_ok=True)
train_args: list[str] = [
"--config",
args.config,
"--train",
args.train,
"--save-to",
str(Path(args.model_path)),
"--report-to",
str(report_dir),
"--batch-size",
str(args.batch_size),
]
train_args.extend(["--eval", *args.eval])
if args.backend:
train_args.extend(["--backend", args.backend])
if args.plm_model:
train_args.extend(["--plm-model", args.plm_model])
if args.cache_dir:
train_args.extend(["--cache-dir", args.cache_dir])
if args.device:
train_args.extend(["--device", args.device])
if args.paired:
train_args.append("--paired")
return train_cli.main(train_args)
if __name__ == "__main__":
raise SystemExit(main())