rdune71 commited on
Commit
83ce746
·
1 Parent(s): 59f10cb

Implement circuit breaker pattern and enhanced fallback logic

Browse files
src/llm/base_provider.py CHANGED
@@ -1,13 +1,23 @@
 
 
1
  from abc import ABC, abstractmethod
2
  from typing import List, Dict, Optional, Union
3
 
 
 
4
  class LLMProvider(ABC):
5
- """Abstract base class for all LLM providers"""
6
 
7
  def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
8
  self.model_name = model_name
9
  self.timeout = timeout
10
  self.max_retries = max_retries
 
 
 
 
 
 
11
 
12
  @abstractmethod
13
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
@@ -18,3 +28,58 @@ class LLMProvider(ABC):
18
  def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
19
  """Generate a response with streaming support"""
20
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
  from abc import ABC, abstractmethod
4
  from typing import List, Dict, Optional, Union
5
 
6
+ logger = logging.getLogger(__name__)
7
+
8
  class LLMProvider(ABC):
9
+ """Abstract base class for all LLM providers with circuit breaker"""
10
 
11
  def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
12
  self.model_name = model_name
13
  self.timeout = timeout
14
  self.max_retries = max_retries
15
+
16
+ # Circuit breaker properties
17
+ self.failure_count = 0
18
+ self.last_failure_time = None
19
+ self.circuit_open = False
20
+ self.reset_timeout = 60 # Reset circuit after 60 seconds
21
 
22
  @abstractmethod
23
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
 
28
  def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
29
  """Generate a response with streaming support"""
30
  pass
31
+
32
+ def _check_circuit_breaker(self) -> bool:
33
+ """Check if circuit breaker is open (preventing calls)"""
34
+ if not self.circuit_open:
35
+ return True
36
+
37
+ # Check if enough time has passed to reset
38
+ if self.last_failure_time and (time.time() - self.last_failure_time) > self.reset_timeout:
39
+ logger.info("Circuit breaker reset - allowing call")
40
+ self.circuit_open = False
41
+ self.failure_count = 0
42
+ return True
43
+
44
+ logger.warning("Circuit breaker is OPEN - preventing call")
45
+ return False
46
+
47
+ def _handle_failure(self, error: Exception):
48
+ """Handle failure and update circuit breaker"""
49
+ self.failure_count += 1
50
+ self.last_failure_time = time.time()
51
+
52
+ # Open circuit after 3 failures
53
+ if self.failure_count >= 3:
54
+ self.circuit_open = True
55
+ logger.warning(f"Circuit breaker OPEN for {self.__class__.__name__} after {self.failure_count} failures")
56
+
57
+ raise error
58
+
59
+ def _retry_with_backoff(self, func, *args, **kwargs):
60
+ """Retry logic with exponential backoff"""
61
+ last_exception = None
62
+
63
+ for attempt in range(self.max_retries):
64
+ try:
65
+ if not self._check_circuit_breaker():
66
+ raise Exception("Circuit breaker is open")
67
+
68
+ result = func(*args, **kwargs)
69
+ # Reset failure count on success
70
+ self.failure_count = 0
71
+ self.circuit_open = False
72
+ return result
73
+
74
+ except Exception as e:
75
+ last_exception = e
76
+ self._handle_failure(e)
77
+
78
+ if attempt < self.max_retries - 1:
79
+ sleep_time = min((2 ** attempt) * 1.0, 10.0) # Cap at 10 seconds
80
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {sleep_time}s...")
81
+ time.sleep(sleep_time)
82
+ else:
83
+ logger.error(f"All {self.max_retries} attempts failed. Last error: {str(e)}")
84
+
85
+ raise last_exception
src/llm/factory.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- from typing import Optional
3
  from src.llm.base_provider import LLMProvider
4
  from src.llm.hf_provider import HuggingFaceProvider
5
  from src.llm.ollama_provider import OllamaProvider
@@ -12,7 +12,7 @@ class ProviderNotAvailableError(Exception):
12
  pass
13
 
14
  class LLMFactory:
15
- """Factory for creating LLM providers with fallback support"""
16
 
17
  _instance = None
18
 
@@ -34,44 +34,77 @@ class LLMFactory:
34
  Raises:
35
  ProviderNotAvailableError: When no providers are available
36
  """
37
- # Check preferred provider first
38
- if preferred_provider == "huggingface" and config.hf_token:
39
- try:
40
- return HuggingFaceProvider(
41
- model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
42
- )
43
- except Exception as e:
44
- logger.warning(f"Failed to initialize HF provider: {e}")
45
 
46
- elif preferred_provider == "ollama" and config.ollama_host:
 
47
  try:
48
- return OllamaProvider(
49
- model_name=config.local_model_name
50
- )
 
 
 
 
 
51
  except Exception as e:
52
- logger.warning(f"Failed to initialize Ollama provider: {e}")
 
 
 
53
 
54
- # Fallback logic based on configuration
 
 
 
 
 
 
 
 
 
 
55
  if config.use_fallback:
56
- # Try HF first if configured
57
  if config.hf_token:
58
- try:
59
- return HuggingFaceProvider(
60
- model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
61
- )
62
- except Exception as e:
63
- logger.warning(f"Failed to initialize HF provider: {e}")
64
 
65
- # Then try Ollama if configured
66
  if config.ollama_host:
67
- try:
68
- return OllamaProvider(
69
- model_name=config.local_model_name
70
- )
71
- except Exception as e:
72
- logger.warning(f"Failed to initialize Ollama provider: {e}")
 
73
 
74
- raise ProviderNotAvailableError("No LLM providers are available or configured")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Global factory instance
77
  llm_factory = LLMFactory()
 
1
  import logging
2
+ from typing import Optional, List
3
  from src.llm.base_provider import LLMProvider
4
  from src.llm.hf_provider import HuggingFaceProvider
5
  from src.llm.ollama_provider import OllamaProvider
 
12
  pass
13
 
14
  class LLMFactory:
15
+ """Factory for creating LLM providers with intelligent fallback"""
16
 
17
  _instance = None
18
 
 
34
  Raises:
35
  ProviderNotAvailableError: When no providers are available
36
  """
37
+ # Build provider chain based on configuration and preference
38
+ provider_chain = self._build_provider_chain(preferred_provider)
 
 
 
 
 
 
39
 
40
+ # Try providers in order
41
+ for provider_name, provider_class, model_name in provider_chain:
42
  try:
43
+ logger.info(f"Attempting to initialize {provider_name} provider...")
44
+ provider = provider_class(model_name=model_name)
45
+ # Test that provider is working
46
+ if self._test_provider(provider):
47
+ logger.info(f"Successfully initialized {provider_name} provider")
48
+ return provider
49
+ else:
50
+ logger.warning(f"{provider_name} provider failed validation test")
51
  except Exception as e:
52
+ logger.warning(f"Failed to initialize {provider_name} provider: {e}")
53
+ continue
54
+
55
+ raise ProviderNotAvailableError("No LLM providers are available or configured")
56
 
57
+ def _build_provider_chain(self, preferred_provider: Optional[str]) -> List[tuple]:
58
+ """Build provider chain based on preference and configuration"""
59
+ chain = []
60
+
61
+ # Add preferred provider first if specified
62
+ if preferred_provider:
63
+ provider_info = self._get_provider_info(preferred_provider)
64
+ if provider_info:
65
+ chain.append(provider_info)
66
+
67
+ # Add fallback providers based on configuration
68
  if config.use_fallback:
69
+ # Add HF if configured
70
  if config.hf_token:
71
+ chain.append((
72
+ "huggingface",
73
+ HuggingFaceProvider,
74
+ "DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
75
+ ))
 
76
 
77
+ # Add Ollama if configured
78
  if config.ollama_host:
79
+ chain.append((
80
+ "ollama",
81
+ OllamaProvider,
82
+ config.local_model_name
83
+ ))
84
+
85
+ return chain
86
 
87
+ def _get_provider_info(self, provider_name: str) -> Optional[tuple]:
88
+ """Get provider class and model info"""
89
+ provider_map = {
90
+ "huggingface": (
91
+ "huggingface",
92
+ HuggingFaceProvider,
93
+ "DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
94
+ ),
95
+ "ollama": (
96
+ "ollama",
97
+ OllamaProvider,
98
+ config.local_model_name
99
+ )
100
+ }
101
+ return provider_map.get(provider_name)
102
+
103
+ def _test_provider(self, provider: LLMProvider) -> bool:
104
+ """Test if provider is working (stub implementation)"""
105
+ # In a real implementation, you might want to do a lightweight test
106
+ # For now, we'll assume initialization success means it's working
107
+ return True
108
 
109
  # Global factory instance
110
  llm_factory = LLMFactory()
src/llm/hf_provider.py CHANGED
@@ -34,6 +34,14 @@ class HuggingFaceProvider(LLMProvider):
34
 
35
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
36
  """Generate a response synchronously"""
 
 
 
 
 
 
 
 
37
  try:
38
  response = self.client.chat.completions.create(
39
  model=self.model_name,
@@ -44,9 +52,8 @@ class HuggingFaceProvider(LLMProvider):
44
  )
45
  return response.choices[0].message.content
46
  except Exception as e:
47
- logger.error(f"HF generation failed: {e}")
48
  # Handle scale-to-zero behavior
49
- if "503" in str(e) or "service unavailable" in str(e).lower():
50
  logger.info("HF endpoint is scaling up, waiting...")
51
  time.sleep(60) # Wait for endpoint to initialize
52
  # Retry once
@@ -60,8 +67,8 @@ class HuggingFaceProvider(LLMProvider):
60
  return response.choices[0].message.content
61
  raise
62
 
63
- def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
64
- """Generate a response with streaming support"""
65
  try:
66
  response = self.client.chat.completions.create(
67
  model=self.model_name,
@@ -78,9 +85,8 @@ class HuggingFaceProvider(LLMProvider):
78
  chunks.append(content)
79
  return chunks
80
  except Exception as e:
81
- logger.error(f"HF stream generation failed: {e}")
82
  # Handle scale-to-zero behavior
83
- if "503" in str(e) or "service unavailable" in str(e).lower():
84
  logger.info("HF endpoint is scaling up, waiting...")
85
  time.sleep(60) # Wait for endpoint to initialize
86
  # Retry once
@@ -99,3 +105,14 @@ class HuggingFaceProvider(LLMProvider):
99
  chunks.append(content)
100
  return chunks
101
  raise
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
36
  """Generate a response synchronously"""
37
+ return self._retry_with_backoff(self._generate_impl, prompt, conversation_history)
38
+
39
+ def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
40
+ """Generate a response with streaming support"""
41
+ return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history)
42
+
43
+ def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str:
44
+ """Implementation of synchronous generation"""
45
  try:
46
  response = self.client.chat.completions.create(
47
  model=self.model_name,
 
52
  )
53
  return response.choices[0].message.content
54
  except Exception as e:
 
55
  # Handle scale-to-zero behavior
56
+ if self._is_scale_to_zero_error(e):
57
  logger.info("HF endpoint is scaling up, waiting...")
58
  time.sleep(60) # Wait for endpoint to initialize
59
  # Retry once
 
67
  return response.choices[0].message.content
68
  raise
69
 
70
+ def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]:
71
+ """Implementation of streaming generation"""
72
  try:
73
  response = self.client.chat.completions.create(
74
  model=self.model_name,
 
85
  chunks.append(content)
86
  return chunks
87
  except Exception as e:
 
88
  # Handle scale-to-zero behavior
89
+ if self._is_scale_to_zero_error(e):
90
  logger.info("HF endpoint is scaling up, waiting...")
91
  time.sleep(60) # Wait for endpoint to initialize
92
  # Retry once
 
105
  chunks.append(content)
106
  return chunks
107
  raise
108
+
109
+ def _is_scale_to_zero_error(self, error: Exception) -> bool:
110
+ """Check if the error is related to scale-to-zero initialization"""
111
+ error_str = str(error).lower()
112
+ scale_to_zero_indicators = [
113
+ "503",
114
+ "service unavailable",
115
+ "initializing",
116
+ "cold start"
117
+ ]
118
+ return any(indicator in error_str for indicator in scale_to_zero_indicators)
src/llm/ollama_provider.py CHANGED
@@ -30,6 +30,14 @@ class OllamaProvider(LLMProvider):
30
 
31
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
32
  """Generate a response synchronously"""
 
 
 
 
 
 
 
 
33
  try:
34
  url = f"{self.host}/api/chat"
35
  payload = {
@@ -51,8 +59,8 @@ class OllamaProvider(LLMProvider):
51
  logger.error(f"Ollama generation failed: {e}")
52
  raise
53
 
54
- def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
55
- """Generate a response with streaming support"""
56
  try:
57
  url = f"{self.host}/api/chat"
58
  payload = {
 
30
 
31
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
32
  """Generate a response synchronously"""
33
+ return self._retry_with_backoff(self._generate_impl, prompt, conversation_history)
34
+
35
+ def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
36
+ """Generate a response with streaming support"""
37
+ return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history)
38
+
39
+ def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str:
40
+ """Implementation of synchronous generation"""
41
  try:
42
  url = f"{self.host}/api/chat"
43
  payload = {
 
59
  logger.error(f"Ollama generation failed: {e}")
60
  raise
61
 
62
+ def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]:
63
+ """Implementation of streaming generation"""
64
  try:
65
  url = f"{self.host}/api/chat"
66
  payload = {