Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import Dict, Any, Optional | |
| import logging | |
| from enum import Enum | |
| logger = logging.getLogger(__name__) | |
| class ModelType(Enum): | |
| """Enum for different VLM model types""" | |
| GPT4V = "gpt4v" | |
| CLAUDE_3_5_SONNET = "claude_3_5_sonnet" | |
| GEMINI_PRO_VISION = "gemini_pro_vision" | |
| LLAMA_VISION = "llama_vision" | |
| CUSTOM = "custom" | |
| class VLMService(ABC): | |
| """Abstract base class for VLM services""" | |
| def __init__(self, model_name: str, model_type: ModelType): | |
| self.model_name = model_name | |
| self.model_type = model_type | |
| self.is_available = True | |
| async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "") -> Dict[str, Any]: | |
| """Generate caption for an image""" | |
| pass | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get model information""" | |
| return { | |
| "name": self.model_name, | |
| "type": self.model_type.value, | |
| "available": self.is_available, | |
| } | |
| class VLMServiceManager: | |
| """Manager for multiple VLM services""" | |
| def __init__(self): | |
| self.services: Dict[str, VLMService] = {} | |
| self.default_service: Optional[str] = None | |
| def register_service(self, service: VLMService): | |
| """Register a VLM service""" | |
| self.services[service.model_name] = service | |
| if not self.default_service: | |
| self.default_service = service.model_name | |
| logger.info(f"Registered VLM service: {service.model_name}") | |
| def get_service(self, model_name: str) -> Optional[VLMService]: | |
| """Get a specific VLM service""" | |
| return self.services.get(model_name) | |
| def get_default_service(self) -> Optional[VLMService]: | |
| """Get the default VLM service""" | |
| if self.default_service: | |
| return self.services.get(self.default_service) | |
| return None | |
| def get_available_models(self) -> list: | |
| """Get list of available model names""" | |
| return list(self.services.keys()) | |
| async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session = None) -> dict: | |
| """Generate caption using the specified model or fallback to available service.""" | |
| service = None | |
| if model_name and model_name != "random": | |
| service = self.services.get(model_name) | |
| if not service: | |
| print(f"Model '{model_name}' not found, using fallback") | |
| if not service and self.services: | |
| # If random is selected or no specific model, choose a random available service | |
| if db_session: | |
| # Check database availability for random selection | |
| try: | |
| from .. import crud | |
| available_models = crud.get_models(db_session) | |
| available_model_codes = [m.m_code for m in available_models if m.is_available] | |
| print(f"DEBUG: Available models in database: {available_model_codes}") | |
| print(f"DEBUG: Registered services: {list(self.services.keys())}") | |
| # Filter services to only those marked as available in database | |
| available_services = [s for s in self.services.values() if s.model_name in available_model_codes] | |
| print(f"DEBUG: Available services after filtering: {[s.model_name for s in available_services]}") | |
| if available_services: | |
| import random | |
| import time | |
| # Use current time as seed for better randomness | |
| random.seed(int(time.time() * 1000000) % 1000000) | |
| # Shuffle the list first for better randomization | |
| shuffled_services = available_services.copy() | |
| random.shuffle(shuffled_services) | |
| service = shuffled_services[0] | |
| print(f"Randomly selected service: {service.model_name} (from {len(available_services)} available)") | |
| print(f"DEBUG: All available services were: {[s.model_name for s in available_services]}") | |
| print(f"DEBUG: Shuffled order: {[s.model_name for s in shuffled_services]}") | |
| else: | |
| # Fallback to any service | |
| service = next(iter(self.services.values())) | |
| print(f"Using fallback service: {service.model_name}") | |
| except Exception as e: | |
| print(f"Error checking database availability: {e}, using fallback") | |
| service = next(iter(self.services.values())) | |
| print(f"Using fallback service: {service.model_name}") | |
| else: | |
| # No database session, use service property | |
| available_services = [s for s in self.services.values() if s.is_available] | |
| if available_services: | |
| import random | |
| service = random.choice(available_services) | |
| print(f"Randomly selected service: {service.model_name}") | |
| else: | |
| # Fallback to any service | |
| service = next(iter(self.services.values())) | |
| print(f"Using fallback service: {service.model_name}") | |
| if not service: | |
| raise ValueError("No VLM services available") | |
| # Track attempts to avoid infinite loops | |
| attempted_services = set() | |
| max_attempts = len(self.services) | |
| while len(attempted_services) < max_attempts: | |
| try: | |
| result = await service.generate_caption(image_bytes, prompt, metadata_instructions) | |
| if isinstance(result, dict): | |
| result["model"] = service.model_name | |
| result["fallback_used"] = len(attempted_services) > 0 | |
| if len(attempted_services) > 0: | |
| result["original_model"] = model_name | |
| result["fallback_reason"] = "model_unavailable" | |
| return result | |
| except Exception as e: | |
| error_str = str(e) | |
| print(f"Error with service {service.model_name}: {error_str}") | |
| # Check if it's a model unavailable error (any type of error) | |
| if "MODEL_UNAVAILABLE" in error_str: | |
| attempted_services.add(service.model_name) | |
| print(f"Model {service.model_name} is unavailable, trying another service...") | |
| # Try to find another available service | |
| if db_session: | |
| try: | |
| from .. import crud | |
| available_models = crud.get_models(db_session) | |
| available_model_codes = [m.m_code for m in available_models if m.is_available] | |
| # Find next available service that hasn't been attempted | |
| for next_service in self.services.values(): | |
| if (next_service.model_name in available_model_codes and | |
| next_service.model_name not in attempted_services): | |
| service = next_service | |
| print(f"Switching to fallback service: {service.model_name}") | |
| break | |
| else: | |
| # No more available services, use any untried service | |
| for next_service in self.services.values(): | |
| if next_service.model_name not in attempted_services: | |
| service = next_service | |
| print(f"Using untried service as fallback: {service.model_name}") | |
| break | |
| except Exception as db_error: | |
| print(f"Error checking database availability: {db_error}") | |
| # Fallback to any untried service | |
| for next_service in self.services.values(): | |
| if next_service.model_name not in attempted_services: | |
| service = next_service | |
| print(f"Using untried service as fallback: {service.model_name}") | |
| break | |
| else: | |
| # No database session, use any untried service | |
| for next_service in self.services.values(): | |
| if next_service.model_name not in attempted_services: | |
| service = next_service | |
| print(f"Using untried service as fallback: {service.model_name}") | |
| break | |
| if not service: | |
| raise ValueError("No more VLM services available after model failures") | |
| continue # Try again with new service | |
| else: | |
| # Non-model-unavailable error, don't retry | |
| print(f"Non-model-unavailable error, not retrying: {error_str}") | |
| raise | |
| # If we get here, we've tried all services | |
| raise ValueError("All VLM services failed due to model unavailability") | |
| vlm_manager = VLMServiceManager() |