rdune71 commited on
Commit
5c1efea
Β·
1 Parent(s): 5adc6a4

Implement hybrid AI architecture with HF Endpoint heavy lifting and Ollama local caching

Browse files
src/llm/factory.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from typing import Optional
3
  from src.llm.base_provider import LLMProvider
 
4
  from src.llm.hf_provider import HuggingFaceProvider
5
  from src.llm.ollama_provider import OllamaProvider
6
  from utils.config import config
@@ -13,7 +14,7 @@ class ProviderNotAvailableError(Exception):
13
  pass
14
 
15
  class LLMFactory:
16
- """Factory for creating LLM providers with intelligent fallback"""
17
 
18
  _instance = None
19
 
@@ -25,29 +26,37 @@ class LLMFactory:
25
  def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider:
26
  """
27
  Get an LLM provider based on preference and availability.
28
- Priority: HF Endpoint > Ollama > Error
29
  """
30
- # Check if HF token is available and endpoint is ready
31
- if config.hf_token:
32
- status = hf_monitor.get_endpoint_status()
33
- if status["available"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- logger.info("Using HF Endpoint as primary provider")
36
- return HuggingFaceProvider(
37
- model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
38
  )
39
- except Exception as e:
40
- logger.warning(f"Failed to initialize HF provider: {e}")
41
-
42
- # Try Ollama as fallback
43
- if config.ollama_host:
44
- try:
45
- logger.info("Using Ollama as provider")
46
- return OllamaProvider(
47
- model_name=config.local_model_name
48
- )
49
- except Exception as e:
50
- logger.warning(f"Failed to initialize Ollama provider: {e}")
51
 
52
  raise ProviderNotAvailableError("No LLM providers are available or configured")
53
 
 
1
  import logging
2
  from typing import Optional
3
  from src.llm.base_provider import LLMProvider
4
+ from src.llm.hybrid_provider import HybridProvider
5
  from src.llm.hf_provider import HuggingFaceProvider
6
  from src.llm.ollama_provider import OllamaProvider
7
  from utils.config import config
 
14
  pass
15
 
16
  class LLMFactory:
17
+ """Factory for creating LLM providers with hybrid approach"""
18
 
19
  _instance = None
20
 
 
26
  def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider:
27
  """
28
  Get an LLM provider based on preference and availability.
29
+ Default: Hybrid approach (HF primary + Ollama backup/cache)
30
  """
31
+ try:
32
+ # Always try hybrid provider first (uses both HF and Ollama intelligently)
33
+ logger.info("Initializing Hybrid Provider (HF + Ollama)")
34
+ return HybridProvider(
35
+ model_name="hybrid_model"
36
+ )
37
+ except Exception as e:
38
+ logger.warning(f"Failed to initialize Hybrid provider: {e}")
39
+
40
+ # Fallback to individual providers
41
+ if config.hf_token:
42
+ status = hf_monitor.get_endpoint_status()
43
+ if status["available"]:
44
+ try:
45
+ logger.info("Falling back to HF Endpoint")
46
+ return HuggingFaceProvider(
47
+ model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
48
+ )
49
+ except Exception as hf_error:
50
+ logger.warning(f"Failed to initialize HF provider: {hf_error}")
51
+
52
+ if config.ollama_host:
53
  try:
54
+ logger.info("Falling back to Ollama")
55
+ return OllamaProvider(
56
+ model_name=config.local_model_name
57
  )
58
+ except Exception as ollama_error:
59
+ logger.warning(f"Failed to initialize Ollama provider: {ollama_error}")
 
 
 
 
 
 
 
 
 
 
60
 
61
  raise ProviderNotAvailableError("No LLM providers are available or configured")
62
 
src/llm/hybrid_provider.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ from typing import List, Dict, Optional, Union
4
+ from src.llm.base_provider import LLMProvider
5
+ from src.llm.hf_provider import HuggingFaceProvider
6
+ from src.llm.ollama_provider import OllamaProvider
7
+ from core.session import session_manager
8
+ from utils.config import config
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class HybridProvider(LLMProvider):
13
+ """Hybrid provider that uses HF for heavy lifting and Ollama for local caching/summarization"""
14
+
15
+ def __init__(self, model_name: str, timeout: int = 120, max_retries: int = 2):
16
+ super().__init__(model_name, timeout, max_retries)
17
+ self.hf_provider = None
18
+ self.ollama_provider = None
19
+
20
+ # Initialize providers
21
+ try:
22
+ if config.hf_token:
23
+ self.hf_provider = HuggingFaceProvider(
24
+ model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf",
25
+ timeout=120
26
+ )
27
+ except Exception as e:
28
+ logger.warning(f"Failed to initialize HF provider: {e}")
29
+
30
+ try:
31
+ if config.ollama_host:
32
+ self.ollama_provider = OllamaProvider(
33
+ model_name=config.local_model_name,
34
+ timeout=60
35
+ )
36
+ except Exception as e:
37
+ logger.warning(f"Failed to initialize Ollama provider: {e}")
38
+
39
+ def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
40
+ """Generate response using hybrid approach"""
41
+ try:
42
+ # Step 1: Get heavy lifting from HF Endpoint
43
+ hf_response = self._get_hf_response(prompt, conversation_history)
44
+
45
+ if not hf_response:
46
+ raise Exception("HF Endpoint failed to provide response")
47
+
48
+ # Step 2: Store HF response in local cache via Ollama
49
+ self._cache_response_locally(prompt, hf_response, conversation_history)
50
+
51
+ # Step 3: Optionally create local summary (if needed)
52
+ # For now, return HF response directly but with local backup
53
+ return hf_response
54
+
55
+ except Exception as e:
56
+ logger.error(f"Hybrid generation failed: {e}")
57
+
58
+ # Fallback to Ollama if available
59
+ if self.ollama_provider:
60
+ try:
61
+ logger.info("Falling back to Ollama for local response")
62
+ return self.ollama_provider.generate(prompt, conversation_history)
63
+ except Exception as fallback_error:
64
+ logger.error(f"Ollama fallback also failed: {fallback_error}")
65
+
66
+ raise Exception(f"Both HF Endpoint and Ollama failed: {str(e)}")
67
+
68
+ def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
69
+ """Stream response using hybrid approach"""
70
+ try:
71
+ # Get streaming response from HF
72
+ if self.hf_provider:
73
+ return self.hf_provider.stream_generate(prompt, conversation_history)
74
+ elif self.ollama_provider:
75
+ return self.ollama_provider.stream_generate(prompt, conversation_history)
76
+ else:
77
+ raise Exception("No providers available")
78
+ except Exception as e:
79
+ logger.error(f"Hybrid stream generation failed: {e}")
80
+ raise
81
+
82
+ def _get_hf_response(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
83
+ """Get response from HF Endpoint with fallback handling"""
84
+ if not self.hf_provider:
85
+ return None
86
+
87
+ try:
88
+ logger.info("πŸš€ Getting detailed response from HF Endpoint...")
89
+ response = self.hf_provider.generate(prompt, conversation_history)
90
+ logger.info("βœ… HF Endpoint response received")
91
+ return response
92
+ except Exception as e:
93
+ logger.error(f"HF Endpoint failed: {e}")
94
+ # Don't raise here, let hybrid provider handle fallback
95
+ return None
96
+
97
+ def _cache_response_locally(self, prompt: str, response: str, conversation_history: List[Dict]):
98
+ """Cache HF response locally using Ollama for future reference"""
99
+ if not self.ollama_provider:
100
+ return
101
+
102
+ try:
103
+ # Create a simplified cache entry
104
+ cache_prompt = f"Cache this response for future reference:\n\nQuestion: {prompt}\n\nResponse: {response[:500]}..."
105
+
106
+ # Store in local Ollama for quick retrieval
107
+ # This helps if HF connection fails later
108
+ logger.info("πŸ’Ύ Caching response locally with Ollama...")
109
+ self.ollama_provider.generate(cache_prompt, [])
110
+
111
+ # Also store in Redis session for persistence
112
+ self._store_in_session_cache(prompt, response)
113
+
114
+ except Exception as e:
115
+ logger.warning(f"Failed to cache response locally: {e}")
116
+
117
+ def _store_in_session_cache(self, prompt: str, response: str):
118
+ """Store response in Redis session cache"""
119
+ try:
120
+ user_session = session_manager.get_session("default_user")
121
+ cache = user_session.get("response_cache", {})
122
+
123
+ # Simple cache key
124
+ cache_key = hash(prompt) % 1000000
125
+ cache[str(cache_key)] = {
126
+ "prompt": prompt,
127
+ "response": response,
128
+ "timestamp": time.time()
129
+ }
130
+
131
+ # Keep only last 50 cached responses
132
+ if len(cache) > 50:
133
+ # Remove oldest entries
134
+ sorted_keys = sorted(cache.keys(), key=lambda k: cache[k]["timestamp"])
135
+ for key in sorted_keys[:-50]:
136
+ del cache[key]
137
+
138
+ user_session["response_cache"] = cache
139
+ session_manager.update_session("default_user", user_session)
140
+
141
+ except Exception as e:
142
+ logger.warning(f"Failed to store in session cache: {e}")
143
+
144
+ def get_cached_response(self, prompt: str) -> Optional[str]:
145
+ """Get cached response if available"""
146
+ try:
147
+ user_session = session_manager.get_session("default_user")
148
+ cache = user_session.get("response_cache", {})
149
+
150
+ cache_key = str(hash(prompt) % 1000000)
151
+ if cache_key in cache:
152
+ cached_entry = cache[cache_key]
153
+ # Check if cache is still valid (1 hour)
154
+ if time.time() - cached_entry["timestamp"] < 3600:
155
+ return cached_entry["response"]
156
+ except Exception as e:
157
+ logger.warning(f"Failed to retrieve cached response: {e}")
158
+
159
+ return None
src/ui/chat_handler.py CHANGED
@@ -4,12 +4,11 @@ import logging
4
  from typing import Optional
5
  from src.llm.factory import llm_factory, ProviderNotAvailableError
6
  from core.session import session_manager
7
- from utils.config import config
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class ChatHandler:
12
- """Handles chat interactions with better UI feedback"""
13
 
14
  def __init__(self):
15
  self.is_processing = False
@@ -54,7 +53,7 @@ class ChatHandler:
54
  st.session_state.last_processed_message = ""
55
 
56
  def process_ai_response(self, user_input: str, selected_model: str):
57
- """Process AI response after user message is displayed"""
58
  if not user_input or not user_input.strip():
59
  return
60
 
@@ -65,14 +64,25 @@ class ChatHandler:
65
  response_placeholder = st.empty()
66
 
67
  try:
68
- # Determine provider based on selection and availability
69
- provider_name = self._get_best_provider(selected_model)
70
- status_placeholder.info(f"πŸš€ Contacting {self._get_provider_display_name(provider_name)}...")
71
 
72
- # Get response with timeout handling
 
 
 
 
 
 
73
  response = None
74
  try:
75
- response = self._get_ai_response(user_input, provider_name)
 
 
 
 
 
76
  except Exception as e:
77
  logger.error(f"AI response error: {e}")
78
  raise
@@ -81,13 +91,19 @@ class ChatHandler:
81
  status_placeholder.success("βœ… Response received!")
82
  response_placeholder.markdown(response)
83
 
84
- # Add to session history
85
  timestamp = time.strftime("%H:%M:%S")
 
 
 
 
 
 
86
  st.session_state.messages.append({
87
  "role": "assistant",
88
  "content": response,
89
  "timestamp": timestamp,
90
- "provider": provider_name
91
  })
92
  else:
93
  status_placeholder.warning("⚠️ Empty response received")
@@ -97,7 +113,7 @@ class ChatHandler:
97
  "role": "assistant",
98
  "content": "*No response generated. Please try again.*",
99
  "timestamp": timestamp,
100
- "provider": provider_name
101
  })
102
 
103
  except ProviderNotAvailableError as e:
@@ -112,24 +128,15 @@ class ChatHandler:
112
  logger.error(f"Provider not available: {e}")
113
 
114
  except Exception as e:
115
- # Better user-friendly error messages
116
  status_placeholder.error("❌ Request failed")
117
 
118
- # More specific error messages
119
  if "timeout" in str(e).lower() or "500" in str(e):
120
- error_message = ("⏰ Request failed. This might be because:\n"
121
- "β€’ The AI model is taking too long to respond\n"
122
- "β€’ The provider is overloaded\n\n"
123
- "**Try one of these solutions:**\n"
124
- "1. Use the HF Endpoint (🟒 HF Endpoint: Available and ready)\n"
125
- "2. Wait a moment and try again\n"
126
- "3. Simplify your question")
127
- elif "connection" in str(e).lower():
128
- error_message = ("πŸ”Œ Connection failed. This might be because:\n"
129
- "β€’ Your Ollama server is offline\n"
130
- "β€’ Incorrect Ollama URL\n"
131
- "β€’ Network firewall blocking connection\n"
132
- "β€’ Try using the HF Endpoint instead")
133
  else:
134
  error_message = f"Sorry, I encountered an error: {str(e)}"
135
 
@@ -151,65 +158,14 @@ class ChatHandler:
151
  st.session_state.last_processed_message = ""
152
  time.sleep(0.1)
153
 
154
- def _get_best_provider(self, selected_model: str) -> str:
155
- """Determine the best available provider"""
156
- from src.services.hf_monitor import hf_monitor
157
-
158
- # If user selected specific provider, try that
159
- if selected_model == "ollama" and config.ollama_host:
160
- return "ollama"
161
- elif selected_model == "huggingface" and config.hf_token:
162
- return "huggingface"
163
-
164
- # Auto-select based on availability
165
- if config.hf_token:
166
- status = hf_monitor.get_endpoint_status()
167
- if status["available"]:
168
- return "huggingface"
169
-
170
- if config.ollama_host:
171
- return "ollama"
172
-
173
- return "ollama" # Default fallback
174
-
175
  def _get_provider_display_name(self, provider_name: str) -> str:
176
  """Get display name for provider"""
177
  display_names = {
178
- "ollama": "πŸ¦™ Ollama",
179
- "huggingface": "πŸ€— HF Endpoint"
 
180
  }
181
  return display_names.get(provider_name, provider_name)
182
-
183
- def _get_ai_response(self, user_input: str, provider_name: str) -> Optional[str]:
184
- """Get AI response from specified provider"""
185
- try:
186
- # Get session and conversation history
187
- user_session = session_manager.get_session("default_user")
188
- conversation_history = user_session.get("conversation", []).copy()
189
-
190
- # Add current user message
191
- conversation_history.append({"role": "user", "content": user_input})
192
-
193
- # Get provider (with intelligent fallback)
194
- provider = llm_factory.get_provider(provider_name)
195
-
196
- # Generate response with timeout
197
- logger.info(f"Generating response with {provider_name} provider")
198
- response = provider.generate(user_input, conversation_history)
199
- logger.info(f"Received response from {provider_name}: {response[:100] if response else 'None'}")
200
-
201
- # Update session with conversation
202
- if response:
203
- conversation = user_session.get("conversation", []).copy()
204
- conversation.append({"role": "user", "content": user_input})
205
- conversation.append({"role": "assistant", "content": response})
206
- session_manager.update_session("default_user", {"conversation": conversation})
207
-
208
- return response
209
-
210
- except Exception as e:
211
- logger.error(f"AI response generation failed: {e}", exc_info=True)
212
- raise
213
 
214
  # Global instance
215
  chat_handler = ChatHandler()
 
4
  from typing import Optional
5
  from src.llm.factory import llm_factory, ProviderNotAvailableError
6
  from core.session import session_manager
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
  class ChatHandler:
11
+ """Handles chat interactions with hybrid AI approach"""
12
 
13
  def __init__(self):
14
  self.is_processing = False
 
53
  st.session_state.last_processed_message = ""
54
 
55
  def process_ai_response(self, user_input: str, selected_model: str):
56
+ """Process AI response with hybrid approach"""
57
  if not user_input or not user_input.strip():
58
  return
59
 
 
64
  response_placeholder = st.empty()
65
 
66
  try:
67
+ # Get hybrid provider
68
+ status_placeholder.info("πŸš€ Contacting AI providers...")
69
+ provider = llm_factory.get_provider()
70
 
71
+ # Show which approach is being used
72
+ if hasattr(provider, 'hf_provider') and provider.hf_provider:
73
+ status_placeholder.info("🧠 Getting detailed response from HF Endpoint...")
74
+ else:
75
+ status_placeholder.info("πŸ¦™ Getting response from local Ollama...")
76
+
77
+ # Get response
78
  response = None
79
  try:
80
+ # Get session and conversation history
81
+ user_session = session_manager.get_session("default_user")
82
+ conversation_history = user_session.get("conversation", []).copy()
83
+ conversation_history.append({"role": "user", "content": user_input})
84
+
85
+ response = provider.generate(user_input, conversation_history)
86
  except Exception as e:
87
  logger.error(f"AI response error: {e}")
88
  raise
 
91
  status_placeholder.success("βœ… Response received!")
92
  response_placeholder.markdown(response)
93
 
94
+ # Add to session history with provider info
95
  timestamp = time.strftime("%H:%M:%S")
96
+ provider_info = "hybrid"
97
+ if hasattr(provider, 'hf_provider') and provider.hf_provider:
98
+ provider_info = "hf_endpoint"
99
+ elif hasattr(provider, 'ollama_provider') and provider.ollama_provider:
100
+ provider_info = "ollama"
101
+
102
  st.session_state.messages.append({
103
  "role": "assistant",
104
  "content": response,
105
  "timestamp": timestamp,
106
+ "provider": provider_info
107
  })
108
  else:
109
  status_placeholder.warning("⚠️ Empty response received")
 
113
  "role": "assistant",
114
  "content": "*No response generated. Please try again.*",
115
  "timestamp": timestamp,
116
+ "provider": "unknown"
117
  })
118
 
119
  except ProviderNotAvailableError as e:
 
128
  logger.error(f"Provider not available: {e}")
129
 
130
  except Exception as e:
 
131
  status_placeholder.error("❌ Request failed")
132
 
133
+ # User-friendly error messages
134
  if "timeout" in str(e).lower() or "500" in str(e):
135
+ error_message = ("⏰ Request timed out. The AI model is taking too long to respond.\n\n"
136
+ "**Current setup:**\n"
137
+ "β€’ πŸ€– HF Endpoint: Doing heavy lifting\n"
138
+ "β€’ πŸ¦™ Ollama: Providing local backup\n\n"
139
+ "Please try again or simplify your question.")
 
 
 
 
 
 
 
 
140
  else:
141
  error_message = f"Sorry, I encountered an error: {str(e)}"
142
 
 
158
  st.session_state.last_processed_message = ""
159
  time.sleep(0.1)
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  def _get_provider_display_name(self, provider_name: str) -> str:
162
  """Get display name for provider"""
163
  display_names = {
164
+ "ollama": "πŸ¦™ Ollama (Local)",
165
+ "hf_endpoint": "πŸ€— HF Endpoint (Heavy Lifting)",
166
+ "hybrid": "πŸ”„ Hybrid (HF + Ollama)"
167
  }
168
  return display_names.get(provider_name, provider_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Global instance
171
  chat_handler = ChatHandler()