Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,081 Bytes
4c346eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import re
from collections.abc import Collection
from enum import StrEnum, auto
from typing import Any
from datasets import DatasetDict
from pydantic import BaseModel, Field, model_validator
from ether0.utils import TDataset
REWARD_REASON_KEY = "reward_reason" # Sentinel key
class RewardReason(StrEnum):
FORMAT_FAILED = auto()
INVALID_MOL = auto()
# Catch-all for invalid values that aren't a molecule or a reaction
INVALID_VALUE = auto()
# Oracle regression values
WRONG_NUMERICAL_ANSWER = auto()
# Reaction/retro-synthesis failures
INVALID_RXN = auto()
WRONG_PRODUCT = auto()
PRODUCT_IS_REACTANT = auto()
NOT_PURCHASABLE = auto()
# Molecule formula/functional group failures
WRONG_FORMULA = auto()
FAILED_CONSTRAINT = auto()
# Unreasonable molecules
FAILED_REOS_CHECK = auto()
FAILED_RING_CHECK = auto()
FAILED_COUNTERION_CHECK = auto()
# Really this is a bug, but we don't want to blow up training if a
# few bad examples slip through.
INVALID_GROUND_TRUTH = auto()
# Failover reason if we have an exception during a reward function.
# NOTE: not using "failed" or "error" since an unhandled exception
# may be something else
REWARD_FUNCTION_EXCEPTION = auto()
# These are automatically added if no other reason is given
WRONG_ANSWER = auto()
RIGHT_ANSWER = auto()
def set_reason(self, metadata: dict | None) -> None:
if metadata is not None:
metadata[REWARD_REASON_KEY] = self.value
@classmethod
def set_default_reason(cls, reward: float, metadata: dict | None) -> None:
if metadata is not None and REWARD_REASON_KEY not in metadata:
(cls.RIGHT_ANSWER if reward >= 1.0 else cls.WRONG_ANSWER).set_reason(
metadata
)
SOLUTION_DELIMITER = "!:!"
class RewardFunctionInfo(BaseModel):
"""Metadata used by a reward function to evaluate a solution."""
fxn_name: str = Field(description="Name of the reward function to use.")
answer_info: str = Field(
description="Serialized metadata used by the reward function."
)
problem_type: str = Field(description="Problem type, for reference.")
@model_validator(mode="before")
@classmethod
def check_card_number_not_present(cls, data: Any) -> Any:
if isinstance(data, str):
# Deserialize from a string 3-tuple
fn, ainfo, pt = data.split(SOLUTION_DELIMITER, maxsplit=2)
return {"fxn_name": fn, "answer_info": ainfo, "problem_type": pt}
return data
class QAExample(BaseModel):
"""Question-answer example with reward function info."""
id: str = Field(description="Unique identifier for this example.")
problem: str = Field(description="Problem to solve.")
problem_type: str = Field(description="Problem type, for reference or filtering.")
solution: RewardFunctionInfo = Field(
description="Metadata for the reward function."
)
ideal: str | None = Field(
description=(
"An optional ideal answer. This could be a candidate SMILES, a log10 of"
" water solubility, or None if having an ideal does not make sense."
)
)
unformatted: str | None = Field(
description=(
"Optional raw data used to generate the problem, used for traceability."
)
)
def filter_problem_types(
dataset: TDataset, problem_types: str | Collection[str] | None
) -> TDataset:
"""Filter a dataset by problem types.
Args:
dataset: The dataset to filter. Can be a single Dataset or a DatasetDict.
problem_types: A string or collection of strings specifying the problem
types to filter by.
- If None, the original dataset is returned.
- If a string or a collection of strings:
- Strings starting with "re:" are treated as regex patterns.
If a regex filter is provided, then it must be the only filter.
- Strings starting with "!" are treated as problem types to exclude.
- Other strings are treated as exact problem types to include.
- Mixing inclusion and exclusion rules (e.g. ["type_a", "!type_b"])
is not allowed.
Returns:
The filtered dataset.
"""
if problem_types is None:
return dataset
if isinstance(problem_types, str): # Assume single problem type as a string
problem_types = [problem_types]
problem_types = {pt.strip() for pt in problem_types}
columns = (
next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset
).column_names
# ether0-benchmark uses 'problem_type'; some variants may use 'type'
type_col = "problem_type" if "problem_type" in columns else "type"
if any(pt.startswith("re:") for pt in problem_types):
# A regex was passed in
if len(problem_types) != 1:
raise ValueError(
"If filtering by regex, only one filter is supported,"
f" passed {problem_types}."
)
regex = re.compile(next(iter(problem_types)).removeprefix("re:"))
def filter_func(x):
return regex.match(x[type_col]) is not None
else:
# Treat as exact string match
valid_problem_types = {pt for pt in problem_types if not pt.startswith("!")}
invalid_problem_types = {
pt.removeprefix("!") for pt in problem_types if pt.startswith("!")
}
if valid_problem_types:
if invalid_problem_types:
raise ValueError(
"Cannot specify both problem types to keep and to exclude,"
f" passed {problem_types}."
)
def filter_func(x):
return x[type_col] in valid_problem_types
else:
def filter_func(x):
return x[type_col] not in invalid_problem_types
return dataset.filter(filter_func, desc="Filtering problem types")
|