File size: 8,147 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import logging
import threading
import time
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI, Stream

from camel.configs import SGLANG_API_PARAMS, SGLangConfig
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import (
    ChatCompletion,
    ChatCompletionChunk,
    ModelType,
)
from camel.utils import BaseTokenCounter, OpenAITokenCounter


class SGLangModel(BaseModelBackend):
    r"""SGLang service interface.

    Args:
        model_type (Union[ModelType, str]): Model for which a backend is
            created.
        model_config_dict (Optional[Dict[str, Any]], optional): A dictionary
            that will be fed into:obj:`openai.ChatCompletion.create()`. If
            :obj:`None`, :obj:`SGLangConfig().as_dict()` will be used.
            (default: :obj:`None`)
        api_key (Optional[str], optional): The API key for authenticating with
            the model service. SGLang doesn't need API key, it would be ignored
            if set. (default: :obj:`None`)
        url (Optional[str], optional): The url to the model service. If not
            provided, :obj:`"http://127.0.0.1:30000/v1"` will be used.
            (default: :obj:`None`)
        token_counter (Optional[BaseTokenCounter], optional): Token counter to
            use for the model. If not provided, :obj:`OpenAITokenCounter(
            ModelType.GPT_4O_MINI)` will be used.
            (default: :obj:`None`)

    Reference: https://sgl-project.github.io/backend/openai_api_completions.html
    """

    def __init__(
        self,
        model_type: Union[ModelType, str],
        model_config_dict: Optional[Dict[str, Any]] = None,
        api_key: Optional[str] = None,
        url: Optional[str] = None,
        token_counter: Optional[BaseTokenCounter] = None,
    ) -> None:
        if model_config_dict is None:
            model_config_dict = SGLangConfig().as_dict()

        self.server_process = None
        self.last_run_time: Optional[float] = (
            None  # Will be set when the server starts
        )
        self._lock = threading.Lock()
        self._inactivity_thread: Optional[threading.Thread] = None

        super().__init__(
            model_type, model_config_dict, api_key, url, token_counter
        )

        self._client = None

        if self._url:
            # Initialize the client if an existing URL is provided
            self._client = OpenAI(
                timeout=180,
                max_retries=3,
                api_key="Set-but-ignored",  # required but ignored
                base_url=self._url,
            )

    def _start_server(self) -> None:
        from sglang.utils import (  # type: ignore[import-untyped]
            execute_shell_command,
            wait_for_server,
        )

        try:
            if not self._url:
                cmd = (
                    f"python -m sglang.launch_server "
                    f"--model-path {self.model_type} "
                    f"--port 30000 "
                    f"--host 0.0.0.0"
                )

                server_process = execute_shell_command(cmd)
                wait_for_server("http://localhost:30000")
                self._url = "http://127.0.0.1:30000/v1"
                self.server_process = server_process
                # Start the inactivity monitor in a background thread
                self._inactivity_thread = threading.Thread(
                    target=self._monitor_inactivity, daemon=True
                )
                self._inactivity_thread.start()
            self.last_run_time = time.time()
            # Initialize the client after the server starts
            self._client = OpenAI(
                timeout=180,
                max_retries=3,
                api_key="Set-but-ignored",  # required but ignored
                base_url=self._url,
            )
        except Exception as e:
            raise RuntimeError(f"Failed to start SGLang server: {e}") from e

    def _ensure_server_running(self) -> None:
        r"""Ensures that the server is running. If not, starts the server."""
        with self._lock:
            if self.server_process is None:
                self._start_server()

    def _monitor_inactivity(self):
        r"""Monitor whether the server process has been inactive for over 10
        minutes.
        """
        from sglang.utils import terminate_process

        while True:
            # Check every 10 seconds
            time.sleep(10)
            # Over 10 minutes
            with self._lock:
                # Over 10 minutes
                if self.last_run_time and (
                    time.time() - self.last_run_time > 600
                ):
                    if self.server_process:
                        terminate_process(self.server_process)
                        self.server_process = None
                        self._client = None  # Invalidate the client
                        logging.info(
                            "Server process terminated due to inactivity."
                        )
                    break

    @property
    def token_counter(self) -> BaseTokenCounter:
        r"""Initialize the token counter for the model backend.

        Returns:
            BaseTokenCounter: The token counter following the model's
                tokenization style.
        """
        if not self._token_counter:
            self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
        return self._token_counter

    def check_model_config(self):
        r"""Check whether the model configuration contains any
        unexpected arguments to SGLang API.

        Raises:
            ValueError: If the model configuration dictionary contains any
                unexpected arguments to OpenAI API.
        """
        for param in self.model_config_dict:
            if param not in SGLANG_API_PARAMS:
                raise ValueError(
                    f"Unexpected argument `{param}` is "
                    "input into SGLang model backend."
                )

    def run(
        self,
        messages: List[OpenAIMessage],
    ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
        r"""Runs inference of OpenAI chat completion.

        Args:
            messages (List[OpenAIMessage]): Message list with the chat history
                in OpenAI API format.

        Returns:
            Union[ChatCompletion, Stream[ChatCompletionChunk]]:
                `ChatCompletion` in the non-stream mode, or
                `Stream[ChatCompletionChunk]` in the stream mode.
        """

        # Ensure server is running
        self._ensure_server_running()

        with self._lock:
            # Update last run time
            self.last_run_time = time.time()

        if self._client is None:
            raise RuntimeError(
                "Client is not initialized. Ensure the server is running."
            )

        response = self._client.chat.completions.create(
            messages=messages,
            model=self.model_type,
            **self.model_config_dict,
        )

        return response

    @property
    def stream(self) -> bool:
        r"""Returns whether the model is in stream mode, which sends partial
        results each time.

        Returns:
            bool: Whether the model is in stream mode.
        """
        return self.model_config_dict.get('stream', False)