"""Configuration helpers for the polyreactivity project.""" from __future__ import annotations from dataclasses import asdict, dataclass, field try: import importlib.resources as pkg_resources from importlib.resources.abc import Traversable except (ModuleNotFoundError, AttributeError): # pragma: no cover - compatibility import importlib_resources as pkg_resources # type: ignore[no-redef] from importlib_resources.abc import Traversable # type: ignore[assignment] from pathlib import Path from typing import Any, Sequence import yaml @dataclass(slots=True) class FeatureBackendSettings: type: str = "plm" plm_model_name: str = "facebook/esm2_t12_35M_UR50D" layer_pool: str = "mean" cache_dir: str = ".cache/embeddings" standardize: bool = True @dataclass(slots=True) class DescriptorSettings: use_anarci: bool = True regions: Sequence[str] = field(default_factory=lambda: ["CDRH1", "CDRH2", "CDRH3"]) features: Sequence[str] = field( default_factory=lambda: [ "length", "charge", "hydropathy", "aromaticity", "pI", "net_charge", ] ) ph: float = 7.4 @dataclass(slots=True) class ModelSettings: head: str = "logreg" C: float = 1.0 class_weight: Any = "balanced" @dataclass(slots=True) class CalibrationSettings: method: str | None = "isotonic" @dataclass(slots=True) class TrainingSettings: cv_folds: int = 10 scoring: str = "roc_auc" n_jobs: int = -1 @dataclass(slots=True) class IOSettings: outputs_dir: str = "artifacts" preds_filename: str = "preds.csv" metrics_filename: str = "metrics.csv" @dataclass(slots=True) class Config: seed: int = 42 device: str = "auto" feature_backend: FeatureBackendSettings = field(default_factory=FeatureBackendSettings) descriptors: DescriptorSettings = field(default_factory=DescriptorSettings) model: ModelSettings = field(default_factory=ModelSettings) calibration: CalibrationSettings = field(default_factory=CalibrationSettings) training: TrainingSettings = field(default_factory=TrainingSettings) io: IOSettings = field(default_factory=IOSettings) raw: dict[str, Any] = field(default_factory=dict) def _merge_section(default: Any, data: dict[str, Any] | None) -> Any: if data is None: return default merged = asdict(default) | data return type(default)(**merged) def load_config(path: str | Path | None = None) -> Config: """Load a YAML configuration file into a strongly-typed ``Config`` object.""" data = _read_config_data(path) feature_backend = _merge_section(FeatureBackendSettings(), data.get("feature_backend")) descriptors = _merge_section(DescriptorSettings(), data.get("descriptors")) model = _merge_section(ModelSettings(), data.get("model")) calibration = _merge_section(CalibrationSettings(), data.get("calibration")) training = _merge_section(TrainingSettings(), data.get("training")) io_settings = _merge_section(IOSettings(), data.get("io")) config = Config( seed=int(data.get("seed", 42)), device=str(data.get("device", "auto")), feature_backend=feature_backend, descriptors=descriptors, model=model, calibration=calibration, training=training, io=io_settings, raw=data, ) return config def _read_config_data(path: str | Path | None) -> dict[str, Any]: """Return mapping data from YAML or the bundled default.""" if path is None: resource = pkg_resources.files("polyreact.configs") / "default.yaml" return _load_yaml_resource(resource) resolved = _resolve_config_path(Path(path)) if resolved is not None: return _load_yaml_file(resolved) resource_root = pkg_resources.files("polyreact") resource = resource_root / Path(path).as_posix() if resource.is_file(): return _load_yaml_resource(resource) msg = f"Configuration file not found: {path}" raise FileNotFoundError(msg) def _resolve_config_path(path: Path) -> Path | None: if path.exists(): return path if not path.is_absolute(): candidate = Path(__file__).resolve().parent / path if candidate.exists(): return candidate return None def _load_yaml_file(path: Path) -> dict[str, Any]: with path.open("r", encoding="utf-8") as handle: return _parse_yaml(handle.read()) def _load_yaml_resource(resource: Traversable) -> dict[str, Any]: with resource.open("r", encoding="utf-8") as handle: return _parse_yaml(handle.read()) def _parse_yaml(text: str) -> dict[str, Any]: parsed = yaml.safe_load(text) or {} if not isinstance(parsed, dict): # pragma: no cover - safeguard msg = "Configuration must be a mapping at the top level" raise ValueError(msg) return parsed