Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Any, Union, Optional | |
| import io | |
| import os | |
| import base64 | |
| from PIL import Image | |
| import mimetypes | |
| import google.generativeai as genai | |
| import tempfile | |
| import time | |
| from urllib.parse import urlparse | |
| import requests | |
| from io import BytesIO | |
| class GeminiWrapper: | |
| """Wrapper for Gemini to support multiple models and logging""" | |
| def __init__( | |
| self, | |
| model_name: str = "gemini-1.5-pro-002", | |
| temperature: float = 0.7, | |
| print_cost: bool = False, | |
| verbose: bool = False, | |
| use_langfuse: bool = False | |
| ): | |
| """ | |
| Initialize the Gemini wrapper | |
| Args: | |
| model_name: Name of the model to use | |
| temperature: Temperature for completion | |
| print_cost: Whether to print the cost of the completion | |
| verbose: Whether to print verbose output | |
| use_langfuse: Whether to enable Langfuse logging | |
| """ | |
| self.model_name = model_name.split('/')[-1] if '/' in model_name else model_name | |
| self.temperature = temperature | |
| self.print_cost = print_cost | |
| self.verbose = verbose | |
| self.accumulated_cost = 0 | |
| api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.") | |
| genai.configure(api_key=api_key) | |
| generation_config = { | |
| "temperature": self.temperature, | |
| "top_p": 0.95, | |
| "response_mime_type": "text/plain", | |
| } | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, | |
| ] | |
| self.model = genai.GenerativeModel( | |
| model_name=self.model_name, | |
| safety_settings=safety_settings, | |
| generation_config=generation_config, | |
| ) | |
| def _get_mime_type(self, file_path: str) -> str: | |
| """ | |
| Get the MIME type of a file based on its extension | |
| Args: | |
| file_path: Path to the file | |
| Returns: | |
| MIME type as a string (e.g., "image/jpeg", "audio/mp3") | |
| """ | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| if mime_type is None: | |
| raise ValueError(f"Unsupported file type: {file_path}") | |
| return mime_type | |
| def _download_file(self, url: str) -> str: | |
| """ | |
| Download a file from a URL and save it as a temporary file | |
| Args: | |
| url: URL of the file to download | |
| Returns: | |
| Path to the temporary file | |
| """ | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False) | |
| temp_file.write(response.content) | |
| temp_file.close() | |
| return temp_file.name | |
| else: | |
| raise ValueError(f"Failed to download file from URL: {url}") | |
| def _save_image_to_temp(self, image: Image.Image) -> str: | |
| """ | |
| Save a PIL Image to a temporary file | |
| Args: | |
| image: PIL Image object | |
| Returns: | |
| Path to the temporary file | |
| """ | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| image.save(temp_file, format="PNG") | |
| temp_file.close() | |
| return temp_file.name | |
| def _upload_to_gemini(self, file_path: str, mime_type: Optional[str] = None): | |
| """ | |
| Uploads the given file to Gemini. | |
| Args: | |
| file_path: Path to the file | |
| mime_type: MIME type of the file | |
| Returns: | |
| Uploaded file object | |
| """ | |
| return genai.upload_file(file_path, mime_type=mime_type) | |
| def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: | |
| """ | |
| Process messages and return completion | |
| Args: | |
| messages: List of message dictionaries with 'type' and 'content' keys | |
| metadata: Optional metadata to pass to Gemini completion | |
| Returns: | |
| Generated text response | |
| """ | |
| contents = [] | |
| for msg in messages: | |
| if msg["type"] == "text": | |
| contents.append(msg["content"]) | |
| elif msg["type"] in ["image", "audio", "video"]: | |
| if isinstance(msg["content"], Image.Image): | |
| file_path = self._save_image_to_temp(msg["content"]) | |
| mime_type = "image/png" | |
| elif isinstance(msg["content"], str): | |
| if msg["content"].startswith("http"): | |
| file_path = self._download_file(msg["content"]) | |
| mime_type = self._get_mime_type(msg["content"]) | |
| else: | |
| file_path = msg["content"] | |
| mime_type = self._get_mime_type(file_path) | |
| else: | |
| raise ValueError("Unsupported content type") | |
| uploaded_file = self._upload_to_gemini(file_path, mime_type) | |
| while uploaded_file.state.name == "PROCESSING": | |
| print('.', end='') | |
| time.sleep(3) | |
| uploaded_file = genai.get_file(uploaded_file.name) | |
| if uploaded_file.state.name == "FAILED": | |
| raise ValueError(uploaded_file.state.name) | |
| print("Upload successfully") | |
| contents.append(uploaded_file) | |
| else: | |
| raise ValueError("Unsupported message type") | |
| response = self.model.generate_content(contents, request_options={"timeout": 600}) | |
| try: | |
| return response.text | |
| except Exception as e: | |
| print(e) | |
| print(response.prompt_feedback) | |
| return str(response.prompt_feedback) | |
| if __name__ == "__main__": | |
| pass |