|
|
from huggingface_hub import HfApi |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
from transformers.modeling_outputs import MaskedLMOutput |
|
|
|
|
|
|
|
|
def get_models() -> list[None|str]: |
|
|
"""Fetch suitable ESM models from HuggingFace Hub.""" |
|
|
if not any( |
|
|
out := [ |
|
|
m.modelId for m in HfApi().list_models( |
|
|
author="facebook", |
|
|
model_name="esm", |
|
|
task="fill-mask", |
|
|
sort="lastModified", |
|
|
direction=-1 |
|
|
) |
|
|
] |
|
|
): |
|
|
raise RuntimeError("Error while retrieving models from HuggingFace Hub") |
|
|
return out |
|
|
|
|
|
|
|
|
class Model: |
|
|
"""Wrapper for ESM models.""" |
|
|
def __init__(self, model_name: str = ""): |
|
|
"""Load selected model and tokenizer.""" |
|
|
self.model_name = model_name |
|
|
if model_name: |
|
|
self.model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
|
self.batch_converter = AutoTokenizer.from_pretrained(model_name) |
|
|
self.alphabet = self.batch_converter.get_vocab() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self.model = self.model.cuda() |
|
|
self.device = torch.device("cuda") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
|
|
|
def tokenise(self, input: str) -> BatchEncoding: |
|
|
"""Convert input string to batch of tokens.""" |
|
|
return self.batch_converter(input, return_tensors="pt") |
|
|
|
|
|
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput: |
|
|
"""Run model on batch of tokens.""" |
|
|
return self.model(batch_tokens.to(self.device), **kwargs) |
|
|
|
|
|
def __getitem__(self, key: str) -> int: |
|
|
"""Get token ID from character.""" |
|
|
return self.alphabet[key] |
|
|
|
|
|
def run_model(self, data): |
|
|
"""Run model on data.""" |
|
|
def label_row(row, token_probs): |
|
|
"""Label row with score.""" |
|
|
|
|
|
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1] |
|
|
|
|
|
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]] |
|
|
return score.item() |
|
|
|
|
|
|
|
|
batch_tokens = self.tokenise(data.seq).input_ids |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1) |
|
|
|
|
|
data.token_probs = token_probs.cpu().numpy() |
|
|
|
|
|
|
|
|
if data.scoring_strategy.startswith("masked-marginals"): |
|
|
all_token_probs = [] |
|
|
|
|
|
for i in tqdm(range(batch_tokens.size()[1])): |
|
|
|
|
|
if i in data.resi: |
|
|
|
|
|
batch_tokens_masked = batch_tokens.clone() |
|
|
batch_tokens_masked[0, i] = self['<mask>'] |
|
|
|
|
|
with torch.no_grad(): |
|
|
masked_token_probs = torch.log_softmax( |
|
|
self(batch_tokens_masked).logits, dim=-1 |
|
|
) |
|
|
else: |
|
|
|
|
|
masked_token_probs = token_probs |
|
|
|
|
|
all_token_probs.append(masked_token_probs[:, i]) |
|
|
|
|
|
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0) |
|
|
|
|
|
|
|
|
data.out[self.model_name] = data.sub.apply( |
|
|
lambda row: label_row( |
|
|
row['0'], |
|
|
token_probs, |
|
|
), |
|
|
axis=1, |
|
|
) |