Spaces:
Paused
Paused
| import pytest | |
| import requests | |
| from litellm.proxy.client.chat import ChatClient | |
| from litellm.proxy.client.exceptions import UnauthorizedError | |
| def base_url(): | |
| return "http://localhost:8000" | |
| def api_key(): | |
| return "test-api-key" | |
| def client(base_url, api_key): | |
| return ChatClient(base_url=base_url, api_key=api_key) | |
| def sample_messages(): | |
| return [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Name 3 countries"}, | |
| ] | |
| def test_client_initialization(base_url, api_key): | |
| """Test that the ChatClient is properly initialized""" | |
| client = ChatClient(base_url=base_url, api_key=api_key) | |
| assert client._base_url == base_url | |
| assert client._api_key == api_key | |
| def test_client_initialization_strips_trailing_slash(): | |
| """Test that the client properly strips trailing slashes from base_url during initialization""" | |
| base_url = "http://localhost:8000/////" | |
| client = ChatClient(base_url=base_url) | |
| assert client._base_url == "http://localhost:8000" | |
| def test_client_without_api_key(base_url): | |
| """Test that the client works without an API key""" | |
| client = ChatClient(base_url=base_url) | |
| assert client._api_key is None | |
| def test_completions_request_creation(client, base_url, api_key, sample_messages): | |
| """Test that completions creates a request with correct URL, headers, and body""" | |
| request = client.completions( | |
| model="gpt-4", | |
| messages=sample_messages, | |
| temperature=0.7, | |
| max_tokens=100, | |
| return_request=True, | |
| ) | |
| # Check request method and URL | |
| assert request.method == "POST" | |
| assert request.url == f"{base_url}/chat/completions" | |
| # Check headers | |
| assert request.headers["Content-Type"] == "application/json" | |
| assert request.headers["Authorization"] == f"Bearer {api_key}" | |
| # Check request body | |
| assert request.json == { | |
| "model": "gpt-4", | |
| "messages": sample_messages, | |
| "temperature": 0.7, | |
| "max_tokens": 100, | |
| } | |
| def test_completions_minimal_request(client, sample_messages): | |
| """Test that completions works with only required parameters""" | |
| request = client.completions( | |
| model="gpt-4", messages=sample_messages, return_request=True | |
| ) | |
| # Check request body has only required fields | |
| assert request.json == {"model": "gpt-4", "messages": sample_messages} | |
| def test_completions_all_parameters(client, sample_messages): | |
| """Test that completions accepts all optional parameters""" | |
| request = client.completions( | |
| model="gpt-4", | |
| messages=sample_messages, | |
| temperature=0.7, | |
| top_p=0.9, | |
| n=2, | |
| max_tokens=100, | |
| presence_penalty=0.5, | |
| frequency_penalty=-0.5, | |
| user="test-user", | |
| return_request=True, | |
| ) | |
| # Check all parameters are included in request body | |
| assert request.json == { | |
| "model": "gpt-4", | |
| "messages": sample_messages, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "n": 2, | |
| "max_tokens": 100, | |
| "presence_penalty": 0.5, | |
| "frequency_penalty": -0.5, | |
| "user": "test-user", | |
| } | |
| def test_completions_mock_response(client, sample_messages, requests_mock): | |
| """Test completions with a mocked successful response""" | |
| mock_response = { | |
| "id": "chatcmpl-123", | |
| "object": "chat.completion", | |
| "created": 1677858242, | |
| "model": "gpt-4", | |
| "usage": {"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20}, | |
| "choices": [ | |
| { | |
| "message": { | |
| "role": "assistant", | |
| "content": "Hello! How can I help you today?", | |
| }, | |
| "finish_reason": "stop", | |
| "index": 0, | |
| } | |
| ], | |
| } | |
| # Mock the POST request | |
| requests_mock.post(f"{client._base_url}/chat/completions", json=mock_response) | |
| response = client.completions(model="gpt-4", messages=sample_messages) | |
| assert response == mock_response | |
| assert ( | |
| response["choices"][0]["message"]["content"] | |
| == "Hello! How can I help you today?" | |
| ) | |
| def test_completions_unauthorized_error(client, sample_messages, requests_mock): | |
| """Test that completions raises UnauthorizedError for 401 responses""" | |
| # Mock a 401 response | |
| requests_mock.post( | |
| f"{client._base_url}/chat/completions", | |
| status_code=401, | |
| json={"error": "Unauthorized"}, | |
| ) | |
| with pytest.raises(UnauthorizedError): | |
| client.completions(model="gpt-4", messages=sample_messages) | |
| def test_completions_other_errors(client, sample_messages, requests_mock): | |
| """Test that completions raises HTTPError for other error responses""" | |
| # Mock a 500 response | |
| requests_mock.post( | |
| f"{client._base_url}/chat/completions", | |
| status_code=500, | |
| json={"error": "Internal Server Error"}, | |
| ) | |
| with pytest.raises(requests.exceptions.HTTPError) as exc_info: | |
| client.completions(model="gpt-4", messages=sample_messages) | |
| assert exc_info.value.response.status_code == 500 | |