File size: 3,278 Bytes
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Run end-to-end benchmarks for the polyreactivity model."""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import List

from .. import train as train_cli

PROJECT_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_TRAIN = PROJECT_ROOT / "tests" / "fixtures" / "boughter.csv"
DEFAULT_EVAL = [
    PROJECT_ROOT / "tests" / "fixtures" / "jain.csv",
    PROJECT_ROOT / "tests" / "fixtures" / "shehata.csv",
    PROJECT_ROOT / "tests" / "fixtures" / "harvey.csv",
]


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Run polyreactivity benchmarks")
    parser.add_argument(
        "--config",
        default="configs/default.yaml",
        help="Path to configuration YAML file.",
    )
    parser.add_argument(
        "--train",
        default=str(DEFAULT_TRAIN),
        help="Training dataset CSV (defaults to bundled fixture).",
    )
    parser.add_argument(
        "--eval",
        nargs="+",
        default=[str(path) for path in DEFAULT_EVAL],
        help="Evaluation dataset CSV paths (>=1).",
    )
    parser.add_argument(
        "--report-dir",
        default="artifacts",
        help="Directory to write metrics, predictions, and plots.",
    )
    parser.add_argument(
        "--model-path",
        default="artifacts/model.joblib",
        help="Destination for the trained model artifact.",
    )
    parser.add_argument(
        "--backend",
        choices=["descriptors", "plm", "concat"],
        help="Override feature backend during training.",
    )
    parser.add_argument("--plm-model", help="Optional PLM model override.")
    parser.add_argument("--cache-dir", help="Embedding cache directory override.")
    parser.add_argument(
        "--device",
        choices=["auto", "cpu", "cuda"],
        help="Device override for embeddings.",
    )
    parser.add_argument(
        "--paired",
        action="store_true",
        help="Use paired heavy/light chains when available.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="Batch size for PLM embedding batches.",
    )
    return parser


def main(argv: List[str] | None = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)

    if len(args.eval) < 1:
        parser.error("Provide at least one evaluation dataset via --eval.")

    report_dir = Path(args.report_dir)
    report_dir.mkdir(parents=True, exist_ok=True)

    train_args: list[str] = [
        "--config",
        args.config,
        "--train",
        args.train,
        "--save-to",
        str(Path(args.model_path)),
        "--report-to",
        str(report_dir),
        "--batch-size",
        str(args.batch_size),
    ]

    train_args.extend(["--eval", *args.eval])

    if args.backend:
        train_args.extend(["--backend", args.backend])
    if args.plm_model:
        train_args.extend(["--plm-model", args.plm_model])
    if args.cache_dir:
        train_args.extend(["--cache-dir", args.cache_dir])
    if args.device:
        train_args.extend(["--device", args.device])
    if args.paired:
        train_args.append("--paired")

    return train_cli.main(train_args)


if __name__ == "__main__":
    raise SystemExit(main())