Spaces:
Paused
Paused
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| from unittest.mock import AsyncMock | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system-path | |
| import litellm | |
| from litellm import completion, embedding | |
| import pytest | |
| from unittest.mock import MagicMock, patch | |
| from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler | |
| import pytest_asyncio | |
| from openai import AsyncOpenAI | |
| async def test_litellm_gateway_from_sdk(): | |
| litellm.set_verbose = True | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": "Hello world", | |
| } | |
| ] | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| with patch.object( | |
| openai_client.chat.completions.with_raw_response, "create", new=MagicMock() | |
| ) as mock_call: | |
| try: | |
| completion( | |
| model="litellm_proxy/my-vllm-model", | |
| messages=messages, | |
| response_format={"type": "json_object"}, | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| hello="world", | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_call.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) | |
| assert "hello" in mock_call.call_args.kwargs["extra_body"] | |
| async def test_litellm_gateway_from_sdk_structured_output(): | |
| from pydantic import BaseModel | |
| class Result(BaseModel): | |
| answer: str | |
| litellm.set_verbose = True | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| with patch.object( | |
| openai_client.chat.completions, "create", new=MagicMock() | |
| ) as mock_call: | |
| try: | |
| litellm.completion( | |
| model="litellm_proxy/openai/gpt-4o", | |
| messages=[ | |
| {"role": "user", "content": "What is the capital of France?"} | |
| ], | |
| api_key="my-test-api-key", | |
| user="test", | |
| response_format=Result, | |
| base_url="https://litellm.ml-serving-internal.scale.com", | |
| client=openai_client, | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_call.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_call.call_args.kwargs)) | |
| json_schema = mock_call.call_args.kwargs["response_format"] | |
| assert "json_schema" in json_schema | |
| async def test_litellm_gateway_from_sdk_embedding(is_async): | |
| litellm.set_verbose = True | |
| litellm._turn_on_debug() | |
| if is_async: | |
| from openai import AsyncOpenAI | |
| openai_client = AsyncOpenAI(api_key="fake-key") | |
| mock_method = AsyncMock() | |
| patch_target = openai_client.embeddings.create | |
| else: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| mock_method = MagicMock() | |
| patch_target = openai_client.embeddings.create | |
| with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): | |
| try: | |
| if is_async: | |
| await litellm.aembedding( | |
| model="litellm_proxy/my-vllm-model", | |
| input="Hello world", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| else: | |
| litellm.embedding( | |
| model="litellm_proxy/my-vllm-model", | |
| input="Hello world", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_method.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) | |
| assert "Hello world" == mock_method.call_args.kwargs["input"] | |
| assert "my-vllm-model" == mock_method.call_args.kwargs["model"] | |
| async def test_litellm_gateway_from_sdk_image_generation(is_async): | |
| litellm._turn_on_debug() | |
| if is_async: | |
| from openai import AsyncOpenAI | |
| openai_client = AsyncOpenAI(api_key="fake-key") | |
| mock_method = AsyncMock() | |
| patch_target = openai_client.images.generate | |
| else: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| mock_method = MagicMock() | |
| patch_target = openai_client.images.generate | |
| with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): | |
| try: | |
| if is_async: | |
| response = await litellm.aimage_generation( | |
| model="litellm_proxy/dall-e-3", | |
| prompt="A beautiful sunset over mountains", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| else: | |
| response = litellm.image_generation( | |
| model="litellm_proxy/dall-e-3", | |
| prompt="A beautiful sunset over mountains", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| print("response=", response) | |
| except Exception as e: | |
| print("got error", e) | |
| mock_method.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) | |
| assert ( | |
| "A beautiful sunset over mountains" | |
| == mock_method.call_args.kwargs["prompt"] | |
| ) | |
| assert "dall-e-3" == mock_method.call_args.kwargs["model"] | |
| async def test_litellm_gateway_from_sdk_transcription(is_async): | |
| litellm.set_verbose = True | |
| litellm._turn_on_debug() | |
| if is_async: | |
| from openai import AsyncOpenAI | |
| openai_client = AsyncOpenAI(api_key="fake-key") | |
| mock_method = AsyncMock() | |
| patch_target = openai_client.audio.transcriptions.create | |
| else: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| mock_method = MagicMock() | |
| patch_target = openai_client.audio.transcriptions.create | |
| with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): | |
| try: | |
| if is_async: | |
| await litellm.atranscription( | |
| model="litellm_proxy/whisper-1", | |
| file=b"sample_audio", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| else: | |
| litellm.transcription( | |
| model="litellm_proxy/whisper-1", | |
| file=b"sample_audio", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_method.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) | |
| assert "whisper-1" == mock_method.call_args.kwargs["model"] | |
| async def test_litellm_gateway_from_sdk_speech(is_async): | |
| litellm.set_verbose = True | |
| if is_async: | |
| from openai import AsyncOpenAI | |
| openai_client = AsyncOpenAI(api_key="fake-key") | |
| mock_method = AsyncMock() | |
| patch_target = openai_client.audio.speech.create | |
| else: | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| mock_method = MagicMock() | |
| patch_target = openai_client.audio.speech.create | |
| with patch.object(patch_target.__self__, patch_target.__name__, new=mock_method): | |
| try: | |
| if is_async: | |
| await litellm.aspeech( | |
| model="litellm_proxy/tts-1", | |
| input="Hello, this is a test of text to speech", | |
| voice="alloy", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| else: | |
| litellm.speech( | |
| model="litellm_proxy/tts-1", | |
| input="Hello, this is a test of text to speech", | |
| voice="alloy", | |
| client=openai_client, | |
| api_base="my-custom-api-base", | |
| ) | |
| except Exception as e: | |
| print(e) | |
| mock_method.assert_called_once() | |
| print("Call KWARGS - {}".format(mock_method.call_args.kwargs)) | |
| assert ( | |
| "Hello, this is a test of text to speech" | |
| == mock_method.call_args.kwargs["input"] | |
| ) | |
| assert "tts-1" == mock_method.call_args.kwargs["model"] | |
| assert "alloy" == mock_method.call_args.kwargs["voice"] | |
| async def test_litellm_gateway_from_sdk_rerank(is_async): | |
| litellm.set_verbose = True | |
| litellm._turn_on_debug() | |
| if is_async: | |
| client = AsyncHTTPHandler() | |
| mock_method = AsyncMock() | |
| patch_target = client.post | |
| else: | |
| client = HTTPHandler() | |
| mock_method = MagicMock() | |
| patch_target = client.post | |
| with patch.object(client, "post", new=mock_method): | |
| mock_response = MagicMock() | |
| # Create a mock response similar to OpenAI's rerank response | |
| mock_response.text = json.dumps( | |
| { | |
| "id": "rerank-123456", | |
| "object": "reranking", | |
| "results": [ | |
| { | |
| "index": 0, | |
| "relevance_score": 0.9, | |
| "document": { | |
| "id": "0", | |
| "text": "Machine learning is a field of study in artificial intelligence", | |
| }, | |
| }, | |
| { | |
| "index": 1, | |
| "relevance_score": 0.2, | |
| "document": { | |
| "id": "1", | |
| "text": "Biology is the study of living organisms", | |
| }, | |
| }, | |
| ], | |
| "model": "rerank-english-v2.0", | |
| "usage": {"prompt_tokens": 10, "total_tokens": 10}, | |
| } | |
| ) | |
| mock_response.status_code = 200 | |
| mock_response.headers = {"Content-Type": "application/json"} | |
| mock_response.json = lambda: json.loads(mock_response.text) | |
| if is_async: | |
| mock_method.return_value = mock_response | |
| else: | |
| mock_method.return_value = mock_response | |
| try: | |
| if is_async: | |
| response = await litellm.arerank( | |
| model="litellm_proxy/rerank-english-v2.0", | |
| query="What is machine learning?", | |
| documents=[ | |
| "Machine learning is a field of study in artificial intelligence", | |
| "Biology is the study of living organisms", | |
| ], | |
| client=client, | |
| api_base="my-custom-api-base", | |
| ) | |
| else: | |
| response = litellm.rerank( | |
| model="litellm_proxy/rerank-english-v2.0", | |
| query="What is machine learning?", | |
| documents=[ | |
| "Machine learning is a field of study in artificial intelligence", | |
| "Biology is the study of living organisms", | |
| ], | |
| client=client, | |
| api_base="my-custom-api-base", | |
| ) | |
| except Exception as e: | |
| print(e) | |
| # Verify the request | |
| mock_method.assert_called_once() | |
| call_args = mock_method.call_args | |
| print("call_args=", call_args) | |
| # Check that the URL is correct | |
| assert "my-custom-api-base/v1/rerank" == call_args.kwargs["url"] | |
| # Check that the request body contains the expected data | |
| request_body = json.loads(call_args.kwargs["data"]) | |
| assert request_body["query"] == "What is machine learning?" | |
| assert request_body["model"] == "rerank-english-v2.0" | |
| assert len(request_body["documents"]) == 2 | |
| def test_litellm_gateway_from_sdk_with_response_cost_in_additional_headers(): | |
| litellm.set_verbose = True | |
| litellm._turn_on_debug() | |
| from openai import OpenAI | |
| openai_client = OpenAI(api_key="fake-key") | |
| # Create mock response object | |
| mock_response = MagicMock() | |
| mock_response.headers = {"x-litellm-response-cost": "120"} | |
| mock_response.parse.return_value = litellm.ModelResponse( | |
| **{ | |
| "id": "chatcmpl-BEkxQvRGp9VAushfAsOZCbhMFLsoy", | |
| "choices": [ | |
| { | |
| "finish_reason": "stop", | |
| "index": 0, | |
| "logprobs": None, | |
| "message": { | |
| "content": "Hello! How can I assist you today?", | |
| "refusal": None, | |
| "role": "assistant", | |
| "annotations": [], | |
| "audio": None, | |
| "function_call": None, | |
| "tool_calls": None, | |
| }, | |
| } | |
| ], | |
| "created": 1742856796, | |
| "model": "gpt-4o-2024-08-06", | |
| "object": "chat.completion", | |
| "service_tier": "default", | |
| "system_fingerprint": "fp_6ec83003ad", | |
| "usage": { | |
| "completion_tokens": 10, | |
| "prompt_tokens": 9, | |
| "total_tokens": 19, | |
| "completion_tokens_details": { | |
| "accepted_prediction_tokens": 0, | |
| "audio_tokens": 0, | |
| "reasoning_tokens": 0, | |
| "rejected_prediction_tokens": 0, | |
| }, | |
| "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, | |
| }, | |
| } | |
| ) | |
| with patch.object( | |
| openai_client.chat.completions.with_raw_response, | |
| "create", | |
| return_value=mock_response, | |
| ) as mock_call: | |
| response = litellm.completion( | |
| model="litellm_proxy/gpt-4o", | |
| messages=[{"role": "user", "content": "Hello world"}], | |
| api_base="http://0.0.0.0:4000", | |
| api_key="sk-PIp1h0RekR", | |
| client=openai_client, | |
| ) | |
| # Assert the headers were properly passed through | |
| print(f"additional_headers: {response._hidden_params['additional_headers']}") | |
| assert ( | |
| response._hidden_params["additional_headers"][ | |
| "llm_provider-x-litellm-response-cost" | |
| ] | |
| == "120" | |
| ) | |
| assert response._hidden_params["response_cost"] == 120 | |
| def test_litellm_gateway_from_sdk_with_thinking_param(): | |
| try: | |
| response = litellm.completion( | |
| model="litellm_proxy/anthropic.claude-3-7-sonnet-20250219-v1:0", | |
| messages=[{"role": "user", "content": "Hello world"}], | |
| api_base="http://0.0.0.0:4000", | |
| api_key="sk-PIp1h0RekR", | |
| # client=openai_client, | |
| thinking={"type": "enabled", "max_budget": 100}, | |
| ) | |
| pytest.fail("Expected an error to be raised") | |
| except Exception as e: | |
| assert "Connection error." in str(e) | |