"""Utility helpers for dataset loading.""" from __future__ import annotations import logging from typing import Iterable, Sequence import pandas as pd EXPECTED_COLUMNS = ("id", "heavy_seq", "light_seq", "label") OPTIONAL_COLUMNS = ("source", "is_test") LOGGER = logging.getLogger("polyreact.data") _DEFAULT_ALIASES: dict[str, Sequence[str]] = { "id": ("id", "sequence_id", "antibody_id", "uid"), "heavy_seq": ("heavy_seq", "heavy", "heavy_chain", "H", "H_chain"), "light_seq": ("light_seq", "light", "light_chain", "L", "L_chain"), "label": ("label", "polyreactive", "is_polyreactive", "class", "target"), } DEFAULT_LABEL_MAP: dict[str | int | float | bool, int] = { 1: 1, 0: 0, "1": 1, "0": 0, True: 1, False: 0, "true": 1, "false": 0, "polyreactive": 1, "non-polyreactive": 0, "poly": 1, "non": 0, "positive": 1, "negative": 0, } def _normalize_label_key(value: object) -> object: if isinstance(value, str): trimmed = value.strip().lower() if trimmed in { "polyreactive", "non-polyreactive", "poly", "non", "positive", "negative", "high", "low", "pos", "neg", "1", "0", "true", "false", }: return trimmed if trimmed.isdigit(): return trimmed return value def ensure_columns(frame: pd.DataFrame, *, heavy_only: bool = True) -> pd.DataFrame: """Validate and coerce dataframe columns to the canonical format.""" frame = frame.copy() for column in ("id", "heavy_seq", "label"): if column not in frame.columns: msg = f"Required column '{column}' missing from dataframe" raise KeyError(msg) if "light_seq" not in frame.columns: frame["light_seq"] = "" if heavy_only: frame["light_seq"] = "" frame["id"] = frame["id"].astype(str) frame["heavy_seq"] = frame["heavy_seq"].fillna("").astype(str) frame["light_seq"] = frame["light_seq"].fillna("").astype(str) frame["label"] = frame["label"].astype(int) ordered = list(EXPECTED_COLUMNS) + [ col for col in frame.columns if col not in EXPECTED_COLUMNS ] return frame[ordered] def standardize_frame( frame: pd.DataFrame, *, source: str, heavy_only: bool = True, column_aliases: dict[str, Sequence[str]] | None = None, label_map: dict[str | int | float | bool, int] | None = None, is_test: bool | None = None, ) -> pd.DataFrame: """Rename columns using aliases and coerce labels to integers.""" aliases = {**_DEFAULT_ALIASES} if column_aliases: for key, values in column_aliases.items(): aliases[key] = tuple(values) + tuple(aliases.get(key, ())) rename_map: dict[str, str] = {} for target, candidates in aliases.items(): if target in frame.columns: continue for candidate in candidates: if candidate in frame.columns and candidate not in rename_map: rename_map[candidate] = target break normalized = frame.rename(columns=rename_map).copy() if "light_seq" not in normalized.columns: normalized["light_seq"] = "" label_lookup = label_map or DEFAULT_LABEL_MAP normalized["label"] = normalized["label"].map(lambda x: label_lookup.get(_normalize_label_key(x))) if normalized["label"].isnull().any(): msg = "Label column contains unmapped or missing values" raise ValueError(msg) normalized["source"] = source if is_test is not None: normalized["is_test"] = bool(is_test) normalized = ensure_columns(normalized, heavy_only=heavy_only) return normalized def deduplicate_sequences( frames: Iterable[pd.DataFrame], *, heavy_only: bool = True, key_columns: Sequence[str] | None = None, keep_intra_frames: set[int] | None = None, ) -> list[pd.DataFrame]: """Remove duplicate entries across multiple dataframes with configurable keys.""" if key_columns is None: key_columns = ["heavy_seq"] if heavy_only else ["heavy_seq", "light_seq"] keep_intra_frames = keep_intra_frames or set() seen: set[tuple[str, ...]] = set() cleaned: list[pd.DataFrame] = [] for frame_idx, frame in enumerate(frames): valid_columns = [col for col in key_columns if col in frame.columns] if not valid_columns: valid_columns = ["heavy_seq"] mask: list[bool] = [] frame_seen: set[tuple[str, ...]] = set() allow_intra = frame_idx in keep_intra_frames for values in frame[valid_columns].itertuples(index=False, name=None): key = tuple(_normalise_key_value(value) for value in values) if key in seen: mask.append(False) continue if not allow_intra and key in frame_seen: mask.append(False) continue mask.append(True) frame_seen.add(key) seen.update(frame_seen) filtered = frame.loc[mask].reset_index(drop=True) removed = len(frame) - len(filtered) if removed: dataset = "" if "source" in frame.columns and not frame["source"].empty: dataset = str(frame["source"].iloc[0]) LOGGER.info("Removed %s duplicate sequences from %s", removed, dataset) cleaned.append(filtered) return cleaned def _normalise_key_value(value: object) -> str: if value is None or (isinstance(value, float) and pd.isna(value)): return "" return str(value).strip()