makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
2.3 kB
"""Loader for the Boughter et al. 2020 dataset."""
from __future__ import annotations
from typing import Iterable
import numpy as np
import pandas as pd
from .utils import LOGGER, standardize_frame
_COLUMN_ALIASES = {
"id": ("sequence_id",),
"heavy_seq": ("heavy", "heavy_chain"),
"light_seq": ("light", "light_chain"),
"label": ("polyreactive",),
}
def _find_flag_columns(columns: Iterable[str]) -> list[str]:
flag_cols: list[str] = []
for column in columns:
normalized = column.lower().replace(" ", "")
if "flag" in normalized:
flag_cols.append(column)
return flag_cols
def _apply_flag_policy(frame: pd.DataFrame, flag_columns: list[str]) -> pd.DataFrame:
if not flag_columns:
return frame
flag_values = (
frame[flag_columns]
.apply(pd.to_numeric, errors="coerce")
.fillna(0.0)
)
flag_binary = (flag_values > 0).astype(int)
flags_total = flag_binary.sum(axis=1)
specific_mask = flags_total == 0
nonspecific_mask = flags_total >= 4
keep_mask = specific_mask | nonspecific_mask
dropped = int((~keep_mask).sum())
if dropped:
LOGGER.info("Dropped %s mildly polyreactive sequences (1-3 ELISA flags)", dropped)
filtered = frame.loc[keep_mask].copy()
filtered["flags_total"] = flags_total.loc[keep_mask].astype(int)
filtered["label"] = np.where(nonspecific_mask.loc[keep_mask], 1, 0)
filtered["polyreactive"] = filtered["label"]
return filtered
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
"""Load the Boughter dataset into the canonical format."""
frame = pd.read_csv(path_or_url)
flag_columns = _find_flag_columns(frame.columns)
frame = _apply_flag_policy(frame, flag_columns)
label_series = frame.get("label")
if label_series is not None:
frame = frame[label_series.isin({0, 1})].copy()
standardized = standardize_frame(
frame,
source="boughter2020",
heavy_only=heavy_only,
column_aliases=_COLUMN_ALIASES,
is_test=False,
)
if "flags_total" in frame.columns and "flags_total" not in standardized.columns:
standardized["flags_total"] = frame["flags_total"].to_numpy(dtype=int)
return standardized