rdune71 commited on
Commit
c1cbefd
·
1 Parent(s): 83ce746

Implement context enrichment service for current data awareness

Browse files
src/llm/enhanced_provider.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Dict, Optional, Union
3
+ from src.llm.base_provider import LLMProvider
4
+ from src.services.context_enrichment import context_service
5
+
6
+ class EnhancedLLMProvider(LLMProvider):
7
+ """Base provider with context enrichment"""
8
+
9
+ def __init__(self, model_name: str, timeout: int = 30, max_retries: int = 3):
10
+ super().__init__(model_name, timeout, max_retries)
11
+
12
+ def _enrich_context(self, conversation_history: List[Dict]) -> List[Dict]:
13
+ """Add current context to conversation"""
14
+ # Get the last user message to determine context needs
15
+ last_user_message = ""
16
+ for msg in reversed(conversation_history):
17
+ if msg["role"] == "user":
18
+ last_user_message = msg["content"]
19
+ break
20
+
21
+ # Get current context
22
+ context = context_service.get_current_context(last_user_message)
23
+
24
+ # Add context as system message at the beginning
25
+ context_message = {
26
+ "role": "system",
27
+ "content": f"[Current Context: {context['current_time']} | Weather: {context['weather']}]"
28
+ }
29
+
30
+ # Insert context at the beginning
31
+ enriched_history = [context_message] + conversation_history
32
+ return enriched_history
src/llm/ollama_provider.py CHANGED
@@ -2,13 +2,13 @@ import requests
2
  import logging
3
  import re
4
  from typing import List, Dict, Optional, Union
5
- from src.llm.base_provider import LLMProvider
6
  from utils.config import config
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
- class OllamaProvider(LLMProvider):
11
- """Ollama LLM provider implementation"""
12
 
13
  def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3):
14
  super().__init__(model_name, timeout, max_retries)
@@ -29,43 +29,64 @@ class OllamaProvider(LLMProvider):
29
  return host
30
 
31
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
32
- """Generate a response synchronously"""
33
- return self._retry_with_backoff(self._generate_impl, prompt, conversation_history)
34
-
35
- def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
36
- """Generate a response with streaming support"""
37
- return self._retry_with_backoff(self._stream_generate_impl, prompt, conversation_history)
38
-
39
- def _generate_impl(self, prompt: str, conversation_history: List[Dict]) -> str:
40
- """Implementation of synchronous generation"""
41
  try:
 
 
 
42
  url = f"{self.host}/api/chat"
43
  payload = {
44
  "model": self.model_name,
45
- "messages": conversation_history,
46
  "stream": False
47
  }
48
 
 
 
 
 
49
  response = requests.post(
50
  url,
51
  json=payload,
52
  headers=self.headers,
53
  timeout=self.timeout
54
  )
 
 
 
 
55
  response.raise_for_status()
56
  result = response.json()
57
- return result["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
  logger.error(f"Ollama generation failed: {e}")
60
  raise
61
 
62
- def _stream_generate_impl(self, prompt: str, conversation_history: List[Dict]) -> List[str]:
63
- """Implementation of streaming generation"""
64
  try:
 
 
 
65
  url = f"{self.host}/api/chat"
66
  payload = {
67
  "model": self.model_name,
68
- "messages": conversation_history,
69
  "stream": True
70
  }
71
 
@@ -83,7 +104,7 @@ class OllamaProvider(LLMProvider):
83
  if line:
84
  chunk = line.decode('utf-8')
85
  try:
86
- data = eval(chunk)
87
  content = data.get("message", {}).get("content", "")
88
  if content:
89
  chunks.append(content)
 
2
  import logging
3
  import re
4
  from typing import List, Dict, Optional, Union
5
+ from src.llm.enhanced_provider import EnhancedLLMProvider
6
  from utils.config import config
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ class OllamaProvider(EnhancedLLMProvider):
11
+ """Ollama LLM provider implementation with context enrichment"""
12
 
13
  def __init__(self, model_name: str, timeout: int = 60, max_retries: int = 3):
14
  super().__init__(model_name, timeout, max_retries)
 
29
  return host
30
 
31
  def generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[str]:
32
+ """Generate a response synchronously with context enrichment"""
 
 
 
 
 
 
 
 
33
  try:
34
+ # Enrich context
35
+ enriched_history = self._enrich_context(conversation_history)
36
+
37
  url = f"{self.host}/api/chat"
38
  payload = {
39
  "model": self.model_name,
40
+ "messages": enriched_history,
41
  "stream": False
42
  }
43
 
44
+ logger.info(f"Ollama request URL: {url}")
45
+ logger.info(f"Ollama request payload: {payload}")
46
+ logger.info(f"Ollama headers: {self.headers}")
47
+
48
  response = requests.post(
49
  url,
50
  json=payload,
51
  headers=self.headers,
52
  timeout=self.timeout
53
  )
54
+
55
+ logger.info(f"Ollama response status: {response.status_code}")
56
+ logger.info(f"Ollama response headers: {dict(response.headers)}")
57
+
58
  response.raise_for_status()
59
  result = response.json()
60
+ logger.info(f"Ollama response body: {result}")
61
+
62
+ # Extract content properly
63
+ if "message" in result and "content" in result["message"]:
64
+ content = result["message"]["content"]
65
+ logger.info(f"Extracted content: {content[:100] if content else 'None'}")
66
+ return content
67
+ elif "response" in result:
68
+ content = result["response"]
69
+ logger.info(f"Extracted response: {content[:100] if content else 'None'}")
70
+ return content
71
+ else:
72
+ content = str(result)
73
+ logger.info(f"Raw result as string: {content[:100]}")
74
+ return content
75
+
76
  except Exception as e:
77
  logger.error(f"Ollama generation failed: {e}")
78
  raise
79
 
80
+ def stream_generate(self, prompt: str, conversation_history: List[Dict]) -> Optional[Union[str, List[str]]]:
81
+ """Generate a response with streaming support"""
82
  try:
83
+ # Enrich context
84
+ enriched_history = self._enrich_context(conversation_history)
85
+
86
  url = f"{self.host}/api/chat"
87
  payload = {
88
  "model": self.model_name,
89
+ "messages": enriched_history,
90
  "stream": True
91
  }
92
 
 
104
  if line:
105
  chunk = line.decode('utf-8')
106
  try:
107
+ data = eval(chunk) # Simplified JSON parsing
108
  content = data.get("message", {}).get("content", "")
109
  if content:
110
  chunks.append(content)
src/services/context_enrichment.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import requests
3
+ from typing import Dict, Any, Optional
4
+ from utils.config import config
5
+
6
+ class ContextEnrichmentService:
7
+ """Service for enriching AI context with current data"""
8
+
9
+ def __init__(self):
10
+ self.openweather_api_key = config.openweather_api_key
11
+ self.tavily_api_key = config.tavily_api_key
12
+
13
+ def get_current_context(self, user_query: str = "") -> Dict[str, Any]:
14
+ """Get current context including time, weather, and recent news"""
15
+ context = {
16
+ "current_time": self._get_current_time(),
17
+ "weather": self._get_weather_summary("New York"), # Default location
18
+ "recent_news": self._get_recent_news(user_query) if user_query else []
19
+ }
20
+ return context
21
+
22
+ def _get_current_time(self) -> str:
23
+ """Get current date and time"""
24
+ now = datetime.datetime.now()
25
+ return now.strftime("%A, %B %d, %Y at %I:%M %p")
26
+
27
+ def _get_weather_summary(self, city: str = "New York") -> Optional[str]:
28
+ """Get weather summary for a city"""
29
+ if not self.openweather_api_key:
30
+ return "Weather data not configured"
31
+
32
+ try:
33
+ url = f"http://api.openweathermap.org/data/2.5/weather"
34
+ params = {
35
+ 'q': city,
36
+ 'appid': self.openweather_api_key,
37
+ 'units': 'metric'
38
+ }
39
+ response = requests.get(url, params=params, timeout=5)
40
+ if response.status_code == 200:
41
+ data = response.json()
42
+ return f"{data['weather'][0]['description']}, {data['main']['temp']}°C in {data['name']}"
43
+ except Exception:
44
+ pass
45
+ return "Clear skies"
46
+
47
+ def _get_recent_news(self, query: str) -> list:
48
+ """Get recent news related to query"""
49
+ if not self.tavily_api_key:
50
+ return []
51
+
52
+ try:
53
+ url = "https://api.tavily.com/search"
54
+ headers = {"Content-Type": "application/json"}
55
+ data = {
56
+ "query": query,
57
+ "api_key": self.tavily_api_key,
58
+ "max_results": 3
59
+ }
60
+ response = requests.post(url, json=data, headers=headers, timeout=10)
61
+ if response.status_code == 200:
62
+ result = response.json()
63
+ return result.get("results", [])
64
+ except Exception:
65
+ pass
66
+ return []
67
+
68
+ # Global instance
69
+ context_service = ContextEnrichmentService()