Spaces:
Running
Running
File size: 9,933 Bytes
d7291ef fe5d98f d7291ef 65933cd d7291ef ba5edb0 5778774 d7291ef 5778774 65933cd 5778774 d7291ef 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 ba5edb0 5778774 d7291ef 65933cd d7291ef 186c8e8 5778774 186c8e8 5778774 01ab2a4 186c8e8 5778774 186c8e8 5778774 186c8e8 5778774 186c8e8 5778774 186c8e8 5778774 186c8e8 5778774 186c8e8 d7291ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
from enum import Enum
logger = logging.getLogger(__name__)
class ModelType(Enum):
"""Enum for different VLM model types"""
GPT4V = "gpt4v"
CLAUDE_3_5_SONNET = "claude_3_5_sonnet"
GEMINI_PRO_VISION = "gemini_pro_vision"
LLAMA_VISION = "llama_vision"
CUSTOM = "custom"
class VLMService(ABC):
"""Abstract base class for VLM services"""
def __init__(self, model_name: str, model_type: ModelType):
self.model_name = model_name
self.model_type = model_type
self.is_available = True
@abstractmethod
async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "") -> Dict[str, Any]:
"""Generate caption for an image"""
pass
def get_model_info(self) -> Dict[str, Any]:
"""Get model information"""
return {
"name": self.model_name,
"type": self.model_type.value,
"available": self.is_available,
}
class VLMServiceManager:
"""Manager for multiple VLM services"""
def __init__(self):
self.services: Dict[str, VLMService] = {}
self.default_service: Optional[str] = None
def register_service(self, service: VLMService):
"""Register a VLM service"""
self.services[service.model_name] = service
if not self.default_service:
self.default_service = service.model_name
logger.info(f"Registered VLM service: {service.model_name}")
def get_service(self, model_name: str) -> Optional[VLMService]:
"""Get a specific VLM service"""
return self.services.get(model_name)
def get_default_service(self) -> Optional[VLMService]:
"""Get the default VLM service"""
if self.default_service:
return self.services.get(self.default_service)
return None
def get_available_models(self) -> list:
"""Get list of available model names"""
return list(self.services.keys())
async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session = None) -> dict:
"""Generate caption using the specified model or fallback to available service."""
service = None
if model_name and model_name != "random":
service = self.services.get(model_name)
if not service:
print(f"Model '{model_name}' not found, using fallback")
if not service and self.services:
# If random is selected or no specific model, choose a random available service
if db_session:
# Check database availability for random selection
try:
from .. import crud
available_models = crud.get_models(db_session)
available_model_codes = [m.m_code for m in available_models if m.is_available]
print(f"DEBUG: Available models in database: {available_model_codes}")
print(f"DEBUG: Registered services: {list(self.services.keys())}")
# Filter services to only those marked as available in database
available_services = [s for s in self.services.values() if s.model_name in available_model_codes]
print(f"DEBUG: Available services after filtering: {[s.model_name for s in available_services]}")
if available_services:
import random
import time
# Use current time as seed for better randomness
random.seed(int(time.time() * 1000000) % 1000000)
# Shuffle the list first for better randomization
shuffled_services = available_services.copy()
random.shuffle(shuffled_services)
service = shuffled_services[0]
print(f"Randomly selected service: {service.model_name} (from {len(available_services)} available)")
print(f"DEBUG: All available services were: {[s.model_name for s in available_services]}")
print(f"DEBUG: Shuffled order: {[s.model_name for s in shuffled_services]}")
else:
# Fallback to any service
service = next(iter(self.services.values()))
print(f"Using fallback service: {service.model_name}")
except Exception as e:
print(f"Error checking database availability: {e}, using fallback")
service = next(iter(self.services.values()))
print(f"Using fallback service: {service.model_name}")
else:
# No database session, use service property
available_services = [s for s in self.services.values() if s.is_available]
if available_services:
import random
service = random.choice(available_services)
print(f"Randomly selected service: {service.model_name}")
else:
# Fallback to any service
service = next(iter(self.services.values()))
print(f"Using fallback service: {service.model_name}")
if not service:
raise ValueError("No VLM services available")
# Track attempts to avoid infinite loops
attempted_services = set()
max_attempts = len(self.services)
while len(attempted_services) < max_attempts:
try:
result = await service.generate_caption(image_bytes, prompt, metadata_instructions)
if isinstance(result, dict):
result["model"] = service.model_name
result["fallback_used"] = len(attempted_services) > 0
if len(attempted_services) > 0:
result["original_model"] = model_name
result["fallback_reason"] = "model_unavailable"
return result
except Exception as e:
error_str = str(e)
print(f"Error with service {service.model_name}: {error_str}")
# Check if it's a model unavailable error (any type of error)
if "MODEL_UNAVAILABLE" in error_str:
attempted_services.add(service.model_name)
print(f"Model {service.model_name} is unavailable, trying another service...")
# Try to find another available service
if db_session:
try:
from .. import crud
available_models = crud.get_models(db_session)
available_model_codes = [m.m_code for m in available_models if m.is_available]
# Find next available service that hasn't been attempted
for next_service in self.services.values():
if (next_service.model_name in available_model_codes and
next_service.model_name not in attempted_services):
service = next_service
print(f"Switching to fallback service: {service.model_name}")
break
else:
# No more available services, use any untried service
for next_service in self.services.values():
if next_service.model_name not in attempted_services:
service = next_service
print(f"Using untried service as fallback: {service.model_name}")
break
except Exception as db_error:
print(f"Error checking database availability: {db_error}")
# Fallback to any untried service
for next_service in self.services.values():
if next_service.model_name not in attempted_services:
service = next_service
print(f"Using untried service as fallback: {service.model_name}")
break
else:
# No database session, use any untried service
for next_service in self.services.values():
if next_service.model_name not in attempted_services:
service = next_service
print(f"Using untried service as fallback: {service.model_name}")
break
if not service:
raise ValueError("No more VLM services available after model failures")
continue # Try again with new service
else:
# Non-model-unavailable error, don't retry
print(f"Non-model-unavailable error, not retrying: {error_str}")
raise
# If we get here, we've tried all services
raise ValueError("All VLM services failed due to model unavailability")
vlm_manager = VLMServiceManager() |