Spaces:
Paused
Paused
| """Create and return a model.""" | |
| import os | |
| import re | |
| from platform import node | |
| from get_gemini_keys import get_gemini_keys | |
| from loguru import logger | |
| from smolagents import HfApiModel, LiteLLMRouterModel | |
| def get_model(cat: str = "hf", provider=None, model_id=None): | |
| """ | |
| Create and return a model. | |
| Args: | |
| cat: category | |
| provider: for HfApiModel (cat='hf') | |
| model_id: model name | |
| if no gemini_api_keys, return HfApiModel() | |
| """ | |
| if cat.lower() in ["gemini"]: | |
| # get gemini_api_keys | |
| # dedup | |
| _ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", "")) | |
| gemini_api_keys = dict.fromkeys(get_gemini_keys() + _) | |
| # assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again." | |
| if not gemini_api_keys: | |
| logger.warning( | |
| "cat='gemini' but no GEMINI_API_KEYS found, " | |
| " returning HfApiModel()..." | |
| " Set env var GEMINI_API_KEYS and/or .env-gemini " | |
| " with free space gemini-api-keys if you want to try 'gemini' " | |
| ) | |
| return HfApiModel() | |
| # setup proxy for gemini and for golay (local) | |
| if "golay" in node(): | |
| os.environ.update( | |
| HTTPS_PROXY="http://localhost:8081", | |
| HTTP_PROXY="http://localhost:8081", | |
| ALL_PROXY="http://localhost:8081", | |
| NO_PROXY="localhost,127.0.0.1,oracle", | |
| ) | |
| if model_id is None: | |
| model_id = "gemini-2.5-flash-preview-04-17" | |
| # model_id = "gemini-2.5-flash-preview-04-17" | |
| llm_loadbalancer_model_list_gemini = [] | |
| for api_key in gemini_api_keys: | |
| llm_loadbalancer_model_list_gemini.append( | |
| { | |
| "model_name": "model-group-1", | |
| "litellm_params": { | |
| "model": f"gemini/{model_id}", | |
| "api_key": api_key, | |
| }, | |
| }, | |
| ) | |
| model_id = "deepseek-ai/DeepSeek-V3" | |
| llm_loadbalancer_model_list_siliconflow = [ | |
| { | |
| "model_name": "model-group-2", | |
| "litellm_params": { | |
| "model": f"openai/{model_id}", | |
| "api_key": os.getenv("SILICONFLOW_API_KEY"), | |
| "api_base": "https://api.siliconflow.cn/v1", | |
| }, | |
| } | |
| ] | |
| fallbacks = [] | |
| model_list = llm_loadbalancer_model_list_gemini | |
| if os.getenv("SILICONFLOW_API_KEY"): | |
| fallbacks = [{"model-group-1": "model-group-2"}] | |
| model_list += llm_loadbalancer_model_list_siliconflow | |
| model = LiteLLMRouterModel( | |
| model_id="model-group-1", | |
| model_list=model_list, | |
| client_kwargs={ | |
| "routing_strategy": "simple-shuffle", | |
| "num_retries": 3, | |
| # "retry_after": 130, # waits min s before retrying request | |
| "fallbacks": fallbacks, | |
| }, | |
| ) | |
| return model | |
| # if cat.lower() in ["hf"]: default | |
| return HfApiModel(provider=provider, model_id=model_id) | |