Spaces:
Paused
Paused
| import sys, os, platform, time, copy, re, asyncio, inspect | |
| import threading, ast | |
| import shutil, random, traceback, requests | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Optional, List | |
| import secrets, subprocess | |
| import hashlib, uuid | |
| import warnings | |
| import importlib | |
| messages: list = [] | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path - for litellm local dev | |
| sample = """ | |
| from openai import OpenAI | |
| import json | |
| base_url = "https://ka1kuk-litellm.hf.space" | |
| api_key = "hf_xxxx" | |
| client = OpenAI(base_url=base_url, api_key=api_key) | |
| messages = [{"role": "user", "content": "What's the capital of France?"}] | |
| response = client.chat.completions.create( | |
| model="huggingface/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| response_format={ "type": "json_object" }, | |
| messages=messages, | |
| stream=False, | |
| ) | |
| print(response.choices[0].message.content) | |
| """ | |
| description = f"Proxy Server to call 100+ LLMs in the OpenAI format\n\nSample with openai library:\n\n{sample}" | |
| try: | |
| import fastapi | |
| import backoff | |
| import yaml | |
| import orjson | |
| import logging | |
| except ImportError as e: | |
| raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`") | |
| import litellm | |
| from litellm.proxy.utils import ( | |
| PrismaClient, | |
| DBClient, | |
| get_instance_fn, | |
| ProxyLogging, | |
| _cache_user_row, | |
| send_email, | |
| ) | |
| from litellm.proxy.secret_managers.google_kms import load_google_kms | |
| import pydantic | |
| from litellm.proxy._types import * | |
| from litellm.caching import DualCache | |
| from litellm.proxy.health_check import perform_health_check | |
| from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
| litellm.suppress_debug_info = True | |
| from fastapi import ( | |
| FastAPI, | |
| Request, | |
| HTTPException, | |
| status, | |
| Depends, | |
| BackgroundTasks, | |
| Header, | |
| Response, | |
| ) | |
| from fastapi.routing import APIRouter | |
| from fastapi.security import OAuth2PasswordBearer | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security.api_key import APIKeyHeader | |
| import json | |
| import logging | |
| from typing import Union | |
| app = FastAPI( | |
| docs_url="/", | |
| title="LiteLLM API", | |
| description= description, | |
| ) | |
| router = APIRouter() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| from typing import Dict | |
| api_key_header = APIKeyHeader(name="Authorization", auto_error=False) | |
| user_api_base = None | |
| user_model = None | |
| user_debug = False | |
| user_max_tokens = None | |
| user_request_timeout = None | |
| user_temperature = None | |
| user_telemetry = True | |
| user_config = None | |
| user_headers = None | |
| user_config_file_path = f"config_{int(time.time())}.yaml" | |
| local_logging = True # writes logs to a local api_log.json file for debugging | |
| experimental = False | |
| #### GLOBAL VARIABLES #### | |
| llm_router: Optional[litellm.Router] = None | |
| llm_model_list: Optional[list] = None | |
| general_settings: dict = {} | |
| log_file = "api_log.json" | |
| worker_config = None | |
| master_key = None | |
| otel_logging = False | |
| prisma_client: Optional[PrismaClient] = None | |
| custom_db_client: Optional[DBClient] = None | |
| user_api_key_cache = DualCache() | |
| user_custom_auth = None | |
| use_background_health_checks = None | |
| use_queue = False | |
| health_check_interval = None | |
| health_check_results = {} | |
| queue: List = [] | |
| ### INITIALIZE GLOBAL LOGGING OBJECT ### | |
| proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) | |
| ### REDIS QUEUE ### | |
| async_result = None | |
| celery_app_conn = None | |
| celery_fn = None # Redis Queue for handling requests | |
| ### logger ### | |
| def usage_telemetry( | |
| feature: str, | |
| ): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off | |
| if user_telemetry: | |
| data = {"feature": feature} # "local_proxy_server" | |
| threading.Thread( | |
| target=litellm.utils.litellm_telemetry, args=(data,), daemon=True | |
| ).start() | |
| def _get_bearer_token(api_key: str): | |
| assert api_key.startswith("Bearer ") # ensure Bearer token passed in | |
| api_key = api_key.replace("Bearer ", "") # extract the token | |
| return api_key | |
| def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict: | |
| try: | |
| return pydantic_obj.model_dump() # type: ignore | |
| except: | |
| # if using pydantic v1 | |
| return pydantic_obj.dict() | |
| async def user_api_key_auth( | |
| request: Request, api_key: str = fastapi.Security(api_key_header) | |
| ) -> UserAPIKeyAuth: | |
| global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client | |
| try: | |
| if isinstance(api_key, str): | |
| api_key = _get_bearer_token(api_key=api_key) | |
| ### USER-DEFINED AUTH FUNCTION ### | |
| if user_custom_auth is not None: | |
| response = await user_custom_auth(request=request, api_key=api_key) | |
| return UserAPIKeyAuth.model_validate(response) | |
| ### LITELLM-DEFINED AUTH FUNCTION ### | |
| if master_key is None: | |
| if isinstance(api_key, str): | |
| return UserAPIKeyAuth(api_key=api_key) | |
| else: | |
| return UserAPIKeyAuth() | |
| route: str = request.url.path | |
| if route == "/user/auth": | |
| if general_settings.get("allow_user_auth", False) == True: | |
| return UserAPIKeyAuth() | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="'allow_user_auth' not set or set to False", | |
| ) | |
| if api_key is None: # only require api key if master key is set | |
| raise Exception(f"No api key passed in.") | |
| # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead | |
| is_master_key_valid = secrets.compare_digest(api_key, master_key) | |
| if is_master_key_valid: | |
| return UserAPIKeyAuth(api_key=master_key) | |
| if route.startswith("/config/") and not is_master_key_valid: | |
| raise Exception(f"Only admin can modify config") | |
| if ( | |
| (route.startswith("/key/") or route.startswith("/user/")) | |
| or route.startswith("/model/") | |
| and not is_master_key_valid | |
| and general_settings.get("allow_user_auth", False) != True | |
| ): | |
| raise Exception( | |
| f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users" | |
| ) | |
| if ( | |
| prisma_client is None and custom_db_client is None | |
| ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error | |
| raise Exception("No connected db.") | |
| ## check for cache hit (In-Memory Cache) | |
| valid_token = user_api_key_cache.get_cache(key=api_key) | |
| verbose_proxy_logger.debug(f"valid_token from cache: {valid_token}") | |
| if valid_token is None: | |
| ## check db | |
| verbose_proxy_logger.debug(f"api key: {api_key}") | |
| if prisma_client is not None: | |
| valid_token = await prisma_client.get_data( | |
| token=api_key, | |
| ) | |
| expires = datetime.utcnow().replace(tzinfo=timezone.utc) | |
| elif custom_db_client is not None: | |
| valid_token = await custom_db_client.get_data( | |
| key=api_key, table_name="key" | |
| ) | |
| # Token exists, now check expiration. | |
| if valid_token.expires is not None: | |
| expiry_time = datetime.fromisoformat(valid_token.expires) | |
| if expiry_time >= datetime.utcnow(): | |
| # Token exists and is not expired. | |
| return response | |
| else: | |
| # Token exists but is expired. | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="expired user key", | |
| ) | |
| verbose_proxy_logger.debug(f"valid token from prisma: {valid_token}") | |
| user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) | |
| elif valid_token is not None: | |
| verbose_proxy_logger.debug(f"API Key Cache Hit!") | |
| if valid_token: | |
| litellm.model_alias_map = valid_token.aliases | |
| config = valid_token.config | |
| if config != {}: | |
| model_list = config.get("model_list", []) | |
| llm_model_list = model_list | |
| verbose_proxy_logger.debug( | |
| f"\n new llm router model list {llm_model_list}" | |
| ) | |
| if ( | |
| len(valid_token.models) == 0 | |
| ): # assume an empty model list means all models are allowed to be called | |
| pass | |
| else: | |
| try: | |
| data = await request.json() | |
| except json.JSONDecodeError: | |
| data = {} # Provide a default value, such as an empty dictionary | |
| model = data.get("model", None) | |
| if model in litellm.model_alias_map: | |
| model = litellm.model_alias_map[model] | |
| if model and model not in valid_token.models: | |
| raise Exception(f"Token not allowed to access model") | |
| api_key = valid_token.token | |
| valid_token_dict = _get_pydantic_json_dict(valid_token) | |
| valid_token_dict.pop("token", None) | |
| """ | |
| asyncio create task to update the user api key cache with the user db table as well | |
| This makes the user row data accessible to pre-api call hooks. | |
| """ | |
| if prisma_client is not None: | |
| asyncio.create_task( | |
| _cache_user_row( | |
| user_id=valid_token.user_id, | |
| cache=user_api_key_cache, | |
| db=prisma_client, | |
| ) | |
| ) | |
| elif custom_db_client is not None: | |
| asyncio.create_task( | |
| _cache_user_row( | |
| user_id=valid_token.user_id, | |
| cache=user_api_key_cache, | |
| db=custom_db_client, | |
| ) | |
| ) | |
| return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) | |
| else: | |
| raise Exception(f"Invalid token") | |
| except Exception as e: | |
| # verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}") | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="invalid user key", | |
| ) | |
| def prisma_setup(database_url: Optional[str]): | |
| global prisma_client, proxy_logging_obj, user_api_key_cache | |
| if database_url is not None: | |
| try: | |
| prisma_client = PrismaClient( | |
| database_url=database_url, proxy_logging_obj=proxy_logging_obj | |
| ) | |
| except Exception as e: | |
| raise e | |
| def load_from_azure_key_vault(use_azure_key_vault: bool = False): | |
| if use_azure_key_vault is False: | |
| return | |
| try: | |
| from azure.keyvault.secrets import SecretClient | |
| from azure.identity import ClientSecretCredential | |
| # Set your Azure Key Vault URI | |
| KVUri = os.getenv("AZURE_KEY_VAULT_URI", None) | |
| # Set your Azure AD application/client ID, client secret, and tenant ID | |
| client_id = os.getenv("AZURE_CLIENT_ID", None) | |
| client_secret = os.getenv("AZURE_CLIENT_SECRET", None) | |
| tenant_id = os.getenv("AZURE_TENANT_ID", None) | |
| if ( | |
| KVUri is not None | |
| and client_id is not None | |
| and client_secret is not None | |
| and tenant_id is not None | |
| ): | |
| # Initialize the ClientSecretCredential | |
| credential = ClientSecretCredential( | |
| client_id=client_id, client_secret=client_secret, tenant_id=tenant_id | |
| ) | |
| # Create the SecretClient using the credential | |
| client = SecretClient(vault_url=KVUri, credential=credential) | |
| litellm.secret_manager_client = client | |
| litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT | |
| else: | |
| raise Exception( | |
| f"Missing KVUri or client_id or client_secret or tenant_id from environment" | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.debug( | |
| "Error when loading keys from Azure Key Vault. Ensure you run `pip install azure-identity azure-keyvault-secrets`" | |
| ) | |
| def cost_tracking(): | |
| global prisma_client, custom_db_client | |
| if prisma_client is not None or custom_db_client is not None: | |
| if isinstance(litellm.success_callback, list): | |
| verbose_proxy_logger.debug("setting litellm success callback to track cost") | |
| if (track_cost_callback) not in litellm.success_callback: # type: ignore | |
| litellm.success_callback.append(track_cost_callback) # type: ignore | |
| async def track_cost_callback( | |
| kwargs, # kwargs to completion | |
| completion_response: litellm.ModelResponse, # response from completion | |
| start_time=None, | |
| end_time=None, # start/end time for completion | |
| ): | |
| global prisma_client, custom_db_client | |
| try: | |
| # check if it has collected an entire stream response | |
| verbose_proxy_logger.debug( | |
| f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" | |
| ) | |
| if "complete_streaming_response" in kwargs: | |
| # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost | |
| completion_response = kwargs["complete_streaming_response"] | |
| response_cost = litellm.completion_cost( | |
| completion_response=completion_response | |
| ) | |
| verbose_proxy_logger.debug(f"streaming response_cost {response_cost}") | |
| user_api_key = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key", None | |
| ) | |
| user_id = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key_user_id", None | |
| ) | |
| if user_api_key and ( | |
| prisma_client is not None or custom_db_client is not None | |
| ): | |
| await update_database(token=user_api_key, response_cost=response_cost) | |
| elif kwargs["stream"] == False: # for non streaming responses | |
| response_cost = litellm.completion_cost( | |
| completion_response=completion_response | |
| ) | |
| user_api_key = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key", None | |
| ) | |
| user_id = kwargs["litellm_params"]["metadata"].get( | |
| "user_api_key_user_id", None | |
| ) | |
| if user_api_key and ( | |
| prisma_client is not None or custom_db_client is not None | |
| ): | |
| await update_database( | |
| token=user_api_key, response_cost=response_cost, user_id=user_id | |
| ) | |
| except Exception as e: | |
| verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") | |
| async def update_database(token, response_cost, user_id=None): | |
| try: | |
| verbose_proxy_logger.debug( | |
| f"Enters prisma db call, token: {token}; user_id: {user_id}" | |
| ) | |
| ### UPDATE USER SPEND ### | |
| async def _update_user_db(): | |
| if user_id is None: | |
| return | |
| if prisma_client is not None: | |
| existing_spend_obj = await prisma_client.get_data(user_id=user_id) | |
| elif custom_db_client is not None: | |
| existing_spend_obj = await custom_db_client.get_data( | |
| key=user_id, table_name="user" | |
| ) | |
| if existing_spend_obj is None: | |
| existing_spend = 0 | |
| else: | |
| existing_spend = existing_spend_obj.spend | |
| # Calculate the new cost by adding the existing cost and response_cost | |
| new_spend = existing_spend + response_cost | |
| verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
| # Update the cost column for the given user id | |
| if prisma_client is not None: | |
| await prisma_client.update_data( | |
| user_id=user_id, data={"spend": new_spend} | |
| ) | |
| elif custom_db_client is not None: | |
| await custom_db_client.update_data( | |
| key=user_id, value={"spend": new_spend}, table_name="user" | |
| ) | |
| ### UPDATE KEY SPEND ### | |
| async def _update_key_db(): | |
| if prisma_client is not None: | |
| # Fetch the existing cost for the given token | |
| existing_spend_obj = await prisma_client.get_data(token=token) | |
| verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") | |
| if existing_spend_obj is None: | |
| existing_spend = 0 | |
| else: | |
| existing_spend = existing_spend_obj.spend | |
| # Calculate the new cost by adding the existing cost and response_cost | |
| new_spend = existing_spend + response_cost | |
| verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
| # Update the cost column for the given token | |
| await prisma_client.update_data(token=token, data={"spend": new_spend}) | |
| elif custom_db_client is not None: | |
| # Fetch the existing cost for the given token | |
| existing_spend_obj = await custom_db_client.get_data( | |
| key=token, table_name="key" | |
| ) | |
| verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") | |
| if existing_spend_obj is None: | |
| existing_spend = 0 | |
| else: | |
| existing_spend = existing_spend_obj.spend | |
| # Calculate the new cost by adding the existing cost and response_cost | |
| new_spend = existing_spend + response_cost | |
| verbose_proxy_logger.debug(f"new cost: {new_spend}") | |
| # Update the cost column for the given token | |
| await custom_db_client.update_data( | |
| key=token, value={"spend": new_spend}, table_name="key" | |
| ) | |
| tasks = [] | |
| tasks.append(_update_user_db()) | |
| tasks.append(_update_key_db()) | |
| await asyncio.gather(*tasks) | |
| except Exception as e: | |
| verbose_proxy_logger.debug( | |
| f"Error updating Prisma database: {traceback.format_exc()}" | |
| ) | |
| pass | |
| def run_ollama_serve(): | |
| try: | |
| command = ["ollama", "serve"] | |
| with open(os.devnull, "w") as devnull: | |
| process = subprocess.Popen(command, stdout=devnull, stderr=devnull) | |
| except Exception as e: | |
| verbose_proxy_logger.debug( | |
| f""" | |
| LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` | |
| """ | |
| ) | |
| async def _run_background_health_check(): | |
| """ | |
| Periodically run health checks in the background on the endpoints. | |
| Update health_check_results, based on this. | |
| """ | |
| global health_check_results, llm_model_list, health_check_interval | |
| while True: | |
| healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
| model_list=llm_model_list | |
| ) | |
| # Update the global variable with the health check results | |
| health_check_results["healthy_endpoints"] = healthy_endpoints | |
| health_check_results["unhealthy_endpoints"] = unhealthy_endpoints | |
| health_check_results["healthy_count"] = len(healthy_endpoints) | |
| health_check_results["unhealthy_count"] = len(unhealthy_endpoints) | |
| await asyncio.sleep(health_check_interval) | |
| class ProxyConfig: | |
| """ | |
| Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic. | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def is_yaml(self, config_file_path: str) -> bool: | |
| if not os.path.isfile(config_file_path): | |
| return False | |
| _, file_extension = os.path.splitext(config_file_path) | |
| return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml" | |
| async def get_config(self, config_file_path: Optional[str] = None) -> dict: | |
| global prisma_client, user_config_file_path | |
| file_path = config_file_path or user_config_file_path | |
| if config_file_path is not None: | |
| user_config_file_path = config_file_path | |
| # Load existing config | |
| ## Yaml | |
| if os.path.exists(f"{file_path}"): | |
| with open(f"{file_path}", "r") as config_file: | |
| config = yaml.safe_load(config_file) | |
| else: | |
| config = { | |
| "model_list": [], | |
| "general_settings": {}, | |
| "router_settings": {}, | |
| "litellm_settings": {}, | |
| } | |
| ## DB | |
| if ( | |
| prisma_client is not None | |
| and litellm.get_secret("SAVE_CONFIG_TO_DB", False) == True | |
| ): | |
| prisma_setup(database_url=None) # in case it's not been connected yet | |
| _tasks = [] | |
| keys = [ | |
| "model_list", | |
| "general_settings", | |
| "router_settings", | |
| "litellm_settings", | |
| ] | |
| for k in keys: | |
| response = prisma_client.get_generic_data( | |
| key="param_name", value=k, table_name="config" | |
| ) | |
| _tasks.append(response) | |
| responses = await asyncio.gather(*_tasks) | |
| return config | |
| async def save_config(self, new_config: dict): | |
| global prisma_client, llm_router, user_config_file_path, llm_model_list, general_settings | |
| # Load existing config | |
| backup_config = await self.get_config() | |
| # Save the updated config | |
| ## YAML | |
| with open(f"{user_config_file_path}", "w") as config_file: | |
| yaml.dump(new_config, config_file, default_flow_style=False) | |
| # update Router - verifies if this is a valid config | |
| try: | |
| ( | |
| llm_router, | |
| llm_model_list, | |
| general_settings, | |
| ) = await proxy_config.load_config( | |
| router=llm_router, config_file_path=user_config_file_path | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| # Revert to old config instead | |
| with open(f"{user_config_file_path}", "w") as config_file: | |
| yaml.dump(backup_config, config_file, default_flow_style=False) | |
| raise HTTPException(status_code=400, detail="Invalid config passed in") | |
| ## DB - writes valid config to db | |
| """ | |
| - Do not write restricted params like 'api_key' to the database | |
| - if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`) | |
| """ | |
| if ( | |
| prisma_client is not None | |
| and litellm.get_secret("SAVE_CONFIG_TO_DB", default_value=False) == True | |
| ): | |
| ### KEY REMOVAL ### | |
| models = new_config.get("model_list", []) | |
| for m in models: | |
| if m.get("litellm_params", {}).get("api_key", None) is not None: | |
| # pop the key | |
| api_key = m["litellm_params"].pop("api_key") | |
| # store in local env | |
| key_name = f"LITELLM_MODEL_KEY_{uuid.uuid4()}" | |
| os.environ[key_name] = api_key | |
| # save the key name (not the value) | |
| m["litellm_params"]["api_key"] = f"os.environ/{key_name}" | |
| await prisma_client.insert_data(data=new_config, table_name="config") | |
| async def load_config( | |
| self, router: Optional[litellm.Router], config_file_path: str | |
| ): | |
| """ | |
| Load config values into proxy global state | |
| """ | |
| global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client | |
| # Load existing config | |
| config = await self.get_config(config_file_path=config_file_path) | |
| ## PRINT YAML FOR CONFIRMING IT WORKS | |
| printed_yaml = copy.deepcopy(config) | |
| printed_yaml.pop("environment_variables", None) | |
| verbose_proxy_logger.debug( | |
| f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" | |
| ) | |
| ## ENVIRONMENT VARIABLES | |
| environment_variables = config.get("environment_variables", None) | |
| if environment_variables: | |
| for key, value in environment_variables.items(): | |
| os.environ[key] = value | |
| ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) | |
| litellm_settings = config.get("litellm_settings", None) | |
| if litellm_settings is None: | |
| litellm_settings = {} | |
| if litellm_settings: | |
| # ANSI escape code for blue text | |
| blue_color_code = "\033[94m" | |
| reset_color_code = "\033[0m" | |
| for key, value in litellm_settings.items(): | |
| if key == "cache": | |
| print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa | |
| from litellm.caching import Cache | |
| cache_params = {} | |
| if "cache_params" in litellm_settings: | |
| cache_params_in_config = litellm_settings["cache_params"] | |
| # overwrie cache_params with cache_params_in_config | |
| cache_params.update(cache_params_in_config) | |
| cache_type = cache_params.get("type", "redis") | |
| verbose_proxy_logger.debug(f"passed cache type={cache_type}") | |
| if cache_type == "redis": | |
| cache_host = litellm.get_secret("REDIS_HOST", None) | |
| cache_port = litellm.get_secret("REDIS_PORT", None) | |
| cache_password = litellm.get_secret("REDIS_PASSWORD", None) | |
| cache_params.update( | |
| { | |
| "type": cache_type, | |
| "host": cache_host, | |
| "port": cache_port, | |
| "password": cache_password, | |
| } | |
| ) | |
| # Assuming cache_type, cache_host, cache_port, and cache_password are strings | |
| print( # noqa | |
| f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}" | |
| ) # noqa | |
| print( # noqa | |
| f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}" | |
| ) # noqa | |
| print( # noqa | |
| f"{blue_color_code}Cache Port:{reset_color_code} {cache_port}" | |
| ) # noqa | |
| print( # noqa | |
| f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}" | |
| ) | |
| print() # noqa | |
| # users can pass os.environ/ variables on the proxy - we should read them from the env | |
| for key, value in cache_params.items(): | |
| if type(value) is str and value.startswith("os.environ/"): | |
| cache_params[key] = litellm.get_secret(value) | |
| ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables | |
| litellm.cache = Cache(**cache_params) | |
| print( # noqa | |
| f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" | |
| ) | |
| elif key == "callbacks": | |
| litellm.callbacks = [ | |
| get_instance_fn(value=value, config_file_path=config_file_path) | |
| ] | |
| verbose_proxy_logger.debug( | |
| f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" | |
| ) | |
| elif key == "post_call_rules": | |
| litellm.post_call_rules = [ | |
| get_instance_fn(value=value, config_file_path=config_file_path) | |
| ] | |
| verbose_proxy_logger.debug( | |
| f"litellm.post_call_rules: {litellm.post_call_rules}" | |
| ) | |
| elif key == "success_callback": | |
| litellm.success_callback = [] | |
| # intialize success callbacks | |
| for callback in value: | |
| # user passed custom_callbacks.async_on_succes_logger. They need us to import a function | |
| if "." in callback: | |
| litellm.success_callback.append( | |
| get_instance_fn(value=callback) | |
| ) | |
| # these are litellm callbacks - "langfuse", "sentry", "wandb" | |
| else: | |
| litellm.success_callback.append(callback) | |
| verbose_proxy_logger.debug( | |
| f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}" | |
| ) | |
| elif key == "failure_callback": | |
| litellm.failure_callback = [] | |
| # intialize success callbacks | |
| for callback in value: | |
| # user passed custom_callbacks.async_on_succes_logger. They need us to import a function | |
| if "." in callback: | |
| litellm.failure_callback.append( | |
| get_instance_fn(value=callback) | |
| ) | |
| # these are litellm callbacks - "langfuse", "sentry", "wandb" | |
| else: | |
| litellm.failure_callback.append(callback) | |
| verbose_proxy_logger.debug( | |
| f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}" | |
| ) | |
| elif key == "cache_params": | |
| # this is set in the cache branch | |
| # see usage here: https://docs.litellm.ai/docs/proxy/caching | |
| pass | |
| else: | |
| setattr(litellm, key, value) | |
| ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging | |
| general_settings = config.get("general_settings", {}) | |
| if general_settings is None: | |
| general_settings = {} | |
| if general_settings: | |
| ### LOAD SECRET MANAGER ### | |
| key_management_system = general_settings.get("key_management_system", None) | |
| if key_management_system is not None: | |
| if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value: | |
| ### LOAD FROM AZURE KEY VAULT ### | |
| load_from_azure_key_vault(use_azure_key_vault=True) | |
| elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value: | |
| ### LOAD FROM GOOGLE KMS ### | |
| load_google_kms(use_google_kms=True) | |
| else: | |
| raise ValueError("Invalid Key Management System selected") | |
| ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms | |
| use_google_kms = general_settings.get("use_google_kms", False) | |
| load_google_kms(use_google_kms=use_google_kms) | |
| ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager | |
| use_azure_key_vault = general_settings.get("use_azure_key_vault", False) | |
| load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault) | |
| ### ALERTING ### | |
| proxy_logging_obj.update_values( | |
| alerting=general_settings.get("alerting", None), | |
| alerting_threshold=general_settings.get("alerting_threshold", 600), | |
| ) | |
| ### CONNECT TO DATABASE ### | |
| database_url = general_settings.get("database_url", None) | |
| if database_url and database_url.startswith("os.environ/"): | |
| verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") | |
| database_url = litellm.get_secret(database_url) | |
| verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") | |
| ### MASTER KEY ### | |
| master_key = general_settings.get( | |
| "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) | |
| ) | |
| if master_key and master_key.startswith("os.environ/"): | |
| master_key = litellm.get_secret(master_key) | |
| ### CUSTOM API KEY AUTH ### | |
| ## pass filepath | |
| custom_auth = general_settings.get("custom_auth", None) | |
| if custom_auth is not None: | |
| user_custom_auth = get_instance_fn( | |
| value=custom_auth, config_file_path=config_file_path | |
| ) | |
| ## dynamodb | |
| database_type = general_settings.get("database_type", None) | |
| if database_type is not None and ( | |
| database_type == "dynamo_db" or database_type == "dynamodb" | |
| ): | |
| database_args = general_settings.get("database_args", None) | |
| custom_db_client = DBClient( | |
| custom_db_args=database_args, custom_db_type=database_type | |
| ) | |
| ## COST TRACKING ## | |
| cost_tracking() | |
| ### BACKGROUND HEALTH CHECKS ### | |
| # Enable background health checks | |
| use_background_health_checks = general_settings.get( | |
| "background_health_checks", False | |
| ) | |
| health_check_interval = general_settings.get("health_check_interval", 300) | |
| router_params: dict = { | |
| "num_retries": 3, | |
| "cache_responses": litellm.cache | |
| != None, # cache if user passed in cache values | |
| } | |
| ## MODEL LIST | |
| model_list = config.get("model_list", None) | |
| if model_list: | |
| router_params["model_list"] = model_list | |
| print( # noqa | |
| f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m" | |
| ) # noqa | |
| for model in model_list: | |
| ### LOAD FROM os.environ/ ### | |
| for k, v in model["litellm_params"].items(): | |
| if isinstance(v, str) and v.startswith("os.environ/"): | |
| model["litellm_params"][k] = litellm.get_secret(v) | |
| print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa | |
| litellm_model_name = model["litellm_params"]["model"] | |
| litellm_model_api_base = model["litellm_params"].get("api_base", None) | |
| if "ollama" in litellm_model_name and litellm_model_api_base is None: | |
| run_ollama_serve() | |
| ## ROUTER SETTINGS (e.g. routing_strategy, ...) | |
| router_settings = config.get("router_settings", None) | |
| if router_settings and isinstance(router_settings, dict): | |
| arg_spec = inspect.getfullargspec(litellm.Router) | |
| # model list already set | |
| exclude_args = { | |
| "self", | |
| "model_list", | |
| } | |
| available_args = [x for x in arg_spec.args if x not in exclude_args] | |
| for k, v in router_settings.items(): | |
| if k in available_args: | |
| router_params[k] = v | |
| router = litellm.Router(**router_params) # type:ignore | |
| return router, model_list, general_settings | |
| proxy_config = ProxyConfig() | |
| async def generate_key_helper_fn( | |
| duration: Optional[str], | |
| models: list, | |
| aliases: dict, | |
| config: dict, | |
| spend: float, | |
| max_budget: Optional[float] = None, | |
| token: Optional[str] = None, | |
| user_id: Optional[str] = None, | |
| user_email: Optional[str] = None, | |
| max_parallel_requests: Optional[int] = None, | |
| metadata: Optional[dict] = {}, | |
| ): | |
| global prisma_client, custom_db_client | |
| if prisma_client is None and custom_db_client is None: | |
| raise Exception( | |
| f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " | |
| ) | |
| if token is None: | |
| token = f"sk-{secrets.token_urlsafe(16)}" | |
| def _duration_in_seconds(duration: str): | |
| match = re.match(r"(\d+)([smhd]?)", duration) | |
| if not match: | |
| raise ValueError("Invalid duration format") | |
| value, unit = match.groups() | |
| value = int(value) | |
| if unit == "s": | |
| return value | |
| elif unit == "m": | |
| return value * 60 | |
| elif unit == "h": | |
| return value * 3600 | |
| elif unit == "d": | |
| return value * 86400 | |
| else: | |
| raise ValueError("Unsupported duration unit") | |
| if duration is None: # allow tokens that never expire | |
| expires = None | |
| else: | |
| duration_s = _duration_in_seconds(duration=duration) | |
| expires = datetime.utcnow() + timedelta(seconds=duration_s) | |
| aliases_json = json.dumps(aliases) | |
| config_json = json.dumps(config) | |
| metadata_json = json.dumps(metadata) | |
| user_id = user_id or str(uuid.uuid4()) | |
| try: | |
| # Create a new verification token (you may want to enhance this logic based on your needs) | |
| user_data = { | |
| "max_budget": max_budget, | |
| "user_email": user_email, | |
| "user_id": user_id, | |
| "spend": spend, | |
| } | |
| key_data = { | |
| "token": token, | |
| "expires": expires, | |
| "models": models, | |
| "aliases": aliases_json, | |
| "config": config_json, | |
| "spend": spend, | |
| "user_id": user_id, | |
| "max_parallel_requests": max_parallel_requests, | |
| "metadata": metadata_json, | |
| } | |
| if prisma_client is not None: | |
| verification_token_data = dict(key_data) | |
| verification_token_data.update(user_data) | |
| verbose_proxy_logger.debug("PrismaClient: Before Insert Data") | |
| await prisma_client.insert_data(data=verification_token_data) | |
| elif custom_db_client is not None: | |
| ## CREATE USER (If necessary) | |
| await custom_db_client.insert_data(value=user_data, table_name="user") | |
| ## CREATE KEY | |
| await custom_db_client.insert_data(value=key_data, table_name="key") | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) | |
| return { | |
| "token": token, | |
| "expires": expires, | |
| "user_id": user_id, | |
| "max_budget": max_budget, | |
| } | |
| async def delete_verification_token(tokens: List): | |
| global prisma_client | |
| try: | |
| if prisma_client: | |
| # Assuming 'db' is your Prisma Client instance | |
| deleted_tokens = await prisma_client.delete_data(tokens=tokens) | |
| else: | |
| raise Exception | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) | |
| return deleted_tokens | |
| def save_worker_config(**data): | |
| import json | |
| os.environ["WORKER_CONFIG"] = json.dumps(data) | |
| async def initialize( | |
| model=None, | |
| alias=None, | |
| api_base=None, | |
| api_version=None, | |
| debug=False, | |
| detailed_debug=False, | |
| temperature=None, | |
| max_tokens=None, | |
| request_timeout=600, | |
| max_budget=None, | |
| telemetry=False, | |
| drop_params=True, | |
| add_function_to_prompt=True, | |
| headers=None, | |
| save=False, | |
| use_queue=False, | |
| config=None, | |
| ): | |
| global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client | |
| user_model = model | |
| user_debug = debug | |
| if debug == True: # this needs to be first, so users can see Router init debugg | |
| from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
| import logging | |
| # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS | |
| verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info | |
| verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info | |
| if detailed_debug == True: | |
| from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
| import logging | |
| verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to info | |
| verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug | |
| litellm.set_verbose = True | |
| elif debug == False and detailed_debug == False: | |
| # users can control proxy debugging using env variable = 'LITELLM_LOG' | |
| litellm_log_setting = os.environ.get("LITELLM_LOG", "") | |
| if litellm_log_setting != None: | |
| if litellm_log_setting.upper() == "INFO": | |
| from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
| import logging | |
| # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS | |
| verbose_router_logger.setLevel( | |
| level=logging.INFO | |
| ) # set router logs to info | |
| verbose_proxy_logger.setLevel( | |
| level=logging.INFO | |
| ) # set proxy logs to info | |
| elif litellm_log_setting.upper() == "DEBUG": | |
| from litellm._logging import verbose_router_logger, verbose_proxy_logger | |
| import logging | |
| verbose_router_logger.setLevel( | |
| level=logging.DEBUG | |
| ) # set router logs to info | |
| verbose_proxy_logger.setLevel( | |
| level=logging.DEBUG | |
| ) # set proxy logs to debug | |
| litellm.set_verbose = True | |
| dynamic_config = {"general": {}, user_model: {}} | |
| if config: | |
| ( | |
| llm_router, | |
| llm_model_list, | |
| general_settings, | |
| ) = await proxy_config.load_config(router=llm_router, config_file_path=config) | |
| if headers: # model-specific param | |
| user_headers = headers | |
| dynamic_config[user_model]["headers"] = headers | |
| if api_base: # model-specific param | |
| user_api_base = api_base | |
| dynamic_config[user_model]["api_base"] = api_base | |
| if api_version: | |
| os.environ[ | |
| "AZURE_API_VERSION" | |
| ] = api_version # set this for azure - litellm can read this from the env | |
| if max_tokens: # model-specific param | |
| user_max_tokens = max_tokens | |
| dynamic_config[user_model]["max_tokens"] = max_tokens | |
| if temperature: # model-specific param | |
| user_temperature = temperature | |
| dynamic_config[user_model]["temperature"] = temperature | |
| if request_timeout: | |
| user_request_timeout = request_timeout | |
| dynamic_config[user_model]["request_timeout"] = request_timeout | |
| if alias: # model-specific param | |
| dynamic_config[user_model]["alias"] = alias | |
| if drop_params == True: # litellm-specific param | |
| litellm.drop_params = True | |
| dynamic_config["general"]["drop_params"] = True | |
| if add_function_to_prompt == True: # litellm-specific param | |
| litellm.add_function_to_prompt = True | |
| dynamic_config["general"]["add_function_to_prompt"] = True | |
| if max_budget: # litellm-specific param | |
| litellm.max_budget = max_budget | |
| dynamic_config["general"]["max_budget"] = max_budget | |
| if experimental: | |
| pass | |
| user_telemetry = telemetry | |
| usage_telemetry(feature="local_proxy_server") | |
| # for streaming | |
| def data_generator(response): | |
| verbose_proxy_logger.debug("inside generator") | |
| for chunk in response: | |
| verbose_proxy_logger.debug(f"returned chunk: {chunk}") | |
| try: | |
| yield f"data: {json.dumps(chunk.dict())}\n\n" | |
| except: | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| async def async_data_generator(response, user_api_key_dict): | |
| verbose_proxy_logger.debug("inside generator") | |
| try: | |
| start_time = time.time() | |
| async for chunk in response: | |
| verbose_proxy_logger.debug(f"returned chunk: {chunk}") | |
| try: | |
| yield f"data: {json.dumps(chunk.dict())}\n\n" | |
| except Exception as e: | |
| yield f"data: {str(e)}\n\n" | |
| ### ALERTING ### | |
| end_time = time.time() | |
| asyncio.create_task( | |
| proxy_logging_obj.response_taking_too_long( | |
| start_time=start_time, end_time=end_time, type="slow_response" | |
| ) | |
| ) | |
| # Streaming is done, yield the [DONE] chunk | |
| done_message = "[DONE]" | |
| yield f"data: {done_message}\n\n" | |
| except Exception as e: | |
| yield f"data: {str(e)}\n\n" | |
| def get_litellm_model_info(model: dict = {}): | |
| model_info = model.get("model_info", {}) | |
| model_to_lookup = model.get("litellm_params", {}).get("model", None) | |
| try: | |
| if "azure" in model_to_lookup: | |
| model_to_lookup = model_info.get("base_model", None) | |
| litellm_model_info = litellm.get_model_info(model_to_lookup) | |
| return litellm_model_info | |
| except: | |
| # this should not block returning on /model/info | |
| # if litellm does not have info on the model it should return {} | |
| return {} | |
| def parse_cache_control(cache_control): | |
| cache_dict = {} | |
| directives = cache_control.split(", ") | |
| for directive in directives: | |
| if "=" in directive: | |
| key, value = directive.split("=") | |
| cache_dict[key] = value | |
| else: | |
| cache_dict[directive] = True | |
| return cache_dict | |
| async def startup_event(): | |
| global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings | |
| import json | |
| ### LOAD MASTER KEY ### | |
| # check if master key set in environment - load from there | |
| master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) | |
| # check if DATABASE_URL in environment - load from there | |
| if prisma_client is None: | |
| prisma_setup(database_url=os.getenv("DATABASE_URL")) | |
| ### LOAD CONFIG ### | |
| worker_config = litellm.get_secret("WORKER_CONFIG") | |
| verbose_proxy_logger.debug(f"worker_config: {worker_config}") | |
| # check if it's a valid file path | |
| if os.path.isfile(worker_config): | |
| if proxy_config.is_yaml(config_file_path=worker_config): | |
| ( | |
| llm_router, | |
| llm_model_list, | |
| general_settings, | |
| ) = await proxy_config.load_config( | |
| router=llm_router, config_file_path=worker_config | |
| ) | |
| else: | |
| await initialize(**worker_config) | |
| else: | |
| # if not, assume it's a json string | |
| worker_config = json.loads(os.getenv("WORKER_CONFIG")) | |
| await initialize(**worker_config) | |
| proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made | |
| if use_background_health_checks: | |
| asyncio.create_task( | |
| _run_background_health_check() | |
| ) # start the background health check coroutine. | |
| verbose_proxy_logger.debug(f"prisma client - {prisma_client}") | |
| if prisma_client is not None: | |
| await prisma_client.connect() | |
| if custom_db_client is not None: | |
| await custom_db_client.connect() | |
| if prisma_client is not None and master_key is not None: | |
| # add master key to db | |
| await generate_key_helper_fn( | |
| duration=None, models=[], aliases={}, config={}, spend=0, token=master_key | |
| ) | |
| if custom_db_client is not None and master_key is not None: | |
| # add master key to db | |
| await generate_key_helper_fn( | |
| duration=None, models=[], aliases={}, config={}, spend=0, token=master_key | |
| ) | |
| #### API ENDPOINTS #### | |
| # if project requires model list | |
| def model_list(): | |
| global llm_model_list, general_settings | |
| all_models = [] | |
| if general_settings.get("infer_model_from_keys", False): | |
| all_models = litellm.utils.get_valid_models() | |
| if llm_model_list: | |
| all_models = list(set(all_models + [m["model_name"] for m in llm_model_list])) | |
| if user_model is not None: | |
| all_models += [user_model] | |
| verbose_proxy_logger.debug(f"all_models: {all_models}") | |
| ### CHECK OLLAMA MODELS ### | |
| try: | |
| response = requests.get("http://0.0.0.0:11434/api/tags") | |
| models = response.json()["models"] | |
| ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models] | |
| all_models.extend(ollama_models) | |
| except Exception as e: | |
| pass | |
| return dict( | |
| data=[ | |
| { | |
| "id": model, | |
| "object": "model", | |
| "created": 1677610602, | |
| "owned_by": "openai", | |
| } | |
| for model in all_models | |
| ], | |
| object="list", | |
| ) | |
| async def completion( | |
| request: Request, | |
| fastapi_response: Response, | |
| model: Optional[str] = None, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| ): | |
| global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
| try: | |
| body = await request.body() | |
| body_str = body.decode() | |
| try: | |
| data = ast.literal_eval(body_str) | |
| except: | |
| data = json.loads(body_str) | |
| data["user"] = data.get("user", user_api_key_dict.user_id) | |
| data["model"] = ( | |
| general_settings.get("completion_model", None) # server default | |
| or user_model # model name passed via cli args | |
| or model # for azure deployments | |
| or data["model"] # default passed in http request | |
| ) | |
| if user_model: | |
| data["model"] = user_model | |
| if "metadata" in data: | |
| data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| data["metadata"]["headers"] = dict(request.headers) | |
| else: | |
| data["metadata"] = { | |
| "user_api_key": user_api_key_dict.api_key, | |
| "user_api_key_user_id": user_api_key_dict.user_id, | |
| } | |
| data["metadata"]["headers"] = dict(request.headers) | |
| # override with user settings, these are params passed via cli | |
| if user_temperature: | |
| data["temperature"] = user_temperature | |
| if user_request_timeout: | |
| data["request_timeout"] = user_request_timeout | |
| if user_max_tokens: | |
| data["max_tokens"] = user_max_tokens | |
| if user_api_base: | |
| data["api_base"] = user_api_base | |
| ### CALL HOOKS ### - modify incoming data before calling the model | |
| data = await proxy_logging_obj.pre_call_hook( | |
| user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
| ) | |
| start_time = time.time() | |
| ### ROUTE THE REQUESTs ### | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| # skip router if user passed their key | |
| if "api_key" in data: | |
| response = await litellm.atext_completion(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in router_model_names | |
| ): # model in router model list | |
| response = await llm_router.atext_completion(**data) | |
| elif ( | |
| llm_router is not None | |
| and llm_router.model_group_alias is not None | |
| and data["model"] in llm_router.model_group_alias | |
| ): # model set in model_group_alias | |
| response = await llm_router.atext_completion(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in llm_router.deployment_names | |
| ): # model in router deployments, calling a specific deployment on the router | |
| response = await llm_router.atext_completion( | |
| **data, specific_deployment=True | |
| ) | |
| else: # router is not set | |
| response = await litellm.atext_completion(**data) | |
| if hasattr(response, "_hidden_params"): | |
| model_id = response._hidden_params.get("model_id", None) or "" | |
| else: | |
| model_id = "" | |
| verbose_proxy_logger.debug(f"final response: {response}") | |
| if ( | |
| "stream" in data and data["stream"] == True | |
| ): # use generate_responses to stream responses | |
| custom_headers = {"x-litellm-model-id": model_id} | |
| return StreamingResponse( | |
| async_data_generator( | |
| user_api_key_dict=user_api_key_dict, | |
| response=response, | |
| ), | |
| media_type="text/event-stream", | |
| headers=custom_headers, | |
| ) | |
| ### ALERTING ### | |
| end_time = time.time() | |
| asyncio.create_task( | |
| proxy_logging_obj.response_taking_too_long( | |
| start_time=start_time, end_time=end_time, type="slow_response" | |
| ) | |
| ) | |
| fastapi_response.headers["x-litellm-model-id"] = model_id | |
| return response | |
| except Exception as e: | |
| verbose_proxy_logger.debug(f"EXCEPTION RAISED IN PROXY MAIN.PY") | |
| verbose_proxy_logger.debug( | |
| f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" | |
| ) | |
| traceback.print_exc() | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"{str(e)}\n\n{error_traceback}" | |
| try: | |
| status = e.status_code # type: ignore | |
| except: | |
| status = 500 | |
| raise HTTPException(status_code=status, detail=error_msg) | |
| # azure compatible endpoint | |
| async def chat_completion( | |
| request: Request, | |
| fastapi_response: Response, | |
| model: Optional[str] = None, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| ): | |
| global general_settings, user_debug, proxy_logging_obj, llm_model_list | |
| try: | |
| data = {} | |
| body = await request.body() | |
| body_str = body.decode() | |
| try: | |
| data = ast.literal_eval(body_str) | |
| except: | |
| data = json.loads(body_str) | |
| # Include original request and headers in the data | |
| data["proxy_server_request"] = { | |
| "url": str(request.url), | |
| "method": request.method, | |
| "headers": dict(request.headers), | |
| "body": copy.copy(data), # use copy instead of deepcopy | |
| } | |
| ## Cache Controls | |
| headers = request.headers | |
| verbose_proxy_logger.debug(f"Request Headers: {headers}") | |
| cache_control_header = headers.get("Cache-Control", None) | |
| if cache_control_header: | |
| cache_dict = parse_cache_control(cache_control_header) | |
| data["ttl"] = cache_dict.get("s-maxage") | |
| verbose_proxy_logger.debug(f"receiving data: {data}") | |
| data["model"] = ( | |
| general_settings.get("completion_model", None) # server default | |
| or user_model # model name passed via cli args | |
| or model # for azure deployments | |
| or data["model"] # default passed in http request | |
| ) | |
| # users can pass in 'user' param to /chat/completions. Don't override it | |
| if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
| # if users are using user_api_key_auth, set `user` in `data` | |
| data["user"] = user_api_key_dict.user_id | |
| if "metadata" in data: | |
| verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') | |
| data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| data["metadata"]["headers"] = dict(request.headers) | |
| else: | |
| data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
| # override with user settings, these are params passed via cli | |
| if user_temperature: | |
| data["temperature"] = user_temperature | |
| if user_request_timeout: | |
| data["request_timeout"] = user_request_timeout | |
| if user_max_tokens: | |
| data["max_tokens"] = user_max_tokens | |
| if user_api_base: | |
| data["api_base"] = user_api_base | |
| ### CALL HOOKS ### - modify incoming data before calling the model | |
| data = await proxy_logging_obj.pre_call_hook( | |
| user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
| ) | |
| start_time = time.time() | |
| ### ROUTE THE REQUEST ### | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| # skip router if user passed their key | |
| if "api_key" in data: | |
| response = await litellm.acompletion(**data) | |
| elif "user_config" in data: | |
| # initialize a new router instance. make request using this Router | |
| router_config = data.pop("user_config") | |
| user_router = litellm.Router(**router_config) | |
| response = await user_router.acompletion(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in router_model_names | |
| ): # model in router model list | |
| response = await llm_router.acompletion(**data) | |
| elif ( | |
| llm_router is not None | |
| and llm_router.model_group_alias is not None | |
| and data["model"] in llm_router.model_group_alias | |
| ): # model set in model_group_alias | |
| response = await llm_router.acompletion(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in llm_router.deployment_names | |
| ): # model in router deployments, calling a specific deployment on the router | |
| response = await llm_router.acompletion(**data, specific_deployment=True) | |
| else: # router is not set | |
| response = await litellm.acompletion(**data) | |
| if hasattr(response, "_hidden_params"): | |
| model_id = response._hidden_params.get("model_id", None) or "" | |
| else: | |
| model_id = "" | |
| if ( | |
| "stream" in data and data["stream"] == True | |
| ): # use generate_responses to stream responses | |
| custom_headers = {"x-litellm-model-id": model_id} | |
| return StreamingResponse( | |
| async_data_generator( | |
| user_api_key_dict=user_api_key_dict, | |
| response=response, | |
| ), | |
| media_type="text/event-stream", | |
| headers=custom_headers, | |
| ) | |
| ### ALERTING ### | |
| end_time = time.time() | |
| asyncio.create_task( | |
| proxy_logging_obj.response_taking_too_long( | |
| start_time=start_time, end_time=end_time, type="slow_response" | |
| ) | |
| ) | |
| fastapi_response.headers["x-litellm-model-id"] = model_id | |
| return response | |
| except Exception as e: | |
| traceback.print_exc() | |
| await proxy_logging_obj.post_call_failure_hook( | |
| user_api_key_dict=user_api_key_dict, original_exception=e | |
| ) | |
| verbose_proxy_logger.debug( | |
| f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" | |
| ) | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| if llm_router is not None and data.get("model", "") in router_model_names: | |
| verbose_proxy_logger.debug("Results from router") | |
| verbose_proxy_logger.debug("\nRouter stats") | |
| verbose_proxy_logger.debug("\nTotal Calls made") | |
| for key, value in llm_router.total_calls.items(): | |
| verbose_proxy_logger.debug(f"{key}: {value}") | |
| verbose_proxy_logger.debug("\nSuccess Calls made") | |
| for key, value in llm_router.success_calls.items(): | |
| verbose_proxy_logger.debug(f"{key}: {value}") | |
| verbose_proxy_logger.debug("\nFail Calls made") | |
| for key, value in llm_router.fail_calls.items(): | |
| verbose_proxy_logger.debug(f"{key}: {value}") | |
| if user_debug: | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"{str(e)}\n\n{error_traceback}" | |
| try: | |
| status = e.status_code # type: ignore | |
| except: | |
| status = 500 | |
| raise HTTPException(status_code=status, detail=error_msg) | |
| async def embeddings( | |
| request: Request, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| ): | |
| global proxy_logging_obj | |
| try: | |
| # Use orjson to parse JSON data, orjson speeds up requests significantly | |
| body = await request.body() | |
| data = orjson.loads(body) | |
| # Include original request and headers in the data | |
| data["proxy_server_request"] = { | |
| "url": str(request.url), | |
| "method": request.method, | |
| "headers": dict(request.headers), | |
| "body": copy.copy(data), # use copy instead of deepcopy | |
| } | |
| if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
| data["user"] = user_api_key_dict.user_id | |
| data["model"] = ( | |
| general_settings.get("embedding_model", None) # server default | |
| or user_model # model name passed via cli args | |
| or data["model"] # default passed in http request | |
| ) | |
| if user_model: | |
| data["model"] = user_model | |
| if "metadata" in data: | |
| data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| else: | |
| data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| if ( | |
| "input" in data | |
| and isinstance(data["input"], list) | |
| and isinstance(data["input"][0], list) | |
| and isinstance(data["input"][0][0], int) | |
| ): # check if array of tokens passed in | |
| # check if non-openai/azure model called - e.g. for langchain integration | |
| if llm_model_list is not None and data["model"] in router_model_names: | |
| for m in llm_model_list: | |
| if m["model_name"] == data["model"] and ( | |
| m["litellm_params"]["model"] in litellm.open_ai_embedding_models | |
| or m["litellm_params"]["model"].startswith("azure/") | |
| ): | |
| pass | |
| else: | |
| # non-openai/azure embedding model called with token input | |
| input_list = [] | |
| for i in data["input"]: | |
| input_list.append( | |
| litellm.decode(model="gpt-3.5-turbo", tokens=i) | |
| ) | |
| data["input"] = input_list | |
| break | |
| ### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
| data = await proxy_logging_obj.pre_call_hook( | |
| user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" | |
| ) | |
| start_time = time.time() | |
| ## ROUTE TO CORRECT ENDPOINT ## | |
| # skip router if user passed their key | |
| if "api_key" in data: | |
| response = await litellm.aembedding(**data) | |
| elif "user_config" in data: | |
| # initialize a new router instance. make request using this Router | |
| router_config = data.pop("user_config") | |
| user_router = litellm.Router(**router_config) | |
| response = await user_router.aembedding(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in router_model_names | |
| ): # model in router model list | |
| response = await llm_router.aembedding(**data) | |
| elif ( | |
| llm_router is not None | |
| and llm_router.model_group_alias is not None | |
| and data["model"] in llm_router.model_group_alias | |
| ): # model set in model_group_alias | |
| response = await llm_router.aembedding( | |
| **data | |
| ) # ensure this goes the llm_router, router will do the correct alias mapping | |
| elif ( | |
| llm_router is not None and data["model"] in llm_router.deployment_names | |
| ): # model in router deployments, calling a specific deployment on the router | |
| response = await llm_router.aembedding(**data, specific_deployment=True) | |
| else: | |
| response = await litellm.aembedding(**data) | |
| ### ALERTING ### | |
| end_time = time.time() | |
| asyncio.create_task( | |
| proxy_logging_obj.response_taking_too_long( | |
| start_time=start_time, end_time=end_time, type="slow_response" | |
| ) | |
| ) | |
| return response | |
| except Exception as e: | |
| await proxy_logging_obj.post_call_failure_hook( | |
| user_api_key_dict=user_api_key_dict, original_exception=e | |
| ) | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"{str(e)}\n\n{error_traceback}" | |
| try: | |
| status = e.status_code # type: ignore | |
| except: | |
| status = 500 | |
| raise HTTPException(status_code=status, detail=error_msg) | |
| async def image_generation( | |
| request: Request, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| ): | |
| global proxy_logging_obj | |
| try: | |
| # Use orjson to parse JSON data, orjson speeds up requests significantly | |
| body = await request.body() | |
| data = orjson.loads(body) | |
| # Include original request and headers in the data | |
| data["proxy_server_request"] = { | |
| "url": str(request.url), | |
| "method": request.method, | |
| "headers": dict(request.headers), | |
| "body": copy.copy(data), # use copy instead of deepcopy | |
| } | |
| if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
| data["user"] = user_api_key_dict.user_id | |
| data["model"] = ( | |
| general_settings.get("image_generation_model", None) # server default | |
| or user_model # model name passed via cli args | |
| or data["model"] # default passed in http request | |
| ) | |
| if user_model: | |
| data["model"] = user_model | |
| if "metadata" in data: | |
| data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| else: | |
| data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| ### CALL HOOKS ### - modify incoming data / reject request before calling the model | |
| data = await proxy_logging_obj.pre_call_hook( | |
| user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings" | |
| ) | |
| start_time = time.time() | |
| ## ROUTE TO CORRECT ENDPOINT ## | |
| # skip router if user passed their key | |
| if "api_key" in data: | |
| response = await litellm.aimage_generation(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in router_model_names | |
| ): # model in router model list | |
| response = await llm_router.aimage_generation(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in llm_router.deployment_names | |
| ): # model in router deployments, calling a specific deployment on the router | |
| response = await llm_router.aimage_generation( | |
| **data, specific_deployment=True | |
| ) | |
| elif ( | |
| llm_router is not None | |
| and llm_router.model_group_alias is not None | |
| and data["model"] in llm_router.model_group_alias | |
| ): # model set in model_group_alias | |
| response = await llm_router.aimage_generation( | |
| **data | |
| ) # ensure this goes the llm_router, router will do the correct alias mapping | |
| else: | |
| response = await litellm.aimage_generation(**data) | |
| ### ALERTING ### | |
| end_time = time.time() | |
| asyncio.create_task( | |
| proxy_logging_obj.response_taking_too_long( | |
| start_time=start_time, end_time=end_time, type="slow_response" | |
| ) | |
| ) | |
| return response | |
| except Exception as e: | |
| await proxy_logging_obj.post_call_failure_hook( | |
| user_api_key_dict=user_api_key_dict, original_exception=e | |
| ) | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"{str(e)}\n\n{error_traceback}" | |
| try: | |
| status = e.status_code # type: ignore | |
| except: | |
| status = 500 | |
| raise HTTPException(status_code=status, detail=error_msg) | |
| #### KEY MANAGEMENT #### | |
| async def generate_key_fn( | |
| request: Request, | |
| data: GenerateKeyRequest, | |
| Authorization: Optional[str] = Header(None), | |
| ): | |
| """ | |
| Generate an API key based on the provided data. | |
| Docs: https://docs.litellm.ai/docs/proxy/virtual_keys | |
| Parameters: | |
| - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** | |
| - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) | |
| - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models | |
| - config: Optional[dict] - any key-specific configs, overrides config in config.yaml | |
| - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend | |
| - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. | |
| - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } | |
| Returns: | |
| - key: (str) The generated api key | |
| - expires: (datetime) Datetime object for when key expires. | |
| - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. | |
| """ | |
| verbose_proxy_logger.debug("entered /key/generate") | |
| data_json = data.json() # type: ignore | |
| response = await generate_key_helper_fn(**data_json) | |
| return GenerateKeyResponse( | |
| key=response["token"], expires=response["expires"], user_id=response["user_id"] | |
| ) | |
| async def update_key_fn(request: Request, data: UpdateKeyRequest): | |
| """ | |
| Update an existing key | |
| """ | |
| global prisma_client | |
| try: | |
| data_json: dict = data.json() | |
| key = data_json.pop("key") | |
| # get the row from db | |
| if prisma_client is None: | |
| raise Exception("Not connected to DB!") | |
| non_default_values = {k: v for k, v in data_json.items() if v is not None} | |
| response = await prisma_client.update_data( | |
| token=key, data={**non_default_values, "token": key} | |
| ) | |
| return {"key": key, **non_default_values} | |
| # update based on remaining passed in values | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": str(e)}, | |
| ) | |
| async def delete_key_fn(request: Request, data: DeleteKeyRequest): | |
| try: | |
| keys = data.keys | |
| deleted_keys = await delete_verification_token(tokens=keys) | |
| assert len(keys) == deleted_keys | |
| return {"deleted_keys": keys} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": str(e)}, | |
| ) | |
| async def info_key_fn( | |
| key: str = fastapi.Query(..., description="Key in the request parameters") | |
| ): | |
| global prisma_client | |
| try: | |
| if prisma_client is None: | |
| raise Exception( | |
| f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" | |
| ) | |
| key_info = await prisma_client.get_data(token=key) | |
| return {"key": key, "info": key_info} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": str(e)}, | |
| ) | |
| #### USER MANAGEMENT #### | |
| async def new_user(data: NewUserRequest): | |
| """ | |
| Use this to create a new user with a budget. | |
| Returns user id, budget + new key. | |
| Parameters: | |
| - user_id: Optional[str] - Specify a user id. If not set, a unique id will be generated. | |
| - max_budget: Optional[float] - Specify max budget for a given user. | |
| - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). **(Default is set to 1 hour.)** | |
| - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) | |
| - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models | |
| - config: Optional[dict] - any key-specific configs, overrides config in config.yaml | |
| - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend | |
| - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. | |
| - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } | |
| Returns: | |
| - key: (str) The generated api key | |
| - expires: (datetime) Datetime object for when key expires. | |
| - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. | |
| - max_budget: (float|None) Max budget for given user. | |
| """ | |
| data_json = data.json() # type: ignore | |
| response = await generate_key_helper_fn(**data_json) | |
| return NewUserResponse( | |
| key=response["token"], | |
| expires=response["expires"], | |
| user_id=response["user_id"], | |
| max_budget=response["max_budget"], | |
| ) | |
| async def user_auth(request: Request): | |
| """ | |
| Allows UI ("https://dashboard.litellm.ai/", or self-hosted - os.getenv("LITELLM_HOSTED_UI")) to request a magic link to be sent to user email, for auth to proxy. | |
| Only allows emails from accepted email subdomains. | |
| Rate limit: 1 request every 60s. | |
| Only works, if you enable 'allow_user_auth' in general settings: | |
| e.g.: | |
| ```yaml | |
| general_settings: | |
| allow_user_auth: true | |
| ``` | |
| Requirements: | |
| SMTP server details saved in .env: | |
| - os.environ["SMTP_HOST"] | |
| - os.environ["SMTP_PORT"] | |
| - os.environ["SMTP_USERNAME"] | |
| - os.environ["SMTP_PASSWORD"] | |
| - os.environ["SMTP_SENDER_EMAIL"] | |
| """ | |
| global prisma_client | |
| data = await request.json() # type: ignore | |
| user_email = data["user_email"] | |
| page_params = data["page"] | |
| if user_email is None: | |
| raise HTTPException(status_code=400, detail="User email is none") | |
| if prisma_client is None: # if no db connected, raise an error | |
| raise Exception("No connected db.") | |
| ### Check if user email in user table | |
| response = await prisma_client.get_generic_data( | |
| key="user_email", value=user_email, table_name="users" | |
| ) | |
| ### if so - generate a 24 hr key with that user id | |
| if response is not None: | |
| user_id = response.user_id | |
| response = await generate_key_helper_fn( | |
| **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id} # type: ignore | |
| ) | |
| else: ### else - create new user | |
| response = await generate_key_helper_fn( | |
| **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_email": user_email} # type: ignore | |
| ) | |
| base_url = os.getenv("LITELLM_HOSTED_UI", "https://dashboard.litellm.ai/") | |
| params = { | |
| "sender_name": "LiteLLM Proxy", | |
| "sender_email": os.getenv("SMTP_SENDER_EMAIL"), | |
| "receiver_email": user_email, | |
| "subject": "Your Magic Link", | |
| "html": f"<strong> Follow this link, to login:\n\n{base_url}user/?token={response['token']}&user_id={response['user_id']}&page={page_params}</strong>", | |
| } | |
| await send_email(**params) | |
| return "Email sent!" | |
| async def user_info( | |
| user_id: str = fastapi.Query(..., description="User ID in the request parameters") | |
| ): | |
| """ | |
| Use this to get user information. (user row + all user key info) | |
| """ | |
| global prisma_client | |
| try: | |
| if prisma_client is None: | |
| raise Exception( | |
| f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" | |
| ) | |
| ## GET USER ROW ## | |
| user_info = await prisma_client.get_data(user_id=user_id) | |
| ## GET ALL KEYS ## | |
| keys = await prisma_client.get_data( | |
| user_id=user_id, table_name="key", query_type="find_all" | |
| ) | |
| return {"user_id": user_id, "user_info": user_info, "keys": keys} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": str(e)}, | |
| ) | |
| async def user_update(request: Request): | |
| """ | |
| [TODO]: Use this to update user budget | |
| """ | |
| pass | |
| #### MODEL MANAGEMENT #### | |
| #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 | |
| async def add_new_model(model_params: ModelParams): | |
| global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config | |
| try: | |
| # Load existing config | |
| config = await proxy_config.get_config() | |
| verbose_proxy_logger.debug(f"User config path: {user_config_file_path}") | |
| verbose_proxy_logger.debug(f"Loaded config: {config}") | |
| # Add the new model to the config | |
| model_info = model_params.model_info.json() | |
| model_info = {k: v for k, v in model_info.items() if v is not None} | |
| config["model_list"].append( | |
| { | |
| "model_name": model_params.model_name, | |
| "litellm_params": model_params.litellm_params, | |
| "model_info": model_info, | |
| } | |
| ) | |
| verbose_proxy_logger.debug(f"updated model list: {config['model_list']}") | |
| # Save new config | |
| await proxy_config.save_config(new_config=config) | |
| return {"message": "Model added successfully"} | |
| except Exception as e: | |
| traceback.print_exc() | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| raise HTTPException( | |
| status_code=500, detail=f"Internal Server Error: {str(e)}" | |
| ) | |
| #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info | |
| async def model_info_v1(request: Request): | |
| global llm_model_list, general_settings, user_config_file_path, proxy_config | |
| # Load existing config | |
| config = await proxy_config.get_config() | |
| all_models = config["model_list"] | |
| for model in all_models: | |
| # provided model_info in config.yaml | |
| model_info = model.get("model_info", {}) | |
| # read litellm model_prices_and_context_window.json to get the following: | |
| # input_cost_per_token, output_cost_per_token, max_tokens | |
| litellm_model_info = get_litellm_model_info(model=model) | |
| for k, v in litellm_model_info.items(): | |
| if k not in model_info: | |
| model_info[k] = v | |
| model["model_info"] = model_info | |
| # don't return the api key | |
| model["litellm_params"].pop("api_key", None) | |
| verbose_proxy_logger.debug(f"all_models: {all_models}") | |
| return {"data": all_models} | |
| #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/964 | |
| async def delete_model(model_info: ModelInfoDelete): | |
| global llm_router, llm_model_list, general_settings, user_config_file_path, proxy_config | |
| try: | |
| if not os.path.exists(user_config_file_path): | |
| raise HTTPException(status_code=404, detail="Config file does not exist.") | |
| # Load existing config | |
| config = await proxy_config.get_config() | |
| # If model_list is not in the config, nothing can be deleted | |
| if len(config.get("model_list", [])) == 0: | |
| raise HTTPException( | |
| status_code=400, detail="No model list available in the config." | |
| ) | |
| # Check if the model with the specified model_id exists | |
| model_to_delete = None | |
| for model in config["model_list"]: | |
| if model.get("model_info", {}).get("id", None) == model_info.id: | |
| model_to_delete = model | |
| break | |
| # If the model was not found, return an error | |
| if model_to_delete is None: | |
| raise HTTPException( | |
| status_code=400, detail="Model with given model_id not found." | |
| ) | |
| # Remove model from the list and save the updated config | |
| config["model_list"].remove(model_to_delete) | |
| # Save updated config | |
| config = await proxy_config.save_config(new_config=config) | |
| return {"message": "Model deleted successfully"} | |
| except HTTPException as e: | |
| # Re-raise the HTTP exceptions to be handled by FastAPI | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") | |
| #### EXPERIMENTAL QUEUING #### | |
| async def _litellm_chat_completions_worker(data, user_api_key_dict): | |
| """ | |
| worker to make litellm completions calls | |
| """ | |
| while True: | |
| try: | |
| ### CALL HOOKS ### - modify incoming data before calling the model | |
| data = await proxy_logging_obj.pre_call_hook( | |
| user_api_key_dict=user_api_key_dict, data=data, call_type="completion" | |
| ) | |
| verbose_proxy_logger.debug(f"_litellm_chat_completions_worker started") | |
| ### ROUTE THE REQUEST ### | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| if ( | |
| llm_router is not None and data["model"] in router_model_names | |
| ): # model in router model list | |
| response = await llm_router.acompletion(**data) | |
| elif ( | |
| llm_router is not None and data["model"] in llm_router.deployment_names | |
| ): # model in router deployments, calling a specific deployment on the router | |
| response = await llm_router.acompletion( | |
| **data, specific_deployment=True | |
| ) | |
| elif ( | |
| llm_router is not None | |
| and llm_router.model_group_alias is not None | |
| and data["model"] in llm_router.model_group_alias | |
| ): # model set in model_group_alias | |
| response = await llm_router.acompletion(**data) | |
| else: # router is not set | |
| response = await litellm.acompletion(**data) | |
| verbose_proxy_logger.debug(f"final response: {response}") | |
| return response | |
| except HTTPException as e: | |
| verbose_proxy_logger.debug( | |
| f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}" | |
| ) | |
| if ( | |
| e.status_code == 429 | |
| and "Max parallel request limit reached" in e.detail | |
| ): | |
| verbose_proxy_logger.debug(f"Max parallel request limit reached!") | |
| timeout = litellm._calculate_retry_after( | |
| remaining_retries=3, max_retries=3, min_timeout=1 | |
| ) | |
| await asyncio.sleep(timeout) | |
| else: | |
| raise e | |
| async def async_queue_request( | |
| request: Request, | |
| model: Optional[str] = None, | |
| user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), | |
| background_tasks: BackgroundTasks = BackgroundTasks(), | |
| ): | |
| global general_settings, user_debug, proxy_logging_obj | |
| """ | |
| v2 attempt at a background worker to handle queuing. | |
| Just supports /chat/completion calls currently. | |
| Now using a FastAPI background task + /chat/completions compatible endpoint | |
| """ | |
| try: | |
| data = {} | |
| data = await request.json() # type: ignore | |
| # Include original request and headers in the data | |
| data["proxy_server_request"] = { | |
| "url": str(request.url), | |
| "method": request.method, | |
| "headers": dict(request.headers), | |
| "body": copy.copy(data), # use copy instead of deepcopy | |
| } | |
| verbose_proxy_logger.debug(f"receiving data: {data}") | |
| data["model"] = ( | |
| general_settings.get("completion_model", None) # server default | |
| or user_model # model name passed via cli args | |
| or model # for azure deployments | |
| or data["model"] # default passed in http request | |
| ) | |
| # users can pass in 'user' param to /chat/completions. Don't override it | |
| if data.get("user", None) is None and user_api_key_dict.user_id is not None: | |
| # if users are using user_api_key_auth, set `user` in `data` | |
| data["user"] = user_api_key_dict.user_id | |
| if "metadata" in data: | |
| verbose_proxy_logger.debug(f'received metadata: {data["metadata"]}') | |
| data["metadata"]["user_api_key"] = user_api_key_dict.api_key | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| else: | |
| data["metadata"] = {"user_api_key": user_api_key_dict.api_key} | |
| data["metadata"]["headers"] = dict(request.headers) | |
| data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id | |
| global user_temperature, user_request_timeout, user_max_tokens, user_api_base | |
| # override with user settings, these are params passed via cli | |
| if user_temperature: | |
| data["temperature"] = user_temperature | |
| if user_request_timeout: | |
| data["request_timeout"] = user_request_timeout | |
| if user_max_tokens: | |
| data["max_tokens"] = user_max_tokens | |
| if user_api_base: | |
| data["api_base"] = user_api_base | |
| response = await asyncio.wait_for( | |
| _litellm_chat_completions_worker( | |
| data=data, user_api_key_dict=user_api_key_dict | |
| ), | |
| timeout=litellm.request_timeout, | |
| ) | |
| if ( | |
| "stream" in data and data["stream"] == True | |
| ): # use generate_responses to stream responses | |
| return StreamingResponse( | |
| async_data_generator( | |
| user_api_key_dict=user_api_key_dict, response=response | |
| ), | |
| media_type="text/event-stream", | |
| ) | |
| return response | |
| except Exception as e: | |
| await proxy_logging_obj.post_call_failure_hook( | |
| user_api_key_dict=user_api_key_dict, original_exception=e | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail={"error": str(e)}, | |
| ) | |
| async def retrieve_server_log(request: Request): | |
| filepath = os.path.expanduser("~/.ollama/logs/server.log") | |
| return FileResponse(filepath) | |
| #### BASIC ENDPOINTS #### | |
| async def update_config(config_info: ConfigYAML): | |
| """ | |
| For Admin UI - allows admin to update config via UI | |
| Currently supports modifying General Settings + LiteLLM settings | |
| """ | |
| global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj | |
| try: | |
| # Load existing config | |
| config = await proxy_config.get_config() | |
| backup_config = copy.deepcopy(config) | |
| verbose_proxy_logger.debug(f"Loaded config: {config}") | |
| # update the general settings | |
| if config_info.general_settings is not None: | |
| config.setdefault("general_settings", {}) | |
| updated_general_settings = config_info.general_settings.dict( | |
| exclude_none=True | |
| ) | |
| config["general_settings"] = { | |
| **updated_general_settings, | |
| **config["general_settings"], | |
| } | |
| if config_info.environment_variables is not None: | |
| config.setdefault("environment_variables", {}) | |
| updated_environment_variables = config_info.environment_variables | |
| config["environment_variables"] = { | |
| **updated_environment_variables, | |
| **config["environment_variables"], | |
| } | |
| # update the litellm settings | |
| if config_info.litellm_settings is not None: | |
| config.setdefault("litellm_settings", {}) | |
| updated_litellm_settings = config_info.litellm_settings | |
| config["litellm_settings"] = { | |
| **updated_litellm_settings, | |
| **config["litellm_settings"], | |
| } | |
| # Save the updated config | |
| await proxy_config.save_config(new_config=config) | |
| # Test new connections | |
| ## Slack | |
| if "slack" in config.get("general_settings", {}).get("alerting", []): | |
| await proxy_logging_obj.alerting_handler( | |
| message="This is a test", level="Low" | |
| ) | |
| return {"message": "Config updated successfully"} | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"An error occurred - {str(e)}") | |
| async def config_yaml_endpoint(config_info: ConfigYAML): | |
| """ | |
| This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI. | |
| Parameters: | |
| The config.yaml object has the following attributes: | |
| - **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc). | |
| - **litellm_settings**: *Optional[dict]*: Settings for the litellm module. You can specify multiple properties like "drop_params", "set_verbose", "api_base", "cache". | |
| - **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others. | |
| Please, refer to each class's description for a better understanding of the specific attributes within them. | |
| Note: This is a mock endpoint primarily meant for demonstration purposes, and does not actually provide or change any configurations. | |
| """ | |
| return {"hello": "world"} | |
| async def test_endpoint(request: Request): | |
| """ | |
| A test endpoint that pings the proxy server to check if it's healthy. | |
| Parameters: | |
| request (Request): The incoming request. | |
| Returns: | |
| dict: A dictionary containing the route of the request URL. | |
| """ | |
| # ping the proxy server to check if its healthy | |
| return {"route": request.url.path} | |
| async def health_endpoint( | |
| request: Request, | |
| model: Optional[str] = fastapi.Query( | |
| None, description="Specify the model name (optional)" | |
| ), | |
| ): | |
| """ | |
| Check the health of all the endpoints in config.yaml | |
| To run health checks in the background, add this to config.yaml: | |
| ``` | |
| general_settings: | |
| # ... other settings | |
| background_health_checks: True | |
| ``` | |
| else, the health checks will be run on models when /health is called. | |
| """ | |
| global health_check_results, use_background_health_checks, user_model | |
| if llm_model_list is None: | |
| # if no router set, check if user set a model using litellm --model ollama/llama2 | |
| if user_model is not None: | |
| healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
| model_list=[], cli_model=user_model | |
| ) | |
| return { | |
| "healthy_endpoints": healthy_endpoints, | |
| "unhealthy_endpoints": unhealthy_endpoints, | |
| "healthy_count": len(healthy_endpoints), | |
| "unhealthy_count": len(unhealthy_endpoints), | |
| } | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail={"error": "Model list not initialized"}, | |
| ) | |
| if use_background_health_checks: | |
| return health_check_results | |
| else: | |
| healthy_endpoints, unhealthy_endpoints = await perform_health_check( | |
| llm_model_list, model | |
| ) | |
| return { | |
| "healthy_endpoints": healthy_endpoints, | |
| "unhealthy_endpoints": unhealthy_endpoints, | |
| "healthy_count": len(healthy_endpoints), | |
| "unhealthy_count": len(unhealthy_endpoints), | |
| } | |
| async def health_readiness(): | |
| """ | |
| Unprotected endpoint for checking if worker can receive requests | |
| """ | |
| global prisma_client | |
| if prisma_client is not None: # if db passed in, check if it's connected | |
| if prisma_client.db.is_connected() == True: | |
| return {"status": "healthy", "db": "connected"} | |
| else: | |
| return {"status": "healthy", "db": "Not connected"} | |
| raise HTTPException(status_code=503, detail="Service Unhealthy") | |
| async def health_liveliness(): | |
| """ | |
| Unprotected endpoint for checking if worker is alive | |
| """ | |
| return "I'm alive!" | |
| async def home(request: Request): | |
| return "LiteLLM: RUNNING" | |
| async def get_routes(): | |
| """ | |
| Get a list of available routes in the FastAPI application. | |
| """ | |
| routes = [] | |
| for route in app.routes: | |
| route_info = { | |
| "path": route.path, | |
| "methods": route.methods, | |
| "name": route.name, | |
| "endpoint": route.endpoint.__name__ if route.endpoint else None, | |
| } | |
| routes.append(route_info) | |
| return {"routes": routes} | |
| async def shutdown_event(): | |
| global prisma_client, master_key, user_custom_auth | |
| if prisma_client: | |
| verbose_proxy_logger.debug("Disconnecting from Prisma") | |
| await prisma_client.disconnect() | |
| ## RESET CUSTOM VARIABLES ## | |
| cleanup_router_config_variables() | |
| def cleanup_router_config_variables(): | |
| global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval | |
| # Set all variables to None | |
| master_key = None | |
| user_config_file_path = None | |
| otel_logging = None | |
| user_custom_auth = None | |
| user_custom_auth_path = None | |
| use_background_health_checks = None | |
| health_check_interval = None | |
| app.include_router(router) | |