Spaces:
Sleeping
Sleeping
| """Run end-to-end benchmarks for the polyreactivity model.""" | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| from typing import List | |
| from .. import train as train_cli | |
| PROJECT_ROOT = Path(__file__).resolve().parents[2] | |
| DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv" | |
| DEFAULT_EVAL = [ | |
| PROJECT_ROOT / "tests" / "fixtures" / "jain.csv", | |
| PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv", | |
| PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv", | |
| ] | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks") | |
| parser.add_argument( | |
| "--config", | |
| default="configs/default.yaml", | |
| help="Path to configuration YAML file.", | |
| ) | |
| parser.add_argument( | |
| "--train", | |
| default=str(DEFAULT_TRAIN), | |
| help="Training dataset CSV (defaults to bundled fixture).", | |
| ) | |
| parser.add_argument( | |
| "--eval", | |
| nargs="+", | |
| default=[str(path) for path in DEFAULT_EVAL], | |
| help="Evaluation dataset CSV paths (>=1).", | |
| ) | |
| parser.add_argument( | |
| "--report-dir", | |
| default="artifacts", | |
| help="Directory to write metrics, predictions, and plots.", | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| default="artifacts/model.joblib", | |
| help="Destination for the trained model artifact.", | |
| ) | |
| parser.add_argument( | |
| "--backend", | |
| choices=["descriptors", "plm", "concat"], | |
| help="Override feature backend during training.", | |
| ) | |
| parser.add_argument("--plm-model", help="Optional PLM model override.") | |
| parser.add_argument("--cache-dir", help="Embedding cache directory override.") | |
| parser.add_argument( | |
| "--device", | |
| choices=["auto", "cpu", "cuda"], | |
| help="Device override for embeddings.", | |
| ) | |
| parser.add_argument( | |
| "--paired", | |
| action="store_true", | |
| help="Use paired heavy/light chains when available.", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=8, | |
| help="Batch size for PLM embedding batches.", | |
| ) | |
| return parser | |
| def main(argv: List[str] | None = None) -> int: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| if len(args.eval) < 1: | |
| parser.error("Provide at least one evaluation dataset via --eval.") | |
| report_dir = Path(args.report_dir) | |
| report_dir.mkdir(parents=True, exist_ok=True) | |
| train_args: list[str] = [ | |
| "--config", | |
| args.config, | |
| "--train", | |
| args.train, | |
| "--save-to", | |
| str(Path(args.model_path)), | |
| "--report-to", | |
| str(report_dir), | |
| "--batch-size", | |
| str(args.batch_size), | |
| ] | |
| train_args.extend(["--eval", *args.eval]) | |
| if args.backend: | |
| train_args.extend(["--backend", args.backend]) | |
| if args.plm_model: | |
| train_args.extend(["--plm-model", args.plm_model]) | |
| if args.cache_dir: | |
| train_args.extend(["--cache-dir", args.cache_dir]) | |
| if args.device: | |
| train_args.extend(["--device", args.device]) | |
| if args.paired: | |
| train_args.append("--paired") | |
| return train_cli.main(train_args) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |