Spaces:
Paused
Paused
| """Create and return a model.""" | |
| # ruff: noqa: F841 | |
| import os | |
| import re | |
| from platform import node | |
| from loguru import logger | |
| from smolagents import InferenceClientModel as HfApiModel | |
| from smolagents import LiteLLMRouterModel, OpenAIServerModel | |
| # FutureWarning: HfApiModel was renamed to InferenceClientModel in version 1.14.0 and will be removed in 1.17.0. | |
| from get_gemini_keys import get_gemini_keys | |
| def get_model(cat: str = "hf", provider=None, model_id=None): | |
| """ | |
| Create and return a model. | |
| Args: | |
| cat: category, hf, gemin, llama (default and fallback: hf) | |
| provider: for HfApiModel (cat='hf') | |
| model_id: model name | |
| if no gemini_api_keys, return HfApiModel() | |
| """ | |
| if cat.lower() in ["hf"]: | |
| logger.info(" usiing HfApiModel, make sure you set HF_TOKEN") | |
| return HfApiModel(provider=provider, model_id=model_id) | |
| # setup proxy for gemini and for golay (local tetsin) | |
| if "golay" in node() and cat.lower() in ["gemini", "llama"]: | |
| os.environ.update( | |
| HTTPS_PROXY="http://localhost:8081", | |
| HTTP_PROXY="http://localhost:8081", | |
| ALL_PROXY="http://localhost:8081", | |
| NO_PROXY="localhost,127.0.0.1", | |
| ) | |
| if cat.lower() in ["gemini"]: | |
| # get gemini_api_keys | |
| # dedup | |
| _ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", "")) | |
| gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)] | |
| # assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again." | |
| if not gemini_api_keys: | |
| logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ") | |
| logger.info(" set gemini but return HfApiModel()") | |
| return HfApiModel() | |
| if model_id is None: | |
| model_id = "gemini-2.5-flash-preview-04-17" | |
| # model_id = "gemini-2.5-flash-preview-04-17" | |
| llm_loadbalancer_model_list_gemini = [] | |
| for api_key in gemini_api_keys: | |
| llm_loadbalancer_model_list_gemini.append( | |
| { | |
| "model_name": "model-group-1", | |
| "litellm_params": { | |
| "model": f"gemini/{model_id}", | |
| "api_key": api_key, | |
| }, | |
| }, | |
| ) | |
| model_id = "deepseek-ai/DeepSeek-V3" | |
| llm_loadbalancer_model_list_siliconflow = [ | |
| { | |
| "model_name": "model-group-2", | |
| "litellm_params": { | |
| "model": f"openai/{model_id}", | |
| "api_key": os.getenv("SILICONFLOW_API_KEY"), | |
| "api_base": "https://api.siliconflow.cn/v1", | |
| }, | |
| }, | |
| ] | |
| # gemma-3-27b-it | |
| llm_loadbalancer_model_list_gemma = [ | |
| { | |
| "model_name": "model-group-3", | |
| "litellm_params": { | |
| "model": "gemini/gemma-3-27b-it", | |
| "api_key": os.getenv("GEMINI_API_KEY") }, | |
| }, | |
| ] | |
| fallbacks = [] | |
| model_list = llm_loadbalancer_model_list_gemini | |
| if os.getenv("SILICONFLOW_API_KEY"): | |
| fallbacks = [{"model-group-1": "model-group-2"}] | |
| model_list += llm_loadbalancer_model_list_siliconflow | |
| model_list += llm_loadbalancer_model_list_gemma | |
| fallbacks13 = [{"model-group-1": "model-group-3"}] | |
| fallbacks31 = [{"model-group-3": "model-group-1"}] | |
| model = LiteLLMRouterModel( | |
| model_id="model-group-1", | |
| model_list=model_list, | |
| client_kwargs={ | |
| "routing_strategy": "simple-shuffle", | |
| "num_retries": 3, | |
| "retry_after": 180, # waits min s before retrying request | |
| "fallbacks": fallbacks13, # falllacks dont seem to work | |
| }, | |
| ) | |
| if os.getenv("SILICONFLOW_API_KEY"): | |
| logger.info(" set gemini, return LiteLLMRouterModel + fallbacks") | |
| else: | |
| logger.info(" set gemini, return LiteLLMRouterModel") | |
| return model | |
| if cat.lower() in ["llama"]: | |
| api_key = os.getenv("LLAMA_API_KEY") | |
| if api_key is None: | |
| logger.warning(" LLAMA_API_EY not set, using HfApiModel(), make sure you set HF_TOKEN") | |
| return HfApiModel() | |
| # default model_id | |
| if model_id is None: | |
| model_id = "Llama-4-Maverick-17B-128E-Instruct-FP8" | |
| model_id = "Llama-4-Scout-17B-16E-Instruct-FP8" | |
| model_llama = OpenAIServerModel( | |
| model_id, | |
| api_base="https://api.llama.com/compat/v1", | |
| api_key=api_key, | |
| # temperature=0., | |
| ) | |
| return model_llama | |
| logger.info(" default return default HfApiModel(provider=None, model_id=None)") | |
| # if cat.lower() in ["hf"]: default | |
| return HfApiModel(provider=provider, model_id=model_id) | |