Update services.py
Browse files- services.py +22 -67
    	
        services.py
    CHANGED
    
    | @@ -1,22 +1,11 @@ | |
| 1 | 
             
            # /services.py
         | 
|  | |
| 2 |  | 
| 3 | 
            -
            """
         | 
| 4 | 
            -
            Manages interactions with external services like LLM providers and web search APIs.
         | 
| 5 | 
            -
            This module has been refactored to support multiple LLM providers:
         | 
| 6 | 
            -
            - Hugging Face
         | 
| 7 | 
            -
            - Groq
         | 
| 8 | 
            -
            - Fireworks AI
         | 
| 9 | 
            -
            - OpenAI
         | 
| 10 | 
            -
            - Google Gemini
         | 
| 11 | 
            -
            - DeepSeek (Direct API via OpenAI client)
         | 
| 12 | 
            -
            """
         | 
| 13 | 
             
            import os
         | 
| 14 | 
             
            import logging
         | 
| 15 | 
             
            from typing import Dict, Any, Generator, List
         | 
| 16 |  | 
| 17 | 
             
            from dotenv import load_dotenv
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            # Import all necessary clients
         | 
| 20 | 
             
            from huggingface_hub import InferenceClient
         | 
| 21 | 
             
            from tavily import TavilyClient
         | 
| 22 | 
             
            from groq import Groq
         | 
| @@ -24,9 +13,6 @@ import fireworks.client as Fireworks | |
| 24 | 
             
            import openai
         | 
| 25 | 
             
            import google.generativeai as genai
         | 
| 26 |  | 
| 27 | 
            -
            # <--- FIX: REMOVED the incorrect 'from deepseek import ...' line ---
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            # --- Setup Logging & Environment ---
         | 
| 30 | 
             
            logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
         | 
| 31 | 
             
            load_dotenv()
         | 
| 32 |  | 
| @@ -39,24 +25,17 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| 39 | 
             
            GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
         | 
| 40 | 
             
            DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
         | 
| 41 |  | 
| 42 | 
            -
            # --- Type Definitions ---
         | 
| 43 | 
             
            Messages = List[Dict[str, Any]]
         | 
| 44 |  | 
| 45 | 
             
            class LLMService:
         | 
| 46 | 
             
                """A multi-provider wrapper for LLM Inference APIs."""
         | 
| 47 | 
            -
             | 
| 48 | 
             
                def __init__(self):
         | 
| 49 | 
            -
                    # Initialize clients only if their API keys are available
         | 
| 50 | 
             
                    self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
         | 
| 51 | 
             
                    self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
         | 
| 52 | 
             
                    self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
         | 
| 53 |  | 
| 54 | 
            -
                    # <--- FIX: Correctly instantiate the DeepSeek client using the OpenAI library ---
         | 
| 55 | 
             
                    if DEEPSEEK_API_KEY:
         | 
| 56 | 
            -
                        self.deepseek_client = openai.OpenAI(
         | 
| 57 | 
            -
                            api_key=DEEPSEEK_API_KEY,
         | 
| 58 | 
            -
                            base_url="https://api.deepseek.com/v1"
         | 
| 59 | 
            -
                        )
         | 
| 60 | 
             
                    else:
         | 
| 61 | 
             
                        self.deepseek_client = None
         | 
| 62 |  | 
| @@ -73,84 +52,60 @@ class LLMService: | |
| 73 | 
             
                        self.gemini_model = None
         | 
| 74 |  | 
| 75 | 
             
                def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
         | 
| 76 | 
            -
                    # This function remains the same
         | 
| 77 | 
             
                    gemini_messages = []
         | 
| 78 | 
             
                    for msg in messages:
         | 
|  | |
| 79 | 
             
                        role = 'model' if msg['role'] == 'assistant' else 'user'
         | 
| 80 | 
             
                        gemini_messages.append({'role': role, 'parts': [msg['content']]})
         | 
| 81 | 
             
                    return gemini_messages
         | 
| 82 |  | 
| 83 | 
            -
                def generate_code_stream(
         | 
| 84 | 
            -
                    self, model_id: str, messages: Messages, max_tokens: int = 8192
         | 
| 85 | 
            -
                ) -> Generator[str, None, None]:
         | 
| 86 | 
            -
                    # This function remains the same, as the dispatcher logic is already correct
         | 
| 87 | 
             
                    provider, model_name = model_id.split('/', 1)
         | 
| 88 | 
             
                    logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
         | 
| 89 |  | 
| 90 | 
             
                    try:
         | 
| 91 | 
             
                        if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
         | 
| 92 | 
            -
                            client_map = {
         | 
| 93 | 
            -
                                'openai': self.openai_client,
         | 
| 94 | 
            -
                                'groq': self.groq_client,
         | 
| 95 | 
            -
                                'deepseek': self.deepseek_client,
         | 
| 96 | 
            -
                                'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None,
         | 
| 97 | 
            -
                            }
         | 
| 98 | 
             
                            client = client_map.get(provider)
         | 
| 99 | 
            -
                            if not client:
         | 
| 100 | 
            -
                                raise ValueError(f"{provider.capitalize()} API key not configured.")
         | 
| 101 | 
            -
                            
         | 
| 102 | 
            -
                            if provider == 'fireworks':
         | 
| 103 | 
            -
                                 stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
         | 
| 104 | 
            -
                            else:
         | 
| 105 | 
            -
                                 stream = client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
         | 
| 106 |  | 
|  | |
| 107 | 
             
                            for chunk in stream:
         | 
| 108 | 
            -
                                if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
         | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
             
                        elif provider == 'gemini':
         | 
| 112 | 
            -
                            if not self.gemini_model:
         | 
| 113 | 
            -
             | 
| 114 | 
             
                            gemini_messages = self._prepare_messages_for_gemini(messages)
         | 
|  | |
|  | |
|  | |
| 115 | 
             
                            stream = self.gemini_model.generate_content(gemini_messages, stream=True)
         | 
| 116 | 
            -
                            for chunk in stream:
         | 
| 117 | 
            -
                                yield chunk.text
         | 
| 118 |  | 
| 119 | 
             
                        elif provider == 'huggingface':
         | 
| 120 | 
            -
                            if not self.hf_client:
         | 
| 121 | 
            -
                                raise ValueError("Hugging Face API token not configured.")
         | 
| 122 | 
             
                            hf_model_id = model_id.split('/', 1)[1]
         | 
| 123 | 
             
                            stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
         | 
| 124 | 
             
                            for chunk in stream:
         | 
| 125 | 
            -
                                if chunk.choices[0].delta.content:
         | 
| 126 | 
            -
                                    yield chunk.choices[0].delta.content
         | 
| 127 | 
             
                        else:
         | 
| 128 | 
             
                            raise ValueError(f"Unknown provider: {provider}")
         | 
| 129 | 
            -
             | 
| 130 | 
             
                    except Exception as e:
         | 
| 131 | 
             
                        logging.error(f"LLM API Error with provider {provider}: {e}")
         | 
| 132 | 
             
                        yield f"Error from {provider.capitalize()}: {str(e)}"
         | 
| 133 |  | 
| 134 | 
            -
            # The SearchService class remains unchanged
         | 
| 135 | 
             
            class SearchService:
         | 
| 136 | 
             
                def __init__(self, api_key: str = TAVILY_API_KEY):
         | 
| 137 | 
            -
                    if  | 
| 138 | 
            -
             | 
| 139 | 
            -
             | 
| 140 | 
            -
                    else:
         | 
| 141 | 
            -
                        self.client = TavilyClient(api_key=api_key)
         | 
| 142 | 
            -
                def is_available(self) -> bool:
         | 
| 143 | 
            -
                    return self.client is not None
         | 
| 144 | 
             
                def search(self, query: str, max_results: int = 5) -> str:
         | 
| 145 | 
             
                    if not self.is_available(): return "Web search is not available."
         | 
| 146 | 
             
                    try:
         | 
| 147 | 
             
                        response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
         | 
| 148 | 
            -
                         | 
| 149 | 
            -
             | 
| 150 | 
            -
                    except Exception as e:
         | 
| 151 | 
            -
                        logging.error(f"Tavily search error: {e}")
         | 
| 152 | 
            -
                        return f"Search error: {str(e)}"
         | 
| 153 |  | 
| 154 | 
            -
            # --- Singleton Instances ---
         | 
| 155 | 
             
            llm_service = LLMService()
         | 
| 156 | 
             
            search_service = SearchService()
         | 
|  | |
| 1 | 
             
            # /services.py
         | 
| 2 | 
            +
            """ Manages interactions with all external LLM and search APIs. """
         | 
| 3 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 4 | 
             
            import os
         | 
| 5 | 
             
            import logging
         | 
| 6 | 
             
            from typing import Dict, Any, Generator, List
         | 
| 7 |  | 
| 8 | 
             
            from dotenv import load_dotenv
         | 
|  | |
|  | |
| 9 | 
             
            from huggingface_hub import InferenceClient
         | 
| 10 | 
             
            from tavily import TavilyClient
         | 
| 11 | 
             
            from groq import Groq
         | 
|  | |
| 13 | 
             
            import openai
         | 
| 14 | 
             
            import google.generativeai as genai
         | 
| 15 |  | 
|  | |
|  | |
|  | |
| 16 | 
             
            logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
         | 
| 17 | 
             
            load_dotenv()
         | 
| 18 |  | 
|  | |
| 25 | 
             
            GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
         | 
| 26 | 
             
            DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
         | 
| 27 |  | 
|  | |
| 28 | 
             
            Messages = List[Dict[str, Any]]
         | 
| 29 |  | 
| 30 | 
             
            class LLMService:
         | 
| 31 | 
             
                """A multi-provider wrapper for LLM Inference APIs."""
         | 
|  | |
| 32 | 
             
                def __init__(self):
         | 
|  | |
| 33 | 
             
                    self.hf_client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else None
         | 
| 34 | 
             
                    self.groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
         | 
| 35 | 
             
                    self.openai_client = openai.OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
         | 
| 36 |  | 
|  | |
| 37 | 
             
                    if DEEPSEEK_API_KEY:
         | 
| 38 | 
            +
                        self.deepseek_client = openai.OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com/v1")
         | 
|  | |
|  | |
|  | |
| 39 | 
             
                    else:
         | 
| 40 | 
             
                        self.deepseek_client = None
         | 
| 41 |  | 
|  | |
| 52 | 
             
                        self.gemini_model = None
         | 
| 53 |  | 
| 54 | 
             
                def _prepare_messages_for_gemini(self, messages: Messages) -> List[Dict[str, Any]]:
         | 
|  | |
| 55 | 
             
                    gemini_messages = []
         | 
| 56 | 
             
                    for msg in messages:
         | 
| 57 | 
            +
                        if msg['role'] == 'system': continue # Gemini doesn't use a system role in this way
         | 
| 58 | 
             
                        role = 'model' if msg['role'] == 'assistant' else 'user'
         | 
| 59 | 
             
                        gemini_messages.append({'role': role, 'parts': [msg['content']]})
         | 
| 60 | 
             
                    return gemini_messages
         | 
| 61 |  | 
| 62 | 
            +
                def generate_code_stream(self, model_id: str, messages: Messages, max_tokens: int = 8192) -> Generator[str, None, None]:
         | 
|  | |
|  | |
|  | |
| 63 | 
             
                    provider, model_name = model_id.split('/', 1)
         | 
| 64 | 
             
                    logging.info(f"Dispatching to provider: {provider} for model: {model_name}")
         | 
| 65 |  | 
| 66 | 
             
                    try:
         | 
| 67 | 
             
                        if provider in ['openai', 'groq', 'deepseek', 'fireworks']:
         | 
| 68 | 
            +
                            client_map = {'openai': self.openai_client, 'groq': self.groq_client, 'deepseek': self.deepseek_client, 'fireworks': self.fireworks_client.ChatCompletion if self.fireworks_client else None}
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 69 | 
             
                            client = client_map.get(provider)
         | 
| 70 | 
            +
                            if not client: raise ValueError(f"{provider.capitalize()} API key not configured.")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 71 |  | 
| 72 | 
            +
                            stream = client.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens) if provider == 'fireworks' else client.chat.completions.create(model=model_name, messages=messages, stream=True, max_tokens=max_tokens)
         | 
| 73 | 
             
                            for chunk in stream:
         | 
| 74 | 
            +
                                if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
         | 
| 75 | 
            +
                        
         | 
|  | |
| 76 | 
             
                        elif provider == 'gemini':
         | 
| 77 | 
            +
                            if not self.gemini_model: raise ValueError("Gemini API key not configured.")
         | 
| 78 | 
            +
                            system_prompt = next((msg['content'] for msg in messages if msg['role'] == 'system'), "")
         | 
| 79 | 
             
                            gemini_messages = self._prepare_messages_for_gemini(messages)
         | 
| 80 | 
            +
                            # Prepend system prompt to first user message for Gemini
         | 
| 81 | 
            +
                            if system_prompt and gemini_messages and gemini_messages[0]['role'] == 'user':
         | 
| 82 | 
            +
                                gemini_messages[0]['parts'][0] = f"{system_prompt}\n\n{gemini_messages[0]['parts'][0]}"
         | 
| 83 | 
             
                            stream = self.gemini_model.generate_content(gemini_messages, stream=True)
         | 
| 84 | 
            +
                            for chunk in stream: yield chunk.text
         | 
|  | |
| 85 |  | 
| 86 | 
             
                        elif provider == 'huggingface':
         | 
| 87 | 
            +
                            if not self.hf_client: raise ValueError("Hugging Face API token not configured.")
         | 
|  | |
| 88 | 
             
                            hf_model_id = model_id.split('/', 1)[1]
         | 
| 89 | 
             
                            stream = self.hf_client.chat_completion(model=hf_model_id, messages=messages, stream=True, max_tokens=max_tokens)
         | 
| 90 | 
             
                            for chunk in stream:
         | 
| 91 | 
            +
                                if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content
         | 
|  | |
| 92 | 
             
                        else:
         | 
| 93 | 
             
                            raise ValueError(f"Unknown provider: {provider}")
         | 
|  | |
| 94 | 
             
                    except Exception as e:
         | 
| 95 | 
             
                        logging.error(f"LLM API Error with provider {provider}: {e}")
         | 
| 96 | 
             
                        yield f"Error from {provider.capitalize()}: {str(e)}"
         | 
| 97 |  | 
|  | |
| 98 | 
             
            class SearchService:
         | 
| 99 | 
             
                def __init__(self, api_key: str = TAVILY_API_KEY):
         | 
| 100 | 
            +
                    self.client = TavilyClient(api_key=api_key) if api_key else None
         | 
| 101 | 
            +
                    if not self.client: logging.warning("TAVILY_API_KEY not set. Web search will be disabled.")
         | 
| 102 | 
            +
                def is_available(self) -> bool: return self.client is not None
         | 
|  | |
|  | |
|  | |
|  | |
| 103 | 
             
                def search(self, query: str, max_results: int = 5) -> str:
         | 
| 104 | 
             
                    if not self.is_available(): return "Web search is not available."
         | 
| 105 | 
             
                    try:
         | 
| 106 | 
             
                        response = self.client.search(query, search_depth="advanced", max_results=min(max(1, max_results), 10))
         | 
| 107 | 
            +
                        return "Web Search Results:\n\n" + "\n---\n".join([f"Title: {res.get('title', 'N/A')}\nURL: {res.get('url', 'N/A')}\nContent: {res.get('content', 'N/A')}" for res in response.get('results', [])])
         | 
| 108 | 
            +
                    except Exception as e: return f"Search error: {str(e)}"
         | 
|  | |
|  | |
|  | |
| 109 |  | 
|  | |
| 110 | 
             
            llm_service = LLMService()
         | 
| 111 | 
             
            search_service = SearchService()
         | 
