File size: 14,305 Bytes
5f58699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""Feature pipeline construction utilities."""

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from typing import Iterable, Sequence

import numpy as np
from sklearn.preprocessing import StandardScaler

from ..config import Config, DescriptorSettings, FeatureBackendSettings
from .descriptors import DescriptorConfig, DescriptorFeaturizer
from .plm import PLMEmbedder


@dataclass(slots=True)
class FeaturePipelineState:
    backend_type: str
    descriptor_featurizer: DescriptorFeaturizer | None
    plm_scaler: StandardScaler | None
    descriptor_config: DescriptorConfig | None
    plm_model_name: str | None
    plm_layer_pool: str | None
    cache_dir: str | None
    device: str
    feature_names: list[str] = field(default_factory=list)


class FeaturePipeline:
    """Fit/transform feature matrices according to configuration."""

    def __init__(
        self,
        *,
        backend: FeatureBackendSettings,
        descriptors: DescriptorSettings,
        device: str,
        cache_dir_override: str | None = None,
        plm_model_override: str | None = None,
        layer_pool_override: str | None = None,
    ) -> None:
        self.backend = backend
        self.descriptor_settings = descriptors
        self.device = device
        self.cache_dir_override = cache_dir_override
        self.plm_model_override = plm_model_override
        self.layer_pool_override = layer_pool_override

        self._descriptor: DescriptorFeaturizer | None = None
        self._plm: PLMEmbedder | None = None
        self._plm_scaler: StandardScaler | None = None
        self._feature_names: list[str] = []

    def fit_transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray:  # noqa: ANN001
        backend_type = self.backend.type if self.backend.type else "descriptors"
        self._validate_heavy_support(backend_type, heavy_only)
        sequences = _extract_sequences(df, heavy_only=heavy_only)

        if backend_type == "descriptors":
            self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
            features = self._descriptor.fit_transform(sequences)
            self._feature_names = list(self._descriptor.feature_names_ or [])
            self._plm = None
            self._plm_scaler = None
            return features.astype(np.float32)

        if backend_type == "plm":
            self._descriptor = None
            self._plm = _build_plm_embedder(
                self.backend,
                device=self.device,
                cache_dir_override=self.cache_dir_override,
                plm_model_override=self.plm_model_override,
                layer_pool_override=self.layer_pool_override,
            )
            embeddings = self._plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize:
                self._plm_scaler = StandardScaler()
                embeddings = self._plm_scaler.fit_transform(embeddings)
            else:
                self._plm_scaler = None
            self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
            return embeddings.astype(np.float32)

        if backend_type == "concat":
            descriptor = _build_descriptor_featurizer(self.descriptor_settings)
            desc_features = descriptor.fit_transform(sequences)
            plm = _build_plm_embedder(
                self.backend,
                device=self.device,
                cache_dir_override=self.cache_dir_override,
                plm_model_override=self.plm_model_override,
                layer_pool_override=self.layer_pool_override,
            )
            embeddings = plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize:
                plm_scaler = StandardScaler()
                embeddings = plm_scaler.fit_transform(embeddings)
            else:
                plm_scaler = None
            self._descriptor = descriptor
            self._plm = plm
            self._plm_scaler = plm_scaler
            self._feature_names = list(descriptor.feature_names_ or []) + [
                f"plm_{i}" for i in range(embeddings.shape[1])
            ]
            return np.concatenate([desc_features, embeddings], axis=1).astype(np.float32)

        msg = f"Unsupported feature backend: {backend_type}"
        raise ValueError(msg)

    def fit(self, df, *, heavy_only: bool, batch_size: int = 8) -> "FeaturePipeline":  # noqa: ANN001
        backend_type = self.backend.type if self.backend.type else "descriptors"
        self._validate_heavy_support(backend_type, heavy_only)
        sequences = _extract_sequences(df, heavy_only=heavy_only)

        if backend_type == "descriptors":
            self._descriptor = _build_descriptor_featurizer(self.descriptor_settings)
            self._descriptor.fit(sequences)
            self._feature_names = list(self._descriptor.feature_names_ or [])
            self._plm = None
            self._plm_scaler = None
        elif backend_type == "plm":
            self._descriptor = None
            self._plm = _build_plm_embedder(
                self.backend,
                device=self.device,
                cache_dir_override=self.cache_dir_override,
                plm_model_override=self.plm_model_override,
                layer_pool_override=self.layer_pool_override,
            )
            embeddings = self._plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize:
                self._plm_scaler = StandardScaler()
                embeddings = self._plm_scaler.fit_transform(embeddings)
            else:
                self._plm_scaler = None
            self._feature_names = [f"plm_{i}" for i in range(embeddings.shape[1])]
        elif backend_type == "concat":
            descriptor = _build_descriptor_featurizer(self.descriptor_settings)
            desc_features = descriptor.fit_transform(sequences)
            plm = _build_plm_embedder(
                self.backend,
                device=self.device,
                cache_dir_override=self.cache_dir_override,
                plm_model_override=self.plm_model_override,
                layer_pool_override=self.layer_pool_override,
            )
            embeddings = plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize:
                plm_scaler = StandardScaler()
                embeddings = plm_scaler.fit_transform(embeddings)
            else:
                plm_scaler = None
            self._descriptor = descriptor
            self._plm = plm
            self._plm_scaler = plm_scaler
            self._feature_names = list(descriptor.feature_names_ or []) + [
                f"plm_{i}" for i in range(embeddings.shape[1])
            ]
        else:  # pragma: no cover - defensive branch
            msg = f"Unsupported feature backend: {backend_type}"
            raise ValueError(msg)
        return self

    def transform(self, df, *, heavy_only: bool, batch_size: int = 8) -> np.ndarray:  # noqa: ANN001
        backend_type = self.backend.type if self.backend.type else "descriptors"
        self._validate_heavy_support(backend_type, heavy_only)
        sequences = _extract_sequences(df, heavy_only=heavy_only)

        if backend_type == "descriptors":
            if self._descriptor is None:
                msg = "Descriptor featurizer is not fitted"
                raise RuntimeError(msg)
            features = self._descriptor.transform(sequences)
        elif backend_type == "plm":
            if self._plm is None:
                msg = "PLM embedder is not initialised"
                raise RuntimeError(msg)
            embeddings = self._plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize and self._plm_scaler is not None:
                embeddings = self._plm_scaler.transform(embeddings)
            features = embeddings
        elif backend_type == "concat":
            if self._descriptor is None or self._plm is None:
                msg = "Feature pipeline not fitted"
                raise RuntimeError(msg)
            desc_features = self._descriptor.transform(sequences)
            embeddings = self._plm.embed(sequences, batch_size=batch_size)
            if self.backend.standardize and self._plm_scaler is not None:
                embeddings = self._plm_scaler.transform(embeddings)
            features = np.concatenate([desc_features, embeddings], axis=1)
        else:  # pragma: no cover - defensive branch
            msg = f"Unsupported feature backend: {backend_type}"
            raise ValueError(msg)

        return features.astype(np.float32)

    @property
    def feature_names(self) -> list[str]:
        return self._feature_names

    def get_state(self) -> FeaturePipelineState:
        descriptor = self._descriptor
        if descriptor is not None and descriptor.numberer is not None:
            if hasattr(descriptor.numberer, "_runner"):
                descriptor.numberer._runner = None  # type: ignore[attr-defined]
        return FeaturePipelineState(
            backend_type=self.backend.type,
            descriptor_featurizer=descriptor,
            plm_scaler=self._plm_scaler,
            descriptor_config=_build_descriptor_config(self.descriptor_settings),
            plm_model_name=self._effective_plm_model_name,
            plm_layer_pool=self._effective_layer_pool,
            cache_dir=self._effective_cache_dir,
            device=self.device,
            feature_names=self._feature_names,
        )

    def load_state(self, state: FeaturePipelineState) -> None:
        self.backend.type = state.backend_type
        if state.plm_model_name:
            self.backend.plm_model_name = state.plm_model_name
            self.plm_model_override = state.plm_model_name
        if state.plm_layer_pool:
            self.backend.layer_pool = state.plm_layer_pool
            self.layer_pool_override = state.plm_layer_pool
        if state.cache_dir:
            self.backend.cache_dir = state.cache_dir
            self.cache_dir_override = state.cache_dir
        if state.descriptor_config:
            self.descriptor_settings = DescriptorSettings(
                use_anarci=state.descriptor_config.use_anarci,
                regions=tuple(state.descriptor_config.regions),
                features=tuple(state.descriptor_config.features),
                ph=state.descriptor_config.ph,
            )
        self._descriptor = state.descriptor_featurizer
        self._plm_scaler = state.plm_scaler
        self._feature_names = state.feature_names
        if self.backend.type in {"plm", "concat"}:
            self._plm = _build_plm_embedder(
                self.backend,
                device=self.device,
                cache_dir_override=self.backend.cache_dir,
                plm_model_override=self.backend.plm_model_name,
                layer_pool_override=self.backend.layer_pool,
            )
        else:
            self._plm = None

    @property
    def _effective_plm_model_name(self) -> str | None:
        if self.backend.type not in {"plm", "concat"}:
            return None
        return self.plm_model_override or self.backend.plm_model_name

    @property
    def _effective_layer_pool(self) -> str | None:
        if self.backend.type not in {"plm", "concat"}:
            return None
        return self.layer_pool_override or self.backend.layer_pool

    @property
    def _effective_cache_dir(self) -> str | None:
        if self.backend.type not in {"plm", "concat"}:
            return None
        if self.cache_dir_override is not None:
            return self.cache_dir_override
        return self.backend.cache_dir

    def _validate_heavy_support(self, backend_type: str, heavy_only: bool) -> None:
        if heavy_only:
            return
        if backend_type == "descriptors" and self.descriptor_settings.use_anarci:
            msg = "Descriptor backend with ANARCI currently supports heavy-chain only inference."
            raise ValueError(msg)
        if backend_type == "concat" and self.descriptor_settings.use_anarci:
            msg = "Concat backend with descriptors requires heavy-chain only data."
            raise ValueError(msg)


def build_feature_pipeline(
    config: Config,
    *,
    backend_override: str | None = None,
    plm_model_override: str | None = None,
    cache_dir_override: str | None = None,
    layer_pool_override: str | None = None,
) -> FeaturePipeline:
    backend = FeatureBackendSettings(**asdict(config.feature_backend))
    if backend_override:
        backend.type = backend_override
    pipeline = FeaturePipeline(
        backend=backend,
        descriptors=config.descriptors,
        device=config.device,
        cache_dir_override=cache_dir_override,
        plm_model_override=plm_model_override,
        layer_pool_override=layer_pool_override,
    )
    return pipeline


def _build_descriptor_featurizer(settings: DescriptorSettings) -> DescriptorFeaturizer:
    descriptor_config = _build_descriptor_config(settings)
    return DescriptorFeaturizer(config=descriptor_config, standardize=True)


def _build_descriptor_config(settings: DescriptorSettings) -> DescriptorConfig:
    return DescriptorConfig(
        use_anarci=settings.use_anarci,
        regions=tuple(settings.regions),
        features=tuple(settings.features),
        ph=settings.ph,
    )


def _build_plm_embedder(
    backend: FeatureBackendSettings,
    *,
    device: str,
    cache_dir_override: str | None,
    plm_model_override: str | None,
    layer_pool_override: str | None,
) -> PLMEmbedder:
    model_name = plm_model_override or backend.plm_model_name
    cache_dir = cache_dir_override or backend.cache_dir
    layer_pool = layer_pool_override or backend.layer_pool
    return PLMEmbedder(
        model_name=model_name,
        layer_pool=layer_pool,
        device=device,
        cache_dir=cache_dir,
    )


def _extract_sequences(df, heavy_only: bool) -> Sequence[str]:  # noqa: ANN001
    if heavy_only or "light_seq" not in df.columns:
        return df["heavy_seq"].fillna("").astype(str).tolist()
    heavy = df["heavy_seq"].fillna("").astype(str)
    light = df["light_seq"].fillna("").astype(str)
    return (heavy + "|" + light).tolist()