Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,130 Bytes
4c346eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
import argparse
import os
import re
import secrets
import tempfile
import uuid
from collections import defaultdict
from pathlib import Path
from typing import ClassVar, Literal
import numpy as np
import numpy.typing as npt
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from molbloom import buy
from molsol import KDESol
from onmt import opts
from onmt.translate.translator import build_translator
from onmt.utils.logging import init_logger
from onmt.utils.misc import split_corpus
from onmt.utils.parse import ArgumentParser
from pydantic import BaseModel
from rdkit import Chem
ETHER0_DIR = Path(__file__).parent
auth_scheme = HTTPBearer()
def validate_token(
credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008
) -> str:
# NOTE: don't use os.environ.get() to avoid possible empty string matches, and
# to have clearer server failures if the AUTH_TOKEN env var isn't present
if not secrets.compare_digest(
credentials.credentials, os.environ["ETHER0_REMOTES_API_TOKEN"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
app = FastAPI(title="ether0 remotes server", dependencies=[Depends(validate_token)])
class MolecularTransformer:
"""Uses code from https://doi.org/10.1021/acscentsci.9b00576."""
DEFAULT_MOLTRANS_MODEL_PATH: ClassVar[Path] = (
ETHER0_DIR / "USPTO480k_model_step_400000.pt"
)
def __init__(self):
# Use `or None` to deny setting empty string to the environment variable
os_environ_model_path = (
os.environ.get("ETHER0_REMOTES_MOLTRANS_MODEL_PATH") or None
)
self.model_path = os_environ_model_path or str(self.DEFAULT_MOLTRANS_MODEL_PATH)
if not Path(self.model_path).exists():
raise FileNotFoundError(
f"MolTrans model not found"
f"{f', did you misconfigure the path {os_environ_model_path}?' if os_environ_model_path else '.'}" # noqa: E501
" Please properly configure the environment variable"
" 'ETHER0_REMOTES_MOLTRANS_MODEL_PATH',"
f" or the default path checked is {self.DEFAULT_MOLTRANS_MODEL_PATH}."
)
@staticmethod
def translate(opt: argparse.Namespace) -> None:
ArgumentParser.validate_translate_opts(opt)
logger = init_logger(opt.log_file)
translator = build_translator(opt, logger=logger, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
features_shards = []
features_names = []
for feat_name, feat_path in opt.src_feats.items():
features_shards.append(split_corpus(feat_path, opt.shard_size))
features_names.append(feat_name)
shard_pairs = zip(src_shards, tgt_shards, *features_shards) # noqa: B905
for (src_shard, tgt_shard, *features_shard) in shard_pairs:
features_shard_ = defaultdict(list)
for j, x in enumerate(features_shard):
features_shard_[features_names[j]] = x
translator.translate(
src=src_shard,
src_feats=features_shard_,
tgt=tgt_shard,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
@staticmethod
def smiles_tokenizer(smiles: str) -> str:
smiles_regex = re.compile(
r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
)
tokens = list(smiles_regex.findall(smiles))
return " ".join(tokens)
@staticmethod
def canonicalize_smiles(smiles: str) -> str:
# Try to use canonical smiles because original uspto is distributed in canonical form.
# If fails, we trust the augmentation and use the original smiles.
try:
return Chem.MolToSmiles(
Chem.MolFromSmiles(smiles), isomericSmiles=True, canonical=True
)
except Exception as err:
# If rdkit failed, it means some molecule is invalid.
# Here we catch which ones are invalid so we inform what's wrong
# on the error message.
invalid_smiles = []
for mol in smiles.split("."):
try:
Chem.MolToSmiles(
Chem.MolFromSmiles(mol), isomericSmiles=True, canonical=True
)
except: # noqa: E722
invalid_smiles.append(mol)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
"The reaction could not be parsed by RDKit. The following"
f" SMILES were invalid: {', '.join(invalid_smiles)}"
),
) from err
def run(self, reaction: str) -> tuple[str, uuid.UUID]:
"""Translates SMILES reaction strings using MolTrans model.
Args:
reaction: SMILES representation of a chemical reaction
Returns:
SMILES representation of the predicted product and a job ID
"""
# Create a unique ID for the request
job_id = uuid.uuid4()
# Create temporary files for use in mol moltransformer
with (
tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as precursor_file,
tempfile.NamedTemporaryFile(
mode="w+", delete=False, encoding="utf-8"
) as output_file,
):
# Write tokenized reaction to the precursor file
precursor_file.write(MolecularTransformer.smiles_tokenizer(reaction))
precursor_file.flush()
# OpenNMT expects to receive a list of arguments to translate
parser = ArgumentParser()
opts.config_opts(parser)
opts.translate_opts(parser)
args_dict = {
"model": self.model_path,
"src": precursor_file.name,
"output": output_file.name,
"batch_size": "64",
"beam_size": "50",
"max_length": "300",
}
args_list = [f"--{k}={v}" for k, v in args_dict.items()]
opt = parser.parse_args(args_list)
MolecularTransformer.translate(opt)
output_file.close()
prediction = Path(output_file.name).read_text(encoding="utf-8")
# Clean up temporary files
# we don't care if a failure leaves them dangling,
# since they are in a temp dir
os.unlink(precursor_file.name)
os.unlink(output_file.name)
return prediction.replace(" ", "").strip(), job_id
class MolBloom:
"""Uses code from https://doi.org/10.1186/s13321-023-00765-1."""
def __init__(self) -> None:
# trigger eager loading of the bloom filter
buy("C1=CC=CC=C1", catalog="zinc20")
self.bloom = buy
def run(self, smiles: str) -> bool:
"""Checks if a molecule is purchasable using MolBloom.
Args:
smiles: SMILES representation of a molecule
Returns:
True if the molecule is purchasable, False otherwise
"""
return self.bloom(smiles, canonicalize=True, catalog="zinc20")
class Solubility:
"""Uses code from https://doi.org/10.1039/D3DD00217A."""
def __init__(self) -> None:
self.sol = KDESol()
def run(self, smiles: str) -> npt.NDArray[np.float32] | Literal[False]:
"""Computes solubility prediction for a molecule using KDESol.
Args:
smiles: SMILES representation of a molecule.
Returns:
Numpy array containing the mean predicted solubility,
aleatoric uncertainty (au), and epistemic uncertainty (eu).
"""
m = Chem.MolFromSmiles(smiles)
if m is None:
return False # type: ignore[unreachable]
prediction = self.sol(Chem.MolToSmiles(m, canonical=True, isomericSmiles=False))
if prediction is None:
# Try without canonicalization.
# The model is an LSTM that uses tokens generated from SELFIES tokens.
# Depending on the SMILES notation, the model might not have the necessary tokens
# in its vocabulary to describe the molecule.
prediction = self.sol(smiles)
return prediction if prediction is not None else False
class MolTransRequest(BaseModel):
reaction: str
@app.post("/translate")
def translate_endpoint(request: MolTransRequest) -> dict[str, str | uuid.UUID]:
reaction = request.reaction.replace(" ", "")
if not reaction.count(">") == 2: # noqa: PLR2004
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Syntax error in the reaction SMILES: {reaction}\n"
"The reaction should have two '>' characters, and no spaces."
),
)
rxn = reaction.split(">")[:-1]
query_reaction = MolecularTransformer.canonicalize_smiles(
".".join([r for r in rxn if r])
)
product, job_id = MolecularTransformer().run(query_reaction)
return {
"product": product,
"id": job_id,
"reaction": query_reaction + ">>" + product,
}
class MolBloomRequest(BaseModel):
smiles: list[str] | str
@app.post("/is_purchasable")
def is_purchasable_endpoint(request: MolBloomRequest) -> dict[str, bool]:
is_purchasable = MolBloom().run
smiles = request.smiles
if isinstance(smiles, str):
smiles = [smiles]
return {s: is_purchasable(s) for s in smiles}
class SmilesRequest(BaseModel):
smiles: str
@app.post("/compute_solubility")
def compute_solubility_endpoint(
request: SmilesRequest,
) -> dict[str, float] | dict[str, str]:
if "." in request.smiles:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only single molecules are supported",
)
prediction = Solubility().run(smiles=request.smiles)
if prediction is False:
return {"error": "Solubility prediction failed."}
mean, au, eu = prediction.tolist()
return {"mean": mean, "au": au, "eu": eu}
def main() -> None:
"""Run uvicorn to serve the FastAPI app."""
try:
import uvicorn # noqa: PLC0415
except ImportError as exc:
raise ImportError(
"Serving requires the 'serve' extra for the `uvicorn` package. Please:"
" `pip install ether0.remotes[serve]`."
) from exc
uvicorn.run("ether0.server:app")
if __name__ == "__main__":
main()
|