rdune71 commited on
Commit
59f10cb
·
1 Parent(s): adf8222

Enhance modular LLM provider interface with HF endpoint integration

Browse files
src/config/llm_config.py CHANGED
@@ -4,23 +4,23 @@ from typing import Optional
4
  class LLMConfig:
5
  """Configuration loader for LLM providers"""
6
 
7
- @staticmethod
8
- def get_active_provider() -> Optional[str]:
9
- """Get the name of the active provider based on environment variables"""
10
- if os.getenv("HF_TOKEN"):
11
- return "huggingface"
12
- # elif os.getenv("OLLAMA_HOST"):
13
- # return "ollama"
14
- # elif os.getenv("OPENAI_API_KEY"):
15
- # return "openai"
16
- return None
 
 
 
 
 
 
 
17
 
18
- @staticmethod
19
- def get_provider_model(provider: str) -> str:
20
- """Get the model name for a given provider"""
21
- model_map = {
22
- "huggingface": os.getenv("HF_MODEL_NAME", "meta-llama/Llama-2-7b-chat-hf"),
23
- # "ollama": os.getenv("LOCAL_MODEL_NAME", "mistral:latest"),
24
- # "openai": "gpt-3.5-turbo"
25
- }
26
- return model_map.get(provider, "unknown-model")
 
4
  class LLMConfig:
5
  """Configuration loader for LLM providers"""
6
 
7
+ def __init__(self):
8
+ # Load all environment variables
9
+ self.hf_token = os.getenv("HF_TOKEN")
10
+ self.ollama_host = os.getenv("OLLAMA_HOST")
11
+ self.local_model_name = os.getenv("LOCAL_MODEL_NAME", "mistral:latest")
12
+ self.hf_api_url = os.getenv("HF_API_ENDPOINT_URL", "https://zxzbfrlg3ssrk7d9.us-east-1.aws.endpoints.huggingface.cloud/v1/")
13
+ self.use_fallback = os.getenv("USE_FALLBACK", "true").lower() == "true"
14
+ self.openweather_api_key = os.getenv("OPENWEATHER_API_KEY")
15
+ self.nasa_api_key = os.getenv("NASA_API_KEY")
16
+ self.tavily_api_key = os.getenv("TAVILY_API_KEY")
17
+ self.redis_host = os.getenv("REDIS_HOST")
18
+ self.redis_port = os.getenv("REDIS_PORT")
19
+ self.redis_username = os.getenv("REDIS_USERNAME")
20
+ self.redis_password = os.getenv("REDIS_PASSWORD")
21
+
22
+ # Detect if running on HF Spaces
23
+ self.is_hf_space = bool(os.getenv("SPACE_ID"))
24
 
25
+ # Global config instance
26
+ config = LLMConfig()
 
 
 
 
 
 
 
src/llm/factory.py CHANGED
@@ -1,17 +1,18 @@
1
- import os
2
  from typing import Optional
3
  from src.llm.base_provider import LLMProvider
4
  from src.llm.hf_provider import HuggingFaceProvider
5
- # Import other providers as they are implemented
6
- # from src.llm.ollama_provider import OllamaProvider
7
- # from src.llm.openai_provider import OpenAIProvider
 
8
 
9
  class ProviderNotAvailableError(Exception):
10
  """Raised when no provider is available"""
11
  pass
12
 
13
  class LLMFactory:
14
- """Factory for creating LLM providers"""
15
 
16
  _instance = None
17
 
@@ -25,7 +26,7 @@ class LLMFactory:
25
  Get an LLM provider based on preference and availability.
26
 
27
  Args:
28
- preferred_provider: Preferred provider name ('huggingface', 'ollama', 'openai')
29
 
30
  Returns:
31
  LLMProvider instance
@@ -33,18 +34,43 @@ class LLMFactory:
33
  Raises:
34
  ProviderNotAvailableError: When no providers are available
35
  """
36
- # For now, we only have HF provider implemented
37
- if preferred_provider == "huggingface" or (preferred_provider is None and os.getenv("HF_TOKEN")):
38
- return HuggingFaceProvider(
39
- model_name=os.getenv("HF_MODEL_NAME", "meta-llama/Llama-2-7b-chat-hf")
40
- )
41
-
42
- # Add other providers as they are implemented
43
- # elif preferred_provider == "ollama" or (preferred_provider is None and os.getenv("OLLAMA_HOST")):
44
- # return OllamaProvider(model_name=os.getenv("LOCAL_MODEL_NAME", "mistral:latest"))
45
- # elif preferred_provider == "openai" or (preferred_provider is None and os.getenv("OPENAI_API_KEY")):
46
- # return OpenAIProvider(model_name="gpt-3.5-turbo")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  raise ProviderNotAvailableError("No LLM providers are available or configured")
49
 
50
  # Global factory instance
 
1
+ import logging
2
  from typing import Optional
3
  from src.llm.base_provider import LLMProvider
4
  from src.llm.hf_provider import HuggingFaceProvider
5
+ from src.llm.ollama_provider import OllamaProvider
6
+ from utils.config import config
7
+
8
+ logger = logging.getLogger(__name__)
9
 
10
  class ProviderNotAvailableError(Exception):
11
  """Raised when no provider is available"""
12
  pass
13
 
14
  class LLMFactory:
15
+ """Factory for creating LLM providers with fallback support"""
16
 
17
  _instance = None
18
 
 
26
  Get an LLM provider based on preference and availability.
27
 
28
  Args:
29
+ preferred_provider: Preferred provider name ('huggingface', 'ollama')
30
 
31
  Returns:
32
  LLMProvider instance
 
34
  Raises:
35
  ProviderNotAvailableError: When no providers are available
36
  """
37
+ # Check preferred provider first
38
+ if preferred_provider == "huggingface" and config.hf_token:
39
+ try:
40
+ return HuggingFaceProvider(
41
+ model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
42
+ )
43
+ except Exception as e:
44
+ logger.warning(f"Failed to initialize HF provider: {e}")
 
 
 
45
 
46
+ elif preferred_provider == "ollama" and config.ollama_host:
47
+ try:
48
+ return OllamaProvider(
49
+ model_name=config.local_model_name
50
+ )
51
+ except Exception as e:
52
+ logger.warning(f"Failed to initialize Ollama provider: {e}")
53
+
54
+ # Fallback logic based on configuration
55
+ if config.use_fallback:
56
+ # Try HF first if configured
57
+ if config.hf_token:
58
+ try:
59
+ return HuggingFaceProvider(
60
+ model_name="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf"
61
+ )
62
+ except Exception as e:
63
+ logger.warning(f"Failed to initialize HF provider: {e}")
64
+
65
+ # Then try Ollama if configured
66
+ if config.ollama_host:
67
+ try:
68
+ return OllamaProvider(
69
+ model_name=config.local_model_name
70
+ )
71
+ except Exception as e:
72
+ logger.warning(f"Failed to initialize Ollama provider: {e}")
73
+
74
  raise ProviderNotAvailableError("No LLM providers are available or configured")
75
 
76
  # Global factory instance
src/llm/hf_provider.py CHANGED
@@ -1,20 +1,101 @@
 
 
1
  from typing import List, Dict, Optional, Union
2
  from src.llm.base_provider import LLMProvider
 
 
 
 
 
 
 
 
 
 
3
 
4
  class HuggingFaceProvider(LLMProvider):
5
- """Hugging Face LLM provider stub implementation"""
6
 
7
- def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
8
  super().__init__(model_name, timeout, max_retries)
9
- # Placeholder for actual client initialization
10
- self.client = None
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
13
- """Stub for synchronous generation"""
14
- # In a real implementation, this would call the HF API
15
- return f"[HuggingFace Stub] Response to: {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
18
- """Stub for streaming generation"""
19
- # In a real implementation, this would stream from the HF API
20
- return ["[HuggingFace", " Streaming", " Stub]"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
  from typing import List, Dict, Optional, Union
4
  from src.llm.base_provider import LLMProvider
5
+ from utils.config import config
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ try:
10
+ from openai import OpenAI
11
+ HF_SDK_AVAILABLE = True
12
+ except ImportError:
13
+ HF_SDK_AVAILABLE = False
14
+ OpenAI = None
15
 
16
  class HuggingFaceProvider(LLMProvider):
17
+ """Hugging Face LLM provider for your custom endpoint"""
18
 
19
+ def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3):
20
  super().__init__(model_name, timeout, max_retries)
21
+
22
+ if not HF_SDK_AVAILABLE:
23
+ raise ImportError("Hugging Face provider requires 'openai' package")
24
+
25
+ if not config.hf_token:
26
+ raise ValueError("HF_TOKEN not set - required for Hugging Face provider")
27
+
28
+ # Use your specific endpoint URL
29
+ self.client = OpenAI(
30
+ base_url=config.hf_api_url,
31
+ api_key=config.hf_token
32
+ )
33
+ logger.info(f"Initialized HF provider with endpoint: {config.hf_api_url}")
34
 
35
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
36
+ """Generate a response synchronously"""
37
+ try:
38
+ response = self.client.chat.completions.create(
39
+ model=self.model_name,
40
+ messages=conversation_history,
41
+ max_tokens=8192,
42
+ temperature=0.7,
43
+ stream=False
44
+ )
45
+ return response.choices[0].message.content
46
+ except Exception as e:
47
+ logger.error(f"HF generation failed: {e}")
48
+ # Handle scale-to-zero behavior
49
+ if "503" in str(e) or "service unavailable" in str(e).lower():
50
+ logger.info("HF endpoint is scaling up, waiting...")
51
+ time.sleep(60) # Wait for endpoint to initialize
52
+ # Retry once
53
+ response = self.client.chat.completions.create(
54
+ model=self.model_name,
55
+ messages=conversation_history,
56
+ max_tokens=8192,
57
+ temperature=0.7,
58
+ stream=False
59
+ )
60
+ return response.choices[0].message.content
61
+ raise
62
 
63
  def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
64
+ """Generate a response with streaming support"""
65
+ try:
66
+ response = self.client.chat.completions.create(
67
+ model=self.model_name,
68
+ messages=conversation_history,
69
+ max_tokens=8192,
70
+ temperature=0.7,
71
+ stream=True
72
+ )
73
+
74
+ chunks = []
75
+ for chunk in response:
76
+ content = chunk.choices[0].delta.content
77
+ if content:
78
+ chunks.append(content)
79
+ return chunks
80
+ except Exception as e:
81
+ logger.error(f"HF stream generation failed: {e}")
82
+ # Handle scale-to-zero behavior
83
+ if "503" in str(e) or "service unavailable" in str(e).lower():
84
+ logger.info("HF endpoint is scaling up, waiting...")
85
+ time.sleep(60) # Wait for endpoint to initialize
86
+ # Retry once
87
+ response = self.client.chat.completions.create(
88
+ model=self.model_name,
89
+ messages=conversation_history,
90
+ max_tokens=8192,
91
+ temperature=0.7,
92
+ stream=True
93
+ )
94
+
95
+ chunks = []
96
+ for chunk in response:
97
+ content = chunk.choices[0].delta.content
98
+ if content:
99
+ chunks.append(content)
100
+ return chunks
101
+ raise
src/llm/ollama_provider.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import logging
3
+ import re
4
+ from typing import List, Dict, Optional, Union
5
+ from src.llm.base_provider import LLMProvider
6
+ from utils.config import config
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class OllamaProvider(LLMProvider):
11
+ """Ollama LLM provider implementation"""
12
+
13
+ def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3):
14
+ super().__init__(model_name, timeout, max_retries)
15
+ self.host = self._sanitize_host(config.ollama_host or "http://localhost:11434")
16
+ self.headers = {
17
+ "ngrok-skip-browser-warning": "true",
18
+ "User-Agent": "CosmicCat-AI-Assistant"
19
+ }
20
+
21
+ def _sanitize_host(self, host: str) -> str:
22
+ """Sanitize host URL by removing whitespace and control characters"""
23
+ if not host:
24
+ return "http://localhost:11434"
25
+ host = host.strip()
26
+ host = re.sub(r'[\r\n\t\0]+', '', host)
27
+ if not host.startswith(('http://', 'https://')):
28
+ host = 'http://' + host
29
+ return host
30
+
31
+ def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
32
+ """Generate a response synchronously"""
33
+ try:
34
+ url = f"{self.host}/api/chat"
35
+ payload = {
36
+ "model": self.model_name,
37
+ "messages": conversation_history,
38
+ "stream": False
39
+ }
40
+
41
+ response = requests.post(
42
+ url,
43
+ json=payload,
44
+ headers=self.headers,
45
+ timeout=self.timeout
46
+ )
47
+ response.raise_for_status()
48
+ result = response.json()
49
+ return result["message"]["content"]
50
+ except Exception as e:
51
+ logger.error(f"Ollama generation failed: {e}")
52
+ raise
53
+
54
+ def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
55
+ """Generate a response with streaming support"""
56
+ try:
57
+ url = f"{self.host}/api/chat"
58
+ payload = {
59
+ "model": self.model_name,
60
+ "messages": conversation_history,
61
+ "stream": True
62
+ }
63
+
64
+ response = requests.post(
65
+ url,
66
+ json=payload,
67
+ headers=self.headers,
68
+ timeout=self.timeout,
69
+ stream=True
70
+ )
71
+ response.raise_for_status()
72
+
73
+ chunks = []
74
+ for line in response.iter_lines():
75
+ if line:
76
+ chunk = line.decode('utf-8')
77
+ try:
78
+ data = eval(chunk)
79
+ content = data.get("message", {}).get("content", "")
80
+ if content:
81
+ chunks.append(content)
82
+ except:
83
+ continue
84
+ return chunks
85
+ except Exception as e:
86
+ logger.error(f"Ollama stream generation failed: {e}")
87
+ raise