File size: 9,933 Bytes
d7291ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe5d98f
d7291ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65933cd
 
 
d7291ef
ba5edb0
5778774
d7291ef
 
5778774
65933cd
5778774
 
d7291ef
5778774
 
ba5edb0
5778774
ba5edb0
 
 
 
 
5778774
 
 
 
ba5edb0
5778774
 
 
ba5edb0
 
5778774
 
 
 
 
ba5edb0
 
5778774
ba5edb0
5778774
 
 
ba5edb0
5778774
ba5edb0
5778774
ba5edb0
5778774
ba5edb0
5778774
ba5edb0
5778774
ba5edb0
 
 
 
5778774
ba5edb0
5778774
ba5edb0
5778774
d7291ef
65933cd
 
d7291ef
186c8e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5778774
186c8e8
 
 
 
5778774
01ab2a4
186c8e8
 
 
 
 
 
 
 
 
 
 
 
5778774
186c8e8
 
 
 
 
 
5778774
186c8e8
 
5778774
186c8e8
 
 
 
5778774
186c8e8
 
 
 
 
 
5778774
186c8e8
 
 
 
 
 
 
 
5778774
186c8e8
 
 
 
d7291ef
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import logging
from enum import Enum

logger = logging.getLogger(__name__)

class ModelType(Enum):
    """Enum for different VLM model types"""
    GPT4V = "gpt4v"
    CLAUDE_3_5_SONNET = "claude_3_5_sonnet"
    GEMINI_PRO_VISION = "gemini_pro_vision"
    LLAMA_VISION = "llama_vision"
    CUSTOM = "custom"

class VLMService(ABC):
    """Abstract base class for VLM services"""
    
    def __init__(self, model_name: str, model_type: ModelType):
        self.model_name = model_name
        self.model_type = model_type
        self.is_available = True
    
    @abstractmethod
    async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "") -> Dict[str, Any]:
        """Generate caption for an image"""
        pass
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information"""
        return {
            "name": self.model_name,
            "type": self.model_type.value,
            "available": self.is_available,
        }

class VLMServiceManager:
    """Manager for multiple VLM services"""
    
    def __init__(self):
        self.services: Dict[str, VLMService] = {}
        self.default_service: Optional[str] = None
    
    def register_service(self, service: VLMService):
        """Register a VLM service"""
        self.services[service.model_name] = service
        if not self.default_service:
            self.default_service = service.model_name
        logger.info(f"Registered VLM service: {service.model_name}")
    
    def get_service(self, model_name: str) -> Optional[VLMService]:
        """Get a specific VLM service"""
        return self.services.get(model_name)
    
    def get_default_service(self) -> Optional[VLMService]:
        """Get the default VLM service"""
        if self.default_service:
            return self.services.get(self.default_service)
        return None
    
    def get_available_models(self) -> list:
        """Get list of available model names"""
        return list(self.services.keys())
    
    async def generate_caption(self, image_bytes: bytes, prompt: str, metadata_instructions: str = "", model_name: str | None = None, db_session = None) -> dict:
        """Generate caption using the specified model or fallback to available service."""
        
        service = None
        if model_name and model_name != "random":
            service = self.services.get(model_name)
            if not service:
                print(f"Model '{model_name}' not found, using fallback")
        
        if not service and self.services:
            # If random is selected or no specific model, choose a random available service
            if db_session:
                # Check database availability for random selection
                try:
                    from .. import crud
                    available_models = crud.get_models(db_session)
                    available_model_codes = [m.m_code for m in available_models if m.is_available]
                    
                    print(f"DEBUG: Available models in database: {available_model_codes}")
                    print(f"DEBUG: Registered services: {list(self.services.keys())}")
                    
                    # Filter services to only those marked as available in database
                    available_services = [s for s in self.services.values() if s.model_name in available_model_codes]
                    
                    print(f"DEBUG: Available services after filtering: {[s.model_name for s in available_services]}")
                    
                    if available_services:
                        import random
                        import time
                        # Use current time as seed for better randomness
                        random.seed(int(time.time() * 1000000) % 1000000)
                        
                        # Shuffle the list first for better randomization
                        shuffled_services = available_services.copy()
                        random.shuffle(shuffled_services)
                        
                        service = shuffled_services[0]
                        print(f"Randomly selected service: {service.model_name} (from {len(available_services)} available)")
                        print(f"DEBUG: All available services were: {[s.model_name for s in available_services]}")
                        print(f"DEBUG: Shuffled order: {[s.model_name for s in shuffled_services]}")
                    else:
                        # Fallback to any service
                        service = next(iter(self.services.values()))
                        print(f"Using fallback service: {service.model_name}")
                except Exception as e:
                    print(f"Error checking database availability: {e}, using fallback")
                    service = next(iter(self.services.values()))
                    print(f"Using fallback service: {service.model_name}")
            else:
                # No database session, use service property
                available_services = [s for s in self.services.values() if s.is_available]
                if available_services:
                    import random
                    service = random.choice(available_services)
                    print(f"Randomly selected service: {service.model_name}")
                else:
                    # Fallback to any service
                    service = next(iter(self.services.values()))
                    print(f"Using fallback service: {service.model_name}")
        
        if not service:
            raise ValueError("No VLM services available")
        
        # Track attempts to avoid infinite loops
        attempted_services = set()
        max_attempts = len(self.services)
        
        while len(attempted_services) < max_attempts:
            try:
                result = await service.generate_caption(image_bytes, prompt, metadata_instructions)
                if isinstance(result, dict):
                    result["model"] = service.model_name
                    result["fallback_used"] = len(attempted_services) > 0
                    if len(attempted_services) > 0:
                        result["original_model"] = model_name
                        result["fallback_reason"] = "model_unavailable"
                return result
            except Exception as e:
                error_str = str(e)
                print(f"Error with service {service.model_name}: {error_str}")
                
                # Check if it's a model unavailable error (any type of error)
                if "MODEL_UNAVAILABLE" in error_str:
                    attempted_services.add(service.model_name)
                    print(f"Model {service.model_name} is unavailable, trying another service...")
                    
                    # Try to find another available service
                    if db_session:
                        try:
                            from .. import crud
                            available_models = crud.get_models(db_session)
                            available_model_codes = [m.m_code for m in available_models if m.is_available]
                            
                            # Find next available service that hasn't been attempted
                            for next_service in self.services.values():
                                if (next_service.model_name in available_model_codes and 
                                    next_service.model_name not in attempted_services):
                                    service = next_service
                                    print(f"Switching to fallback service: {service.model_name}")
                                    break
                            else:
                                # No more available services, use any untried service
                                for next_service in self.services.values():
                                    if next_service.model_name not in attempted_services:
                                        service = next_service
                                        print(f"Using untried service as fallback: {service.model_name}")
                                        break
                        except Exception as db_error:
                            print(f"Error checking database availability: {db_error}")
                            # Fallback to any untried service
                            for next_service in self.services.values():
                                if next_service.model_name not in attempted_services:
                                    service = next_service
                                    print(f"Using untried service as fallback: {service.model_name}")
                                    break
                    else:
                        # No database session, use any untried service
                        for next_service in self.services.values():
                            if next_service.model_name not in attempted_services:
                                service = next_service
                                print(f"Using untried service as fallback: {service.model_name}")
                                break
                    
                    if not service:
                        raise ValueError("No more VLM services available after model failures")
                    
                    continue  # Try again with new service
                else:
                    # Non-model-unavailable error, don't retry
                    print(f"Non-model-unavailable error, not retrying: {error_str}")
                    raise
        
        # If we get here, we've tried all services
        raise ValueError("All VLM services failed due to model unavailability")

vlm_manager = VLMServiceManager()