Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import re | |
| import secrets | |
| import tempfile | |
| import uuid | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import ClassVar, Literal | |
| import numpy as np | |
| import numpy.typing as npt | |
| from fastapi import Depends, FastAPI, HTTPException, status | |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
| from molbloom import buy | |
| from molsol import KDESol | |
| from onmt import opts | |
| from onmt.translate.translator import build_translator | |
| from onmt.utils.logging import init_logger | |
| from onmt.utils.misc import split_corpus | |
| from onmt.utils.parse import ArgumentParser | |
| from pydantic import BaseModel | |
| from rdkit import Chem | |
| ETHER0_DIR = Path(__file__).parent | |
| auth_scheme = HTTPBearer() | |
| def validate_token( | |
| credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008 | |
| ) -> str: | |
| # NOTE: don't use os.environ.get() to avoid possible empty string matches, and | |
| # to have clearer server failures if the AUTH_TOKEN env var isn't present | |
| if not secrets.compare_digest( | |
| credentials.credentials, os.environ["ETHER0_REMOTES_API_TOKEN"] | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect bearer token", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return credentials.credentials | |
| app = FastAPI(title="ether0 remotes server", dependencies=[Depends(validate_token)]) | |
| class MolecularTransformer: | |
| """Uses code from https://doi.org/10.1021/acscentsci.9b00576.""" | |
| DEFAULT_MOLTRANS_MODEL_PATH: ClassVar[Path] = ( | |
| ETHER0_DIR / "USPTO480k_model_step_400000.pt" | |
| ) | |
| def __init__(self): | |
| # Use `or None` to deny setting empty string to the environment variable | |
| os_environ_model_path = ( | |
| os.environ.get("ETHER0_REMOTES_MOLTRANS_MODEL_PATH") or None | |
| ) | |
| self.model_path = os_environ_model_path or str(self.DEFAULT_MOLTRANS_MODEL_PATH) | |
| if not Path(self.model_path).exists(): | |
| raise FileNotFoundError( | |
| f"MolTrans model not found" | |
| f"{f', did you misconfigure the path {os_environ_model_path}?' if os_environ_model_path else '.'}" # noqa: E501 | |
| " Please properly configure the environment variable" | |
| " 'ETHER0_REMOTES_MOLTRANS_MODEL_PATH'," | |
| f" or the default path checked is {self.DEFAULT_MOLTRANS_MODEL_PATH}." | |
| ) | |
| def translate(opt: argparse.Namespace) -> None: | |
| ArgumentParser.validate_translate_opts(opt) | |
| logger = init_logger(opt.log_file) | |
| translator = build_translator(opt, logger=logger, report_score=True) | |
| src_shards = split_corpus(opt.src, opt.shard_size) | |
| tgt_shards = split_corpus(opt.tgt, opt.shard_size) | |
| features_shards = [] | |
| features_names = [] | |
| for feat_name, feat_path in opt.src_feats.items(): | |
| features_shards.append(split_corpus(feat_path, opt.shard_size)) | |
| features_names.append(feat_name) | |
| shard_pairs = zip(src_shards, tgt_shards, *features_shards) # noqa: B905 | |
| for (src_shard, tgt_shard, *features_shard) in shard_pairs: | |
| features_shard_ = defaultdict(list) | |
| for j, x in enumerate(features_shard): | |
| features_shard_[features_names[j]] = x | |
| translator.translate( | |
| src=src_shard, | |
| src_feats=features_shard_, | |
| tgt=tgt_shard, | |
| batch_size=opt.batch_size, | |
| batch_type=opt.batch_type, | |
| attn_debug=opt.attn_debug, | |
| align_debug=opt.align_debug, | |
| ) | |
| def smiles_tokenizer(smiles: str) -> str: | |
| smiles_regex = re.compile( | |
| r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])" | |
| ) | |
| tokens = list(smiles_regex.findall(smiles)) | |
| return " ".join(tokens) | |
| def canonicalize_smiles(smiles: str) -> str: | |
| # Try to use canonical smiles because original uspto is distributed in canonical form. | |
| # If fails, we trust the augmentation and use the original smiles. | |
| try: | |
| return Chem.MolToSmiles( | |
| Chem.MolFromSmiles(smiles), isomericSmiles=True, canonical=True | |
| ) | |
| except Exception as err: | |
| # If rdkit failed, it means some molecule is invalid. | |
| # Here we catch which ones are invalid so we inform what's wrong | |
| # on the error message. | |
| invalid_smiles = [] | |
| for mol in smiles.split("."): | |
| try: | |
| Chem.MolToSmiles( | |
| Chem.MolFromSmiles(mol), isomericSmiles=True, canonical=True | |
| ) | |
| except: # noqa: E722 | |
| invalid_smiles.append(mol) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=( | |
| "The reaction could not be parsed by RDKit. The following" | |
| f" SMILES were invalid: {', '.join(invalid_smiles)}" | |
| ), | |
| ) from err | |
| def run(self, reaction: str) -> tuple[str, uuid.UUID]: | |
| """Translates SMILES reaction strings using MolTrans model. | |
| Args: | |
| reaction: SMILES representation of a chemical reaction | |
| Returns: | |
| SMILES representation of the predicted product and a job ID | |
| """ | |
| # Create a unique ID for the request | |
| job_id = uuid.uuid4() | |
| # Create temporary files for use in mol moltransformer | |
| with ( | |
| tempfile.NamedTemporaryFile( | |
| mode="w+", delete=False, encoding="utf-8" | |
| ) as precursor_file, | |
| tempfile.NamedTemporaryFile( | |
| mode="w+", delete=False, encoding="utf-8" | |
| ) as output_file, | |
| ): | |
| # Write tokenized reaction to the precursor file | |
| precursor_file.write(MolecularTransformer.smiles_tokenizer(reaction)) | |
| precursor_file.flush() | |
| # OpenNMT expects to receive a list of arguments to translate | |
| parser = ArgumentParser() | |
| opts.config_opts(parser) | |
| opts.translate_opts(parser) | |
| args_dict = { | |
| "model": self.model_path, | |
| "src": precursor_file.name, | |
| "output": output_file.name, | |
| "batch_size": "64", | |
| "beam_size": "50", | |
| "max_length": "300", | |
| } | |
| args_list = [f"--{k}={v}" for k, v in args_dict.items()] | |
| opt = parser.parse_args(args_list) | |
| MolecularTransformer.translate(opt) | |
| output_file.close() | |
| prediction = Path(output_file.name).read_text(encoding="utf-8") | |
| # Clean up temporary files | |
| # we don't care if a failure leaves them dangling, | |
| # since they are in a temp dir | |
| os.unlink(precursor_file.name) | |
| os.unlink(output_file.name) | |
| return prediction.replace(" ", "").strip(), job_id | |
| class MolBloom: | |
| """Uses code from https://doi.org/10.1186/s13321-023-00765-1.""" | |
| def __init__(self) -> None: | |
| # trigger eager loading of the bloom filter | |
| buy("C1=CC=CC=C1", catalog="zinc20") | |
| self.bloom = buy | |
| def run(self, smiles: str) -> bool: | |
| """Checks if a molecule is purchasable using MolBloom. | |
| Args: | |
| smiles: SMILES representation of a molecule | |
| Returns: | |
| True if the molecule is purchasable, False otherwise | |
| """ | |
| return self.bloom(smiles, canonicalize=True, catalog="zinc20") | |
| class Solubility: | |
| """Uses code from https://doi.org/10.1039/D3DD00217A.""" | |
| def __init__(self) -> None: | |
| self.sol = KDESol() | |
| def run(self, smiles: str) -> npt.NDArray[np.float32] | Literal[False]: | |
| """Computes solubility prediction for a molecule using KDESol. | |
| Args: | |
| smiles: SMILES representation of a molecule. | |
| Returns: | |
| Numpy array containing the mean predicted solubility, | |
| aleatoric uncertainty (au), and epistemic uncertainty (eu). | |
| """ | |
| m = Chem.MolFromSmiles(smiles) | |
| if m is None: | |
| return False # type: ignore[unreachable] | |
| prediction = self.sol(Chem.MolToSmiles(m, canonical=True, isomericSmiles=False)) | |
| if prediction is None: | |
| # Try without canonicalization. | |
| # The model is an LSTM that uses tokens generated from SELFIES tokens. | |
| # Depending on the SMILES notation, the model might not have the necessary tokens | |
| # in its vocabulary to describe the molecule. | |
| prediction = self.sol(smiles) | |
| return prediction if prediction is not None else False | |
| class MolTransRequest(BaseModel): | |
| reaction: str | |
| def translate_endpoint(request: MolTransRequest) -> dict[str, str | uuid.UUID]: | |
| reaction = request.reaction.replace(" ", "") | |
| if not reaction.count(">") == 2: # noqa: PLR2004 | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=( | |
| f"Syntax error in the reaction SMILES: {reaction}\n" | |
| "The reaction should have two '>' characters, and no spaces." | |
| ), | |
| ) | |
| rxn = reaction.split(">")[:-1] | |
| query_reaction = MolecularTransformer.canonicalize_smiles( | |
| ".".join([r for r in rxn if r]) | |
| ) | |
| product, job_id = MolecularTransformer().run(query_reaction) | |
| return { | |
| "product": product, | |
| "id": job_id, | |
| "reaction": query_reaction + ">>" + product, | |
| } | |
| class MolBloomRequest(BaseModel): | |
| smiles: list[str] | str | |
| def is_purchasable_endpoint(request: MolBloomRequest) -> dict[str, bool]: | |
| is_purchasable = MolBloom().run | |
| smiles = request.smiles | |
| if isinstance(smiles, str): | |
| smiles = [smiles] | |
| return {s: is_purchasable(s) for s in smiles} | |
| class SmilesRequest(BaseModel): | |
| smiles: str | |
| def compute_solubility_endpoint( | |
| request: SmilesRequest, | |
| ) -> dict[str, float] | dict[str, str]: | |
| if "." in request.smiles: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Only single molecules are supported", | |
| ) | |
| prediction = Solubility().run(smiles=request.smiles) | |
| if prediction is False: | |
| return {"error": "Solubility prediction failed."} | |
| mean, au, eu = prediction.tolist() | |
| return {"mean": mean, "au": au, "eu": eu} | |
| def main() -> None: | |
| """Run uvicorn to serve the FastAPI app.""" | |
| try: | |
| import uvicorn # noqa: PLC0415 | |
| except ImportError as exc: | |
| raise ImportError( | |
| "Serving requires the 'serve' extra for the `uvicorn` package. Please:" | |
| " `pip install ether0.remotes[serve]`." | |
| ) from exc | |
| uvicorn.run("ether0.server:app") | |
| if __name__ == "__main__": | |
| main() | |