Spaces:
Paused
Paused
| import sys, os | |
| import traceback | |
| import json | |
| import uuid | |
| from dotenv import load_dotenv | |
| from fastapi import Request | |
| from datetime import datetime | |
| load_dotenv() | |
| import os, io, time | |
| # this file is to test litellm/proxy | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import pytest, logging, asyncio | |
| import litellm | |
| from litellm.proxy.management_endpoints.model_management_endpoints import ( | |
| add_new_model, | |
| update_model, | |
| ) | |
| from litellm.proxy._types import LitellmUserRoles | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy.utils import PrismaClient, ProxyLogging | |
| from litellm.proxy.management_endpoints.team_endpoints import new_team | |
| verbose_proxy_logger.setLevel(level=logging.DEBUG) | |
| from litellm.caching.caching import DualCache | |
| from litellm.router import ( | |
| Deployment, | |
| LiteLLM_Params, | |
| ) | |
| from litellm.types.router import ModelInfo, updateDeployment, updateLiteLLMParams | |
| from litellm.proxy._types import UserAPIKeyAuth, NewTeamRequest, LiteLLM_TeamTable | |
| proxy_logging_obj = ProxyLogging(user_api_key_cache=DualCache()) | |
| def prisma_client(): | |
| from litellm.proxy.proxy_cli import append_query_params | |
| ### add connection pool + pool timeout args | |
| params = {"connection_limit": 100, "pool_timeout": 60} | |
| database_url = os.getenv("DATABASE_URL") | |
| modified_url = append_query_params(database_url, params) | |
| os.environ["DATABASE_URL"] = modified_url | |
| os.environ["STORE_MODEL_IN_DB"] = "true" | |
| # Assuming PrismaClient is a class that needs to be instantiated | |
| prisma_client = PrismaClient( | |
| database_url=os.environ["DATABASE_URL"], proxy_logging_obj=proxy_logging_obj | |
| ) | |
| # Reset litellm.proxy.proxy_server.prisma_client to None | |
| litellm.proxy.proxy_server.litellm_proxy_budget_name = ( | |
| f"litellm-proxy-budget-{time.time()}" | |
| ) | |
| litellm.proxy.proxy_server.user_custom_key_generate = None | |
| return prisma_client | |
| async def test_add_new_model(prisma_client): | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "store_model_in_db", True) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| import uuid | |
| _new_model_id = f"local-test-{uuid.uuid4().hex}" | |
| await add_new_model( | |
| model_params=Deployment( | |
| model_name="test_model", | |
| litellm_params=LiteLLM_Params( | |
| model="azure/gpt-3.5-turbo", | |
| api_key="test_api_key", | |
| api_base="test_api_base", | |
| rpm=1000, | |
| tpm=1000, | |
| ), | |
| model_info=ModelInfo( | |
| id=_new_model_id, | |
| ), | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() | |
| print("_new_models: ", _new_models) | |
| _new_model_in_db = None | |
| for model in _new_models: | |
| print("current model: ", model) | |
| if model.model_info["id"] == _new_model_id: | |
| print("FOUND MODEL: ", model) | |
| _new_model_in_db = model | |
| assert _new_model_in_db is not None | |
| async def test_add_update_model(prisma_client): | |
| # test that existing litellm_params are not updated | |
| # only new / updated params get updated | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "store_model_in_db", True) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.proxy_server import user_api_key_cache | |
| import uuid | |
| _new_model_id = f"local-test-{uuid.uuid4().hex}" | |
| await add_new_model( | |
| model_params=Deployment( | |
| model_name="test_model", | |
| litellm_params=LiteLLM_Params( | |
| model="azure/gpt-3.5-turbo", | |
| api_key="test_api_key", | |
| api_base="test_api_base", | |
| rpm=1000, | |
| tpm=1000, | |
| ), | |
| model_info=ModelInfo( | |
| id=_new_model_id, | |
| ), | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() | |
| print("_new_models: ", _new_models) | |
| _new_model_in_db = None | |
| for model in _new_models: | |
| print("current model: ", model) | |
| if model.model_info["id"] == _new_model_id: | |
| print("FOUND MODEL: ", model) | |
| _new_model_in_db = model | |
| assert _new_model_in_db is not None | |
| _original_model = _new_model_in_db | |
| _original_litellm_params = _new_model_in_db.litellm_params | |
| print("_original_litellm_params: ", _original_litellm_params) | |
| print("now updating the tpm for model") | |
| # run update to update "tpm" | |
| await update_model( | |
| model_params=updateDeployment( | |
| litellm_params=updateLiteLLMParams(tpm=123456), | |
| model_info=ModelInfo( | |
| id=_new_model_id, | |
| ), | |
| ), | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| ) | |
| _new_models = await prisma_client.db.litellm_proxymodeltable.find_many() | |
| _new_model_in_db = None | |
| for model in _new_models: | |
| if model.model_info["id"] == _new_model_id: | |
| print("\nFOUND MODEL: ", model) | |
| _new_model_in_db = model | |
| # assert all other litellm params are identical to _original_litellm_params | |
| for key, value in _original_litellm_params.items(): | |
| if key == "tpm": | |
| # assert that tpm actually got updated | |
| assert _new_model_in_db.litellm_params[key] == 123456 | |
| else: | |
| assert _new_model_in_db.litellm_params[key] == value | |
| assert _original_model.model_id == _new_model_in_db.model_id | |
| assert _original_model.model_name == _new_model_in_db.model_name | |
| assert _original_model.model_info == _new_model_in_db.model_info | |
| async def _create_new_team(prisma_client): | |
| new_team_request = NewTeamRequest( | |
| team_alias=f"team_{uuid.uuid4().hex}", | |
| ) | |
| _new_team = await new_team( | |
| data=new_team_request, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| ), | |
| http_request=Request( | |
| scope={"type": "http", "method": "POST", "path": "/new_team"} | |
| ), | |
| ) | |
| return LiteLLM_TeamTable(**_new_team) | |
| async def test_add_team_model_to_db(prisma_client): | |
| """ | |
| Test adding a team model and verifying the team_public_model_name is stored correctly | |
| """ | |
| setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) | |
| setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") | |
| setattr(litellm.proxy.proxy_server, "store_model_in_db", True) | |
| await litellm.proxy.proxy_server.prisma_client.connect() | |
| from litellm.proxy.management_endpoints.model_management_endpoints import ( | |
| _add_team_model_to_db, | |
| ) | |
| import uuid | |
| new_team = await _create_new_team(prisma_client) | |
| team_id = new_team.team_id | |
| public_model_name = "my-gpt4-model" | |
| model_id = f"local-test-{uuid.uuid4().hex}" | |
| # Create test model deployment | |
| model_params = Deployment( | |
| model_name=public_model_name, | |
| litellm_params=LiteLLM_Params( | |
| model="gpt-4", | |
| api_key="test_api_key", | |
| ), | |
| model_info=ModelInfo( | |
| id=model_id, | |
| team_id=team_id, | |
| ), | |
| ) | |
| # Add model to db | |
| model_response = await _add_team_model_to_db( | |
| model_params=model_params, | |
| user_api_key_dict=UserAPIKeyAuth( | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| api_key="sk-1234", | |
| user_id="1234", | |
| team_id=team_id, | |
| ), | |
| prisma_client=prisma_client, | |
| ) | |
| # Verify model was created with correct attributes | |
| assert model_response is not None | |
| assert model_response.model_name.startswith(f"model_name_{team_id}") | |
| # Verify team_public_model_name was stored in model_info | |
| model_info = model_response.model_info | |
| assert model_info["team_public_model_name"] == public_model_name | |
| await asyncio.sleep(1) | |
| # Verify team model alias was created | |
| team = await prisma_client.db.litellm_teamtable.find_first( | |
| where={ | |
| "team_id": team_id, | |
| }, | |
| include={"litellm_model_table": True}, | |
| ) | |
| print("team=", team.model_dump_json()) | |
| assert team is not None | |
| team_model = team.model_id | |
| print("team model id=", team_model) | |
| litellm_model_table = team.litellm_model_table | |
| print("litellm_model_table=", litellm_model_table.model_dump_json()) | |
| model_aliases = litellm_model_table.model_aliases | |
| print("model_aliases=", model_aliases) | |
| assert public_model_name in model_aliases | |
| assert model_aliases[public_model_name] == model_response.model_name | |