"""Loader for the Boughter et al. 2020 dataset.""" from __future__ import annotations from typing import Iterable import numpy as np import pandas as pd from .utils import LOGGER, standardize_frame _COLUMN_ALIASES = { "id": ("sequence_id",), "heavy_seq": ("heavy", "heavy_chain"), "light_seq": ("light", "light_chain"), "label": ("polyreactive",), } def _find_flag_columns(columns: Iterable[str]) -> list[str]: flag_cols: list[str] = [] for column in columns: normalized = column.lower().replace(" ", "") if "flag" in normalized: flag_cols.append(column) return flag_cols def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame: if not flag_columns: return frame flag_values = ( frame[flag_columns] .apply(pd.to_numeric, errors="coerce") .fillna(0.0) ) flag_binary = (flag_values > 0).astype(int) flags_total = flag_binary.sum(axis=1) specific_mask = flags_total == 0 nonspecific_mask = flags_total >= 4 keep_mask = specific_mask | nonspecific_mask dropped = int((~keep_mask).sum()) if dropped: LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped) filtered = frame.loc[keep_mask].copy() filtered["flags_total"] = flags_total.loc[keep_mask].astype(int) filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0) filtered["polyreactive"] = filtered["label"] return filtered def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame: """Load the Boughter dataset into the canonical format.""" frame = pd.read_csv(path_or_url) flag_columns = _find_flag_columns(frame.columns) frame = _apply_flag_policy(frame, flag_columns) label_series = frame.get("label") if label_series is not None: frame = frame[label_series.isin({0, 1})].copy() standardized = standardize_frame( frame, source="boughter2020", heavy_only=heavy_only, column_aliases=_COLUMN_ALIASES, is_test=False, ) if "flags_total" in frame.columns and "flags_total" not in standardized.columns: standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int) return standardized