Spaces:
Paused
Paused
| # +-----------------------------------------------+ | |
| # | | | |
| # | Give Feedback / Get Help | | |
| # | https://github.com/BerriAI/litellm/issues/new | | |
| # | | | |
| # +-----------------------------------------------+ | |
| # | |
| # Thank you users! We ❤️ you! - Krrish & Ishaan | |
| import litellm | |
| import time, logging | |
| import json, traceback, ast, hashlib | |
| from typing import Optional, Literal, List, Union, Any | |
| from openai._models import BaseModel as OpenAIObject | |
| def print_verbose(print_statement): | |
| try: | |
| if litellm.set_verbose: | |
| print(print_statement) # noqa | |
| except: | |
| pass | |
| class BaseCache: | |
| def set_cache(self, key, value, **kwargs): | |
| raise NotImplementedError | |
| def get_cache(self, key, **kwargs): | |
| raise NotImplementedError | |
| class InMemoryCache(BaseCache): | |
| def __init__(self): | |
| # if users don't provider one, use the default litellm cache | |
| self.cache_dict = {} | |
| self.ttl_dict = {} | |
| def set_cache(self, key, value, **kwargs): | |
| self.cache_dict[key] = value | |
| if "ttl" in kwargs: | |
| self.ttl_dict[key] = time.time() + kwargs["ttl"] | |
| def get_cache(self, key, **kwargs): | |
| if key in self.cache_dict: | |
| if key in self.ttl_dict: | |
| if time.time() > self.ttl_dict[key]: | |
| self.cache_dict.pop(key, None) | |
| return None | |
| original_cached_response = self.cache_dict[key] | |
| try: | |
| cached_response = json.loads(original_cached_response) | |
| except: | |
| cached_response = original_cached_response | |
| return cached_response | |
| return None | |
| def flush_cache(self): | |
| self.cache_dict.clear() | |
| self.ttl_dict.clear() | |
| class RedisCache(BaseCache): | |
| def __init__(self, host=None, port=None, password=None, **kwargs): | |
| import redis | |
| # if users don't provider one, use the default litellm cache | |
| from ._redis import get_redis_client | |
| redis_kwargs = {} | |
| if host is not None: | |
| redis_kwargs["host"] = host | |
| if port is not None: | |
| redis_kwargs["port"] = port | |
| if password is not None: | |
| redis_kwargs["password"] = password | |
| redis_kwargs.update(kwargs) | |
| self.redis_client = get_redis_client(**redis_kwargs) | |
| def set_cache(self, key, value, **kwargs): | |
| ttl = kwargs.get("ttl", None) | |
| print_verbose(f"Set Redis Cache: key: {key}\nValue {value}") | |
| try: | |
| self.redis_client.set(name=key, value=str(value), ex=ttl) | |
| except Exception as e: | |
| # NON blocking - notify users Redis is throwing an exception | |
| logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) | |
| def get_cache(self, key, **kwargs): | |
| try: | |
| print_verbose(f"Get Redis Cache: key: {key}") | |
| cached_response = self.redis_client.get(key) | |
| print_verbose( | |
| f"Got Redis Cache: key: {key}, cached_response {cached_response}" | |
| ) | |
| if cached_response != None: | |
| # cached_response is in `b{} convert it to ModelResponse | |
| cached_response = cached_response.decode( | |
| "utf-8" | |
| ) # Convert bytes to string | |
| try: | |
| cached_response = json.loads( | |
| cached_response | |
| ) # Convert string to dictionary | |
| except: | |
| cached_response = ast.literal_eval(cached_response) | |
| return cached_response | |
| except Exception as e: | |
| # NON blocking - notify users Redis is throwing an exception | |
| traceback.print_exc() | |
| logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e) | |
| def flush_cache(self): | |
| self.redis_client.flushall() | |
| class S3Cache(BaseCache): | |
| def __init__( | |
| self, | |
| s3_bucket_name, | |
| s3_region_name=None, | |
| s3_api_version=None, | |
| s3_use_ssl=True, | |
| s3_verify=None, | |
| s3_endpoint_url=None, | |
| s3_aws_access_key_id=None, | |
| s3_aws_secret_access_key=None, | |
| s3_aws_session_token=None, | |
| s3_config=None, | |
| **kwargs, | |
| ): | |
| import boto3 | |
| self.bucket_name = s3_bucket_name | |
| # Create an S3 client with custom endpoint URL | |
| self.s3_client = boto3.client( | |
| "s3", | |
| region_name=s3_region_name, | |
| endpoint_url=s3_endpoint_url, | |
| api_version=s3_api_version, | |
| use_ssl=s3_use_ssl, | |
| verify=s3_verify, | |
| aws_access_key_id=s3_aws_access_key_id, | |
| aws_secret_access_key=s3_aws_secret_access_key, | |
| aws_session_token=s3_aws_session_token, | |
| config=s3_config, | |
| **kwargs, | |
| ) | |
| def set_cache(self, key, value, **kwargs): | |
| try: | |
| print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}") | |
| ttl = kwargs.get("ttl", None) | |
| # Convert value to JSON before storing in S3 | |
| serialized_value = json.dumps(value) | |
| if ttl is not None: | |
| cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}" | |
| import datetime | |
| # Calculate expiration time | |
| expiration_time = datetime.datetime.now() + ttl | |
| # Upload the data to S3 with the calculated expiration time | |
| self.s3_client.put_object( | |
| Bucket=self.bucket_name, | |
| Key=key, | |
| Body=serialized_value, | |
| Expires=expiration_time, | |
| CacheControl=cache_control, | |
| ContentType="application/json", | |
| ContentLanguage="en", | |
| ContentDisposition=f"inline; filename=\"{key}.json\"" | |
| ) | |
| else: | |
| cache_control = "immutable, max-age=31536000, s-maxage=31536000" | |
| # Upload the data to S3 without specifying Expires | |
| self.s3_client.put_object( | |
| Bucket=self.bucket_name, | |
| Key=key, | |
| Body=serialized_value, | |
| CacheControl=cache_control, | |
| ContentType="application/json", | |
| ContentLanguage="en", | |
| ContentDisposition=f"inline; filename=\"{key}.json\"" | |
| ) | |
| except Exception as e: | |
| # NON blocking - notify users S3 is throwing an exception | |
| print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}") | |
| def get_cache(self, key, **kwargs): | |
| import boto3, botocore | |
| try: | |
| print_verbose(f"Get S3 Cache: key: {key}") | |
| # Download the data from S3 | |
| cached_response = self.s3_client.get_object( | |
| Bucket=self.bucket_name, Key=key | |
| ) | |
| if cached_response != None: | |
| # cached_response is in `b{} convert it to ModelResponse | |
| cached_response = ( | |
| cached_response["Body"].read().decode("utf-8") | |
| ) # Convert bytes to string | |
| try: | |
| cached_response = json.loads( | |
| cached_response | |
| ) # Convert string to dictionary | |
| except Exception as e: | |
| cached_response = ast.literal_eval(cached_response) | |
| if type(cached_response) is not dict: | |
| cached_response = dict(cached_response) | |
| print_verbose( | |
| f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}" | |
| ) | |
| return cached_response | |
| except botocore.exceptions.ClientError as e: | |
| if e.response["Error"]["Code"] == "NoSuchKey": | |
| print_verbose( | |
| f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket." | |
| ) | |
| return None | |
| except Exception as e: | |
| # NON blocking - notify users S3 is throwing an exception | |
| traceback.print_exc() | |
| print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}") | |
| def flush_cache(self): | |
| pass | |
| class DualCache(BaseCache): | |
| """ | |
| This updates both Redis and an in-memory cache simultaneously. | |
| When data is updated or inserted, it is written to both the in-memory cache + Redis. | |
| This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data. | |
| """ | |
| def __init__( | |
| self, | |
| in_memory_cache: Optional[InMemoryCache] = None, | |
| redis_cache: Optional[RedisCache] = None, | |
| ) -> None: | |
| super().__init__() | |
| # If in_memory_cache is not provided, use the default InMemoryCache | |
| self.in_memory_cache = in_memory_cache or InMemoryCache() | |
| # If redis_cache is not provided, use the default RedisCache | |
| self.redis_cache = redis_cache | |
| def set_cache(self, key, value, local_only: bool = False, **kwargs): | |
| # Update both Redis and in-memory cache | |
| try: | |
| print_verbose(f"set cache: key: {key}; value: {value}") | |
| if self.in_memory_cache is not None: | |
| self.in_memory_cache.set_cache(key, value, **kwargs) | |
| if self.redis_cache is not None and local_only == False: | |
| self.redis_cache.set_cache(key, value, **kwargs) | |
| except Exception as e: | |
| print_verbose(e) | |
| def get_cache(self, key, local_only: bool = False, **kwargs): | |
| # Try to fetch from in-memory cache first | |
| try: | |
| print_verbose(f"get cache: cache key: {key}; local_only: {local_only}") | |
| result = None | |
| if self.in_memory_cache is not None: | |
| in_memory_result = self.in_memory_cache.get_cache(key, **kwargs) | |
| print_verbose(f"in_memory_result: {in_memory_result}") | |
| if in_memory_result is not None: | |
| result = in_memory_result | |
| if result is None and self.redis_cache is not None and local_only == False: | |
| # If not found in in-memory cache, try fetching from Redis | |
| redis_result = self.redis_cache.get_cache(key, **kwargs) | |
| if redis_result is not None: | |
| # Update in-memory cache with the value from Redis | |
| self.in_memory_cache.set_cache(key, redis_result, **kwargs) | |
| result = redis_result | |
| print_verbose(f"get cache: cache result: {result}") | |
| return result | |
| except Exception as e: | |
| traceback.print_exc() | |
| def flush_cache(self): | |
| if self.in_memory_cache is not None: | |
| self.in_memory_cache.flush_cache() | |
| if self.redis_cache is not None: | |
| self.redis_cache.flush_cache() | |
| #### LiteLLM.Completion / Embedding Cache #### | |
| class Cache: | |
| def __init__( | |
| self, | |
| type: Optional[Literal["local", "redis", "s3"]] = "local", | |
| host: Optional[str] = None, | |
| port: Optional[str] = None, | |
| password: Optional[str] = None, | |
| supported_call_types: Optional[ | |
| List[Literal["completion", "acompletion", "embedding", "aembedding"]] | |
| ] = ["completion", "acompletion", "embedding", "aembedding"], | |
| # s3 Bucket, boto3 configuration | |
| s3_bucket_name: Optional[str] = None, | |
| s3_region_name: Optional[str] = None, | |
| s3_api_version: Optional[str] = None, | |
| s3_use_ssl: Optional[bool] = True, | |
| s3_verify: Optional[Union[bool, str]] = None, | |
| s3_endpoint_url: Optional[str] = None, | |
| s3_aws_access_key_id: Optional[str] = None, | |
| s3_aws_secret_access_key: Optional[str] = None, | |
| s3_aws_session_token: Optional[str] = None, | |
| s3_config: Optional[Any] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Initializes the cache based on the given type. | |
| Args: | |
| type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local". | |
| host (str, optional): The host address for the Redis cache. Required if type is "redis". | |
| port (int, optional): The port number for the Redis cache. Required if type is "redis". | |
| password (str, optional): The password for the Redis cache. Required if type is "redis". | |
| supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. | |
| **kwargs: Additional keyword arguments for redis.Redis() cache | |
| Raises: | |
| ValueError: If an invalid cache type is provided. | |
| Returns: | |
| None. Cache is set as a litellm param | |
| """ | |
| if type == "redis": | |
| self.cache: BaseCache = RedisCache(host, port, password, **kwargs) | |
| if type == "local": | |
| self.cache = InMemoryCache() | |
| if type == "s3": | |
| self.cache = S3Cache( | |
| s3_bucket_name=s3_bucket_name, | |
| s3_region_name=s3_region_name, | |
| s3_api_version=s3_api_version, | |
| s3_use_ssl=s3_use_ssl, | |
| s3_verify=s3_verify, | |
| s3_endpoint_url=s3_endpoint_url, | |
| s3_aws_access_key_id=s3_aws_access_key_id, | |
| s3_aws_secret_access_key=s3_aws_secret_access_key, | |
| s3_aws_session_token=s3_aws_session_token, | |
| s3_config=s3_config, | |
| **kwargs, | |
| ) | |
| if "cache" not in litellm.input_callback: | |
| litellm.input_callback.append("cache") | |
| if "cache" not in litellm.success_callback: | |
| litellm.success_callback.append("cache") | |
| if "cache" not in litellm._async_success_callback: | |
| litellm._async_success_callback.append("cache") | |
| self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] | |
| self.type = type | |
| def get_cache_key(self, *args, **kwargs): | |
| """ | |
| Get the cache key for the given arguments. | |
| Args: | |
| *args: args to litellm.completion() or embedding() | |
| **kwargs: kwargs to litellm.completion() or embedding() | |
| Returns: | |
| str: The cache key generated from the arguments, or None if no cache key could be generated. | |
| """ | |
| cache_key = "" | |
| print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}") | |
| # for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens | |
| if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None: | |
| print_verbose(f"\nReturning preset cache key: {cache_key}") | |
| return kwargs.get("litellm_params", {}).get("preset_cache_key", None) | |
| # sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4] | |
| completion_kwargs = [ | |
| "model", | |
| "messages", | |
| "temperature", | |
| "top_p", | |
| "n", | |
| "stop", | |
| "max_tokens", | |
| "presence_penalty", | |
| "frequency_penalty", | |
| "logit_bias", | |
| "user", | |
| "response_format", | |
| "seed", | |
| "tools", | |
| "tool_choice", | |
| ] | |
| embedding_only_kwargs = [ | |
| "input", | |
| "encoding_format", | |
| ] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs | |
| # combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set() | |
| combined_kwargs = completion_kwargs + embedding_only_kwargs | |
| for param in combined_kwargs: | |
| # ignore litellm params here | |
| if param in kwargs: | |
| # check if param == model and model_group is passed in, then override model with model_group | |
| if param == "model": | |
| model_group = None | |
| caching_group = None | |
| metadata = kwargs.get("metadata", None) | |
| litellm_params = kwargs.get("litellm_params", {}) | |
| if metadata is not None: | |
| model_group = metadata.get("model_group") | |
| model_group = metadata.get("model_group", None) | |
| caching_groups = metadata.get("caching_groups", None) | |
| if caching_groups: | |
| for group in caching_groups: | |
| if model_group in group: | |
| caching_group = group | |
| break | |
| if litellm_params is not None: | |
| metadata = litellm_params.get("metadata", None) | |
| if metadata is not None: | |
| model_group = metadata.get("model_group", None) | |
| caching_groups = metadata.get("caching_groups", None) | |
| if caching_groups: | |
| for group in caching_groups: | |
| if model_group in group: | |
| caching_group = group | |
| break | |
| param_value = ( | |
| caching_group or model_group or kwargs[param] | |
| ) # use caching_group, if set then model_group if it exists, else use kwargs["model"] | |
| else: | |
| if kwargs[param] is None: | |
| continue # ignore None params | |
| param_value = kwargs[param] | |
| cache_key += f"{str(param)}: {str(param_value)}" | |
| print_verbose(f"\nCreated cache key: {cache_key}") | |
| # Use hashlib to create a sha256 hash of the cache key | |
| hash_object = hashlib.sha256(cache_key.encode()) | |
| # Hexadecimal representation of the hash | |
| hash_hex = hash_object.hexdigest() | |
| print_verbose(f"Hashed cache key (SHA-256): {hash_hex}") | |
| return hash_hex | |
| def generate_streaming_content(self, content): | |
| chunk_size = 5 # Adjust the chunk size as needed | |
| for i in range(0, len(content), chunk_size): | |
| yield { | |
| "choices": [ | |
| { | |
| "delta": { | |
| "role": "assistant", | |
| "content": content[i : i + chunk_size], | |
| } | |
| } | |
| ] | |
| } | |
| time.sleep(0.02) | |
| def get_cache(self, *args, **kwargs): | |
| """ | |
| Retrieves the cached result for the given arguments. | |
| Args: | |
| *args: args to litellm.completion() or embedding() | |
| **kwargs: kwargs to litellm.completion() or embedding() | |
| Returns: | |
| The cached result if it exists, otherwise None. | |
| """ | |
| try: # never block execution | |
| if "cache_key" in kwargs: | |
| cache_key = kwargs["cache_key"] | |
| else: | |
| cache_key = self.get_cache_key(*args, **kwargs) | |
| if cache_key is not None: | |
| cache_control_args = kwargs.get("cache", {}) | |
| max_age = cache_control_args.get( | |
| "s-max-age", cache_control_args.get("s-maxage", float("inf")) | |
| ) | |
| cached_result = self.cache.get_cache(cache_key) | |
| # Check if a timestamp was stored with the cached response | |
| if ( | |
| cached_result is not None | |
| and isinstance(cached_result, dict) | |
| and "timestamp" in cached_result | |
| and max_age is not None | |
| ): | |
| timestamp = cached_result["timestamp"] | |
| current_time = time.time() | |
| # Calculate age of the cached response | |
| response_age = current_time - timestamp | |
| # Check if the cached response is older than the max-age | |
| if response_age > max_age: | |
| print_verbose( | |
| f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s" | |
| ) | |
| return None # Cached response is too old | |
| # If the response is fresh, or there's no max-age requirement, return the cached response | |
| # cached_response is in `b{} convert it to ModelResponse | |
| cached_response = cached_result.get("response") | |
| try: | |
| if isinstance(cached_response, dict): | |
| pass | |
| else: | |
| cached_response = json.loads( | |
| cached_response | |
| ) # Convert string to dictionary | |
| except: | |
| cached_response = ast.literal_eval(cached_response) | |
| return cached_response | |
| return cached_result | |
| except Exception as e: | |
| print_verbose(f"An exception occurred: {traceback.format_exc()}") | |
| return None | |
| def add_cache(self, result, *args, **kwargs): | |
| """ | |
| Adds a result to the cache. | |
| Args: | |
| *args: args to litellm.completion() or embedding() | |
| **kwargs: kwargs to litellm.completion() or embedding() | |
| Returns: | |
| None | |
| """ | |
| try: | |
| if "cache_key" in kwargs: | |
| cache_key = kwargs["cache_key"] | |
| else: | |
| cache_key = self.get_cache_key(*args, **kwargs) | |
| if cache_key is not None: | |
| if isinstance(result, OpenAIObject): | |
| result = result.model_dump_json() | |
| ## Get Cache-Controls ## | |
| if kwargs.get("cache", None) is not None and isinstance( | |
| kwargs.get("cache"), dict | |
| ): | |
| for k, v in kwargs.get("cache").items(): | |
| if k == "ttl": | |
| kwargs["ttl"] = v | |
| cached_data = {"timestamp": time.time(), "response": result} | |
| self.cache.set_cache(cache_key, cached_data, **kwargs) | |
| except Exception as e: | |
| print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}") | |
| traceback.print_exc() | |
| pass | |
| async def _async_add_cache(self, result, *args, **kwargs): | |
| self.add_cache(result, *args, **kwargs) | |
| def enable_cache( | |
| type: Optional[Literal["local", "redis", "s3"]] = "local", | |
| host: Optional[str] = None, | |
| port: Optional[str] = None, | |
| password: Optional[str] = None, | |
| supported_call_types: Optional[ | |
| List[Literal["completion", "acompletion", "embedding", "aembedding"]] | |
| ] = ["completion", "acompletion", "embedding", "aembedding"], | |
| **kwargs, | |
| ): | |
| """ | |
| Enable cache with the specified configuration. | |
| Args: | |
| type (Optional[Literal["local", "redis"]]): The type of cache to enable. Defaults to "local". | |
| host (Optional[str]): The host address of the cache server. Defaults to None. | |
| port (Optional[str]): The port number of the cache server. Defaults to None. | |
| password (Optional[str]): The password for the cache server. Defaults to None. | |
| supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): | |
| The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| None | |
| Raises: | |
| None | |
| """ | |
| print_verbose("LiteLLM: Enabling Cache") | |
| if "cache" not in litellm.input_callback: | |
| litellm.input_callback.append("cache") | |
| if "cache" not in litellm.success_callback: | |
| litellm.success_callback.append("cache") | |
| if "cache" not in litellm._async_success_callback: | |
| litellm._async_success_callback.append("cache") | |
| if litellm.cache == None: | |
| litellm.cache = Cache( | |
| type=type, | |
| host=host, | |
| port=port, | |
| password=password, | |
| supported_call_types=supported_call_types, | |
| **kwargs, | |
| ) | |
| print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") | |
| print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") | |
| def update_cache( | |
| type: Optional[Literal["local", "redis"]] = "local", | |
| host: Optional[str] = None, | |
| port: Optional[str] = None, | |
| password: Optional[str] = None, | |
| supported_call_types: Optional[ | |
| List[Literal["completion", "acompletion", "embedding", "aembedding"]] | |
| ] = ["completion", "acompletion", "embedding", "aembedding"], | |
| **kwargs, | |
| ): | |
| """ | |
| Update the cache for LiteLLM. | |
| Args: | |
| type (Optional[Literal["local", "redis"]]): The type of cache. Defaults to "local". | |
| host (Optional[str]): The host of the cache. Defaults to None. | |
| port (Optional[str]): The port of the cache. Defaults to None. | |
| password (Optional[str]): The password for the cache. Defaults to None. | |
| supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): | |
| The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. | |
| **kwargs: Additional keyword arguments for the cache. | |
| Returns: | |
| None | |
| """ | |
| print_verbose("LiteLLM: Updating Cache") | |
| litellm.cache = Cache( | |
| type=type, | |
| host=host, | |
| port=port, | |
| password=password, | |
| supported_call_types=supported_call_types, | |
| **kwargs, | |
| ) | |
| print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}") | |
| print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") | |
| def disable_cache(): | |
| """ | |
| Disable the cache used by LiteLLM. | |
| This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None. | |
| Parameters: | |
| None | |
| Returns: | |
| None | |
| """ | |
| from contextlib import suppress | |
| print_verbose("LiteLLM: Disabling Cache") | |
| with suppress(ValueError): | |
| litellm.input_callback.remove("cache") | |
| litellm.success_callback.remove("cache") | |
| litellm._async_success_callback.remove("cache") | |
| litellm.cache = None | |
| print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") | |