Spaces:
Running
Running
| import os | |
| import sys | |
| import torch | |
| import selfies as sf # selfies>=2.1.1 | |
| import pickle | |
| import pandas as pd | |
| import numpy as np | |
| from datasets import Dataset | |
| from rdkit import Chem | |
| from transformers import AutoTokenizer, AutoModel | |
| class SELFIES(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = None | |
| self.tokenizer = None | |
| self.invalid = [] | |
| def get_selfies(self, smiles_list): | |
| self.invalid = [] | |
| spaced_selfies_batch = [] | |
| for i, smiles in enumerate(smiles_list): | |
| try: | |
| selfies = sf.encoder(smiles.rstrip()) | |
| except: | |
| try: | |
| smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.rstrip())) | |
| selfies = sf.encoder(smiles) | |
| except: | |
| selfies = "[]" | |
| self.invalid.append(i) | |
| spaced_selfies_batch.append(selfies.replace('][', '] [')) | |
| return spaced_selfies_batch | |
| def get_embedding(self, selfies): | |
| encoding = self.tokenizer(selfies["selfies"], return_tensors='pt', max_length=128, truncation=True, padding='max_length') | |
| input_ids = encoding['input_ids'] | |
| attention_mask = encoding['attention_mask'] | |
| outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| model_output = outputs.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float() | |
| sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| model_output = sum_embeddings / sum_mask | |
| encoding["embedding"] = model_output | |
| return encoding | |
| def load(self, checkpoint="bart-2908.pickle"): | |
| """ | |
| inputs : | |
| checkpoint (pickle object) | |
| """ | |
| self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") | |
| self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted") | |
| """if os.path.isfile(checkpoint): | |
| with open(checkpoint, "rb") as input_file: | |
| self.model, self.tokenizer = pickle.load(input_file) | |
| for p in sys.path: | |
| file = p + "/" + checkpoint | |
| if os.path.isfile(file): | |
| with open(file, "rb") as input_file: | |
| self.model, self.tokenizer = pickle.load(input_file)""" | |
| # TODO: remove `use_gpu` argument in validation pipeline | |
| def encode(self, smiles_list=[], use_gpu=False, return_tensor=False): | |
| """ | |
| inputs : | |
| checkpoint (pickle object) | |
| :return: embedding | |
| """ | |
| selfies = self.get_selfies(smiles_list) | |
| selfies_df = pd.DataFrame(selfies,columns=["selfies"]) | |
| data = Dataset.from_pandas(selfies_df) | |
| embedding = data.map(self.get_embedding, batched=True, num_proc=1, batch_size=128) | |
| emb = np.asarray(embedding["embedding"].copy()) | |
| for idx in self.invalid: | |
| emb[idx] = np.nan | |
| print("Cannot encode {0} to selfies and embedding replaced by NaN".format(smiles_list[idx])) | |
| if return_tensor: | |
| return torch.tensor(emb) | |
| return pd.DataFrame(emb) | |