Spaces:
Paused
Paused
| import datetime | |
| import json | |
| import os | |
| import sys | |
| import unittest | |
| from typing import List, Optional, Tuple | |
| from unittest.mock import ANY, MagicMock, Mock, patch | |
| import httpx | |
| import pytest | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system-path | |
| import litellm | |
| from litellm.integrations.custom_prompt_management import CustomPromptManagement | |
| from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import StandardCallbackDynamicParams | |
| def setup_anthropic_api_key(monkeypatch): | |
| monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key") | |
| class TestCustomPromptManagement(CustomPromptManagement): | |
| def get_chat_completion_prompt( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| non_default_params: dict, | |
| prompt_id: Optional[str], | |
| prompt_variables: Optional[dict], | |
| dynamic_callback_params: StandardCallbackDynamicParams, | |
| prompt_label: Optional[str], | |
| ) -> Tuple[str, List[AllMessageValues], dict]: | |
| print( | |
| "TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ", | |
| prompt_id, | |
| ) | |
| if prompt_id == "test_prompt_id": | |
| messages = [ | |
| {"role": "user", "content": "This is the prompt for test_prompt_id"}, | |
| ] | |
| return model, messages, non_default_params | |
| elif prompt_id == "prompt_with_variables": | |
| content = "Hello, {name}! You are {age} years old and live in {city}." | |
| content_with_variables = content.format(**(prompt_variables or {})) | |
| messages = [ | |
| {"role": "user", "content": content_with_variables}, | |
| ] | |
| return model, messages, non_default_params | |
| else: | |
| return model, messages, non_default_params | |
| async def test_custom_prompt_management_with_prompt_id(monkeypatch): | |
| custom_prompt_management = TestCustomPromptManagement() | |
| litellm.callbacks = [custom_prompt_management] | |
| # Mock AsyncHTTPHandler.post method | |
| client = AsyncHTTPHandler() | |
| with patch.object(client, "post", return_value=MagicMock()) as mock_post: | |
| await litellm.acompletion( | |
| model="anthropic/claude-3-5-sonnet", | |
| messages=[{"role": "user", "content": "Hello, how are you?"}], | |
| client=client, | |
| prompt_id="test_prompt_id", | |
| ) | |
| mock_post.assert_called_once() | |
| print(mock_post.call_args.kwargs) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("request_body: ", json.dumps(request_body, indent=4)) | |
| assert request_body["model"] == "claude-3-5-sonnet" | |
| # the message gets applied to the prompt from the custom prompt management callback | |
| assert ( | |
| request_body["messages"][0]["content"][0]["text"] | |
| == "This is the prompt for test_prompt_id" | |
| ) | |
| async def test_custom_prompt_management_with_prompt_id_and_prompt_variables(): | |
| custom_prompt_management = TestCustomPromptManagement() | |
| litellm.callbacks = [custom_prompt_management] | |
| # Mock AsyncHTTPHandler.post method | |
| client = AsyncHTTPHandler() | |
| with patch.object(client, "post", return_value=MagicMock()) as mock_post: | |
| await litellm.acompletion( | |
| model="anthropic/claude-3-5-sonnet", | |
| messages=[], | |
| client=client, | |
| prompt_id="prompt_with_variables", | |
| prompt_variables={"name": "John", "age": 30, "city": "New York"}, | |
| ) | |
| mock_post.assert_called_once() | |
| print(mock_post.call_args.kwargs) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("request_body: ", json.dumps(request_body, indent=4)) | |
| assert request_body["model"] == "claude-3-5-sonnet" | |
| # the message gets applied to the prompt from the custom prompt management callback | |
| assert ( | |
| request_body["messages"][0]["content"][0]["text"] | |
| == "Hello, John! You are 30 years old and live in New York." | |
| ) | |
| async def test_custom_prompt_management_without_prompt_id(): | |
| custom_prompt_management = TestCustomPromptManagement() | |
| litellm.callbacks = [custom_prompt_management] | |
| # Mock AsyncHTTPHandler.post method | |
| client = AsyncHTTPHandler() | |
| with patch.object(client, "post", return_value=MagicMock()) as mock_post: | |
| await litellm.acompletion( | |
| model="anthropic/claude-3-5-sonnet", | |
| messages=[{"role": "user", "content": "Hello, how are you?"}], | |
| client=client, | |
| ) | |
| mock_post.assert_called_once() | |
| print(mock_post.call_args.kwargs) | |
| request_body = mock_post.call_args.kwargs["json"] | |
| print("request_body: ", json.dumps(request_body, indent=4)) | |
| assert request_body["model"] == "claude-3-5-sonnet" | |
| # the message does not get applied to the prompt from the custom prompt management callback since we did not pass a prompt_id | |
| assert ( | |
| request_body["messages"][0]["content"][0]["text"] == "Hello, how are you?" | |
| ) | |