File size: 15,182 Bytes
5f58699
 
 
 
 
7420f23
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7420f23
 
 
 
 
 
 
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""Protein language model embeddings backend with caching support."""

from __future__ import annotations

import hashlib
import warnings
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, Iterable, List, Sequence, Tuple

import numpy as np
import torch
from torch import nn

try:  # pragma: no cover - optional dependency
    from transformers import AutoModel, AutoTokenizer
except ImportError:  # pragma: no cover - optional dependency
    AutoModel = None
    AutoTokenizer = None

try:  # pragma: no cover - optional dependency
    import esm
except ImportError:  # pragma: no cover - optional dependency
    esm = None

if esm is not None:  # pragma: no cover - optional dependency
    warnings.filterwarnings(
        "ignore",
        message="Regression weights not found, predicting contacts will not produce correct results.",
        module="esm.pretrained",
    )

from .anarsi import AnarciNumberer

ModelLoader = Callable[[str, str], Tuple[object, nn.Module]]

if esm is not None:  # pragma: no cover - optional dependency
    _ESM1V_LOADERS = {
        "esm1v_t33_650m_ur90s_1": esm.pretrained.esm1v_t33_650M_UR90S_1,
        "esm1v_t33_650m_ur90s_2": esm.pretrained.esm1v_t33_650M_UR90S_2,
        "esm1v_t33_650m_ur90s_3": esm.pretrained.esm1v_t33_650M_UR90S_3,
        "esm1v_t33_650m_ur90s_4": esm.pretrained.esm1v_t33_650M_UR90S_4,
        "esm1v_t33_650m_ur90s_5": esm.pretrained.esm1v_t33_650M_UR90S_5,
    }
else:  # pragma: no cover - optional dependency
    _ESM1V_LOADERS: dict[str, Callable[[], tuple[nn.Module, object]]] = {}


class _ESMTokenizer:
    """Callable wrapper that mimics Hugging Face tokenizers for ESM models."""

    def __init__(self, alphabet) -> None:  # noqa: ANN001
        self.alphabet = alphabet
        self._batch_converter = alphabet.get_batch_converter()

    def __call__(
        self,
        sequences: Sequence[str],
        *,
        return_tensors: str = "pt",
        padding: bool = True,  # noqa: FBT002
        truncation: bool = True,  # noqa: FBT002
        add_special_tokens: bool = True,  # noqa: FBT002
        return_special_tokens_mask: bool = True,  # noqa: FBT002
    ) -> dict[str, torch.Tensor]:
        if return_tensors != "pt":  # pragma: no cover - defensive branch
            msg = "ESM tokenizer only supports return_tensors='pt'"
            raise ValueError(msg)
        data = [(str(idx), (seq or "").upper()) for idx, seq in enumerate(sequences)]
        _labels, _strings, tokens = self._batch_converter(data)
        attention_mask = (tokens != self.alphabet.padding_idx).long()
        special_tokens = torch.zeros_like(tokens)
        specials = {
            self.alphabet.padding_idx,
            self.alphabet.cls_idx,
            self.alphabet.eos_idx,
        }
        for special in specials:
            special_tokens |= tokens == special
        output: dict[str, torch.Tensor] = {
            "input_ids": tokens,
            "attention_mask": attention_mask,
        }
        if return_special_tokens_mask:
            output["special_tokens_mask"] = special_tokens.long()
        return output


class _ESMModelWrapper(nn.Module):
    """Adapter providing a Hugging Face style interface for ESM models."""

    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model
        self.layer_index = getattr(model, "num_layers", None)
        if self.layer_index is None:
            msg = "Unable to determine final layer for ESM model"
            raise AttributeError(msg)

    def eval(self) -> "_ESMModelWrapper":  # pragma: no cover - trivial
        self.model.eval()
        return self

    def to(self, device: str) -> "_ESMModelWrapper":  # pragma: no cover - trivial
        self.model.to(device)
        return self

    def forward(self, input_ids: torch.Tensor, **_):  # noqa: ANN003
        with torch.no_grad():
            outputs = self.model(
                input_ids,
                repr_layers=[self.layer_index],
                return_contacts=False,
            )
        hidden = outputs["representations"][self.layer_index]
        return SimpleNamespace(last_hidden_state=hidden)

    __call__ = forward


@dataclass(slots=True)
class PLMConfig:
    model_name: str = "facebook/esm1v_t33_650M_UR90S_1"
    layer_pool: str = "mean"
    cache_dir: Path = Path(".cache/embeddings")
    device: str = "auto"


class PLMEmbedder:
    """Embed amino-acid sequences using a transformer model with caching."""

    def __init__(
        self,
        model_name: str = "facebook/esm1v_t33_650M_UR90S_1",
        *,
        layer_pool: str = "mean",
        device: str = "auto",
        cache_dir: str | Path | None = None,
        numberer: AnarciNumberer | None = None,
        model_loader: ModelLoader | None = None,
    ) -> None:
        self.model_name = model_name
        self.layer_pool = layer_pool
        self.device = self._resolve_device(device)
        self.cache_dir = Path(cache_dir or ".cache/embeddings")
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.numberer = numberer
        self.model_loader = model_loader
        self._tokenizer: object | None = None
        self._model: nn.Module | None = None

    @staticmethod
    def _resolve_device(device: str) -> str:
        if device == "auto":
            return "cuda" if torch.cuda.is_available() else "cpu"
        return device

    @property
    def tokenizer(self):  # noqa: D401
        if self._tokenizer is None:
            tokenizer, model = self._load_model_components()
            self._tokenizer = tokenizer
            self._model = model
        return self._tokenizer

    @property
    def model(self) -> nn.Module:
        if self._model is None:
            tokenizer, model = self._load_model_components()
            self._tokenizer = tokenizer
            self._model = model
        return self._model

    def _load_model_components(self) -> Tuple[object, nn.Module]:
        if self.model_loader is not None:
            tokenizer, model = self.model_loader(self.model_name, self.device)
            return tokenizer, model

        if self._is_esm1v_model(self.model_name):
            return self._load_esm_model()

        if AutoModel is None or AutoTokenizer is None:  # pragma: no cover - optional dependency
            msg = "transformers must be installed to use PLMEmbedder"
            raise ImportError(msg)
        tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
        model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True)
        model.eval()
        model.to(self.device)
        return tokenizer, model

    def _load_esm_model(self) -> Tuple[object, nn.Module]:
        if esm is None:  # pragma: no cover - optional dependency
            msg = (
                "The 'esm' package is required to use ESM-1v models."
            )
            raise ImportError(msg)

        normalized = self._canonical_esm_name(self.model_name)
        loader = _ESM1V_LOADERS.get(normalized)
        if loader is None:  # pragma: no cover - guard branch
            msg = f"Unsupported ESM-1v model: {self.model_name}"
            raise ValueError(msg)

        model, alphabet = loader()
        model.eval()
        model.to(self.device)
        tokenizer = _ESMTokenizer(alphabet)
        wrapper = _ESMModelWrapper(model)
        return tokenizer, wrapper

    @staticmethod
    def _canonical_esm_name(model_name: str) -> str:
        name = model_name.lower()
        if "/" in name:
            name = name.split("/")[-1]
        return name

    @classmethod
    def _is_esm1v_model(cls, model_name: str) -> bool:
        return cls._canonical_esm_name(model_name).startswith("esm1v")

    def embed(self, sequences: Iterable[str], *, batch_size: int = 8) -> np.ndarray:
        batch_sequences = list(sequences)
        if not batch_sequences:
            return np.empty((0, 0), dtype=np.float32)

        outputs: List[np.ndarray | None] = [None] * len(batch_sequences)
        unique_to_compute: dict[str, List[Tuple[int, Path]]] = {}
        model_dir = self.cache_dir / self._normalized_model_name()
        model_dir.mkdir(parents=True, exist_ok=True)

        cache_hits: list[tuple[int, Path]] = []
        for idx, sequence in enumerate(batch_sequences):
            cache_path = self._sequence_cache_path(model_dir, sequence)
            if cache_path.exists():
                cache_hits.append((idx, cache_path))
            else:
                unique_to_compute.setdefault(sequence, []).append((idx, cache_path))

        if cache_hits:
            loaders = [path for _, path in cache_hits]
            max_workers = min(len(loaders), 32)
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                for (idx, _), embedding in zip(cache_hits, executor.map(np.load, loaders), strict=True):
                    outputs[idx] = embedding

        if unique_to_compute:
            embeddings = self._compute_embeddings(list(unique_to_compute.keys()), batch_size=batch_size)
            for sequence, embedding in zip(unique_to_compute.keys(), embeddings, strict=True):
                targets = unique_to_compute[sequence]
                for idx, cache_path in targets:
                    outputs[idx] = embedding
                    np.save(cache_path, embedding)
        if any(item is None for item in outputs):  # pragma: no cover - safety
            msg = "Failed to compute embeddings for all sequences"
            raise RuntimeError(msg)
        array_outputs = [np.asarray(item, dtype=np.float32) for item in outputs]  # type: ignore[arg-type]
        return np.stack(array_outputs, axis=0)

    def _compute_embeddings(self, sequences: Sequence[str], *, batch_size: int) -> List[np.ndarray]:
        tokenizer = self.tokenizer
        model = self.model
        model.eval()
        embeddings: List[np.ndarray] = []
        for start in range(0, len(sequences), batch_size):
            chunk = list(sequences[start : start + batch_size])
            tokenized = self._tokenize(tokenizer, chunk)
            model_inputs: dict[str, torch.Tensor] = {}
            aux_inputs: dict[str, torch.Tensor] = {}
            for key, value in tokenized.items():
                if isinstance(value, torch.Tensor):
                    tensor_value = value.to(self.device)
                else:
                    tensor_value = value
                if key == "special_tokens_mask":
                    aux_inputs[key] = tensor_value
                else:
                    model_inputs[key] = tensor_value
            with torch.no_grad():
                outputs = model(**model_inputs)
                hidden_states = outputs.last_hidden_state.detach().cpu()
            attention_mask = model_inputs.get("attention_mask")
            special_tokens_mask = aux_inputs.get("special_tokens_mask")
            if isinstance(attention_mask, torch.Tensor):
                attention_mask = attention_mask.detach().cpu()
            if isinstance(special_tokens_mask, torch.Tensor):
                special_tokens_mask = special_tokens_mask.detach().cpu()

            for idx, sequence in enumerate(chunk):
                hidden = hidden_states[idx]
                mask = attention_mask[idx] if isinstance(attention_mask, torch.Tensor) else None
                special_mask = (
                    special_tokens_mask[idx]
                    if isinstance(special_tokens_mask, torch.Tensor)
                    else None
                )
                embedding = self._pool_hidden(hidden, mask, special_mask, sequence)
                embeddings.append(embedding)
        return embeddings

    def _tokenize(self, tokenizer, sequences: Sequence[str]):
        if hasattr(tokenizer, "__call__"):
            return tokenizer(
                list(sequences),
                return_tensors="pt",
                padding=True,
                truncation=True,
                add_special_tokens=True,
                return_special_tokens_mask=True,
            )
        msg = "Tokenizer does not implement __call__"
        raise TypeError(msg)

    def _pool_hidden(
        self,
        hidden: torch.Tensor,
        attention_mask: torch.Tensor | None,
        special_mask: torch.Tensor | None,
        sequence: str,
    ) -> np.ndarray:
        if attention_mask is None:
            attention = torch.ones(hidden.size(0), dtype=torch.float32)
        else:
            attention = attention_mask.to(dtype=torch.float32)
        if special_mask is not None:
            attention = attention * (1.0 - special_mask.to(dtype=torch.float32))
        if attention.sum() == 0:
            attention = torch.ones_like(attention)

        if self.layer_pool == "mean":
            return self._masked_mean(hidden, attention)
        if self.layer_pool == "cls":
            return hidden[0].detach().cpu().numpy()
        if self.layer_pool == "per_token_mean_cdrh3":
            return self._pool_cdrh3(hidden, attention, sequence)
        msg = f"Unsupported layer pool: {self.layer_pool}"
        raise ValueError(msg)

    @staticmethod
    def _masked_mean(hidden: torch.Tensor, mask: torch.Tensor) -> np.ndarray:
        weights = mask.unsqueeze(-1)
        weighted = hidden * weights
        denom = weights.sum()
        if denom == 0:
            pooled = hidden.mean(dim=0)
        else:
            pooled = weighted.sum(dim=0) / denom
        return pooled.detach().cpu().numpy()

    def _pool_cdrh3(self, hidden: torch.Tensor, mask: torch.Tensor, sequence: str) -> np.ndarray:
        numberer = self.numberer
        if numberer is None:
            numberer = AnarciNumberer()
            self.numberer = numberer
        numbered = numberer.number_sequence(sequence)
        cdr = numbered.regions.get("CDRH3", "")
        if not cdr:
            return self._masked_mean(hidden, mask)
        sequence_upper = sequence.upper()
        start = sequence_upper.find(cdr.upper())
        if start == -1:
            return self._masked_mean(hidden, mask)
        residues_idx = mask.nonzero(as_tuple=False).squeeze(-1).tolist()
        if not residues_idx:
            return self._masked_mean(hidden, mask)
        end = start + len(cdr)
        if end > len(residues_idx):
            return self._masked_mean(hidden, mask)
        cdr_token_positions = residues_idx[start:end]
        if not cdr_token_positions:
            return self._masked_mean(hidden, mask)
        cdr_mask = torch.zeros_like(mask)
        for pos in cdr_token_positions:
            cdr_mask[pos] = 1.0
        return self._masked_mean(hidden, cdr_mask)

    def _sequence_cache_path(self, model_dir: Path, sequence: str) -> Path:
        digest = hashlib.sha1(sequence.encode("utf-8")).hexdigest()
        return model_dir / f"{digest}.npy"

    def _normalized_model_name(self) -> str:
        if self._is_esm1v_model(self.model_name):
            return self._canonical_esm_name(self.model_name)
        return self.model_name.replace("/", "_")