Spaces:
Paused
Paused
| import sys | |
| import os | |
| import json | |
| import traceback | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| from fastapi import Request | |
| from datetime import datetime | |
| from unittest.mock import AsyncMock, patch, MagicMock | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| from litellm import Router, CustomLogger | |
| from litellm.types.utils import StandardLoggingPayload | |
| # Get the current directory of the file being run | |
| pwd = os.path.dirname(os.path.realpath(__file__)) | |
| print(pwd) | |
| file_path = os.path.join(pwd, "gettysburg.wav") | |
| audio_file = open(file_path, "rb") | |
| from pathlib import Path | |
| import litellm | |
| import pytest | |
| import asyncio | |
| def model_list(): | |
| return [ | |
| { | |
| "model_name": "gpt-3.5-turbo", | |
| "litellm_params": { | |
| "model": "gpt-3.5-turbo", | |
| "api_key": os.getenv("OPENAI_API_KEY"), | |
| }, | |
| }, | |
| { | |
| "model_name": "gpt-4o", | |
| "litellm_params": { | |
| "model": "gpt-4o", | |
| "api_key": os.getenv("OPENAI_API_KEY"), | |
| }, | |
| }, | |
| { | |
| "model_name": "dall-e-3", | |
| "litellm_params": { | |
| "model": "dall-e-3", | |
| "api_key": os.getenv("OPENAI_API_KEY"), | |
| }, | |
| }, | |
| { | |
| "model_name": "cohere-rerank", | |
| "litellm_params": { | |
| "model": "cohere/rerank-english-v3.0", | |
| "api_key": os.getenv("COHERE_API_KEY"), | |
| }, | |
| }, | |
| { | |
| "model_name": "claude-3-5-sonnet-20240620", | |
| "litellm_params": { | |
| "model": "gpt-3.5-turbo", | |
| "mock_response": "hi this is macintosh.", | |
| }, | |
| }, | |
| ] | |
| # This file includes the custom callbacks for LiteLLM Proxy | |
| # Once defined, these can be passed in proxy_config.yaml | |
| class MyCustomHandler(CustomLogger): | |
| def __init__(self): | |
| self.openai_client = None | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| try: | |
| # init logging config | |
| print("logging a transcript kwargs: ", kwargs) | |
| print("openai client=", kwargs.get("client")) | |
| self.openai_client = kwargs.get("client") | |
| self.standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( | |
| "standard_logging_object" | |
| ) | |
| except Exception: | |
| pass | |
| # Set litellm.callbacks = [proxy_handler_instance] on the proxy | |
| # need to set litellm.callbacks = [proxy_handler_instance] # on the proxy | |
| async def test_transcription_on_router(): | |
| proxy_handler_instance = MyCustomHandler() | |
| litellm.set_verbose = True | |
| litellm.callbacks = [proxy_handler_instance] | |
| print("\n Testing async transcription on router\n") | |
| try: | |
| model_list = [ | |
| { | |
| "model_name": "whisper", | |
| "litellm_params": { | |
| "model": "whisper-1", | |
| }, | |
| }, | |
| { | |
| "model_name": "whisper", | |
| "litellm_params": { | |
| "model": "azure/azure-whisper", | |
| "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/", | |
| "api_key": os.getenv("AZURE_EUROPE_API_KEY"), | |
| "api_version": "2024-02-15-preview", | |
| }, | |
| }, | |
| ] | |
| router = Router(model_list=model_list) | |
| router_level_clients = [] | |
| for deployment in router.model_list: | |
| _deployment_openai_client = router._get_client( | |
| deployment=deployment, | |
| kwargs={"model": "whisper-1"}, | |
| client_type="async", | |
| ) | |
| router_level_clients.append(str(_deployment_openai_client)) | |
| ## test 1: user facing function | |
| response = await router.atranscription( | |
| model="whisper", | |
| file=audio_file, | |
| ) | |
| ## test 2: underlying function | |
| response = await router._atranscription( | |
| model="whisper", | |
| file=audio_file, | |
| ) | |
| print(response) | |
| # PROD Test | |
| # Ensure we ONLY use OpenAI/Azure client initialized on the router level | |
| await asyncio.sleep(5) | |
| print("OpenAI Client used= ", proxy_handler_instance.openai_client) | |
| print("all router level clients= ", router_level_clients) | |
| assert proxy_handler_instance.openai_client in router_level_clients | |
| except Exception as e: | |
| traceback.print_exc() | |
| pytest.fail(f"Error occurred: {e}") | |
| # "file", | |
| async def test_audio_speech_router(mode): | |
| litellm.set_verbose = True | |
| test_logger = MyCustomHandler() | |
| litellm.callbacks = [test_logger] | |
| from litellm import Router | |
| client = Router( | |
| model_list=[ | |
| { | |
| "model_name": "tts", | |
| "litellm_params": { | |
| "model": "openai/tts-1", | |
| }, | |
| }, | |
| ] | |
| ) | |
| response = await client.aspeech( | |
| model="tts", | |
| voice="alloy", | |
| input="the quick brown fox jumped over the lazy dogs", | |
| api_base=None, | |
| api_key=None, | |
| organization=None, | |
| project=None, | |
| max_retries=1, | |
| timeout=600, | |
| client=None, | |
| optional_params={}, | |
| ) | |
| await asyncio.sleep(3) | |
| from litellm.llms.openai.openai import HttpxBinaryResponseContent | |
| assert isinstance(response, HttpxBinaryResponseContent) | |
| assert test_logger.standard_logging_object is not None | |
| print( | |
| "standard_logging_object=", | |
| json.dumps(test_logger.standard_logging_object, indent=4), | |
| ) | |
| assert test_logger.standard_logging_object["model_group"] == "tts" | |
| async def test_rerank_endpoint(model_list): | |
| from litellm.types.utils import RerankResponse | |
| router = Router(model_list=model_list) | |
| ## Test 1: user facing function | |
| response = await router.arerank( | |
| model="cohere-rerank", | |
| query="hello", | |
| documents=["hello", "world"], | |
| top_n=3, | |
| ) | |
| ## Test 2: underlying function | |
| response = await router._arerank( | |
| model="cohere-rerank", | |
| query="hello", | |
| documents=["hello", "world"], | |
| top_n=3, | |
| ) | |
| print("async re rank response: ", response) | |
| assert response.id is not None | |
| assert response.results is not None | |
| RerankResponse.model_validate(response) | |
| async def test_moderation_endpoint(model): | |
| litellm.set_verbose = True | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "openai/*", | |
| "litellm_params": { | |
| "model": "openai/*", | |
| }, | |
| }, | |
| { | |
| "model_name": "*", | |
| "litellm_params": { | |
| "model": "openai/*", | |
| }, | |
| }, | |
| ] | |
| ) | |
| if model is None: | |
| response = await router.amoderation(input="hello this is a test") | |
| else: | |
| response = await router.amoderation(model=model, input="hello this is a test") | |
| print("moderation response: ", response) | |
| async def test_aaaaatext_completion_endpoint(model_list, sync_mode): | |
| router = Router(model_list=model_list) | |
| if sync_mode: | |
| response = router.text_completion( | |
| model="gpt-3.5-turbo", | |
| prompt="Hello, how are you?", | |
| mock_response="I'm fine, thank you!", | |
| ) | |
| else: | |
| ## Test 1: user facing function | |
| response = await router.atext_completion( | |
| model="gpt-3.5-turbo", | |
| prompt="Hello, how are you?", | |
| mock_response="I'm fine, thank you!", | |
| ) | |
| ## Test 2: underlying function | |
| response_2 = await router._atext_completion( | |
| model="gpt-3.5-turbo", | |
| prompt="Hello, how are you?", | |
| mock_response="I'm fine, thank you!", | |
| ) | |
| assert response_2.choices[0].text == "I'm fine, thank you!" | |
| assert response.choices[0].text == "I'm fine, thank you!" | |
| async def test_router_with_empty_choices(model_list): | |
| """ | |
| https://github.com/BerriAI/litellm/issues/8306 | |
| """ | |
| router = Router(model_list=model_list) | |
| mock_response = litellm.ModelResponse( | |
| choices=[], | |
| usage=litellm.Usage( | |
| prompt_tokens=10, | |
| completion_tokens=10, | |
| total_tokens=20, | |
| ), | |
| model="gpt-3.5-turbo", | |
| object="chat.completion", | |
| created=1723081200, | |
| ).model_dump() | |
| response = await router.acompletion( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": "Hello, how are you?"}], | |
| mock_response=mock_response, | |
| ) | |
| assert response is not None | |
| def test_generic_api_call_with_fallbacks_basic(sync_mode): | |
| """ | |
| Test both the sync and async versions of generic_api_call_with_fallbacks with a basic successful call | |
| """ | |
| # Create a mock function that will be passed to generic_api_call_with_fallbacks | |
| if sync_mode: | |
| from unittest.mock import Mock | |
| mock_function = Mock() | |
| mock_function.__name__ = "test_function" | |
| else: | |
| mock_function = AsyncMock() | |
| mock_function.__name__ = "test_function" | |
| # Create a mock response | |
| mock_response = { | |
| "id": "resp_123456", | |
| "role": "assistant", | |
| "content": "This is a test response", | |
| "model": "test-model", | |
| "usage": {"input_tokens": 10, "output_tokens": 20}, | |
| } | |
| mock_function.return_value = mock_response | |
| # Create a router with a test model | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "test-model-alias", | |
| "litellm_params": { | |
| "model": "anthropic/test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| } | |
| ] | |
| ) | |
| # Call the appropriate generic_api_call_with_fallbacks method | |
| if sync_mode: | |
| response = router._generic_api_call_with_fallbacks( | |
| model="test-model-alias", | |
| original_function=mock_function, | |
| messages=[{"role": "user", "content": "Hello"}], | |
| max_tokens=100, | |
| ) | |
| else: | |
| response = asyncio.run( | |
| router._ageneric_api_call_with_fallbacks( | |
| model="test-model-alias", | |
| original_function=mock_function, | |
| messages=[{"role": "user", "content": "Hello"}], | |
| max_tokens=100, | |
| ) | |
| ) | |
| # Verify the mock function was called | |
| mock_function.assert_called_once() | |
| # Verify the response | |
| assert response == mock_response | |
| async def test_aadapter_completion(): | |
| """ | |
| Test the aadapter_completion method which uses async_function_with_fallbacks | |
| """ | |
| # Create a mock for the _aadapter_completion method | |
| mock_response = { | |
| "id": "adapter_resp_123", | |
| "object": "adapter.completion", | |
| "created": 1677858242, | |
| "model": "test-model-with-adapter", | |
| "choices": [ | |
| { | |
| "text": "This is a test adapter response", | |
| "index": 0, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, | |
| } | |
| # Create a router with a patched _aadapter_completion method | |
| with patch.object( | |
| Router, "_aadapter_completion", new_callable=AsyncMock | |
| ) as mock_method: | |
| mock_method.return_value = mock_response | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "test-adapter-model", | |
| "litellm_params": { | |
| "model": "anthropic/test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| } | |
| ] | |
| ) | |
| # Replace the async_function_with_fallbacks with a mock | |
| router.async_function_with_fallbacks = AsyncMock(return_value=mock_response) | |
| # Call the aadapter_completion method | |
| response = await router.aadapter_completion( | |
| adapter_id="test-adapter-id", | |
| model="test-adapter-model", | |
| prompt="This is a test prompt", | |
| max_tokens=100, | |
| ) | |
| # Verify the response | |
| assert response == mock_response | |
| # Verify async_function_with_fallbacks was called with the right parameters | |
| router.async_function_with_fallbacks.assert_called_once() | |
| call_kwargs = router.async_function_with_fallbacks.call_args.kwargs | |
| assert call_kwargs["adapter_id"] == "test-adapter-id" | |
| assert call_kwargs["model"] == "test-adapter-model" | |
| assert call_kwargs["prompt"] == "This is a test prompt" | |
| assert call_kwargs["max_tokens"] == 100 | |
| assert call_kwargs["original_function"] == router._aadapter_completion | |
| assert "metadata" in call_kwargs | |
| assert call_kwargs["metadata"]["model_group"] == "test-adapter-model" | |
| async def test__aadapter_completion(): | |
| """ | |
| Test the _aadapter_completion method directly | |
| """ | |
| # Create a mock response for litellm.aadapter_completion | |
| mock_response = { | |
| "id": "adapter_resp_123", | |
| "object": "adapter.completion", | |
| "created": 1677858242, | |
| "model": "test-model-with-adapter", | |
| "choices": [ | |
| { | |
| "text": "This is a test adapter response", | |
| "index": 0, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, | |
| } | |
| # Create a router with a mocked litellm.aadapter_completion | |
| with patch( | |
| "litellm.aadapter_completion", new_callable=AsyncMock | |
| ) as mock_adapter_completion: | |
| mock_adapter_completion.return_value = mock_response | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "test-adapter-model", | |
| "litellm_params": { | |
| "model": "anthropic/test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| } | |
| ] | |
| ) | |
| # Mock the async_get_available_deployment method | |
| router.async_get_available_deployment = AsyncMock( | |
| return_value={ | |
| "model_name": "test-adapter-model", | |
| "litellm_params": { | |
| "model": "test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| "model_info": { | |
| "id": "test-unique-id", | |
| }, | |
| } | |
| ) | |
| # Mock the async_routing_strategy_pre_call_checks method | |
| router.async_routing_strategy_pre_call_checks = AsyncMock() | |
| # Call the _aadapter_completion method | |
| response = await router._aadapter_completion( | |
| adapter_id="test-adapter-id", | |
| model="test-adapter-model", | |
| prompt="This is a test prompt", | |
| max_tokens=100, | |
| ) | |
| # Verify the response | |
| assert response == mock_response | |
| # Verify litellm.aadapter_completion was called with the right parameters | |
| mock_adapter_completion.assert_called_once() | |
| call_kwargs = mock_adapter_completion.call_args.kwargs | |
| assert call_kwargs["adapter_id"] == "test-adapter-id" | |
| assert call_kwargs["model"] == "test-model" | |
| assert call_kwargs["prompt"] == "This is a test prompt" | |
| assert call_kwargs["max_tokens"] == 100 | |
| assert call_kwargs["api_key"] == "fake-api-key" | |
| assert call_kwargs["caching"] == router.cache_responses | |
| # Verify the success call was recorded | |
| assert router.success_calls["test-model"] == 1 | |
| assert router.total_calls["test-model"] == 1 | |
| # Verify async_routing_strategy_pre_call_checks was called | |
| router.async_routing_strategy_pre_call_checks.assert_called_once() | |
| def test_initialize_router_endpoints(): | |
| """ | |
| Test that initialize_router_endpoints correctly sets up all router endpoints | |
| """ | |
| # Create a router with a basic model | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "test-model", | |
| "litellm_params": { | |
| "model": "anthropic/test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| } | |
| ] | |
| ) | |
| # Explicitly call initialize_router_endpoints | |
| router.initialize_router_endpoints() | |
| # Verify all expected endpoints are initialized | |
| assert hasattr(router, "amoderation") | |
| assert hasattr(router, "aanthropic_messages") | |
| assert hasattr(router, "aresponses") | |
| assert hasattr(router, "responses") | |
| assert hasattr(router, "aget_responses") | |
| assert hasattr(router, "adelete_responses") | |
| # Verify the endpoints are callable | |
| assert callable(router.amoderation) | |
| assert callable(router.aanthropic_messages) | |
| assert callable(router.aresponses) | |
| assert callable(router.responses) | |
| assert callable(router.aget_responses) | |
| assert callable(router.adelete_responses) | |
| async def test_init_responses_api_endpoints(): | |
| """ | |
| A simpler test for _init_responses_api_endpoints that focuses on the basic functionality | |
| """ | |
| from litellm.responses.utils import ResponsesAPIRequestUtils | |
| # Create a router with a basic model | |
| router = Router( | |
| model_list=[ | |
| { | |
| "model_name": "test-model", | |
| "litellm_params": { | |
| "model": "openai/test-model", | |
| "api_key": "fake-api-key", | |
| }, | |
| } | |
| ] | |
| ) | |
| # Just mock the _ageneric_api_call_with_fallbacks method | |
| router._ageneric_api_call_with_fallbacks = AsyncMock() | |
| # Add a mock implementation of _get_model_id_from_response_id to the Router instance | |
| ResponsesAPIRequestUtils.get_model_id_from_response_id = MagicMock(return_value=None) | |
| # Call without a response_id (no model extraction should happen) | |
| await router._init_responses_api_endpoints( | |
| original_function=AsyncMock(), | |
| thread_id="thread_xyz" | |
| ) | |
| # Verify _ageneric_api_call_with_fallbacks was called but model wasn't changed | |
| first_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs | |
| assert "model" not in first_call_kwargs | |
| assert first_call_kwargs["thread_id"] == "thread_xyz" | |
| # Reset the mock | |
| router._ageneric_api_call_with_fallbacks.reset_mock() | |
| # Change the return value for the second call | |
| ResponsesAPIRequestUtils.get_model_id_from_response_id.return_value = "claude-3-sonnet" | |
| # Call with a response_id | |
| await router._init_responses_api_endpoints( | |
| original_function=AsyncMock(), | |
| response_id="resp_claude_123" | |
| ) | |
| # Verify model was updated in the kwargs | |
| second_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs | |
| assert second_call_kwargs["model"] == "claude-3-sonnet" | |
| assert second_call_kwargs["response_id"] == "resp_claude_123" | |