Implement modular LLM provider interface with Hugging Face support
Browse files- src/config/llm_config.py +26 -0
- src/llm/base_provider.py +20 -0
- src/llm/factory.py +51 -0
- src/llm/hf_provider.py +20 -0
- src/services/chat_service.py +40 -0
src/config/llm_config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
class LLMConfig:
|
| 5 |
+
"""Configuration loader for LLM providers"""
|
| 6 |
+
|
| 7 |
+
@staticmethod
|
| 8 |
+
def get_active_provider() -> Optional[str]:
|
| 9 |
+
"""Get the name of the active provider based on environment variables"""
|
| 10 |
+
if os.getenv("HF_TOKEN"):
|
| 11 |
+
return "huggingface"
|
| 12 |
+
# elif os.getenv("OLLAMA_HOST"):
|
| 13 |
+
# return "ollama"
|
| 14 |
+
# elif os.getenv("OPENAI_API_KEY"):
|
| 15 |
+
# return "openai"
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def get_provider_model(provider: str) -> str:
|
| 20 |
+
"""Get the model name for a given provider"""
|
| 21 |
+
model_map = {
|
| 22 |
+
"huggingface": os.getenv("HF_MODEL_NAME", "meta-llama/Llama-2-7b-chat-hf"),
|
| 23 |
+
# "ollama": os.getenv("LOCAL_MODEL_NAME", "mistral:latest"),
|
| 24 |
+
# "openai": "gpt-3.5-turbo"
|
| 25 |
+
}
|
| 26 |
+
return model_map.get(provider, "unknown-model")
|
src/llm/base_provider.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
class LLMProvider(ABC):
|
| 5 |
+
"""Abstract base class for all LLM providers"""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
|
| 8 |
+
self.model_name = model_name
|
| 9 |
+
self.timeout = timeout
|
| 10 |
+
self.max_retries = max_retries
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
|
| 14 |
+
"""Generate a response synchronously"""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
|
| 19 |
+
"""Generate a response with streaming support"""
|
| 20 |
+
pass
|
src/llm/factory.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from src.llm.base_provider import LLMProvider
|
| 4 |
+
from src.llm.hf_provider import HuggingFaceProvider
|
| 5 |
+
# Import other providers as they are implemented
|
| 6 |
+
# from src.llm.ollama_provider import OllamaProvider
|
| 7 |
+
# from src.llm.openai_provider import OpenAIProvider
|
| 8 |
+
|
| 9 |
+
class ProviderNotAvailableError(Exception):
|
| 10 |
+
"""Raised when no provider is available"""
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
class LLMFactory:
|
| 14 |
+
"""Factory for creating LLM providers"""
|
| 15 |
+
|
| 16 |
+
_instance = None
|
| 17 |
+
|
| 18 |
+
def __new__(cls):
|
| 19 |
+
if cls._instance is None:
|
| 20 |
+
cls._instance = super(LLMFactory, cls).__new__(cls)
|
| 21 |
+
return cls._instance
|
| 22 |
+
|
| 23 |
+
def get_provider(self, preferred_provider: Optional[str] = None) -> LLMProvider:
|
| 24 |
+
"""
|
| 25 |
+
Get an LLM provider based on preference and availability.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
preferred_provider: Preferred provider name ('huggingface', 'ollama', 'openai')
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
LLMProvider instance
|
| 32 |
+
|
| 33 |
+
Raises:
|
| 34 |
+
ProviderNotAvailableError: When no providers are available
|
| 35 |
+
"""
|
| 36 |
+
# For now, we only have HF provider implemented
|
| 37 |
+
if preferred_provider == "huggingface" or (preferred_provider is None and os.getenv("HF_TOKEN")):
|
| 38 |
+
return HuggingFaceProvider(
|
| 39 |
+
model_name=os.getenv("HF_MODEL_NAME", "meta-llama/Llama-2-7b-chat-hf")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Add other providers as they are implemented
|
| 43 |
+
# elif preferred_provider == "ollama" or (preferred_provider is None and os.getenv("OLLAMA_HOST")):
|
| 44 |
+
# return OllamaProvider(model_name=os.getenv("LOCAL_MODEL_NAME", "mistral:latest"))
|
| 45 |
+
# elif preferred_provider == "openai" or (preferred_provider is None and os.getenv("OPENAI_API_KEY")):
|
| 46 |
+
# return OpenAIProvider(model_name="gpt-3.5-turbo")
|
| 47 |
+
|
| 48 |
+
raise ProviderNotAvailableError("No LLM providers are available or configured")
|
| 49 |
+
|
| 50 |
+
# Global factory instance
|
| 51 |
+
llm_factory = LLMFactory()
|
src/llm/hf_provider.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional, Union
|
| 2 |
+
from src.llm.base_provider import LLMProvider
|
| 3 |
+
|
| 4 |
+
class HuggingFaceProvider(LLMProvider):
|
| 5 |
+
"""Hugging Face LLM provider stub implementation"""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
|
| 8 |
+
super().__init__(model_name, timeout, max_retries)
|
| 9 |
+
# Placeholder for actual client initialization
|
| 10 |
+
self.client = None
|
| 11 |
+
|
| 12 |
+
def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
|
| 13 |
+
"""Stub for synchronous generation"""
|
| 14 |
+
# In a real implementation, this would call the HF API
|
| 15 |
+
return f"[HuggingFace Stub] Response to: {prompt}"
|
| 16 |
+
|
| 17 |
+
def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
|
| 18 |
+
"""Stub for streaming generation"""
|
| 19 |
+
# In a real implementation, this would stream from the HF API
|
| 20 |
+
return ["[HuggingFace", " Streaming", " Stub]"]
|
src/services/chat_service.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Optional
|
| 2 |
+
from src.llm.factory import llm_factory, ProviderNotAvailableError
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
class ChatService:
|
| 8 |
+
"""Service for handling chat interactions with LLM providers"""
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
try:
|
| 12 |
+
self.provider = llm_factory.get_provider()
|
| 13 |
+
except ProviderNotAvailableError:
|
| 14 |
+
self.provider = None
|
| 15 |
+
logger.error("No LLM providers available")
|
| 16 |
+
|
| 17 |
+
def generate_response(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
|
| 18 |
+
"""Generate a response using the configured provider"""
|
| 19 |
+
if not self.provider:
|
| 20 |
+
raise ProviderNotAvailableError("No LLM provider available")
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
return self.provider.generate(prompt, conversation_history)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.error(f"LLM generation failed: {e}")
|
| 26 |
+
raise
|
| 27 |
+
|
| 28 |
+
def stream_response(self, prompt: str, conversation_history: List[Dict]):
|
| 29 |
+
"""Stream a response using the configured provider"""
|
| 30 |
+
if not self.provider:
|
| 31 |
+
raise ProviderNotAvailableError("No LLM provider available")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
return self.provider.stream_generate(prompt, conversation_history)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"LLM stream generation failed: {e}")
|
| 37 |
+
raise
|
| 38 |
+
|
| 39 |
+
# Global chat service instance
|
| 40 |
+
chat_service = ChatService()
|