Spaces:
Runtime error
Runtime error
| import json | |
| import re | |
| from typing import List, Dict, Any, Union, Optional | |
| import io | |
| import os | |
| import base64 | |
| from PIL import Image | |
| import mimetypes | |
| import litellm | |
| from litellm import completion, completion_cost | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class LiteLLMWrapper: | |
| """Wrapper for LiteLLM to support multiple models and logging""" | |
| def __init__( | |
| self, | |
| model_name: str = "gpt-4-vision-preview", | |
| temperature: float = 0.7, | |
| print_cost: bool = False, | |
| verbose: bool = False, | |
| use_langfuse: bool = True, | |
| ): | |
| """ | |
| Initialize the LiteLLM wrapper | |
| Args: | |
| model_name: Name of the model to use (e.g. "azure/gpt-4", "vertex_ai/gemini-pro") | |
| temperature: Temperature for completion | |
| print_cost: Whether to print the cost of the completion | |
| verbose: Whether to print verbose output | |
| use_langfuse: Whether to enable Langfuse logging | |
| """ | |
| self.model_name = model_name | |
| self.temperature = temperature | |
| self.print_cost = print_cost | |
| self.verbose = verbose | |
| self.accumulated_cost = 0 | |
| if self.verbose: | |
| os.environ['LITELLM_LOG'] = 'DEBUG' | |
| # Set langfuse callback only if enabled | |
| if use_langfuse: | |
| litellm.success_callback = ["langfuse"] | |
| litellm.failure_callback = ["langfuse"] | |
| def _encode_file(self, file_path: Union[str, Image.Image]) -> str: | |
| """ | |
| Encode local file or PIL Image to base64 string | |
| Args: | |
| file_path: Path to local file or PIL Image object | |
| Returns: | |
| Base64 encoded file string | |
| """ | |
| if isinstance(file_path, Image.Image): | |
| buffered = io.BytesIO() | |
| file_path.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| else: | |
| with open(file_path, "rb") as file: | |
| return base64.b64encode(file.read()).decode("utf-8") | |
| def _get_mime_type(self, file_path: str) -> str: | |
| """ | |
| Get the MIME type of a file based on its extension | |
| Args: | |
| file_path: Path to the file | |
| Returns: | |
| MIME type as a string (e.g., "image/jpeg", "audio/mp3") | |
| """ | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| if mime_type is None: | |
| raise ValueError(f"Unsupported file type: {file_path}") | |
| return mime_type | |
| def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: | |
| """ | |
| Process messages and return completion | |
| Args: | |
| messages: List of message dictionaries with 'type' and 'content' keys | |
| metadata: Optional metadata to pass to litellm completion, e.g. for Langfuse tracking | |
| Returns: | |
| Generated text response | |
| """ | |
| if metadata is None: | |
| print("No metadata provided, using empty metadata") | |
| metadata = {} | |
| metadata["trace_name"] = f"litellm-completion-{self.model_name}" | |
| # Convert messages to LiteLLM format | |
| formatted_messages = [] | |
| for msg in messages: | |
| if msg["type"] == "text": | |
| formatted_messages.append({ | |
| "role": "user", | |
| "content": [{"type": "text", "text": msg["content"]}] | |
| }) | |
| elif msg["type"] in ["image", "audio", "video"]: | |
| # Check if content is a local file path or PIL Image | |
| if isinstance(msg["content"], Image.Image) or os.path.isfile(msg["content"]): | |
| try: | |
| if isinstance(msg["content"], Image.Image): | |
| mime_type = "image/png" | |
| else: | |
| mime_type = self._get_mime_type(msg["content"]) | |
| base64_data = self._encode_file(msg["content"]) | |
| data_url = f"data:{mime_type};base64,{base64_data}" | |
| except ValueError as e: | |
| print(f"Error processing file {msg['content']}: {e}") | |
| continue | |
| else: | |
| data_url = msg["content"] | |
| # Append the formatted message based on the model | |
| if "gemini" in self.model_name: | |
| formatted_messages.append({ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image_url", | |
| "image_url": data_url | |
| } | |
| ] | |
| }) | |
| elif "gpt" in self.model_name: | |
| # GPT and other models expect a different format | |
| if msg["type"] == "image": | |
| # Default format for images and videos in GPT | |
| formatted_messages.append({ | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": f"image_url", | |
| f"{msg['type']}_url": { | |
| "url": data_url, | |
| "detail": "high" | |
| } | |
| } | |
| ] | |
| }) | |
| else: | |
| raise ValueError("For GPT, only text and image inferencing are supported") | |
| else: | |
| raise ValueError("Only support Gemini and Gpt for Multimodal capability now") | |
| try: | |
| # if it's openai o series model, set temperature to None and reasoning_effort to "medium" | |
| if (re.match(r"^o\d+.*$", self.model_name) or re.match(r"^openai/o.*$", self.model_name)): | |
| self.temperature = None | |
| self.reasoning_effort = "medium" | |
| response = completion( | |
| model=self.model_name, | |
| messages=formatted_messages, | |
| temperature=self.temperature, | |
| reasoning_effort=self.reasoning_effort, | |
| metadata=metadata, | |
| max_retries=99 | |
| ) | |
| else: | |
| response = completion( | |
| model=self.model_name, | |
| messages=formatted_messages, | |
| temperature=self.temperature, | |
| metadata=metadata, | |
| max_retries=99 | |
| ) | |
| if self.print_cost: | |
| # pass your response from completion to completion_cost | |
| cost = completion_cost(completion_response=response) | |
| formatted_string = f"Cost: ${float(cost):.10f}" | |
| # print(formatted_string) | |
| self.accumulated_cost += cost | |
| print(f"Accumulated Cost: ${self.accumulated_cost:.10f}") | |
| content = response.choices[0].message.content | |
| if content is None: | |
| print(f"Got null response from model. Full response: {response}") | |
| return content | |
| except Exception as e: | |
| print(f"Error in model completion: {e}") | |
| return str(e) | |
| if __name__ == "__main__": | |
| pass |