jonahkall's picture
Upload 51 files
4c346eb verified
raw
history blame
11.1 kB
import argparse
import os
import re
import secrets
import tempfile
import uuid
from collections import defaultdict
from pathlib import Path
from typing import ClassVar, Literal
import numpy as np
import numpy.typing as npt
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from molbloom import buy
from molsol import KDESol
from onmt import opts
from onmt.translate.translator import build_translator
from onmt.utils.logging import init_logger
from onmt.utils.misc import split_corpus
from onmt.utils.parse import ArgumentParser
from pydantic import BaseModel
from rdkit import Chem
ETHER0_DIR = Path(__file__).parent
auth_scheme = HTTPBearer()
def validate_token(
credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008
) -> str:
# NOTE: don't use os.environ.get() to avoid possible empty string matches, and
# to have clearer server failures if the AUTH_TOKEN env var isn't present
if not secrets.compare_digest(
credentials.credentials, os.environ["ETHER0_REMOTES_API_TOKEN"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
app = FastAPI(title="ether0 remotes server", dependencies=[Depends(validate_token)])
class MolecularTransformer:
"""Uses code from https://doi.org/10.1021/acscentsci.9b00576."""
DEFAULT_MOLTRANS_MODEL_PATH: ClassVar[Path] = (
ETHER0_DIR / "USPTO480k_model_step_400000.pt"
)
def __init__(self):
# Use `or None` to deny setting empty string to the environment variable
os_environ_model_path = (
os.environ.get("ETHER0_REMOTES_MOLTRANS_MODEL_PATH") or None
)
self.model_path = os_environ_model_path or str(self.DEFAULT_MOLTRANS_MODEL_PATH)
if not Path(self.model_path).exists():
raise FileNotFoundError(
f"MolTrans model not found"
f"{f', did you misconfigure the path {os_environ_model_path}?' if os_environ_model_path else '.'}" # noqa: E501
" Please properly configure the environment variable"
" 'ETHER0_REMOTES_MOLTRANS_MODEL_PATH',"
f" or the default path checked is {self.DEFAULT_MOLTRANS_MODEL_PATH}."
)
@staticmethod
def translate(opt: argparse.Namespace) -> None:
ArgumentParser.validate_translate_opts(opt)
logger = init_logger(opt.log_file)
translator = build_translator(opt, logger=logger, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
features_shards = []
features_names = []
for feat_name, feat_path in opt.src_feats.items():
features_shards.append(split_corpus(feat_path, opt.shard_size))
features_names.append(feat_name)
shard_pairs = zip(src_shards, tgt_shards, *features_shards) # noqa: B905
for (src_shard, tgt_shard, *features_shard) in shard_pairs:
features_shard_ = defaultdict(list)
for j, x in enumerate(features_shard):
features_shard_[features_names[j]] = x
translator.translate(
src=src_shard,
src_feats=features_shard_,
tgt=tgt_shard,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
@staticmethod
def smiles_tokenizer(smiles: str) -> str:
smiles_regex = re.compile(
r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
)
tokens = list(smiles_regex.findall(smiles))
return " ".join(tokens)
@staticmethod
def canonicalize_smiles(smiles: str) -> str:
# Try to use canonical smiles because original uspto is distributed in canonical form.
# If fails, we trust the augmentation and use the original smiles.
try:
return Chem.MolToSmiles(
Chem.MolFromSmiles(smiles), isomericSmiles=True, canonical=True
)
except Exception as err:
# If rdkit failed, it means some molecule is invalid.
# Here we catch which ones are invalid so we inform what's wrong
# on the error message.
invalid_smiles = []
for mol in smiles.split("."):
try:
Chem.MolToSmiles(
Chem.MolFromSmiles(mol), isomericSmiles=True, canonical=True
)
except: # noqa: E722
invalid_smiles.append(mol)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"The reaction could not be parsed by RDKit. The following"
f" SMILES were invalid: {', '.join(invalid_smiles)}"
),
) from err
def run(self, reaction: str) -> tuple[str, uuid.UUID]:
"""Translates SMILES reaction strings using MolTrans model.
Args:
reaction: SMILES representation of a chemical reaction
Returns:
SMILES representation of the predicted product and a job ID
"""
# Create a unique ID for the request
job_id = uuid.uuid4()
# Create temporary files for use in mol moltransformer
with (
tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as precursor_file,
tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as output_file,
):
# Write tokenized reaction to the precursor file
precursor_file.write(MolecularTransformer.smiles_tokenizer(reaction))
precursor_file.flush()
# OpenNMT expects to receive a list of arguments to translate
parser = ArgumentParser()
opts.config_opts(parser)
opts.translate_opts(parser)
args_dict = {
"model": self.model_path,
"src": precursor_file.name,
"output": output_file.name,
"batch_size": "64",
"beam_size": "50",
"max_length": "300",
}
args_list = [f"--{k}={v}" for k, v in args_dict.items()]
opt = parser.parse_args(args_list)
MolecularTransformer.translate(opt)
output_file.close()
prediction = Path(output_file.name).read_text(encoding="utf-8")
# Clean up temporary files
# we don't care if a failure leaves them dangling,
# since they are in a temp dir
os.unlink(precursor_file.name)
os.unlink(output_file.name)
return prediction.replace(" ", "").strip(), job_id
class MolBloom:
"""Uses code from https://doi.org/10.1186/s13321-023-00765-1."""
def __init__(self) -> None:
# trigger eager loading of the bloom filter
buy("C1=CC=CC=C1", catalog="zinc20")
self.bloom = buy
def run(self, smiles: str) -> bool:
"""Checks if a molecule is purchasable using MolBloom.
Args:
smiles: SMILES representation of a molecule
Returns:
True if the molecule is purchasable, False otherwise
"""
return self.bloom(smiles, canonicalize=True, catalog="zinc20")
class Solubility:
"""Uses code from https://doi.org/10.1039/D3DD00217A."""
def __init__(self) -> None:
self.sol = KDESol()
def run(self, smiles: str) -> npt.NDArray[np.float32] | Literal[False]:
"""Computes solubility prediction for a molecule using KDESol.
Args:
smiles: SMILES representation of a molecule.
Returns:
Numpy array containing the mean predicted solubility,
aleatoric uncertainty (au), and epistemic uncertainty (eu).
"""
m = Chem.MolFromSmiles(smiles)
if m is None:
return False # type: ignore[unreachable]
prediction = self.sol(Chem.MolToSmiles(m, canonical=True, isomericSmiles=False))
if prediction is None:
# Try without canonicalization.
# The model is an LSTM that uses tokens generated from SELFIES tokens.
# Depending on the SMILES notation, the model might not have the necessary tokens
# in its vocabulary to describe the molecule.
prediction = self.sol(smiles)
return prediction if prediction is not None else False
class MolTransRequest(BaseModel):
reaction: str
@app.post("/translate")
def translate_endpoint(request: MolTransRequest) -> dict[str, str | uuid.UUID]:
reaction = request.reaction.replace(" ", "")
if not reaction.count(">") == 2: # noqa: PLR2004
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Syntax error in the reaction SMILES: {reaction}\n"
"The reaction should have two '>' characters, and no spaces."
),
)
rxn = reaction.split(">")[:-1]
query_reaction = MolecularTransformer.canonicalize_smiles(
".".join([r for r in rxn if r])
)
product, job_id = MolecularTransformer().run(query_reaction)
return {
"product": product,
"id": job_id,
"reaction": query_reaction + ">>" + product,
}
class MolBloomRequest(BaseModel):
smiles: list[str] | str
@app.post("/is_purchasable")
def is_purchasable_endpoint(request: MolBloomRequest) -> dict[str, bool]:
is_purchasable = MolBloom().run
smiles = request.smiles
if isinstance(smiles, str):
smiles = [smiles]
return {s: is_purchasable(s) for s in smiles}
class SmilesRequest(BaseModel):
smiles: str
@app.post("/compute_solubility")
def compute_solubility_endpoint(
request: SmilesRequest,
) -> dict[str, float] | dict[str, str]:
if "." in request.smiles:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only single molecules are supported",
)
prediction = Solubility().run(smiles=request.smiles)
if prediction is False:
return {"error": "Solubility prediction failed."}
mean, au, eu = prediction.tolist()
return {"mean": mean, "au": au, "eu": eu}
def main() -> None:
"""Run uvicorn to serve the FastAPI app."""
try:
import uvicorn # noqa: PLC0415
except ImportError as exc:
raise ImportError(
"Serving requires the 'serve' extra for the `uvicorn` package. Please:"
" `pip install ether0.remotes[serve]`."
) from exc
uvicorn.run("ether0.server:app")
if __name__ == "__main__":
main()