Spaces:
Running
Running
| from typing import Any | |
| from urllib.parse import urljoin | |
| import httpx | |
| from langchain_openai import ChatOpenAI | |
| from pydantic.v1 import SecretStr | |
| from typing_extensions import override | |
| from langflow.base.models.model import LCModelComponent | |
| from langflow.field_typing import LanguageModel | |
| from langflow.field_typing.range_spec import RangeSpec | |
| from langflow.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput | |
| from langflow.inputs.inputs import HandleInput | |
| class LMStudioModelComponent(LCModelComponent): | |
| display_name = "LM Studio" | |
| description = "Generate text using LM Studio Local LLMs." | |
| icon = "LMStudio" | |
| name = "LMStudioModel" | |
| def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): | |
| if field_name == "model_name": | |
| base_url_dict = build_config.get("base_url", {}) | |
| base_url_load_from_db = base_url_dict.get("load_from_db", False) | |
| base_url_value = base_url_dict.get("value") | |
| if base_url_load_from_db: | |
| base_url_value = self.variables(base_url_value) | |
| elif not base_url_value: | |
| base_url_value = "http://localhost:1234/v1" | |
| build_config["model_name"]["options"] = self.get_model(base_url_value) | |
| return build_config | |
| def get_model(self, base_url_value: str) -> list[str]: | |
| try: | |
| url = urljoin(base_url_value, "/v1/models") | |
| with httpx.Client() as client: | |
| response = client.get(url) | |
| response.raise_for_status() | |
| data = response.json() | |
| return [model["id"] for model in data.get("data", [])] | |
| except Exception as e: | |
| msg = "Could not retrieve models. Please, make sure the LM Studio server is running." | |
| raise ValueError(msg) from e | |
| inputs = [ | |
| *LCModelComponent._base_inputs, | |
| IntInput( | |
| name="max_tokens", | |
| display_name="Max Tokens", | |
| advanced=True, | |
| info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.", | |
| range_spec=RangeSpec(min=0, max=128000), | |
| ), | |
| DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True), | |
| DropdownInput( | |
| name="model_name", | |
| display_name="Model Name", | |
| advanced=False, | |
| refresh_button=True, | |
| ), | |
| StrInput( | |
| name="base_url", | |
| display_name="Base URL", | |
| advanced=False, | |
| info="Endpoint of the LM Studio API. Defaults to 'http://localhost:1234/v1' if not specified.", | |
| value="http://localhost:1234/v1", | |
| ), | |
| SecretStrInput( | |
| name="api_key", | |
| display_name="LM Studio API Key", | |
| info="The LM Studio API Key to use for LM Studio.", | |
| advanced=True, | |
| value="LMSTUDIO_API_KEY", | |
| ), | |
| FloatInput(name="temperature", display_name="Temperature", value=0.1), | |
| IntInput( | |
| name="seed", | |
| display_name="Seed", | |
| info="The seed controls the reproducibility of the job.", | |
| advanced=True, | |
| value=1, | |
| ), | |
| HandleInput( | |
| name="output_parser", | |
| display_name="Output Parser", | |
| info="The parser to use to parse the output of the model", | |
| advanced=True, | |
| input_types=["OutputParser"], | |
| ), | |
| ] | |
| def build_model(self) -> LanguageModel: # type: ignore[type-var] | |
| lmstudio_api_key = self.api_key | |
| temperature = self.temperature | |
| model_name: str = self.model_name | |
| max_tokens = self.max_tokens | |
| model_kwargs = self.model_kwargs or {} | |
| base_url = self.base_url or "http://localhost:1234/v1" | |
| seed = self.seed | |
| api_key = SecretStr(lmstudio_api_key) if lmstudio_api_key else None | |
| return ChatOpenAI( | |
| max_tokens=max_tokens or None, | |
| model_kwargs=model_kwargs, | |
| model=model_name, | |
| base_url=base_url, | |
| api_key=api_key, | |
| temperature=temperature if temperature is not None else 0.1, | |
| seed=seed, | |
| ) | |
| def _get_exception_message(self, e: Exception): | |
| """Get a message from an LM Studio exception. | |
| Args: | |
| e (Exception): The exception to get the message from. | |
| Returns: | |
| str: The message from the exception. | |
| """ | |
| try: | |
| from openai import BadRequestError | |
| except ImportError: | |
| return None | |
| if isinstance(e, BadRequestError): | |
| message = e.body.get("message") | |
| if message: | |
| return message | |
| return None | |