File size: 6,092 Bytes
4c346eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import re
from collections.abc import Callable, MutableMapping
from dataclasses import dataclass, field
from functools import wraps
from itertools import starmap
from typing import Any, ParamSpec, TypeVar, cast

from ether0.model_prompts import (
    ANSWER_END,
    ANSWER_START,
    THINK_END,
    THINK_START,
    ProblemPrompt,
    SysPrompt,
    extract_answer_loose,
)
from ether0.rewards import accuracy_reward, format_reward

P = ParamSpec("P")
R = TypeVar("R")


def wrap_reward_func(func: Callable[P, R], **wrap_kwargs: Any) -> Callable[P, R]:
    @wraps(func)  # needed by GRPOTrainer for logging
    def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
        return func(*args, **wrap_kwargs, **kwargs)

    return wrapped


@dataclass
class ChatArguments:
    """Arguments for making a chat conversation for SFT or RL training."""

    sys_prompt: SysPrompt | None = field(
        default=None,
        metadata={
            "help": (
                "If provided, use this system prompt. If not provided, the chat"
                " template may inject one."
            )
        },
    )

    problem_prompt: ProblemPrompt = field(
        default=ProblemPrompt.NONE,
        metadata={
            "help": (
                "Prompt to put before the problem in the first user message, relevant"
                " for both RL or SFT. Make sure this matches between SFT and RL, so if"
                " the SFT'd model wasn't passed this during SFT, don't pass this to RL."
            )
        },
    )

    reasoning: bool = field(
        default=True,
        metadata={
            "help": (
                "If True (default), it is assumed that the model's response contains"
                f" reasoning enclosed in `{THINK_START}` and `{THINK_END}`."
            )
        },
    )

    def make_rl_conversation(
        self, row: MutableMapping[str, str | list[str]]
    ) -> dict[str, list[dict] | list[list[dict]]]:
        """Format a dataset row into a chat-like conversation structure.

        This will add a `messages` key to the dataset. Unlike make_sft_convo,
        the answer will not be included.
        """
        if not self.sys_prompt:
            msgs: list[dict] = []
        else:
            msgs = [{
                "role": "system",
                "content": SysPrompt(self.sys_prompt).get_sys_prompt(),
            }]
        problem_prompt = ProblemPrompt(self.problem_prompt).get_prompt()
        if problem_prompt:
            problem_prompt += "\n\n"

        def add_user(problem: str) -> list[dict]:
            return [*msgs, {"role": "user", "content": problem_prompt + problem}]

        if isinstance(row["problem"], str):  # Single
            all_msgs: list[dict] | list[list[dict]] = add_user(row["problem"])
        else:  # Batched
            all_msgs = [add_user(p) for p in row["problem"]]
        return {"prompt": all_msgs}

    def make_sft_conversation(
        self, row: MutableMapping[str, str | list[str]]
    ) -> dict[str, list[dict] | list[list[dict]]]:
        """Format a dataset row into a chat-like conversation structure.

        This will add a `messages` key to the dataset.
        """
        if (
            self.reasoning
            and ProblemPrompt(self.problem_prompt) == ProblemPrompt.ANSWER
        ):
            raise ValueError(
                "It does not make sense to include reasoning in the SFT traces,"
                " but then only prompt about answer XML (without thoughts)."
            )

        def add_assistant(
            raw_answer: str, thought: str, prior_msgs: list[dict]
        ) -> list[dict]:
            if re.search(r"<\/answer>", raw_answer):
                # Remove prelude and postlude plus XML tags,
                # because an OpenRouter-hosted DeepSeek R1 can give answer
                # with a prelude and XML tags, but our training expects just an answer
                # > The reaction involves sodium borohydride ([BH4-].[Na+]), <redacted>.
                # > Under these conditions, <redacted>.
                # > <answer>N1(CCOCC1)C1=CC=C(C(O))C=C1</answer>
                answer = extract_answer_loose(raw_answer)
                if not answer:
                    raise ValueError(
                        "Failed to extract just the answer from the answer"
                        f" {raw_answer!r}."
                    )
            else:
                answer = raw_answer

            return [
                *prior_msgs,
                {
                    "role": "assistant",
                    "content": (
                        (f"{THINK_START}{thought}{THINK_END}" if self.reasoning else "")
                        + f"{ANSWER_START}{answer}{ANSWER_END}"
                    ),
                },
            ]

        # The first part will be the same as the RL conversation
        msgs = self.make_rl_conversation(row)["prompt"]
        # Now add the answer, with optional thinking
        if isinstance(row["problem"], str):  # Single
            all_msgs: list[dict] | list[list[dict]] = add_assistant(
                cast(str, row["answer"]),
                cast(str, row["thought"]),
                cast(list[dict], msgs),
            )
        else:  # Batched
            all_msgs = list(
                starmap(
                    add_assistant, zip(row["answer"], row["thought"], msgs, strict=True)
                )
            )
        return {"messages": all_msgs}

    def get_reward_funcs(
        self,
        format_reward_value: float = 1.0,
        soft: bool = False,
        test: bool = False,
        good_molecule_bonus: float = 0.0,
    ) -> list[Callable]:
        return [
            wrap_reward_func(
                format_reward,
                reasoning=self.reasoning,
                reward=format_reward_value,
            ),
            wrap_reward_func(
                accuracy_reward,
                reasoning=self.reasoning,
                soft=soft,
                test=test,
                good_molecule_bonus=good_molecule_bonus,
            ),
        ]