Spaces:
Running
Running
| import requests | |
| from pydantic.v1 import SecretStr | |
| from typing_extensions import override | |
| from langflow.base.models.groq_constants import GROQ_MODELS | |
| from langflow.base.models.model import LCModelComponent | |
| from langflow.field_typing import LanguageModel | |
| from langflow.inputs.inputs import HandleInput | |
| from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, SecretStrInput | |
| class GroqModel(LCModelComponent): | |
| display_name: str = "Groq" | |
| description: str = "Generate text using Groq." | |
| icon = "Groq" | |
| name = "GroqModel" | |
| inputs = [ | |
| *LCModelComponent._base_inputs, | |
| SecretStrInput(name="groq_api_key", display_name="Groq API Key", info="API key for the Groq API."), | |
| MessageTextInput( | |
| name="groq_api_base", | |
| display_name="Groq API Base", | |
| info="Base URL path for API requests, leave blank if not using a proxy or service emulator.", | |
| advanced=True, | |
| value="https://api.groq.com", | |
| ), | |
| IntInput( | |
| name="max_tokens", | |
| display_name="Max Output Tokens", | |
| info="The maximum number of tokens to generate.", | |
| advanced=True, | |
| ), | |
| FloatInput( | |
| name="temperature", | |
| display_name="Temperature", | |
| info="Run inference with this temperature. Must by in the closed interval [0.0, 1.0].", | |
| value=0.1, | |
| ), | |
| IntInput( | |
| name="n", | |
| display_name="N", | |
| info="Number of chat completions to generate for each prompt. " | |
| "Note that the API may not return the full n completions if duplicates are generated.", | |
| advanced=True, | |
| ), | |
| DropdownInput( | |
| name="model_name", | |
| display_name="Model", | |
| info="The name of the model to use.", | |
| options=GROQ_MODELS, | |
| value="llama-3.1-8b-instant", | |
| refresh_button=True, | |
| ), | |
| HandleInput( | |
| name="output_parser", | |
| display_name="Output Parser", | |
| info="The parser to use to parse the output of the model", | |
| advanced=True, | |
| input_types=["OutputParser"], | |
| ), | |
| ] | |
| def get_models(self) -> list[str]: | |
| api_key = self.groq_api_key | |
| base_url = self.groq_api_base or "https://api.groq.com" | |
| url = f"{base_url}/openai/v1/models" | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| try: | |
| response = requests.get(url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| model_list = response.json() | |
| return [model["id"] for model in model_list.get("data", [])] | |
| except requests.RequestException as e: | |
| self.status = f"Error fetching models: {e}" | |
| return GROQ_MODELS | |
| def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): | |
| if field_name in {"groq_api_key", "groq_api_base", "model_name"}: | |
| models = self.get_models() | |
| build_config["model_name"]["options"] = models | |
| return build_config | |
| def build_model(self) -> LanguageModel: # type: ignore[type-var] | |
| try: | |
| from langchain_groq import ChatGroq | |
| except ImportError as e: | |
| msg = "langchain-groq is not installed. Please install it with `pip install langchain-groq`." | |
| raise ImportError(msg) from e | |
| groq_api_key = self.groq_api_key | |
| model_name = self.model_name | |
| max_tokens = self.max_tokens | |
| temperature = self.temperature | |
| groq_api_base = self.groq_api_base | |
| n = self.n | |
| stream = self.stream | |
| return ChatGroq( | |
| model=model_name, | |
| max_tokens=max_tokens or None, | |
| temperature=temperature, | |
| base_url=groq_api_base, | |
| n=n or 1, | |
| api_key=SecretStr(groq_api_key).get_secret_value(), | |
| streaming=stream, | |
| ) | |