Spaces:
Sleeping
Sleeping
| """ANARCI/ANARCII numbering helpers.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from typing import Dict, List, Sequence, Tuple | |
| try: | |
| from anarcii.pipeline import Anarcii # type: ignore | |
| except ImportError: # pragma: no cover - optional dependency | |
| Anarcii = None | |
| class NumberedResidue: | |
| """Single residue with IMGT numbering metadata.""" | |
| position: int | |
| insertion: str | |
| amino_acid: str | |
| class NumberedSequence: | |
| """Container for numbering results and derived regions.""" | |
| sequence: str | |
| scheme: str | |
| chain_type: str | |
| residues: list[NumberedResidue] | |
| regions: dict[str, str] | |
| _IMGT_HEAVY_REGIONS: Sequence[Tuple[str, int, int]] = ( | |
| ("FR1", 1, 26), | |
| ("CDRH1", 27, 38), | |
| ("FR2", 39, 55), | |
| ("CDRH2", 56, 65), | |
| ("FR3", 66, 104), | |
| ("CDRH3", 105, 117), | |
| ("FR4", 118, 128), | |
| ) | |
| _IMGT_LIGHT_REGIONS: Sequence[Tuple[str, int, int]] = ( | |
| ("FR1", 1, 26), | |
| ("CDRL1", 27, 38), | |
| ("FR2", 39, 55), | |
| ("CDRL2", 56, 65), | |
| ("FR3", 66, 104), | |
| ("CDRL3", 105, 117), | |
| ("FR4", 118, 128), | |
| ) | |
| _REGION_MAP: dict[Tuple[str, str], Sequence[Tuple[str, int, int]]] = { | |
| ("imgt", "H"): _IMGT_HEAVY_REGIONS, | |
| ("imgt", "L"): _IMGT_LIGHT_REGIONS, | |
| } | |
| _VALID_AMINO_ACIDS = set("ACDEFGHIKLMNPQRSTVWY") | |
| _DEFAULT_SCHEME = "imgt" | |
| _DEFAULT_CHAIN = "H" | |
| _DEFAULT_NUMBERER: AnarciNumberer | None = None | |
| def _sanitize_sequence(sequence: str) -> str: | |
| return "".join(residue for residue in sequence.upper() if residue in _VALID_AMINO_ACIDS) | |
| def get_default_numberer() -> AnarciNumberer: | |
| global _DEFAULT_NUMBERER | |
| if _DEFAULT_NUMBERER is None: | |
| _DEFAULT_NUMBERER = AnarciNumberer(chain_type=_DEFAULT_CHAIN, cpu=True, ncpu=1, verbose=False) | |
| return _DEFAULT_NUMBERER | |
| def trim_variable_domain( | |
| sequence: str, | |
| *, | |
| numberer: AnarciNumberer | None = None, | |
| scheme: str = _DEFAULT_SCHEME, | |
| chain_type: str = _DEFAULT_CHAIN, | |
| fallback_length: int = 130, | |
| ) -> str: | |
| """Return the FR1–FR4 variable domain for a heavy/light chain sequence.""" | |
| cleaned = _sanitize_sequence(sequence) | |
| if not cleaned: | |
| return "" | |
| active_numberer = numberer or get_default_numberer() | |
| try: | |
| numbered = active_numberer.number_sequence(cleaned) | |
| except Exception: # pragma: no cover - best effort safeguard | |
| return cleaned[:fallback_length] | |
| region_sets = _region_boundaries(scheme, chain_type) | |
| pieces: list[str] = [] | |
| for region_name, _start, _end in region_sets: | |
| segment = numbered.regions.get(region_name, "") | |
| if segment: | |
| pieces.append(segment) | |
| trimmed = "".join(pieces) | |
| if not trimmed: | |
| trimmed = numbered.regions.get("full", "") | |
| if not trimmed: | |
| trimmed = cleaned[:fallback_length] | |
| return trimmed | |
| def _normalise_chain_type(chain_type: str) -> str: | |
| upper = chain_type.upper() | |
| if upper in {"H", "HV"}: | |
| return "H" | |
| if upper in {"L", "K", "LV", "KV"}: | |
| return "L" | |
| return upper | |
| class AnarciNumberer: | |
| """Thin wrapper around the ANARCII pipeline to obtain IMGT regions.""" | |
| def __init__( | |
| self, | |
| *, | |
| scheme: str = "imgt", | |
| chain_type: str = "H", | |
| cpu: bool = True, | |
| ncpu: int = 1, | |
| verbose: bool = False, | |
| ) -> None: | |
| if Anarcii is None: # pragma: no cover - optional dependency guard | |
| msg = ( | |
| "anarcii is required for numbering but is not installed." | |
| " Install 'anarcii' to enable ANARCI-based features." | |
| ) | |
| raise ImportError(msg) | |
| self.scheme = scheme | |
| self.expected_chain_type = _normalise_chain_type(chain_type) | |
| self.cpu = cpu | |
| self.ncpu = ncpu | |
| self.verbose = verbose | |
| self._runner = None | |
| def _ensure_runner(self) -> Anarcii: | |
| if self._runner is None: | |
| self._runner = Anarcii( | |
| seq_type="antibody", | |
| mode="accuracy", | |
| batch_size=1, | |
| cpu=self.cpu, | |
| ncpu=self.ncpu, | |
| verbose=self.verbose, | |
| ) | |
| return self._runner | |
| def number_sequence(self, sequence: str) -> NumberedSequence: | |
| """Return numbering metadata for a single amino-acid sequence.""" | |
| runner = self._ensure_runner() | |
| output = runner.number([sequence]) | |
| record = next(iter(output.values())) | |
| if record.get("error"): | |
| raise RuntimeError(f"ANARCI failed: {record['error']}") | |
| scheme = record.get("scheme", self.scheme) | |
| detected_chain = record.get("chain_type", self.expected_chain_type) | |
| normalised_chain = _normalise_chain_type(detected_chain) | |
| if self.expected_chain_type and normalised_chain != self.expected_chain_type: | |
| msg = ( | |
| f"Expected chain type {self.expected_chain_type!r} but got" | |
| f" {normalised_chain!r}" | |
| ) | |
| raise ValueError(msg) | |
| residues = [ | |
| NumberedResidue(position=pos, insertion=ins, amino_acid=aa) | |
| for (pos, ins), aa in record["numbering"] | |
| ] | |
| regions = _extract_regions( | |
| residues=residues, | |
| scheme=scheme, | |
| chain_type=normalised_chain, | |
| ) | |
| return NumberedSequence( | |
| sequence=sequence, | |
| scheme=scheme, | |
| chain_type=normalised_chain, | |
| residues=residues, | |
| regions=regions, | |
| ) | |
| def _region_boundaries(scheme: str, chain_type: str) -> Sequence[Tuple[str, int, int]]: | |
| key = (scheme.lower(), chain_type.upper()) | |
| return _REGION_MAP.get(key, ()) | |
| def _extract_regions( | |
| *, | |
| residues: Sequence[NumberedResidue], | |
| scheme: str, | |
| chain_type: str, | |
| ) -> dict[str, str]: | |
| boundaries = _region_boundaries(scheme, chain_type) | |
| slots: Dict[str, List[str]] = {name: [] for name, _, _ in boundaries} | |
| slots["full"] = [] | |
| for residue in residues: | |
| aa = residue.amino_acid | |
| if aa == "-": | |
| continue | |
| slots["full"].append(aa) | |
| for name, start, end in boundaries: | |
| if start <= residue.position <= end: | |
| slots.setdefault(name, []).append(aa) | |
| break | |
| return {key: "".join(value) for key, value in slots.items()} | |