makiling's picture
Upload folder using huggingface_hub
5f58699 verified
raw
history blame
5.18 kB
"""Sequence descriptor features for polyreactivity prediction."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Sequence
import numpy as np
import pandas as pd
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from sklearn.preprocessing import StandardScaler
from .anarsi import AnarciNumberer, NumberedSequence
_VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY")
@dataclass(slots=True)
class DescriptorConfig:
"""Configuration for descriptor-based features."""
use_anarci: bool = True
regions: Sequence[str] = ("CDRH1", "CDRH2", "CDRH3")
features: Sequence[str] = (
"length",
"charge",
"hydropathy",
"aromaticity",
"pI",
"net_charge",
)
ph: float = 7.4
class DescriptorFeaturizer:
"""Compute descriptor features with optional ANARCI-based regions."""
def __init__(
self,
*,
config: DescriptorConfig,
numberer: AnarciNumberer | None = None,
standardize: bool = True,
) -> None:
self.config = config
self.numberer = numberer if not config.use_anarci else numberer or AnarciNumberer()
self.standardize = standardize
self.scaler = StandardScaler() if standardize else None
self.feature_names_: list[str] | None = None
def fit(self, sequences: Iterable[str]) -> "DescriptorFeaturizer":
table = self.compute_feature_table(sequences)
values = table.to_numpy(dtype=float)
if self.standardize and self.scaler is not None:
self.scaler.fit(values)
self.feature_names_ = list(table.columns)
return self
def transform(self, sequences: Iterable[str]) -> np.ndarray:
if self.feature_names_ is None:
msg = "DescriptorFeaturizer must be fitted before calling transform."
raise RuntimeError(msg)
table = self.compute_feature_table(sequences)
values = table.to_numpy(dtype=float)
if self.standardize and self.scaler is not None:
values = self.scaler.transform(values)
return values
def fit_transform(self, sequences: Iterable[str]) -> np.ndarray:
table = self.compute_feature_table(sequences)
values = table.to_numpy(dtype=float)
if self.standardize and self.scaler is not None:
self.scaler.fit(values)
values = self.scaler.transform(values)
self.feature_names_ = list(table.columns)
return values
def compute_feature_table(self, sequences: Iterable[str]) -> pd.DataFrame:
rows: list[dict[str, float]] = []
for sequence in sequences:
regions = self._prepare_regions(sequence)
if not self.config.use_anarci:
region_names = ["FULL"]
else:
region_names = [region.upper() for region in self.config.regions]
row: dict[str, float] = {}
for region_name in region_names:
normalized_name = region_name.upper()
region_sequence = regions.get(normalized_name, "")
for feature_name in self.config.features:
column = f"{normalized_name}_{feature_name}"
row[column] = _compute_feature(
region_sequence,
feature_name,
ph=self.config.ph,
)
rows.append(row)
if not self.config.use_anarci:
region_names = ["FULL"]
else:
region_names = [region.upper() for region in self.config.regions]
columns = [
f"{region}_{feature}"
for region in region_names
for feature in self.config.features
]
frame = pd.DataFrame(rows, columns=columns)
return frame.fillna(0.0)
def _prepare_regions(self, sequence: str) -> dict[str, str]:
if not self.config.use_anarci:
return {"FULL": sequence}
try:
numbered: NumberedSequence = self.numberer.number_sequence(sequence)
except (RuntimeError, ValueError):
return {}
return {key.upper(): value for key, value in numbered.regions.items()}
def _sanitize_sequence(sequence: str) -> str:
return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS)
def _compute_feature(sequence: str, feature_name: str, *, ph: float) -> float:
sanitized = _sanitize_sequence(sequence)
if not sanitized:
return 0.0
analysis = ProteinAnalysis(sanitized)
if feature_name == "length":
return float(len(sanitized))
if feature_name == "hydropathy":
return float(analysis.gravy())
if feature_name == "aromaticity":
return float(analysis.aromaticity())
if feature_name == "pI":
return float(analysis.isoelectric_point())
if feature_name == "net_charge":
return float(analysis.charge_at_pH(ph))
if feature_name == "charge":
net = analysis.charge_at_pH(ph)
return float(net / len(sanitized))
msg = f"Unsupported feature: {feature_name}"
raise ValueError(msg)