makiling's picture
Upload folder using huggingface_hub
7420f23 verified
raw
history blame
15.2 kB
"""Protein language model embeddings backend with caching support."""
from __future__ import annotations
import hashlib
import warnings
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, Iterable, List, Sequence, Tuple
import numpy as np
import torch
from torch import nn
try: # pragma: no cover - optional dependency
from transformers import AutoModel, AutoTokenizer
except ImportError: # pragma: no cover - optional dependency
AutoModel = None
AutoTokenizer = None
try: # pragma: no cover - optional dependency
import esm
except ImportError: # pragma: no cover - optional dependency
esm = None
if esm is not None: # pragma: no cover - optional dependency
warnings.filterwarnings(
"ignore",
message="Regression weights not found, predicting contacts will not produce correct results.",
module="esm.pretrained",
)
from .anarsi import AnarciNumberer
ModelLoader = Callable[[str, str], Tuple[object, nn.Module]]
if esm is not None: # pragma: no cover - optional dependency
_ESM1V_LOADERS = {
"esm1v_t33_650m_ur90s_1": esm.pretrained.esm1v_t33_650M_UR90S_1,
"esm1v_t33_650m_ur90s_2": esm.pretrained.esm1v_t33_650M_UR90S_2,
"esm1v_t33_650m_ur90s_3": esm.pretrained.esm1v_t33_650M_UR90S_3,
"esm1v_t33_650m_ur90s_4": esm.pretrained.esm1v_t33_650M_UR90S_4,
"esm1v_t33_650m_ur90s_5": esm.pretrained.esm1v_t33_650M_UR90S_5,
}
else: # pragma: no cover - optional dependency
_ESM1V_LOADERS: dict[str, Callable[[], tuple[nn.Module, object]]] = {}
class _ESMTokenizer:
"""Callable wrapper that mimics Hugging Face tokenizers for ESM models."""
def __init__(self, alphabet) -> None: # noqa: ANN001
self.alphabet = alphabet
self._batch_converter = alphabet.get_batch_converter()
def __call__(
self,
sequences: Sequence[str],
*,
return_tensors: str = "pt",
padding: bool = True, # noqa: FBT002
truncation: bool = True, # noqa: FBT002
add_special_tokens: bool = True, # noqa: FBT002
return_special_tokens_mask: bool = True, # noqa: FBT002
) -> dict[str, torch.Tensor]:
if return_tensors != "pt": # pragma: no cover - defensive branch
msg = "ESM tokenizer only supports return_tensors='pt'"
raise ValueError(msg)
data = [(str(idx), (seq or "").upper()) for idx, seq in enumerate(sequences)]
_labels, _strings, tokens = self._batch_converter(data)
attention_mask = (tokens != self.alphabet.padding_idx).long()
special_tokens = torch.zeros_like(tokens)
specials = {
self.alphabet.padding_idx,
self.alphabet.cls_idx,
self.alphabet.eos_idx,
}
for special in specials:
special_tokens |= tokens == special
output: dict[str, torch.Tensor] = {
"input_ids": tokens,
"attention_mask": attention_mask,
}
if return_special_tokens_mask:
output["special_tokens_mask"] = special_tokens.long()
return output
class _ESMModelWrapper(nn.Module):
"""Adapter providing a Hugging Face style interface for ESM models."""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
self.layer_index = getattr(model, "num_layers", None)
if self.layer_index is None:
msg = "Unable to determine final layer for ESM model"
raise AttributeError(msg)
def eval(self) -> "_ESMModelWrapper": # pragma: no cover - trivial
self.model.eval()
return self
def to(self, device: str) -> "_ESMModelWrapper": # pragma: no cover - trivial
self.model.to(device)
return self
def forward(self, input_ids: torch.Tensor, **_): # noqa: ANN003
with torch.no_grad():
outputs = self.model(
input_ids,
repr_layers=[self.layer_index],
return_contacts=False,
)
hidden = outputs["representations"][self.layer_index]
return SimpleNamespace(last_hidden_state=hidden)
__call__ = forward
@dataclass(slots=True)
class PLMConfig:
model_name: str = "facebook/esm1v_t33_650M_UR90S_1"
layer_pool: str = "mean"
cache_dir: Path = Path(".cache/embeddings")
device: str = "auto"
class PLMEmbedder:
"""Embed amino-acid sequences using a transformer model with caching."""
def __init__(
self,
model_name: str = "facebook/esm1v_t33_650M_UR90S_1",
*,
layer_pool: str = "mean",
device: str = "auto",
cache_dir: str | Path | None = None,
numberer: AnarciNumberer | None = None,
model_loader: ModelLoader | None = None,
) -> None:
self.model_name = model_name
self.layer_pool = layer_pool
self.device = self._resolve_device(device)
self.cache_dir = Path(cache_dir or ".cache/embeddings")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.numberer = numberer
self.model_loader = model_loader
self._tokenizer: object | None = None
self._model: nn.Module | None = None
@staticmethod
def _resolve_device(device: str) -> str:
if device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return device
@property
def tokenizer(self): # noqa: D401
if self._tokenizer is None:
tokenizer, model = self._load_model_components()
self._tokenizer = tokenizer
self._model = model
return self._tokenizer
@property
def model(self) -> nn.Module:
if self._model is None:
tokenizer, model = self._load_model_components()
self._tokenizer = tokenizer
self._model = model
return self._model
def _load_model_components(self) -> Tuple[object, nn.Module]:
if self.model_loader is not None:
tokenizer, model = self.model_loader(self.model_name, self.device)
return tokenizer, model
if self._is_esm1v_model(self.model_name):
return self._load_esm_model()
if AutoModel is None or AutoTokenizer is None: # pragma: no cover - optional dependency
msg = "transformers must be installed to use PLMEmbedder"
raise ImportError(msg)
tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
model.eval()
model.to(self.device)
return tokenizer, model
def _load_esm_model(self) -> Tuple[object, nn.Module]:
if esm is None: # pragma: no cover - optional dependency
msg = (
"The 'esm' package is required to use ESM-1v models."
)
raise ImportError(msg)
normalized = self._canonical_esm_name(self.model_name)
loader = _ESM1V_LOADERS.get(normalized)
if loader is None: # pragma: no cover - guard branch
msg = f"Unsupported ESM-1v model: {self.model_name}"
raise ValueError(msg)
model, alphabet = loader()
model.eval()
model.to(self.device)
tokenizer = _ESMTokenizer(alphabet)
wrapper = _ESMModelWrapper(model)
return tokenizer, wrapper
@staticmethod
def _canonical_esm_name(model_name: str) -> str:
name = model_name.lower()
if "/" in name:
name = name.split("/")[-1]
return name
@classmethod
def _is_esm1v_model(cls, model_name: str) -> bool:
return cls._canonical_esm_name(model_name).startswith("esm1v")
def embed(self, sequences: Iterable[str], *, batch_size: int = 8) -> np.ndarray:
batch_sequences = list(sequences)
if not batch_sequences:
return np.empty((0, 0), dtype=np.float32)
outputs: List[np.ndarray | None] = [None] * len(batch_sequences)
unique_to_compute: dict[str, List[Tuple[int, Path]]] = {}
model_dir = self.cache_dir / self._normalized_model_name()
model_dir.mkdir(parents=True, exist_ok=True)
cache_hits: list[tuple[int, Path]] = []
for idx, sequence in enumerate(batch_sequences):
cache_path = self._sequence_cache_path(model_dir, sequence)
if cache_path.exists():
cache_hits.append((idx, cache_path))
else:
unique_to_compute.setdefault(sequence, []).append((idx, cache_path))
if cache_hits:
loaders = [path for _, path in cache_hits]
max_workers = min(len(loaders), 32)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for (idx, _), embedding in zip(cache_hits, executor.map(np.load, loaders), strict=True):
outputs[idx] = embedding
if unique_to_compute:
embeddings = self._compute_embeddings(list(unique_to_compute.keys()), batch_size=batch_size)
for sequence, embedding in zip(unique_to_compute.keys(), embeddings, strict=True):
targets = unique_to_compute[sequence]
for idx, cache_path in targets:
outputs[idx] = embedding
np.save(cache_path, embedding)
if any(item is None for item in outputs): # pragma: no cover - safety
msg = "Failed to compute embeddings for all sequences"
raise RuntimeError(msg)
array_outputs = [np.asarray(item, dtype=np.float32) for item in outputs] # type: ignore[arg-type]
return np.stack(array_outputs, axis=0)
def _compute_embeddings(self, sequences: Sequence[str], *, batch_size: int) -> List[np.ndarray]:
tokenizer = self.tokenizer
model = self.model
model.eval()
embeddings: List[np.ndarray] = []
for start in range(0, len(sequences), batch_size):
chunk = list(sequences[start : start + batch_size])
tokenized = self._tokenize(tokenizer, chunk)
model_inputs: dict[str, torch.Tensor] = {}
aux_inputs: dict[str, torch.Tensor] = {}
for key, value in tokenized.items():
if isinstance(value, torch.Tensor):
tensor_value = value.to(self.device)
else:
tensor_value = value
if key == "special_tokens_mask":
aux_inputs[key] = tensor_value
else:
model_inputs[key] = tensor_value
with torch.no_grad():
outputs = model(**model_inputs)
hidden_states = outputs.last_hidden_state.detach().cpu()
attention_mask = model_inputs.get("attention_mask")
special_tokens_mask = aux_inputs.get("special_tokens_mask")
if isinstance(attention_mask, torch.Tensor):
attention_mask = attention_mask.detach().cpu()
if isinstance(special_tokens_mask, torch.Tensor):
special_tokens_mask = special_tokens_mask.detach().cpu()
for idx, sequence in enumerate(chunk):
hidden = hidden_states[idx]
mask = attention_mask[idx] if isinstance(attention_mask, torch.Tensor) else None
special_mask = (
special_tokens_mask[idx]
if isinstance(special_tokens_mask, torch.Tensor)
else None
)
embedding = self._pool_hidden(hidden, mask, special_mask, sequence)
embeddings.append(embedding)
return embeddings
def _tokenize(self, tokenizer, sequences: Sequence[str]):
if hasattr(tokenizer, "__call__"):
return tokenizer(
list(sequences),
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=True,
return_special_tokens_mask=True,
)
msg = "Tokenizer does not implement __call__"
raise TypeError(msg)
def _pool_hidden(
self,
hidden: torch.Tensor,
attention_mask: torch.Tensor | None,
special_mask: torch.Tensor | None,
sequence: str,
) -> np.ndarray:
if attention_mask is None:
attention = torch.ones(hidden.size(0), dtype=torch.float32)
else:
attention = attention_mask.to(dtype=torch.float32)
if special_mask is not None:
attention = attention * (1.0 - special_mask.to(dtype=torch.float32))
if attention.sum() == 0:
attention = torch.ones_like(attention)
if self.layer_pool == "mean":
return self._masked_mean(hidden, attention)
if self.layer_pool == "cls":
return hidden[0].detach().cpu().numpy()
if self.layer_pool == "per_token_mean_cdrh3":
return self._pool_cdrh3(hidden, attention, sequence)
msg = f"Unsupported layer pool: {self.layer_pool}"
raise ValueError(msg)
@staticmethod
def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> np.ndarray:
weights = mask.unsqueeze(-1)
weighted = hidden * weights
denom = weights.sum()
if denom == 0:
pooled = hidden.mean(dim=0)
else:
pooled = weighted.sum(dim=0) / denom
return pooled.detach().cpu().numpy()
def _pool_cdrh3(self, hidden: torch.Tensor, mask: torch.Tensor, sequence: str) -> np.ndarray:
numberer = self.numberer
if numberer is None:
numberer = AnarciNumberer()
self.numberer = numberer
numbered = numberer.number_sequence(sequence)
cdr = numbered.regions.get("CDRH3", "")
if not cdr:
return self._masked_mean(hidden, mask)
sequence_upper = sequence.upper()
start = sequence_upper.find(cdr.upper())
if start == -1:
return self._masked_mean(hidden, mask)
residues_idx = mask.nonzero(as_tuple=False).squeeze(-1).tolist()
if not residues_idx:
return self._masked_mean(hidden, mask)
end = start + len(cdr)
if end > len(residues_idx):
return self._masked_mean(hidden, mask)
cdr_token_positions = residues_idx[start:end]
if not cdr_token_positions:
return self._masked_mean(hidden, mask)
cdr_mask = torch.zeros_like(mask)
for pos in cdr_token_positions:
cdr_mask[pos] = 1.0
return self._masked_mean(hidden, cdr_mask)
def _sequence_cache_path(self, model_dir: Path, sequence: str) -> Path:
digest = hashlib.sha1(sequence.encode("utf-8")).hexdigest()
return model_dir / f"{digest}.npy"
def _normalized_model_name(self) -> str:
if self._is_esm1v_model(self.model_name):
return self._canonical_esm_name(self.model_name)
return self.model_name.replace("/", "_")