|
|
from unittest.mock import patch |
|
|
|
|
|
import pytest |
|
|
from pydantic import JsonValue |
|
|
|
|
|
from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles |
|
|
from ether0.models import RewardReason |
|
|
from ether0.rewards import ( |
|
|
caption_eval, |
|
|
formula_diff, |
|
|
formula_eval, |
|
|
functional_group_eval, |
|
|
oracle_rxn_eval, |
|
|
product_eval, |
|
|
rxn_eval, |
|
|
str_eval, |
|
|
valid_mol_eval, |
|
|
valid_molecule_eval, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1(H)-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
1.0, |
|
|
id="parentheses", |
|
|
), |
|
|
pytest.param( |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1{H}-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
1.0, |
|
|
id="culies parentheses", |
|
|
), |
|
|
pytest.param( |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
1.0, |
|
|
id="same", |
|
|
), |
|
|
pytest.param( |
|
|
" methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate ", |
|
|
1.0, |
|
|
id="spacing", |
|
|
), |
|
|
pytest.param( |
|
|
"methyl 3-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
"methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", |
|
|
0.0, |
|
|
id="different", |
|
|
), |
|
|
pytest.param( |
|
|
"(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
" (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
1.0, |
|
|
id="italics", |
|
|
), |
|
|
pytest.param( |
|
|
"(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
" (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
1.0, |
|
|
id="curlies and carrots", |
|
|
), |
|
|
pytest.param( |
|
|
"(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-benzoylamino-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
" (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
1.0, |
|
|
id="more parentheses", |
|
|
), |
|
|
pytest.param( |
|
|
"(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
" (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyloxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", |
|
|
0.0, |
|
|
id="bad-parentheses", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_str_eval(yhat: str, y: str, expected: float) -> None: |
|
|
assert str_eval(yhat, y) == expected |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"Buchwald-Hartwig amination", |
|
|
"Buchwald-Hartwig amination", |
|
|
1.0, |
|
|
id="same rxn", |
|
|
), |
|
|
pytest.param( |
|
|
"buchwald hartwig amination", |
|
|
"Buchwald-Hartwig amination", |
|
|
1.0, |
|
|
id="caps/hyphens", |
|
|
), |
|
|
pytest.param( |
|
|
"BuchwaldHartwigAmination", |
|
|
"Buchwald-Hartwig amination", |
|
|
1.0, |
|
|
id="no spaces", |
|
|
), |
|
|
pytest.param( |
|
|
"Buchwald\u2013Hartwig amination", |
|
|
"Buchwald-Hartwig amination", |
|
|
1.0, |
|
|
id="en dash", |
|
|
), |
|
|
pytest.param( |
|
|
"Buchwald\u2013Hartwig animation", |
|
|
"Buchwald-Hartwig amination", |
|
|
0.0, |
|
|
id="false positive", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_rxn_eval(yhat: str, y: str, expected: float) -> None: |
|
|
assert rxn_eval(yhat, y) == expected |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14)C=6C=CC=CC6", |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14", |
|
|
1.0, |
|
|
id="full-answer", |
|
|
), |
|
|
pytest.param( |
|
|
")C=6C=CC=CC6", |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14", |
|
|
1.0, |
|
|
id="partial-answer", |
|
|
), |
|
|
pytest.param( |
|
|
"", |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14", |
|
|
0.0, |
|
|
id="empty-generation", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC", |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14", |
|
|
0.0, |
|
|
id="wrong-valid-SMILES", |
|
|
), |
|
|
pytest.param( |
|
|
"applesauce", |
|
|
"O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14", |
|
|
0.0, |
|
|
id="non-SMILES-yhat", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_valid_mol_eval(yhat: str, y: str, expected: float) -> None: |
|
|
metadata: dict[str, JsonValue] = {} |
|
|
assert ( |
|
|
valid_mol_eval(yhat, y, metadata=metadata) == expected |
|
|
), f"Reason for failure: {metadata}" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected_reward", "expected_reason"), |
|
|
[ |
|
|
pytest.param( |
|
|
"CCCO", |
|
|
"CCCO", |
|
|
1.0, |
|
|
None, |
|
|
id="exact-match", |
|
|
), |
|
|
pytest.param( |
|
|
"CCCO", |
|
|
"C#N", |
|
|
0.0, |
|
|
RewardReason.INVALID_GROUND_TRUTH, |
|
|
id="chembench-8ee3546d-a3b8-4c7b-90ef-ead9ff11a50d-removed", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_product_eval( |
|
|
yhat: str, |
|
|
y: str, |
|
|
expected_reward: float, |
|
|
expected_reason: RewardReason | None, |
|
|
) -> None: |
|
|
metadata: dict[str, JsonValue] = {} |
|
|
assert product_eval(yhat, y, metadata=metadata) == expected_reward |
|
|
assert metadata.get("reward_reason") == expected_reason |
|
|
|
|
|
assert caption_eval(yhat, y, metadata=metadata) == expected_reward |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
r"C/C=C(/C)\C(=O)O[C@@H]1C[C@@]2(C(=O)C=C(O2)/C(=C\[C@@H]3[C@@H]1C(=C)C(=O)O3)/CO)C", |
|
|
"C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3", |
|
|
1.0, |
|
|
id="match", |
|
|
), |
|
|
pytest.param( |
|
|
"CC1=CC(=C(C(=C1C(=O)O)O)C)OC(=O)C2=C(C(=C(C=C2C)OC)C)OC", |
|
|
"C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3", |
|
|
0.05, |
|
|
id="formula-match", |
|
|
), |
|
|
pytest.param( |
|
|
"CC1=CC(=C(C(=C1C(=O)O)O)C)OC(=O", |
|
|
"C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3", |
|
|
0.0, |
|
|
id="bad-mol", |
|
|
), |
|
|
pytest.param( |
|
|
"CC1=C[C@@H]2O[C@H]3C[C@H]4OC(=O)C=CC=CC(=O)OCC[C@@]5(C)O[C@@H]5C(=O)OC[C@]2(CC1)[C@@]4(C)[C@]31CO1", |
|
|
"CC1=C[C@@H]2O[C@H]3C[C@H]4OC(=O)C=CC=CC(=O)OCC[C@@]5(C)O[C@@H]5C(=O)OC[C@]2(CC1)[C@@]4(C)[C@]31CO1", |
|
|
1.0, |
|
|
id="wild-molecule", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_formula_eval(yhat: str, y: str, expected: float) -> None: |
|
|
metadata: dict[str, JsonValue] = {} |
|
|
assert ( |
|
|
formula_eval(yhat, y, soft=True, metadata=metadata) >= expected |
|
|
), f"Reason for failure: {metadata}" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1", |
|
|
"('C18H18ClN5O4S2', ['imidoylhalide cyclic'])", |
|
|
1.0, |
|
|
id="match", |
|
|
), |
|
|
pytest.param( |
|
|
r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1", |
|
|
"('C18H18ClN5O4S2', ['imidoylhalide cyclic', 'non-existing'])", |
|
|
0.0, |
|
|
id="bad groups", |
|
|
), |
|
|
pytest.param( |
|
|
r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1", |
|
|
"('C18H18ClN5O4S3', ['imidoylhalide cyclic'])", |
|
|
0.0, |
|
|
id="bad formula", |
|
|
), |
|
|
pytest.param( |
|
|
r"CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(OC)[C@@](O)(c3ccccc3)[C@H](C)O2)[C@H](C)[C@@H](O[C@@H]2O[C@H](C)C[C@H](N(C)C)[C@H]2O)[C@](C)(O)C[C@@H](C)CN[C@H](C)[C@@H](O)[C@]1(C)O", |
|
|
"('C43H74N2O12', ['1,2-Aminoalcohol', 'hydroxylated heteroatom substituted glycosidic ring', 'tertiary alcohol'])", |
|
|
1.0, |
|
|
id="renamed-groups", |
|
|
), |
|
|
pytest.param(r"CCC", "('C3H8', [])", 1.0, id="no-groups"), |
|
|
pytest.param(r"CCCNNNNN", "('C3H13N5', [])", 0.0, id="unreasonable-molecule"), |
|
|
pytest.param(r"C1CCCCC2C1CCCCCCCCC2", "('C16H30', [])", 0.0, id="bad-ring"), |
|
|
pytest.param( |
|
|
"CCCCCBr", "('C5H11Br',['alkylbromide'])", 1.0, id="observed-problem" |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_functional_group_eval(yhat: str, y: str, expected: float) -> None: |
|
|
metadata: dict[str, JsonValue] = {} |
|
|
assert ( |
|
|
functional_group_eval(yhat, y, metadata=metadata) == expected |
|
|
), f"Reason for failure: {metadata}" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "y", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
1.0, |
|
|
id="match", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCN(C)C1CC(C)(C)NC1=O", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
1.0, |
|
|
id="match-w-product", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCCN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
0.0, |
|
|
id="match-w-non-matching-product", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCXeN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
0.0, |
|
|
id="match-w-invalid-product", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
0.0, |
|
|
id="match-wo-trailing", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>>>>", |
|
|
"CCCN(C)C1CC(C)(C)NC1=O", |
|
|
0.0, |
|
|
id="no-match-w-many-trailing", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1", |
|
|
"CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
0.0, |
|
|
id="invalid", |
|
|
), |
|
|
pytest.param( |
|
|
"C(P)(P)(P)CC=O.CC1(C)(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>", |
|
|
"CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
0.0, |
|
|
id="no-purchase", |
|
|
), |
|
|
pytest.param( |
|
|
"OB(O)c1cc(C2CC2)cnc1Cl.Cl -> OB(O)c1cc(C2CC2)cnc1Cl + HBr + HIO2 + HIO3S + CH3COOH || 3s | 3*375I | 9*63BrI | 3*55Br | 3*657s*3*6I | 3*3*7Br*I*P | 3s*369I | 3*7*6s", |
|
|
"OB(O)c1cc(C2CC2)cnc1Cl", |
|
|
0.0, |
|
|
id="insane-reward-hacking", |
|
|
), |
|
|
pytest.param( |
|
|
"CNCCC1CC1(F)F>CC#CC>", |
|
|
"CNCCC1CC1(F)F", |
|
|
0.0, |
|
|
id="trivial-reactants", |
|
|
), |
|
|
pytest.param( |
|
|
"CC(C)CN1CC(O)C1.CC(C)CN1CC(O)CBr.CCO>CC#CC>", |
|
|
"CC(C)CN1CC(O)C1", |
|
|
0.0, |
|
|
id="disallow-product-in-reactants", |
|
|
), |
|
|
pytest.param( |
|
|
"N#N.CCO>CC#CC.CC(C)CN1CC(O)C1>", |
|
|
"CC(C)CN1CC(O)C1", |
|
|
0.0, |
|
|
id="disallow-product-in-reagents", |
|
|
), |
|
|
pytest.param( |
|
|
"C1(CN(C1)CC(C)C)O.CC(C)CN1CC(O)CBr.CCO>CC#CC>", |
|
|
"CC(C)CN1CC(O)C1", |
|
|
0.0, |
|
|
id="disallow-product-in-reactants-with-different-smiles", |
|
|
), |
|
|
pytest.param( |
|
|
"C=CCNC(=O)Br.BrC#Cc1ccccc1.CCO>[Mg].c1ccccc1>", |
|
|
"C=CCNC(=O)C#Cc1ccccc1", |
|
|
0.0, |
|
|
id="hacked-purchasability", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O.[THF]>CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
"CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1", |
|
|
0.0, |
|
|
id="invalid-reagent", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_oracle_rxn_eval(yhat: str, y: str, expected: float) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mock_purchasable = { |
|
|
"CC1(C)CC(N)C(=O)N1": True, |
|
|
"XeCC1(C)CC(N)C(=O)N1": False, |
|
|
"C=CCNC(=O)Br": False, |
|
|
"CC(C)CN1CC(O)C1": True, |
|
|
"CC1(C)(C)CC(N)C(=O)N1": False, |
|
|
"C(P)(P)(P)CC=O": False, |
|
|
} |
|
|
with ( |
|
|
patch("ether0.rewards.fetch_purchasable", return_value=mock_purchasable), |
|
|
patch("ether0.rewards.fetch_forward_rxn", return_value={"product": y}), |
|
|
): |
|
|
metadata: dict[str, JsonValue] = {} |
|
|
result = oracle_rxn_eval(yhat, y, metadata=metadata) |
|
|
assert result == expected, ( |
|
|
f"Given {yhat=} and {y=}, expected {expected} but got {result} with" |
|
|
f" {metadata=}." |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("f1", "f2", "expected"), |
|
|
[ |
|
|
pytest.param("C1", "C2", 1.0, id="simple-1"), |
|
|
pytest.param("C1", "C1H1", 1.0, id="simple-2"), |
|
|
pytest.param("C1H2", "C1H2", 0.0, id="simple-3"), |
|
|
pytest.param("N2", "O2", 8**0.5, id="simple-4"), |
|
|
pytest.param("X100C1", "X100C2", 1.0, id="bad-element-5"), |
|
|
pytest.param("C100", "C100H100", 100, id="big-digits"), |
|
|
pytest.param("CH2", "H2", 1, id="implicit"), |
|
|
], |
|
|
) |
|
|
def test_formula_diff(f1: str, f2: str, expected: float) -> None: |
|
|
assert formula_diff(f1, f2) == expected |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("mol", "ref_mol", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"O=C(/C=C/C1=CC=CC=C1)OC[C@H]1O[C@@H](O[C@@H]2O[C@@H]3C[C@H]4[C@H](O)[C@@H](O)[C@@](O)(CO3)[C@@H]24)[C@H](O)[C@@H](O)[C@@H]1O", |
|
|
None, |
|
|
1, |
|
|
id="passing-1", |
|
|
), |
|
|
pytest.param( |
|
|
"CC(C)C[C@H](NC(=O)[C@H](Cc1c[nH]cn1)NC(=O)[C@H](Cc1ccccc1)NC(=O)OC(C)(C)C)[C@@H](O)[C@@H](O)CC(C)C", |
|
|
None, |
|
|
1, |
|
|
id="passing-2", |
|
|
), |
|
|
pytest.param("CCCCCBr", "CCCCCBr", 1, id="passing-3"), |
|
|
], |
|
|
) |
|
|
def test_is_reasonable_ring_system( |
|
|
mol: str, ref_mol: str | None, expected: float |
|
|
) -> None: |
|
|
mol_ = mol_from_smiles(mol) |
|
|
assert mol_ is not None |
|
|
assert ( |
|
|
is_reasonable_ring_system(mol_, mol_from_smiles(ref_mol) if ref_mol else None) |
|
|
== expected |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("mol", "ref_mol", "expected"), |
|
|
[ |
|
|
pytest.param( |
|
|
"O=C1OC2=CC=CC=C2C=C1c3ccc(O)c(O)c3c4ccc(O)cc4OCC=CCCCCCCC(N)(N)NS", |
|
|
None, |
|
|
False, |
|
|
id="weird-nitrogen-group", |
|
|
), |
|
|
pytest.param( |
|
|
"O=S(=O)(N)c1c(Cl)cc2c(c1)S(=O)(=O)NCN2", |
|
|
None, |
|
|
True, |
|
|
id="sulfonamide", |
|
|
), |
|
|
pytest.param( |
|
|
"C1=NC=NC=C1OCC=CCCC(N)S", |
|
|
None, |
|
|
False, |
|
|
id="weird-S-C-N-group", |
|
|
), |
|
|
pytest.param( |
|
|
"CCC", |
|
|
None, |
|
|
True, |
|
|
id="simple-alkane", |
|
|
), |
|
|
], |
|
|
) |
|
|
def test_is_reasonable_fp(mol: str, ref_mol: str | None, expected: bool) -> None: |
|
|
mol_ = mol_from_smiles(mol) |
|
|
assert mol_ is not None |
|
|
assert ( |
|
|
is_reasonable_fp(mol_, ref_mol=mol_from_smiles(ref_mol) if ref_mol else None) |
|
|
== expected |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("yhat", "expected"), |
|
|
[ |
|
|
("CC(C)CCC", 1.0), |
|
|
("CC(C)(C)(C)C", 0.0), |
|
|
("", 0.0), |
|
|
("INVALID", 0.0), |
|
|
], |
|
|
) |
|
|
def test_valid_molecule_eval(yhat, expected): |
|
|
assert valid_molecule_eval(yhat, y="") == expected |
|
|
|