Spaces:
Runtime error
Runtime error
| import os | |
| from typing import List, Dict, Any, Optional | |
| import vertexai | |
| from vertexai.generative_models import GenerativeModel, Part | |
| from google.auth import default | |
| from google.auth.transport import requests | |
| # TODO: check if this is the correct way to use Vertex AI | |
| # TODO: add langfuse support | |
| class VertexAIWrapper: | |
| """Wrapper for Vertex AI to support Gemini models.""" | |
| def __init__( | |
| self, | |
| model_name: str = "gemini-1.5-pro", | |
| temperature: float = 0.7, | |
| print_cost: bool = False, | |
| verbose: bool = False, | |
| use_langfuse: bool = False | |
| ): | |
| """Initialize the Vertex AI wrapper. | |
| Args: | |
| model_name: Name of the model to use (e.g. "gemini-1.5-pro") | |
| temperature: Temperature for generation between 0 and 1 | |
| print_cost: Whether to print the cost of the completion | |
| verbose: Whether to print verbose output | |
| use_langfuse: Whether to enable Langfuse logging | |
| """ | |
| self.model_name = model_name | |
| self.temperature = temperature | |
| self.print_cost = print_cost | |
| self.verbose = verbose | |
| # Initialize Vertex AI | |
| project_id = os.getenv("GOOGLE_CLOUD_PROJECT") | |
| location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") | |
| if not project_id: | |
| raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables") | |
| vertexai.init(project=project_id, location=location) | |
| self.model = GenerativeModel(model_name) | |
| def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: | |
| """Process messages and return completion. | |
| Args: | |
| messages: List of message dictionaries containing type and content | |
| metadata: Optional metadata dictionary to pass to the model | |
| Returns: | |
| Generated text response from the model | |
| Raises: | |
| ValueError: If message type is not supported | |
| """ | |
| parts = [] | |
| for msg in messages: | |
| if msg["type"] == "text": | |
| parts.append(Part.from_text(msg["content"])) | |
| elif msg["type"] in ["image", "video"]: | |
| mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg" | |
| if isinstance(msg["content"], str): | |
| # Handle GCS URI | |
| parts.append(Part.from_uri( | |
| msg["content"], | |
| mime_type=mime_type | |
| )) | |
| else: | |
| # Handle file path or bytes | |
| parts.append(Part.from_data( | |
| msg["content"], | |
| mime_type=mime_type | |
| )) | |
| response = self.model.generate_content( | |
| parts, | |
| generation_config={ | |
| "temperature": self.temperature, | |
| "top_p": 0.95, | |
| } | |
| ) | |
| return response.text |