makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
5.83 kB
"""Loader for the Shehata et al. (2019) PSR dataset."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable
import pandas as pd
from .utils import standardize_frame
SHEHATA_SOURCE = "shehata2019"
_COLUMN_ALIASES = {
"id": (
"antibody_id",
"antibody",
"antibody name",
"antibody_name",
"sequence_name",
"Antibody Name",
),
"heavy_seq": (
"heavy",
"heavy_chain",
"heavy aa",
"heavy_sequence",
"vh",
"vh_sequence",
"heavy chain aa",
"Heavy Chain AA",
),
"light_seq": (
"light",
"light_chain",
"light aa",
"light_sequence",
"vl",
"vl_sequence",
"light chain aa",
"Light Chain AA",
),
"label": (
"polyreactive",
"binding_class",
"binding class",
"psr_class",
"psr binding",
"psr classification",
"Binding class",
"Binding Class",
),
}
_LABEL_MAP = {
"polyreactive": 1,
"non-polyreactive": 0,
"positive": 1,
"negative": 0,
"high": 1,
"low": 0,
"pos": 1,
"neg": 0,
1: 1,
0: 0,
1.0: 1,
0.0: 0,
"1": 1,
"0": 0,
}
_PSR_SCORE_ALIASES: tuple[str, ...] = (
"psr score",
"psr_score",
"psr overall score",
"overall score",
"psr z",
"psr_z",
)
def _clean_sequence(sequence: object) -> str:
if isinstance(sequence, str):
return "".join(sequence.split()).upper()
return ""
def _maybe_extract_psr_scores(frame: pd.DataFrame) -> pd.DataFrame:
scores: dict[str, pd.Series] = {}
for column in frame.columns:
lowered = column.strip().lower()
if any(alias in lowered for alias in _PSR_SCORE_ALIASES):
key = lowered.replace(" ", "_")
scores[key] = frame[column]
if not scores:
return pd.DataFrame(index=frame.index)
renamed = {}
for name, series in scores.items():
cleaned_name = name
for prefix in ("psr_", "overall_"):
if cleaned_name.startswith(prefix):
cleaned_name = cleaned_name[len(prefix) :]
break
cleaned_name = cleaned_name.replace("__", "_")
cleaned_name = cleaned_name.replace("(", "").replace(")", "")
cleaned_name = cleaned_name.replace("-", "_")
renamed[f"psr_{cleaned_name}"] = pd.to_numeric(series, errors="coerce")
return pd.DataFrame(renamed)
def _pick_source_label(path: Path | None) -> str:
if path is None:
return SHEHATA_SOURCE
stem = path.stem.lower()
if "curated" in stem or "subset" in stem:
return f"{SHEHATA_SOURCE}_curated"
return SHEHATA_SOURCE
def _standardize(
frame: pd.DataFrame,
*,
heavy_only: bool,
source: str,
) -> pd.DataFrame:
standardized = standardize_frame(
frame,
source=source,
heavy_only=heavy_only,
column_aliases=_COLUMN_ALIASES,
label_map=_LABEL_MAP,
is_test=True,
)
psr_scores = _maybe_extract_psr_scores(frame)
mask = standardized["heavy_seq"].map(_clean_sequence) != ""
standardized = standardized.loc[mask].copy()
standardized.reset_index(drop=True, inplace=True)
standardized["heavy_seq"] = standardized["heavy_seq"].map(_clean_sequence)
standardized["light_seq"] = standardized["light_seq"].map(_clean_sequence)
if not psr_scores.empty:
psr_scores = psr_scores.loc[mask]
psr_scores = psr_scores.reset_index(drop=True)
for column in psr_scores.columns:
standardized[column] = psr_scores[column].reset_index(drop=True)
return standardized
def _read_excel(path: Path, *, heavy_only: bool) -> pd.DataFrame:
excel = pd.ExcelFile(path, engine="openpyxl")
sheet_candidates: Iterable[str] = excel.sheet_names
def _score(name: str) -> tuple[int, str]:
lowered = name.lower()
priority = 0
if "psr" in lowered or "polyreactivity" in lowered:
priority = 2
elif "sheet" not in lowered:
priority = 1
return (-priority, name)
sheet_name = sorted(sheet_candidates, key=_score)[0]
raw = excel.parse(sheet_name)
raw = raw.dropna(how="all")
return _standardize(raw, heavy_only=heavy_only, source=_pick_source_label(path))
def load_dataframe(path_or_url: str, heavy_only: bool = True) -> pd.DataFrame:
"""Load the Shehata dataset into the canonical format.
Supports both pre-processed CSV exports and the original Excel supplement
(*.xls/*.xlsx). Additional PSR score columns are preserved when available.
"""
lower = path_or_url.lower()
source_override: str | None = None
if lower.startswith("http://") or lower.startswith("https://"):
if lower.endswith((".xls", ".xlsx")):
raw = pd.read_excel(path_or_url, engine="openpyxl")
return _standardize(raw, heavy_only=heavy_only, source=SHEHATA_SOURCE)
frame = pd.read_csv(path_or_url)
return _standardize(frame, heavy_only=heavy_only, source=SHEHATA_SOURCE)
path = Path(path_or_url)
source_override = _pick_source_label(path)
if path.suffix.lower() in {".xls", ".xlsx"}:
engine = "openpyxl" if path.suffix.lower() == ".xlsx" else None
if engine:
frame = _read_excel(path, heavy_only=heavy_only)
else:
frame = pd.read_excel(path, engine=None)
frame = _standardize(frame, heavy_only=heavy_only, source=source_override)
frame["source"] = source_override
return frame
frame = pd.read_csv(path)
standardized = _standardize(frame, heavy_only=heavy_only, source=source_override)
standardized["source"] = source_override
return standardized