akryldigital commited on
Commit
92633a7
·
verified ·
1 Parent(s): 26449fc
Dockerfile CHANGED
@@ -1,20 +1,29 @@
1
- FROM python:3.13.5-slim
2
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
  build-essential \
7
  curl \
8
  git \
9
  && rm -rf /var/lib/apt/lists/*
10
 
 
11
  COPY requirements.txt ./
12
- COPY src/ ./src/
13
 
14
- RUN pip3 install -r requirements.txt
 
15
 
 
 
 
 
16
  EXPOSE 8501
17
 
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
 
 
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
1
+ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
+ # Install system dependencies
6
  RUN apt-get update && apt-get install -y \
7
  build-essential \
8
  curl \
9
  git \
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
+ # Copy requirements first (for better Docker layer caching)
13
  COPY requirements.txt ./
 
14
 
15
+ # Install Python dependencies
16
+ RUN pip3 install --no-cache-dir -r requirements.txt
17
 
18
+ # Copy all application files (excluding .dockerignore patterns)
19
+ COPY . .
20
+
21
+ # Expose Streamlit port (HF Spaces maps to 7860 automatically)
22
  EXPOSE 8501
23
 
24
+ # Health check for Streamlit
25
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
26
+ CMD curl --fail http://localhost:8501/_stcore/health || exit 1
27
 
28
+ # Run Streamlit app
29
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0", "--server.headless", "true"]
app.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent Audit Report Chatbot UI
3
+ """
4
+
5
+
6
+ import os
7
+ import sys
8
+ import time
9
+ import json
10
+ import uuid
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ import argparse
15
+ import streamlit as st
16
+ from langchain_core.messages import HumanMessage, AIMessage
17
+
18
+ from multi_agent_chatbot import get_multi_agent_chatbot
19
+ from smart_chatbot import get_chatbot as get_smart_chatbot
20
+ from src.reporting.feedback_schema import create_feedback_from_dict
21
+
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Page config
27
+ st.set_page_config(
28
+ layout="wide",
29
+ page_icon="🤖",
30
+ initial_sidebar_state="expanded",
31
+ page_title="Intelligent Audit Report Chatbot"
32
+ )
33
+
34
+ # Custom CSS
35
+ st.markdown("""
36
+ <style>
37
+ .main-header {
38
+ font-size: 2.5rem;
39
+ font-weight: bold;
40
+ color: #1f77b4;
41
+ text-align: center;
42
+ margin-bottom: 1rem;
43
+ }
44
+
45
+ .subtitle {
46
+ font-size: 1.2rem;
47
+ color: #666;
48
+ text-align: center;
49
+ margin-bottom: 2rem;
50
+ }
51
+
52
+ .session-info {
53
+ background-color: #f0f2f6;
54
+ padding: 10px;
55
+ border-radius: 5px;
56
+ margin-bottom: 20px;
57
+ font-size: 0.9rem;
58
+ }
59
+
60
+ .user-message {
61
+ background-color: #007bff;
62
+ color: white;
63
+ padding: 12px 16px;
64
+ border-radius: 18px 18px 4px 18px;
65
+ margin: 8px 0;
66
+ margin-left: 20%;
67
+ word-wrap: break-word;
68
+ }
69
+
70
+ .bot-message {
71
+ background-color: #f1f3f4;
72
+ color: #333;
73
+ padding: 12px 16px;
74
+ border-radius: 18px 18px 18px 4px;
75
+ margin: 8px 0;
76
+ margin-right: 20%;
77
+ word-wrap: break-word;
78
+ border: 1px solid #e0e0e0;
79
+ }
80
+
81
+ .filter-section {
82
+ margin-bottom: 20px;
83
+ padding: 15px;
84
+ background-color: #f8f9fa;
85
+ border-radius: 8px;
86
+ border: 1px solid #e9ecef;
87
+ }
88
+
89
+ .filter-title {
90
+ font-weight: bold;
91
+ margin-bottom: 10px;
92
+ color: #495057;
93
+ }
94
+
95
+ .feedback-section {
96
+ background-color: #f8f9fa;
97
+ padding: 20px;
98
+ border-radius: 10px;
99
+ margin-top: 30px;
100
+ border: 2px solid #dee2e6;
101
+ }
102
+
103
+ .retrieval-history {
104
+ background-color: #ffffff;
105
+ padding: 15px;
106
+ border-radius: 5px;
107
+ margin: 10px 0;
108
+ border-left: 4px solid #007bff;
109
+ }
110
+ </style>
111
+ """, unsafe_allow_html=True)
112
+
113
+ def get_system_type():
114
+ """Get the current system type"""
115
+ system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
116
+ if system == 'smart':
117
+ return "Smart Chatbot System"
118
+ else:
119
+ return "Multi-Agent System"
120
+
121
+ def get_chatbot():
122
+ """Initialize and return the chatbot based on system type"""
123
+ # Check environment variable for system type
124
+ system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
125
+ if system == 'smart':
126
+ return get_smart_chatbot()
127
+ else:
128
+ return get_multi_agent_chatbot()
129
+
130
+ def serialize_messages(messages):
131
+ """Serialize LangChain messages to dictionaries"""
132
+ serialized = []
133
+ for msg in messages:
134
+ if hasattr(msg, 'content'):
135
+ serialized.append({
136
+ "type": type(msg).__name__,
137
+ "content": str(msg.content)
138
+ })
139
+ return serialized
140
+
141
+ def serialize_documents(sources):
142
+ """Serialize document objects to dictionaries with deduplication"""
143
+ serialized = []
144
+ seen_content = set()
145
+
146
+ for doc in sources:
147
+ content = getattr(doc, 'page_content', getattr(doc, 'content', ''))
148
+
149
+ # Skip if we've seen this exact content before
150
+ if content in seen_content:
151
+ continue
152
+
153
+ seen_content.add(content)
154
+
155
+ doc_dict = {
156
+ "content": content,
157
+ "metadata": getattr(doc, 'metadata', {}),
158
+ "score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)),
159
+ "id": getattr(doc, 'metadata', {}).get('_id', 'unknown'),
160
+ "source": getattr(doc, 'metadata', {}).get('source', 'unknown'),
161
+ "year": getattr(doc, 'metadata', {}).get('year', 'unknown'),
162
+ "district": getattr(doc, 'metadata', {}).get('district', 'unknown'),
163
+ "page": getattr(doc, 'metadata', {}).get('page', 'unknown'),
164
+ "chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'),
165
+ "page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'),
166
+ "original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0),
167
+ "reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None)
168
+ }
169
+ serialized.append(doc_dict)
170
+
171
+ return serialized
172
+
173
+ @st.cache_data
174
+ def load_filter_options():
175
+ try:
176
+ with open("filter_options.json", "r") as f:
177
+ return json.load(f)
178
+ except FileNotFoundError:
179
+ st.info([x for x in os.listdir() if x.endswith('.json')])
180
+ st.error("filter_options.json not found. Please run the metadata analysis script.")
181
+ return {"sources": [], "years": [], "districts": [], 'filenames': []}
182
+
183
+ def main():
184
+ # Initialize session state
185
+ if 'messages' not in st.session_state:
186
+ st.session_state.messages = []
187
+ if 'conversation_id' not in st.session_state:
188
+ st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
189
+ if 'session_start_time' not in st.session_state:
190
+ st.session_state.session_start_time = time.time()
191
+ if 'active_filters' not in st.session_state:
192
+ st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
193
+ # Track RAG retrieval history for feedback
194
+ if 'rag_retrieval_history' not in st.session_state:
195
+ st.session_state.rag_retrieval_history = []
196
+ # Initialize chatbot only once per app session (cached)
197
+ if 'chatbot' not in st.session_state:
198
+ with st.spinner("🔄 Loading AI models and connecting to database..."):
199
+ st.session_state.chatbot = get_chatbot()
200
+ st.success("✅ AI system ready!")
201
+
202
+ # Reset conversation history if needed (but keep chatbot cached)
203
+ if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
204
+ st.session_state.messages = []
205
+ st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
206
+ st.session_state.session_start_time = time.time()
207
+ st.session_state.rag_retrieval_history = []
208
+ st.session_state.feedback_submitted = False
209
+ st.session_state.reset_conversation = False
210
+ st.rerun()
211
+
212
+ # Header with system indicator
213
+ col1, col2 = st.columns([3, 1])
214
+ with col1:
215
+ st.markdown('<h1 class="main-header">🤖 Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
216
+ with col2:
217
+ system_type = get_system_type()
218
+ if "Multi-Agent" in system_type:
219
+ st.success(f"🔧 {system_type}")
220
+ else:
221
+ st.info(f"🔧 {system_type}")
222
+ st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
223
+
224
+ # Session info
225
+ duration = int(time.time() - st.session_state.session_start_time)
226
+ duration_str = f"{duration // 60}m {duration % 60}s"
227
+ st.markdown(f'''
228
+ <div class="session-info">
229
+ <strong>Session Info:</strong> Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
230
+ </div>
231
+ ''', unsafe_allow_html=True)
232
+
233
+ # Load filter options
234
+ filter_options = load_filter_options()
235
+
236
+ # Sidebar for filters
237
+ with st.sidebar:
238
+ st.markdown("### 🔍 Search Filters")
239
+ st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
240
+
241
+ st.markdown('<div class="filter-section">', unsafe_allow_html=True)
242
+ st.markdown('<div class="filter-title">📄 Specific Reports (Filename Filter)</div>', unsafe_allow_html=True)
243
+ st.markdown('<p style="font-size: 0.85em; color: #666;">⚠️ Selecting specific reports will ignore all other filters</p>', unsafe_allow_html=True)
244
+ selected_filenames = st.multiselect(
245
+ "Select specific reports:",
246
+ options=filter_options.get('filenames', []),
247
+ default=st.session_state.active_filters.get('filenames', []),
248
+ key="filenames_filter",
249
+ help="Choose specific reports to search. When enabled, all other filters are ignored."
250
+ )
251
+ st.markdown('</div>', unsafe_allow_html=True)
252
+
253
+ # Determine if filename filter is active
254
+ filename_mode = len(selected_filenames) > 0
255
+ # Sources filter
256
+ st.markdown('<div class="filter-section">', unsafe_allow_html=True)
257
+ st.markdown('<div class="filter-title">📊 Sources</div>', unsafe_allow_html=True)
258
+ selected_sources = st.multiselect(
259
+ "Select sources:",
260
+ options=filter_options['sources'],
261
+ default=st.session_state.active_filters['sources'],
262
+ disabled = filename_mode,
263
+ key="sources_filter",
264
+ help="Choose which types of reports to search"
265
+ )
266
+ st.markdown('</div>', unsafe_allow_html=True)
267
+
268
+ # Years filter
269
+ st.markdown('<div class="filter-section">', unsafe_allow_html=True)
270
+ st.markdown('<div class="filter-title">📅 Years</div>', unsafe_allow_html=True)
271
+ selected_years = st.multiselect(
272
+ "Select years:",
273
+ options=filter_options['years'],
274
+ default=st.session_state.active_filters['years'],
275
+ disabled = filename_mode,
276
+ key="years_filter",
277
+ help="Choose which years to search"
278
+ )
279
+ st.markdown('</div>', unsafe_allow_html=True)
280
+
281
+ # Districts filter
282
+ st.markdown('<div class="filter-section">', unsafe_allow_html=True)
283
+ st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True)
284
+ selected_districts = st.multiselect(
285
+ "Select districts:",
286
+ options=filter_options['districts'],
287
+ default=st.session_state.active_filters['districts'],
288
+ disabled = filename_mode,
289
+ key="districts_filter",
290
+ help="Choose which districts to search"
291
+ )
292
+ st.markdown('</div>', unsafe_allow_html=True)
293
+
294
+ # Update active filters
295
+ st.session_state.active_filters = {
296
+ 'sources': selected_sources if not filename_mode else [],
297
+ 'years': selected_years if not filename_mode else [],
298
+ 'districts': selected_districts if not filename_mode else [],
299
+ 'filenames': selected_filenames
300
+ }
301
+
302
+ # Clear filters button
303
+ if st.button("🗑️ Clear All Filters", key="clear_filters_button"):
304
+ st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
305
+ st.rerun()
306
+
307
+ # Main content area with tabs
308
+ tab1, tab2 = st.tabs(["💬 Chat", "📄 Retrieved Documents"])
309
+
310
+ with tab1:
311
+ # Chat container
312
+ chat_container = st.container()
313
+
314
+ with chat_container:
315
+ # Display conversation history
316
+ for message in st.session_state.messages:
317
+ if isinstance(message, HumanMessage):
318
+ st.markdown(f'<div class="user-message">{message.content}</div>', unsafe_allow_html=True)
319
+ elif isinstance(message, AIMessage):
320
+ st.markdown(f'<div class="bot-message">{message.content}</div>', unsafe_allow_html=True)
321
+
322
+ # Input area
323
+ st.markdown("<br>", unsafe_allow_html=True)
324
+
325
+ # Create two columns for input and button
326
+ col1, col2 = st.columns([4, 1])
327
+
328
+ with col1:
329
+ # Use a counter to force input clearing
330
+ if 'input_counter' not in st.session_state:
331
+ st.session_state.input_counter = 0
332
+
333
+ user_input = st.text_input(
334
+ "Type your message here...",
335
+ placeholder="Ask about budget allocations, expenditures, or audit findings...",
336
+ key=f"user_input_{st.session_state.input_counter}",
337
+ label_visibility="collapsed"
338
+ )
339
+
340
+ with col2:
341
+ send_button = st.button("Send", key="send_button", use_container_width=True)
342
+
343
+ # Clear chat button
344
+ if st.button("🗑️ Clear Chat", key="clear_chat_button"):
345
+ st.session_state.reset_conversation = True
346
+ # Clear all conversation files
347
+ import os
348
+ conversations_dir = "conversations"
349
+ if os.path.exists(conversations_dir):
350
+ for file in os.listdir(conversations_dir):
351
+ if file.endswith('.json'):
352
+ os.remove(os.path.join(conversations_dir, file))
353
+ st.rerun()
354
+
355
+ # Handle user input
356
+ if send_button and user_input:
357
+ # Construct filter context string
358
+ filter_context_str = ""
359
+ if selected_filenames:
360
+ filter_context_str += "FILTER CONTEXT:\n"
361
+ filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
362
+ filter_context_str += "USER QUERY:\n"
363
+ elif selected_sources or selected_years or selected_districts:
364
+ filter_context_str += "FILTER CONTEXT:\n"
365
+ if selected_sources:
366
+ filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
367
+ if selected_years:
368
+ filter_context_str += f"Years: {', '.join(selected_years)}\n"
369
+ if selected_districts:
370
+ filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
371
+ filter_context_str += "USER QUERY:\n"
372
+
373
+ full_query = filter_context_str + user_input
374
+
375
+ # Add user message to history
376
+ st.session_state.messages.append(HumanMessage(content=user_input))
377
+
378
+ # Get chatbot response
379
+ with st.spinner("🤔 Thinking..."):
380
+ try:
381
+ # Pass the full query with filter context
382
+ chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id)
383
+
384
+ # Handle both old format (string) and new format (dict)
385
+ if isinstance(chat_result, dict):
386
+ response = chat_result['response']
387
+ rag_result = chat_result.get('rag_result')
388
+ st.session_state.last_rag_result = rag_result
389
+
390
+ # Track RAG retrieval for feedback
391
+ if rag_result:
392
+ sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
393
+
394
+ # Get the actual RAG query
395
+ actual_rag_query = chat_result.get('actual_rag_query', '')
396
+ if actual_rag_query:
397
+ # Format it like the log message
398
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
399
+ formatted_query = f"{timestamp} - INFO - 🔍 ACTUAL RAG QUERY: '{actual_rag_query}'"
400
+ else:
401
+ formatted_query = "No RAG query available"
402
+
403
+ retrieval_entry = {
404
+ "conversation_up_to": serialize_messages(st.session_state.messages),
405
+ "rag_query_expansion": formatted_query,
406
+ "docs_retrieved": serialize_documents(sources)
407
+ }
408
+ st.session_state.rag_retrieval_history.append(retrieval_entry)
409
+ else:
410
+ response = chat_result
411
+ st.session_state.last_rag_result = None
412
+
413
+ # Add bot response to history
414
+ st.session_state.messages.append(AIMessage(content=response))
415
+
416
+ except Exception as e:
417
+ error_msg = f"Sorry, I encountered an error: {str(e)}"
418
+ st.session_state.messages.append(AIMessage(content=error_msg))
419
+
420
+ # Clear input and rerun
421
+ st.session_state.input_counter += 1 # This will clear the input
422
+ st.rerun()
423
+
424
+ with tab2:
425
+ # Document retrieval panel
426
+ if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
427
+ rag_result = st.session_state.last_rag_result
428
+
429
+ # Handle both PipelineResult object and dictionary formats
430
+ sources = None
431
+ if hasattr(rag_result, 'sources'):
432
+ # PipelineResult object format
433
+ sources = rag_result.sources
434
+ elif isinstance(rag_result, dict) and 'sources' in rag_result:
435
+ # Dictionary format from multi-agent system
436
+ sources = rag_result['sources']
437
+
438
+ if sources and len(sources) > 0:
439
+ # Count unique filenames
440
+ unique_filenames = set()
441
+ for doc in sources:
442
+ filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
443
+ unique_filenames.add(filename)
444
+
445
+ st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 10):**")
446
+ if len(unique_filenames) < len(sources):
447
+ st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
448
+
449
+ for i, doc in enumerate(sources[:10]): # Show top 10
450
+ # Get relevance score and ID if available
451
+ metadata = getattr(doc, 'metadata', {})
452
+ score = metadata.get('reranked_score', metadata.get('original_score', None))
453
+ chunk_id = metadata.get('_id', 'Unknown')
454
+ score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
455
+
456
+ with st.expander(f"📄 Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
457
+ # Display document metadata with emojis
458
+ metadata = getattr(doc, 'metadata', {})
459
+ col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
460
+
461
+ with col1:
462
+ st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}")
463
+ with col2:
464
+ st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}")
465
+ with col3:
466
+ st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}")
467
+ with col4:
468
+ # Display page number and chunk ID
469
+ page = metadata.get('page_label', metadata.get('page', 'Unknown'))
470
+ chunk_id = metadata.get('_id', 'Unknown')
471
+ st.write(f"📖 **Page:** {page}")
472
+ st.write(f"🆔 **ID:** {chunk_id}")
473
+
474
+ # Display full content (no truncation)
475
+ content = getattr(doc, 'page_content', 'No content available')
476
+ st.write(f"**Full Content:**")
477
+ st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
478
+ else:
479
+ st.info("No documents were retrieved for the last query.")
480
+ else:
481
+ st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.")
482
+
483
+ # Feedback Dashboard Section
484
+ st.markdown("---")
485
+ st.markdown("### 💬 Feedback Dashboard")
486
+
487
+ # Check if there's any conversation to provide feedback on
488
+ has_conversation = len(st.session_state.messages) > 0
489
+ has_retrievals = len(st.session_state.rag_retrieval_history) > 0
490
+
491
+ if not has_conversation:
492
+ st.info("💡 Start a conversation to provide feedback!")
493
+ st.markdown("The feedback dashboard will be enabled once you begin chatting.")
494
+ else:
495
+ st.markdown("Help us improve by providing feedback on this conversation.")
496
+
497
+ # Initialize feedback state if not exists
498
+ if 'feedback_submitted' not in st.session_state:
499
+ st.session_state.feedback_submitted = False
500
+
501
+ # Feedback form
502
+ with st.form("feedback_form", clear_on_submit=False):
503
+ col1, col2 = st.columns([1, 1])
504
+
505
+ with col1:
506
+ feedback_score = st.slider(
507
+ "Rate this conversation (1-5)",
508
+ min_value=1,
509
+ max_value=5,
510
+ help="How satisfied are you with the conversation?"
511
+ )
512
+
513
+ with col2:
514
+ is_feedback_about_last_retrieval = st.checkbox(
515
+ "Feedback about last retrieval only",
516
+ value=True,
517
+ help="If checked, feedback applies to the most recent document retrieval"
518
+ )
519
+
520
+ open_ended_feedback = st.text_area(
521
+ "Your feedback (optional)",
522
+ placeholder="Tell us what went well or what could be improved...",
523
+ height=100
524
+ )
525
+
526
+ # Disable submit if no score selected
527
+ submit_disabled = feedback_score is None
528
+
529
+ submitted = st.form_submit_button(
530
+ "📤 Submit Feedback",
531
+ use_container_width=True,
532
+ disabled=submit_disabled
533
+ )
534
+
535
+ if submitted and not st.session_state.feedback_submitted:
536
+ # Log the feedback data being submitted
537
+ print("=" * 80)
538
+ print("🔄 FEEDBACK SUBMISSION: Starting...")
539
+ print("=" * 80)
540
+ st.write("🔍 **Debug: Feedback Data Being Submitted:**")
541
+
542
+ # Create feedback data dictionary
543
+ feedback_dict = {
544
+ "open_ended_feedback": open_ended_feedback,
545
+ "score": feedback_score,
546
+ "is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
547
+ "retrieved_data": st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
548
+ "conversation_id": st.session_state.conversation_id,
549
+ "timestamp": time.time(),
550
+ "message_count": len(st.session_state.messages),
551
+ "has_retrievals": has_retrievals,
552
+ "retrieval_count": len(st.session_state.rag_retrieval_history)
553
+ }
554
+
555
+ print(f"📝 FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
556
+
557
+ # Create UserFeedback dataclass instance
558
+ feedback_obj = None # Initialize outside try block
559
+ try:
560
+ feedback_obj = create_feedback_from_dict(feedback_dict)
561
+ print(f"✅ FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
562
+ st.write(f"✅ **Feedback Object Created**")
563
+ st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
564
+ st.write(f"- Score: {feedback_obj.score}/5")
565
+ st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
566
+
567
+ # Convert back to dict for JSON serialization
568
+ feedback_data = feedback_obj.to_dict()
569
+ except Exception as e:
570
+ print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
571
+ st.error(f"Failed to create feedback object: {e}")
572
+ feedback_data = feedback_dict
573
+
574
+ # Display the data being submitted
575
+ st.json(feedback_data)
576
+
577
+ # Save feedback to file
578
+ feedback_dir = Path("feedback")
579
+ feedback_dir.mkdir(exist_ok=True)
580
+
581
+ feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
582
+
583
+ try:
584
+ # Save to local file
585
+ print(f"💾 FEEDBACK SAVE: Saving to local file: {feedback_file}")
586
+ with open(feedback_file, 'w') as f:
587
+ json.dump(feedback_data, f, indent=2, default=str)
588
+
589
+ print(f"✅ FEEDBACK SAVE: Local file saved successfully")
590
+ st.success("✅ Thank you for your feedback! It has been saved locally.")
591
+ st.balloons()
592
+
593
+ # Save to Snowflake if enabled and credentials available
594
+ logger.info("🔄 FEEDBACK SAVE: Starting Snowflake save process...")
595
+ logger.info(f"📊 FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
596
+
597
+ try:
598
+ import os
599
+ snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
600
+ logger.info(f"🔍 SNOWFLAKE CHECK: enabled={snowflake_enabled}")
601
+
602
+ if snowflake_enabled:
603
+ if feedback_obj:
604
+ try:
605
+ from auditqa.reporting.snowflake_connector import save_to_snowflake
606
+ logger.info("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
607
+ print("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
608
+
609
+ if save_to_snowflake(feedback_obj):
610
+ logger.info("✅ SNOWFLAKE UI: Successfully saved to Snowflake")
611
+ print("✅ SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
612
+ st.success("✅ Feedback also saved to Snowflake!")
613
+ else:
614
+ logger.warning("⚠️ SNOWFLAKE UI: Save failed")
615
+ print("⚠️ SNOWFLAKE UI: Save failed") # Also print to terminal
616
+ st.warning("⚠️ Snowflake save failed, but local save succeeded")
617
+ except Exception as e:
618
+ logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
619
+ print(f"❌ SNOWFLAKE UI ERROR: {e}") # Also print to terminal
620
+ import traceback
621
+ traceback.print_exc() # Print full traceback to terminal
622
+ st.warning(f"⚠️ Could not save to Snowflake: {e}")
623
+ else:
624
+ logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
625
+ print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
626
+ st.warning("⚠️ Skipping Snowflake save (feedback object not created)")
627
+ else:
628
+ logger.info("💡 SNOWFLAKE UI: Integration disabled")
629
+ print("💡 SNOWFLAKE UI: Integration disabled") # Also print to terminal
630
+ st.info("💡 Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
631
+ except NameError as e:
632
+ import traceback
633
+ traceback.print_exc()
634
+ logger.error(f"❌ NameError in Snowflake save: {e}")
635
+ print(f"❌ NameError in Snowflake save: {e}") # Also print to terminal
636
+ st.warning(f"⚠️ Snowflake save error: {e}")
637
+ except Exception as e:
638
+ logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
639
+ print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") # Also print to terminal
640
+ st.warning(f"⚠️ Snowflake save error: {e}")
641
+
642
+ # Mark feedback as submitted to prevent resubmission
643
+ st.session_state.feedback_submitted = True
644
+
645
+ print("=" * 80)
646
+ print(f"✅ FEEDBACK SUBMISSION: Completed successfully")
647
+ print("=" * 80)
648
+
649
+ # Log file location
650
+ st.info(f"📁 Feedback saved to: {feedback_file}")
651
+
652
+ except Exception as e:
653
+ print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
654
+ print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
655
+ import traceback
656
+ traceback.print_exc()
657
+ st.error(f"❌ Error saving feedback: {e}")
658
+ st.write(f"Debug error: {str(e)}")
659
+
660
+ elif st.session_state.feedback_submitted:
661
+ st.success("✅ Feedback already submitted for this conversation!")
662
+ if st.button("🔄 Submit New Feedback", key="new_feedback_button"):
663
+ st.session_state.feedback_submitted = False
664
+ st.rerun()
665
+
666
+ # Display retrieval history stats
667
+ if st.session_state.rag_retrieval_history:
668
+ st.markdown("---")
669
+ st.markdown("#### 📊 Retrieval History")
670
+
671
+ with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
672
+ for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
673
+ st.markdown(f"**Retrieval #{idx}**")
674
+
675
+ # Display the actual RAG query
676
+ rag_query_expansion = entry.get("rag_query_expansion", "No query available")
677
+ st.code(rag_query_expansion, language="text")
678
+
679
+ # Display summary stats
680
+ st.json({
681
+ "conversation_length": len(entry.get("conversation_up_to", [])),
682
+ "documents_retrieved": len(entry.get("docs_retrieved", []))
683
+ })
684
+ st.markdown("---")
685
+
686
+ # Auto-scroll to bottom
687
+ st.markdown("""
688
+ <script>
689
+ window.scrollTo(0, document.body.scrollHeight);
690
+ </script>
691
+ """, unsafe_allow_html=True)
692
+
693
+ if __name__ == "__main__":
694
+ main()
multi_agent_chatbot.py ADDED
@@ -0,0 +1,1167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Agent RAG Chatbot using LangGraph
3
+
4
+ This system implements a 3-agent architecture:
5
+ 1. Main Agent: Handles conversation flow, follow-ups, and determines when to call RAG
6
+ 2. RAG Agent: Rewrites queries and applies filters for document retrieval
7
+ 3. Response Agent: Generates final answers from retrieved documents
8
+
9
+ Each agent has specialized prompts and responsibilities.
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import time
15
+ import logging
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+ from dataclasses import dataclass
19
+ from typing import Dict, List, Any, Optional, TypedDict
20
+
21
+
22
+ import re
23
+ from langchain_core.tools import tool
24
+ from langgraph.graph import StateGraph, END
25
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
26
+ from langchain_core.prompts import ChatPromptTemplate
27
+
28
+
29
+ from src.pipeline import PipelineManager
30
+ from src.config.loader import load_config
31
+ from src.llm.adapters import get_llm_client
32
+
33
+
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class QueryContext:
40
+ """Context extracted from conversation"""
41
+ has_district: bool = False
42
+ has_source: bool = False
43
+ has_year: bool = False
44
+ extracted_district: Optional[str] = None
45
+ extracted_source: Optional[str] = None
46
+ extracted_year: Optional[str] = None
47
+ ui_filters: Dict[str, List[str]] = None
48
+ confidence_score: float = 0.0
49
+ needs_follow_up: bool = False
50
+ follow_up_question: Optional[str] = None
51
+
52
+ class MultiAgentState(TypedDict):
53
+ """State for the multi-agent conversation flow"""
54
+ conversation_id: str
55
+ messages: List[Any]
56
+ current_query: str
57
+ query_context: Optional[QueryContext]
58
+ rag_query: Optional[str]
59
+ rag_filters: Optional[Dict[str, Any]]
60
+ retrieved_documents: Optional[List[Any]]
61
+ final_response: Optional[str]
62
+ agent_logs: List[str]
63
+ conversation_context: Dict[str, Any]
64
+ session_start_time: float
65
+ last_ai_message_time: float
66
+
67
+ class MultiAgentRAGChatbot:
68
+ """Multi-agent RAG chatbot with specialized agents"""
69
+
70
+ def __init__(self, config_path: str = "auditqa/config/settings.yaml"):
71
+ """Initialize the multi-agent chatbot"""
72
+ self.config = load_config(config_path)
73
+
74
+ # Get LLM provider from config
75
+ reader_config = self.config.get("reader", {})
76
+ default_type = reader_config.get("default_type", "INF_PROVIDERS")
77
+ provider_name = default_type.lower()
78
+
79
+ self.llm_adapter = get_llm_client(provider_name, self.config)
80
+
81
+ # Create a simple wrapper for LangChain compatibility
82
+ class LLMWrapper:
83
+ def __init__(self, adapter):
84
+ self.adapter = adapter
85
+
86
+ def invoke(self, messages):
87
+ # Convert LangChain messages to the format expected by the adapter
88
+ if isinstance(messages, list):
89
+ formatted_messages = []
90
+ for msg in messages:
91
+ if hasattr(msg, 'content'):
92
+ role = "user" if msg.__class__.__name__ == "HumanMessage" else "assistant"
93
+ formatted_messages.append({"role": role, "content": msg.content})
94
+ else:
95
+ formatted_messages.append({"role": "user", "content": str(msg)})
96
+ else:
97
+ formatted_messages = [{"role": "user", "content": str(messages)}]
98
+
99
+ # Use the adapter to get response
100
+ response = self.adapter.generate(formatted_messages)
101
+
102
+ # Return a mock response object
103
+ class MockResponse:
104
+ def __init__(self, content):
105
+ self.content = content
106
+
107
+ return MockResponse(response.content)
108
+
109
+ self.llm = LLMWrapper(self.llm_adapter)
110
+
111
+ # Initialize pipeline manager early to load models
112
+ logger.info("🔄 Initializing pipeline manager and loading models...")
113
+ self.pipeline_manager = PipelineManager(self.config)
114
+ logger.info("✅ Pipeline manager initialized and models loaded")
115
+
116
+ # Connect to vector store
117
+ logger.info("🔄 Connecting to vector store...")
118
+ if not self.pipeline_manager.connect_vectorstore():
119
+ logger.error("❌ Failed to connect to vector store")
120
+ raise RuntimeError("Vector store connection failed")
121
+ logger.info("✅ Vector store connected successfully")
122
+
123
+ # Load dynamic data
124
+ self._load_dynamic_data()
125
+
126
+ # Build the multi-agent graph
127
+ self.graph = self._build_graph()
128
+
129
+ # Conversations directory
130
+ self.conversations_dir = Path("conversations")
131
+ self.conversations_dir.mkdir(exist_ok=True)
132
+
133
+ logger.info("🤖 Multi-Agent RAG Chatbot initialized")
134
+
135
+ def _load_dynamic_data(self):
136
+ """Load dynamic data from filter_options.json and add_district_metadata.py"""
137
+ # Load filter options
138
+ try:
139
+ fo = Path("filter_options.json")
140
+ if fo.exists():
141
+ with open(fo) as f:
142
+ data = json.load(f)
143
+ self.year_whitelist = [str(y).strip() for y in data.get("years", [])]
144
+ self.source_whitelist = [str(s).strip() for s in data.get("sources", [])]
145
+ self.district_whitelist = [str(d).strip() for d in data.get("districts", [])]
146
+ else:
147
+ # Fallback to default values
148
+ self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
149
+ self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
150
+ self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
151
+ except Exception as e:
152
+ logger.warning(f"Could not load filter options: {e}")
153
+ self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
154
+ self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
155
+ self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
156
+
157
+ # Enrich district list from add_district_metadata.py
158
+ try:
159
+ from add_district_metadata import DistrictMetadataProcessor
160
+ proc = DistrictMetadataProcessor()
161
+ names = set()
162
+ for key, mapping in proc.district_mappings.items():
163
+ if getattr(mapping, 'is_district', True):
164
+ names.add(mapping.name)
165
+ if names:
166
+ merged = list(self.district_whitelist)
167
+ for n in sorted(names):
168
+ if n not in merged:
169
+ merged.append(n)
170
+ self.district_whitelist = merged
171
+ logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
172
+ except Exception as e:
173
+ logger.info(f"ℹ️ Could not enrich districts: {e}")
174
+
175
+ # Calculate current year dynamically
176
+ self.current_year = str(datetime.now().year)
177
+ self.previous_year = str(datetime.now().year - 1)
178
+
179
+ # Log the actual filter values for debugging
180
+ logger.info(f"📊 ACTUAL FILTER VALUES:")
181
+ logger.info(f" Years: {self.year_whitelist}")
182
+ logger.info(f" Sources: {self.source_whitelist}")
183
+ logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
184
+
185
+ def _build_graph(self) -> StateGraph:
186
+ """Build the multi-agent LangGraph"""
187
+ graph = StateGraph(MultiAgentState)
188
+
189
+ # Add nodes for each agent
190
+ graph.add_node("main_agent", self._main_agent)
191
+ graph.add_node("rag_agent", self._rag_agent)
192
+ graph.add_node("response_agent", self._response_agent)
193
+
194
+ # Define the flow
195
+ graph.set_entry_point("main_agent")
196
+
197
+ # Main agent decides next step
198
+ graph.add_conditional_edges(
199
+ "main_agent",
200
+ self._should_call_rag,
201
+ {
202
+ "follow_up": END,
203
+ "call_rag": "rag_agent"
204
+ }
205
+ )
206
+
207
+ # RAG agent calls response agent
208
+ graph.add_edge("rag_agent", "response_agent")
209
+
210
+ # Response agent returns to main agent for potential follow-ups
211
+ graph.add_edge("response_agent", "main_agent")
212
+
213
+ return graph.compile()
214
+
215
+ def _should_call_rag(self, state: MultiAgentState) -> str:
216
+ """Determine if we should call RAG or ask follow-up"""
217
+ # If we already have a final response (from response agent), end
218
+ if state.get("final_response"):
219
+ return "follow_up"
220
+
221
+ context = state["query_context"]
222
+ if context and context.needs_follow_up:
223
+ return "follow_up"
224
+ return "call_rag"
225
+
226
+ def _main_agent(self, state: MultiAgentState) -> MultiAgentState:
227
+ """Main Agent: Handles conversation flow and follow-ups"""
228
+ logger.info("🎯 MAIN AGENT: Starting analysis")
229
+
230
+ # If we already have a final response from response agent, end gracefully
231
+ if state.get("final_response"):
232
+ logger.info("🎯 MAIN AGENT: Final response already exists, ending conversation flow")
233
+ return state
234
+
235
+ query = state["current_query"]
236
+ messages = state["messages"]
237
+
238
+ logger.info(f"🎯 MAIN AGENT: Extracting UI filters from query")
239
+ ui_filters = self._extract_ui_filters(query)
240
+ logger.info(f"🎯 MAIN AGENT: UI filters extracted: {ui_filters}")
241
+
242
+ # Analyze query context
243
+ logger.info(f"🎯 MAIN AGENT: Analyzing query context")
244
+ context = self._analyze_query_context(query, messages, ui_filters)
245
+
246
+ # Log agent decision
247
+ state["agent_logs"].append(f"MAIN AGENT: Context analyzed - district={context.has_district}, source={context.has_source}, year={context.has_year}")
248
+ logger.info(f"🎯 MAIN AGENT: Context analysis complete - district={context.has_district}, source={context.has_source}, year={context.has_year}")
249
+
250
+ # Store context
251
+ state["query_context"] = context
252
+
253
+ # If follow-up needed, generate response
254
+ if context.needs_follow_up:
255
+ logger.info(f"🎯 MAIN AGENT: Follow-up needed, generating question")
256
+ response = context.follow_up_question
257
+ state["final_response"] = response
258
+ state["last_ai_message_time"] = time.time()
259
+ logger.info(f"🎯 MAIN AGENT: Follow-up question generated: {response[:100]}...")
260
+ else:
261
+ logger.info("🎯 MAIN AGENT: No follow-up needed, proceeding to RAG")
262
+
263
+ return state
264
+
265
+ def _rag_agent(self, state: MultiAgentState) -> MultiAgentState:
266
+ """RAG Agent: Rewrites queries and applies filters"""
267
+ logger.info("🔍 RAG AGENT: Starting query rewriting and filter preparation")
268
+
269
+ context = state["query_context"]
270
+ messages = state["messages"]
271
+
272
+ logger.info(f"🔍 RAG AGENT: Context received - district={context.has_district}, source={context.has_source}, year={context.has_year}")
273
+
274
+ # Rewrite query for RAG
275
+ logger.info(f"🔍 RAG AGENT: Rewriting query for optimal retrieval")
276
+ rag_query = self._rewrite_query_for_rag(messages, context)
277
+ logger.info(f"🔍 RAG AGENT: Query rewritten: '{rag_query}'")
278
+
279
+ # Build filters
280
+ logger.info(f"🔍 RAG AGENT: Building filters from context")
281
+ filters = self._build_filters(context)
282
+ logger.info(f"🔍 RAG AGENT: Filters built: {filters}")
283
+
284
+ # Log RAG preparation
285
+ state["agent_logs"].append(f"RAG AGENT: Query='{rag_query}', Filters={filters}")
286
+
287
+ # Store for response agent
288
+ state["rag_query"] = rag_query
289
+ state["rag_filters"] = filters
290
+
291
+ logger.info(f"🔍 RAG AGENT: Preparation complete, ready for retrieval")
292
+
293
+ return state
294
+
295
+ def _response_agent(self, state: MultiAgentState) -> MultiAgentState:
296
+ """Response Agent: Generates final answer from retrieved documents"""
297
+ logger.info("📝 RESPONSE AGENT: Starting document retrieval and answer generation")
298
+
299
+ rag_query = state["rag_query"]
300
+ filters = state["rag_filters"]
301
+
302
+ logger.info(f"📝 RESPONSE AGENT: Starting RAG retrieval with query: '{rag_query}'")
303
+ logger.info(f"📝 RESPONSE AGENT: Using filters: {filters}")
304
+
305
+ # Perform RAG retrieval
306
+ logger.info(f"📝 RESPONSE AGENT: Calling pipeline manager for retrieval")
307
+ logger.info(f"🔍 ACTUAL RAG QUERY: '{rag_query}'")
308
+ logger.info(f"🔍 ACTUAL FILTERS: {filters}")
309
+ try:
310
+ # Extract filenames from filters if present
311
+ filenames = filters.get("filenames") if filters else None
312
+
313
+ result = self.pipeline_manager.run(
314
+ query=rag_query,
315
+ sources=filters.get("sources") if filters else None,
316
+ auto_infer_filters=False,
317
+ filters=filters if filters else None
318
+ )
319
+
320
+ logger.info(f"📝 RESPONSE AGENT: RAG retrieval completed - {len(result.sources)} documents retrieved")
321
+ logger.info(f"🔍 RETRIEVAL DEBUG: Result type: {type(result)}")
322
+ logger.info(f"🔍 RETRIEVAL DEBUG: Result sources type: {type(result.sources)}")
323
+ # logger.info(f"🔍 RETRIEVAL DEBUG: Result metadata: {getattr(result, 'metadata', 'No metadata')}")
324
+
325
+ if len(result.sources) == 0:
326
+ logger.warning(f"⚠️ NO DOCUMENTS RETRIEVED: Query='{rag_query}', Filters={filters}")
327
+ logger.warning(f"⚠️ RETRIEVAL DEBUG: This could be due to:")
328
+ logger.warning(f" - Query too specific for available documents")
329
+ logger.warning(f" - Filters too restrictive")
330
+ logger.warning(f" - Vector store connection issues")
331
+ logger.warning(f" - Embedding model issues")
332
+ else:
333
+ logger.info(f"✅ DOCUMENTS RETRIEVED: {len(result.sources)} documents found")
334
+ for i, doc in enumerate(result.sources[:3]): # Log first 3 docs
335
+ logger.info(f" Doc {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...")
336
+
337
+ state["retrieved_documents"] = result.sources
338
+ state["agent_logs"].append(f"RESPONSE AGENT: Retrieved {len(result.sources)} documents")
339
+
340
+ # Check highest similarity score
341
+ highest_score = 0.0
342
+ if result.sources:
343
+ # Check reranked_score first (more accurate), fallback to original_score
344
+ for doc in result.sources:
345
+ score = doc.metadata.get('reranked_score') or doc.metadata.get('original_score', 0.0)
346
+ if score > highest_score:
347
+ highest_score = score
348
+
349
+ logger.info(f"📝 RESPONSE AGENT: Highest similarity score: {highest_score:.4f}")
350
+
351
+ # If highest score is too low, don't use retrieved documents
352
+ if highest_score <= 0.15:
353
+ logger.warning(f"⚠️ RESPONSE AGENT: Low similarity score ({highest_score:.4f} <= 0.15), using LLM knowledge only")
354
+ response = self._generate_conversational_response_without_docs(
355
+ state["current_query"],
356
+ state["messages"]
357
+ )
358
+ else:
359
+ # Generate conversational response with documents
360
+ logger.info(f"📝 RESPONSE AGENT: Generating conversational response from {len(result.sources)} documents")
361
+ response = self._generate_conversational_response(
362
+ state["current_query"],
363
+ result.sources,
364
+ result.answer,
365
+ state["messages"]
366
+ )
367
+
368
+ logger.info(f"📝 RESPONSE AGENT: Response generated: {response[:100]}...")
369
+
370
+ state["final_response"] = response
371
+ state["last_ai_message_time"] = time.time()
372
+
373
+ logger.info(f"📝 RESPONSE AGENT: Answer generation complete")
374
+
375
+ except Exception as e:
376
+ logger.error(f"❌ RESPONSE AGENT ERROR: {e}")
377
+ state["final_response"] = "I apologize, but I encountered an error while retrieving information. Please try again."
378
+ state["last_ai_message_time"] = time.time()
379
+
380
+ return state
381
+
382
+ def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
383
+ """Extract UI filters from query"""
384
+ filters = {}
385
+
386
+ # Look for FILTER CONTEXT in query
387
+ if "FILTER CONTEXT:" in query:
388
+ # Extract the entire filter section (until USER QUERY: or end of query)
389
+ filter_section = query.split("FILTER CONTEXT:")[1]
390
+ if "USER QUERY:" in filter_section:
391
+ filter_section = filter_section.split("USER QUERY:")[0]
392
+ filter_section = filter_section.strip()
393
+
394
+ # Parse sources
395
+ if "Sources:" in filter_section:
396
+ sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')][0]
397
+ sources_str = sources_line.split("Sources:")[1].strip()
398
+ if sources_str and sources_str != "None":
399
+ filters["sources"] = [s.strip() for s in sources_str.split(",")]
400
+
401
+ # Parse years
402
+ if "Years:" in filter_section:
403
+ years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')][0]
404
+ years_str = years_line.split("Years:")[1].strip()
405
+ if years_str and years_str != "None":
406
+ filters["years"] = [y.strip() for y in years_str.split(",")]
407
+
408
+ # Parse districts
409
+ if "Districts:" in filter_section:
410
+ districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')][0]
411
+ districts_str = districts_line.split("Districts:")[1].strip()
412
+ if districts_str and districts_str != "None":
413
+ filters["districts"] = [d.strip() for d in districts_str.split(",")]
414
+
415
+ # Parse filenames
416
+ if "Filenames:" in filter_section:
417
+ filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')][0]
418
+ filenames_str = filenames_line.split("Filenames:")[1].strip()
419
+ if filenames_str and filenames_str != "None":
420
+ filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
421
+
422
+ return filters
423
+
424
+ def _analyze_query_context(self, query: str, messages: List[Any], ui_filters: Dict[str, List[str]]) -> QueryContext:
425
+ """Analyze query context using LLM"""
426
+ logger.info(f"🔍 QUERY ANALYSIS: '{query[:50]}...' | UI filters: {ui_filters} | Messages: {len(messages)}")
427
+
428
+ # Build conversation context
429
+ conversation_context = ""
430
+ for i, msg in enumerate(messages[-6:]): # Last 6 messages
431
+ if isinstance(msg, HumanMessage):
432
+ conversation_context += f"User: {msg.content}\n"
433
+ elif isinstance(msg, AIMessage):
434
+ conversation_context += f"Assistant: {msg.content}\n"
435
+
436
+ # Create analysis prompt
437
+ analysis_prompt = ChatPromptTemplate.from_messages([
438
+ SystemMessage(content=f"""You are the Main Agent in an advanced multi-agent RAG system for audit report analysis.
439
+
440
+ 🎯 PRIMARY GOAL: Intelligently analyze user queries and determine the optimal conversation flow, whether that's answering directly, asking follow-ups, or proceeding to RAG retrieval.
441
+
442
+ 🧠 INTELLIGENCE LEVEL: You are a sophisticated conversational AI that can handle any type of user interaction - from greetings to complex audit queries.
443
+
444
+ 📊 YOUR EXPERTISE: You specialize in analyzing audit reports from various sources (Local Government, Ministry, Hospital, etc.) across different years and districts in Uganda.
445
+
446
+ 🔍 AVAILABLE FILTERS:
447
+ - Years: {', '.join(self.year_whitelist)}
448
+ - Current year: {self.current_year}, Previous year: {self.previous_year}
449
+ - Sources: {', '.join(self.source_whitelist)}
450
+ - Districts: {', '.join(self.district_whitelist[:50])}... (and {len(self.district_whitelist)-50} more)
451
+
452
+ 🎛️ UI FILTERS PROVIDED: {ui_filters}
453
+
454
+ 📋 UI FILTER HANDLING:
455
+ - If UI filters contain multiple values (e.g., districts: ['Lwengo', 'Kiboga']), extract ALL values
456
+ - For multiple districts: extract each district separately and validate each one
457
+ - For multiple years: extract each year separately and validate each one
458
+ - For multiple sources: extract each source separately and validate each one
459
+ - UI filters take PRIORITY over conversation context - use them first
460
+
461
+ 🧭 CONVERSATION FLOW INTELLIGENCE:
462
+
463
+ 1. **GREETINGS & GENERAL CHAT**:
464
+ - If user greets you ("Hi", "Hello", "How are you"), respond warmly and guide them to audit-related questions
465
+ - Example: "Hello! I'm here to help you analyze audit reports. What would you like to know about budget allocations, expenditures, or audit findings?"
466
+
467
+ 2. **EDGE CASES**:
468
+ - Handle "What can you do?", "Help", "I don't know what to ask" with helpful guidance
469
+ - Example: "I can help you analyze audit reports! Try asking about budget allocations, salary management, PDM implementation, or any specific audit findings."
470
+
471
+ 3. **AUDIT QUERIES**:
472
+ - Extract ONLY values that EXACTLY match the available lists above
473
+ - DO NOT hallucinate or infer values not in the lists
474
+ - If user mentions "salary payroll management" - this is NOT a valid source filter
475
+
476
+ **YEAR EXTRACTION**:
477
+ - If user mentions "2023" and it's in the years list - extract "2023"
478
+ - If user mentions "2022 / 23" - extract ["2022", "2023"] (as a JSON array)
479
+ - If user mentions "2022-2023" - extract ["2022", "2023"] (as a JSON array)
480
+ - If user mentions "latest couple of years" - extract the 2 most recent years from available data as JSON array
481
+ - Always return years as JSON arrays when multiple years are mentioned
482
+
483
+ **DISTRICT EXTRACTION**:
484
+ - If user mentions "Kampala" and it's in the districts list - extract "Kampala"
485
+ - If user mentions "Pader District" - extract "Pader" (remove "District" suffix)
486
+ - If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
487
+ - If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
488
+ - Always return districts as JSON arrays when multiple districts are mentioned
489
+ - If no exact matches found, set extracted values to null
490
+
491
+ 4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
492
+ - If UI provides filenames filter - ONLY use that, ignore all other filters (year, district, source)
493
+ - With filenames filter, no follow-ups needed - proceed directly to RAG
494
+ - When filenames are specified, skip filter inference entirely
495
+
496
+ 5. **HALLUCINATION PREVENTION**:
497
+ - If user asks about a specific report but NO filename is selected in UI and NONE is extracted from conversation - DO NOT hallucinate
498
+ - Clearly state: "I don't have any specific report selected. Could you please select a report from the list or tell me which report you'd like to analyze?"
499
+ - DO NOT pretend to know which report they mean
500
+ - DO NOT infer reports from context alone - only use explicitly mentioned reports
501
+
502
+ 6. **CONVERSATION CONTEXT AWARENESS**:
503
+ - ALWAYS consider the full conversation context when extracting filters
504
+ - If district was mentioned in previous messages, include it in current analysis
505
+ - If year was mentioned in previous messages, include it in current analysis
506
+ - If source was mentioned in previous messages, include it in current analysis
507
+ - Example: If conversation shows "User: Tell me about Pader District" then "User: 2023", extract both: district="Pader" and year="2023"
508
+
509
+ 5. **SMART FOLLOW-UP STRATEGY**:
510
+ - NEVER ask the same question twice in a row
511
+ - If user provides source info, ask for year or district next
512
+ - If user provides year info, ask for source or district next
513
+ - If user provides district info, ask for year or source next
514
+ - If user provides 2+ pieces of info, proceed to RAG instead of asking more
515
+ - Make follow-ups conversational and contextual, not robotic
516
+
517
+ 5. **DYNAMIC FOLLOW-UP EXAMPLES**:
518
+ - Budget queries: "What year are you interested in?" or "Which department - Local Government or Ministry?"
519
+ - PDM queries: "Which district are you interested in?" or "What year?"
520
+ - General queries: "Could you be more specific about what you'd like to know?"
521
+
522
+ 🎯 DECISION LOGIC:
523
+ - If query is a greeting/general chat → needs_follow_up: true, provide helpful guidance
524
+ - If query has 2+ pieces of info → needs_follow_up: false, proceed to RAG
525
+ - If query has 1 piece of info → needs_follow_up: true, ask for missing piece
526
+ - If query has 0 pieces of info → needs_follow_up: true, ask for clarification
527
+
528
+ RESPOND WITH JSON ONLY:
529
+ {{
530
+ "has_district": boolean,
531
+ "has_source": boolean,
532
+ "has_year": boolean,
533
+ "extracted_district": "single district name or JSON array of districts or null",
534
+ "extracted_source": "single source name or JSON array of sources or null",
535
+ "extracted_year": "single year or JSON array of years or null",
536
+ "confidence_score": 0.0-1.0,
537
+ "needs_follow_up": boolean,
538
+ "follow_up_question": "conversational question or helpful guidance or null"
539
+ }}"""),
540
+ HumanMessage(content=f"""Query: {query}
541
+
542
+ Conversation Context:
543
+ {conversation_context}
544
+
545
+ CRITICAL: You MUST analyze the FULL conversation context above, not just the current query.
546
+ - If ANY district was mentioned in previous messages, extract it
547
+ - If ANY year was mentioned in previous messages, extract it
548
+ - If ANY source was mentioned in previous messages, extract it
549
+ - Combine information from ALL messages in the conversation
550
+
551
+ Analyze this query using ONLY the exact values provided above:""")
552
+ ])
553
+
554
+ try:
555
+ response = self.llm.invoke(analysis_prompt.format_messages())
556
+
557
+ # Clean the response to extract JSON
558
+ content = response.content.strip()
559
+ if content.startswith("```json"):
560
+ # Remove markdown formatting
561
+ content = content.replace("```json", "").replace("```", "").strip()
562
+ elif content.startswith("```"):
563
+ # Remove generic markdown formatting
564
+ content = content.replace("```", "").strip()
565
+
566
+ # Clean and parse JSON with better error handling
567
+ try:
568
+ # Remove comments (// and /* */) from JSON
569
+ import re
570
+ # Remove single-line comments
571
+ content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
572
+ # Remove multi-line comments
573
+ content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
574
+
575
+ analysis = json.loads(content)
576
+ logger.info(f"🔍 QUERY ANALYSIS: ✅ Parsed successfully")
577
+ except json.JSONDecodeError as e:
578
+ logger.error(f"❌ JSON parsing failed: {e}")
579
+ logger.error(f"❌ Raw content: {content[:200]}...")
580
+
581
+ # Try to extract JSON from text if embedded
582
+ import re
583
+ json_match = re.search(r'\{.*\}', content, re.DOTALL)
584
+ if json_match:
585
+ try:
586
+ # Clean the extracted JSON
587
+ cleaned_json = json_match.group()
588
+ cleaned_json = re.sub(r'//.*?$', '', cleaned_json, flags=re.MULTILINE)
589
+ cleaned_json = re.sub(r'/\*.*?\*/', '', cleaned_json, flags=re.DOTALL)
590
+ analysis = json.loads(cleaned_json)
591
+ logger.info(f"🔍 QUERY ANALYSIS: ✅ Extracted and cleaned JSON from text")
592
+ except json.JSONDecodeError as e2:
593
+ logger.error(f"❌ Failed to extract JSON from text: {e2}")
594
+ # Return fallback context
595
+ context = QueryContext(
596
+ has_district=False,
597
+ has_source=False,
598
+ has_year=False,
599
+ extracted_district=None,
600
+ extracted_source=None,
601
+ extracted_year=None,
602
+ confidence_score=0.0,
603
+ needs_follow_up=True,
604
+ follow_up_question="I apologize, but I'm having trouble processing your request. Could you please rephrase it or ask for help?"
605
+ )
606
+ return context
607
+ else:
608
+ # Return fallback context
609
+ context = QueryContext(
610
+ has_district=False,
611
+ has_source=False,
612
+ has_year=False,
613
+ extracted_district=None,
614
+ extracted_source=None,
615
+ extracted_year=None,
616
+ confidence_score=0.0,
617
+ needs_follow_up=True,
618
+ follow_up_question="I apologize, but I'm having trouble processing your request. Could you please rephrase it or ask for help?"
619
+ )
620
+ return context
621
+
622
+ # Validate extracted values against whitelists
623
+ extracted_district = analysis.get("extracted_district")
624
+ extracted_source = analysis.get("extracted_source")
625
+ extracted_year = analysis.get("extracted_year")
626
+
627
+ logger.info(f"🔍 QUERY ANALYSIS: Raw extracted values - district: {extracted_district}, source: {extracted_source}, year: {extracted_year}")
628
+
629
+ # Validate district (handle both single values and arrays)
630
+ if extracted_district:
631
+ if isinstance(extracted_district, list):
632
+ # Validate each district in the array
633
+ valid_districts = []
634
+ for district in extracted_district:
635
+ if district in self.district_whitelist:
636
+ valid_districts.append(district)
637
+ else:
638
+ # Try removing "District" suffix
639
+ district_name = district.replace(" District", "").replace(" district", "")
640
+ if district_name in self.district_whitelist:
641
+ valid_districts.append(district_name)
642
+
643
+ if valid_districts:
644
+ extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
645
+ logger.info(f"🔍 QUERY ANALYSIS: Extracted districts: {extracted_district}")
646
+ else:
647
+ logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
648
+ extracted_district = None
649
+ else:
650
+ # Single district validation
651
+ if extracted_district not in self.district_whitelist:
652
+ # Try removing "District" suffix
653
+ district_name = extracted_district.replace(" District", "").replace(" district", "")
654
+ if district_name in self.district_whitelist:
655
+ logger.info(f"🔍 QUERY ANALYSIS: Normalized district '{extracted_district}' to '{district_name}'")
656
+ extracted_district = district_name
657
+ else:
658
+ logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
659
+ extracted_district = None
660
+
661
+ # Validate source (handle both single values and arrays)
662
+ if extracted_source:
663
+ if isinstance(extracted_source, list):
664
+ # Validate each source in the array
665
+ valid_sources = []
666
+ for source in extracted_source:
667
+ if source in self.source_whitelist:
668
+ valid_sources.append(source)
669
+ else:
670
+ logger.warning(f"⚠️ Invalid source in array: '{source}' not in whitelist")
671
+
672
+ if valid_sources:
673
+ extracted_source = valid_sources[0] if len(valid_sources) == 1 else valid_sources
674
+ logger.info(f"🔍 QUERY ANALYSIS: Extracted sources: {extracted_source}")
675
+ else:
676
+ logger.warning(f"⚠️ No valid sources found in: '{extracted_source}'")
677
+ extracted_source = None
678
+ else:
679
+ # Single source validation
680
+ if extracted_source not in self.source_whitelist:
681
+ logger.warning(f"⚠️ Invalid source extracted: '{extracted_source}' not in whitelist")
682
+ extracted_source = None
683
+
684
+ # Validate year (handle both single values and arrays)
685
+ if extracted_year:
686
+ if isinstance(extracted_year, list):
687
+ # Validate each year in the array
688
+ valid_years = []
689
+ for year in extracted_year:
690
+ year_str = str(year)
691
+ if year_str in self.year_whitelist:
692
+ valid_years.append(year_str)
693
+
694
+ if valid_years:
695
+ extracted_year = valid_years[0] if len(valid_years) == 1 else valid_years
696
+ logger.info(f"🔍 QUERY ANALYSIS: Extracted years: {extracted_year}")
697
+ else:
698
+ logger.warning(f"⚠️ No valid years found in: '{extracted_year}'")
699
+ extracted_year = None
700
+ else:
701
+ # Single year validation
702
+ year_str = str(extracted_year)
703
+ if year_str not in self.year_whitelist:
704
+ logger.warning(f"⚠️ Invalid year extracted: '{extracted_year}' not in whitelist")
705
+ extracted_year = None
706
+ else:
707
+ extracted_year = year_str
708
+
709
+ logger.info(f"🔍 QUERY ANALYSIS: Validated values - district: {extracted_district}, source: {extracted_source}, year: {extracted_year}")
710
+
711
+ # Create QueryContext object
712
+ context = QueryContext(
713
+ has_district=bool(extracted_district),
714
+ has_source=bool(extracted_source),
715
+ has_year=bool(extracted_year),
716
+ extracted_district=extracted_district,
717
+ extracted_source=extracted_source,
718
+ extracted_year=extracted_year,
719
+ ui_filters=ui_filters,
720
+ confidence_score=analysis.get("confidence_score", 0.0),
721
+ needs_follow_up=analysis.get("needs_follow_up", False),
722
+ follow_up_question=analysis.get("follow_up_question")
723
+ )
724
+
725
+ logger.info(f"🔍 QUERY ANALYSIS: Analysis complete - needs_follow_up: {context.needs_follow_up}, confidence: {context.confidence_score}")
726
+
727
+ # If filenames are provided in UI, skip follow-ups and proceed to RAG
728
+ if ui_filters and ui_filters.get("filenames"):
729
+ logger.info(f"🔍 QUERY ANALYSIS: Filenames provided, skipping follow-ups, proceeding to RAG")
730
+ context.needs_follow_up = False
731
+ context.follow_up_question = None
732
+
733
+ # Additional smart decision logic
734
+ if context.needs_follow_up:
735
+ # Check if we have enough information to proceed
736
+ info_count = sum([
737
+ bool(context.extracted_district),
738
+ bool(context.extracted_source),
739
+ bool(context.extracted_year)
740
+ ])
741
+
742
+ # Check if user is asking for more info vs providing it
743
+ query_lower = query.lower()
744
+ is_requesting_info = any(phrase in query_lower for phrase in [
745
+ "please provide", "could you provide", "can you provide",
746
+ "what is", "what are", "how much", "which", "what year",
747
+ "what district", "what source", "tell me about"
748
+ ])
749
+
750
+ # If we have 2+ pieces of info AND user is not requesting more info, proceed to RAG
751
+ if info_count >= 2 and not is_requesting_info:
752
+ logger.info(f"🔍 QUERY ANALYSIS: Smart override - have {info_count} pieces of info and user not requesting more, proceeding to RAG")
753
+ context.needs_follow_up = False
754
+ context.follow_up_question = None
755
+ elif info_count >= 2 and is_requesting_info:
756
+ logger.info(f"🔍 QUERY ANALYSIS: User requesting more info despite having {info_count} pieces, proceeding to RAG with comprehensive answer")
757
+ context.needs_follow_up = False
758
+ context.follow_up_question = None
759
+
760
+ return context
761
+
762
+ except Exception as e:
763
+ logger.error(f"❌ Query analysis failed: {e}")
764
+ # Fallback: proceed with RAG
765
+ return QueryContext(
766
+ has_district=bool(ui_filters.get("districts")),
767
+ has_source=bool(ui_filters.get("sources")),
768
+ has_year=bool(ui_filters.get("years")),
769
+ ui_filters=ui_filters,
770
+ confidence_score=0.5,
771
+ needs_follow_up=False
772
+ )
773
+
774
+ def _rewrite_query_for_rag(self, messages: List[Any], context: QueryContext) -> str:
775
+ """Rewrite query for optimal RAG retrieval"""
776
+ logger.info("🔄 QUERY REWRITING: Starting query rewrite for RAG")
777
+ logger.info(f"🔄 QUERY REWRITING: Processing {len(messages)} messages")
778
+
779
+ # Build conversation context
780
+ logger.info(f"🔄 QUERY REWRITING: Building conversation context from last 6 messages")
781
+ conversation_lines = []
782
+ for i, msg in enumerate(messages[-6:]):
783
+ if isinstance(msg, HumanMessage):
784
+ conversation_lines.append(f"User: {msg.content}")
785
+ logger.info(f"🔄 QUERY REWRITING: Message {i+1}: User - {msg.content[:50]}...")
786
+ elif isinstance(msg, AIMessage):
787
+ conversation_lines.append(f"Assistant: {msg.content}")
788
+ logger.info(f"🔄 QUERY REWRITING: Message {i+1}: Assistant - {msg.content[:50]}...")
789
+
790
+ convo_text = "\n".join(conversation_lines)
791
+ logger.info(f"🔄 QUERY REWRITING: Conversation context built ({len(convo_text)} chars)")
792
+
793
+ # Create rewrite prompt
794
+ rewrite_prompt = ChatPromptTemplate.from_messages([
795
+ SystemMessage(content=f"""You are a query rewriter for RAG retrieval.
796
+
797
+ GOAL: Create the best possible search query for document retrieval.
798
+
799
+ CRITICAL RULES:
800
+ 1. Focus on the core information need from the conversation
801
+ 2. Remove meta-verbs like "summarize", "list", "compare", "how much", "what" - keep the content focus
802
+ 3. DO NOT include filter details (years, districts, sources) - these are applied separately as filters
803
+ 4. DO NOT include specific years, district names, or source types in the query
804
+ 5. Output ONE clear sentence suitable for vector search
805
+ 6. Keep it generic and focused on the topic/subject matter
806
+
807
+ EXAMPLES:
808
+ - "What are the top challenges in budget allocation?" → "budget allocation challenges"
809
+ - "How were PDM administrative costs utilized in 2023?" → "PDM administrative costs utilization"
810
+ - "Compare salary management across districts" → "salary management"
811
+ - "How much was budget allocation for Local Government in 2023?" → "budget allocation"
812
+
813
+ OUTPUT FORMAT:
814
+ Provide your response in this exact format:
815
+
816
+ EXPLANATION: [Your reasoning here]
817
+ QUERY: [One clean sentence for retrieval]
818
+
819
+ The QUERY line will be extracted and used directly for RAG retrieval."""),
820
+ HumanMessage(content=f"""Conversation:
821
+ {convo_text}
822
+
823
+ Rewrite the best retrieval query:""")
824
+ ])
825
+
826
+ try:
827
+ logger.info(f"🔄 QUERY REWRITING: Calling LLM for query rewrite")
828
+ response = self.llm.invoke(rewrite_prompt.format_messages())
829
+ logger.info(f"🔄 QUERY REWRITING: LLM response received: {response.content[:100]}...")
830
+
831
+ rewritten = response.content.strip()
832
+
833
+ # Extract only the QUERY line from the structured response
834
+ lines = rewritten.split('\n')
835
+ query_line = None
836
+ for line in lines:
837
+ if line.strip().startswith('QUERY:'):
838
+ query_line = line.replace('QUERY:', '').strip()
839
+ break
840
+
841
+ if query_line and len(query_line) > 5:
842
+ logger.info(f"🔄 QUERY REWRITING: Query rewritten successfully: '{query_line[:50]}...'")
843
+ return query_line
844
+ else:
845
+ logger.info(f"🔄 QUERY REWRITING: No QUERY line found or too short, using fallback")
846
+ # Fallback to last user message
847
+ for msg in reversed(messages):
848
+ if isinstance(msg, HumanMessage):
849
+ logger.info(f"🔄 QUERY REWRITING: Using fallback message: '{msg.content[:50]}...'")
850
+ return msg.content
851
+ logger.info(f"🔄 QUERY REWRITING: Using default fallback")
852
+ return "audit report information"
853
+
854
+ except Exception as e:
855
+ logger.error(f"❌ QUERY REWRITING: Error during rewrite: {e}")
856
+ # Fallback
857
+ for msg in reversed(messages):
858
+ if isinstance(msg, HumanMessage):
859
+ logger.info(f"🔄 QUERY REWRITING: Using error fallback message: '{msg.content[:50]}...'")
860
+ return msg.content
861
+ logger.info(f"🔄 QUERY REWRITING: Using default error fallback")
862
+ return "audit report information"
863
+
864
+ def _build_filters(self, context: QueryContext) -> Dict[str, Any]:
865
+ """Build filters for RAG retrieval"""
866
+ logger.info("🔧 FILTER BUILDING: Starting filter construction")
867
+ filters = {}
868
+
869
+ # Check for filename filtering first (mutually exclusive)
870
+ if context.ui_filters and context.ui_filters.get("filenames"):
871
+ logger.info(f"🔧 FILTER BUILDING: Filename filtering requested (mutually exclusive mode)")
872
+ filters["filenames"] = context.ui_filters["filenames"]
873
+ logger.info(f"🔧 FILTER BUILDING: Added filenames filter: {context.ui_filters['filenames']}")
874
+ logger.info(f"🔧 FILTER BUILDING: Final filters: {filters}")
875
+ return filters # Return early, skip all other filters
876
+
877
+ # UI filters take priority, but merge with extracted context if UI filters are incomplete
878
+ if context.ui_filters:
879
+ logger.info(f"🔧 FILTER BUILDING: UI filters present: {context.ui_filters}")
880
+
881
+ # Add UI filters first
882
+ if context.ui_filters.get("sources"):
883
+ filters["sources"] = context.ui_filters["sources"]
884
+ logger.info(f"🔧 FILTER BUILDING: Added sources filter from UI: {context.ui_filters['sources']}")
885
+
886
+ if context.ui_filters.get("years"):
887
+ filters["year"] = context.ui_filters["years"]
888
+ logger.info(f"🔧 FILTER BUILDING: Added years filter from UI: {context.ui_filters['years']}")
889
+
890
+ if context.ui_filters.get("districts"):
891
+ # Normalize district names to title case (match Qdrant metadata format)
892
+ normalized_districts = [d.title() for d in context.ui_filters['districts']]
893
+ filters["district"] = normalized_districts
894
+ logger.info(f"🔧 FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} → normalized: {normalized_districts}")
895
+
896
+ # Merge with extracted context for missing filters
897
+ if not filters.get("year") and context.extracted_year:
898
+ # Handle both single values and arrays
899
+ if isinstance(context.extracted_year, list):
900
+ filters["year"] = context.extracted_year
901
+ else:
902
+ filters["year"] = [context.extracted_year]
903
+ logger.info(f"🔧 FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
904
+
905
+ if not filters.get("district") and context.extracted_district:
906
+ # Handle both single values and arrays
907
+ if isinstance(context.extracted_district, list):
908
+ # Normalize district names to title case (match Qdrant metadata format)
909
+ normalized = [d.title() for d in context.extracted_district]
910
+ filters["district"] = normalized
911
+ else:
912
+ filters["district"] = [context.extracted_district.title()]
913
+ logger.info(f"🔧 FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
914
+
915
+ if not filters.get("sources") and context.extracted_source:
916
+ # Handle both single values and arrays
917
+ if isinstance(context.extracted_source, list):
918
+ filters["sources"] = context.extracted_source
919
+ else:
920
+ filters["sources"] = [context.extracted_source]
921
+ logger.info(f"🔧 FILTER BUILDING: Added extracted source filter (UI missing): {context.extracted_source}")
922
+ else:
923
+ logger.info(f"🔧 FILTER BUILDING: No UI filters, using extracted context")
924
+ # Use extracted context
925
+ if context.extracted_source:
926
+ # Handle both single values and arrays
927
+ if isinstance(context.extracted_source, list):
928
+ filters["sources"] = context.extracted_source
929
+ else:
930
+ filters["sources"] = [context.extracted_source]
931
+ logger.info(f"🔧 FILTER BUILDING: Added extracted source filter: {context.extracted_source}")
932
+
933
+ if context.extracted_year:
934
+ # Handle both single values and arrays
935
+ if isinstance(context.extracted_year, list):
936
+ filters["year"] = context.extracted_year
937
+ else:
938
+ filters["year"] = [context.extracted_year]
939
+ logger.info(f"🔧 FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
940
+
941
+ if context.extracted_district:
942
+ # Handle both single values and arrays
943
+ if isinstance(context.extracted_district, list):
944
+ filters["district"] = context.extracted_district
945
+ else:
946
+ filters["district"] = [context.extracted_district]
947
+ logger.info(f"🔧 FILTER BUILDING: Added extracted district filter: {context.extracted_district}")
948
+
949
+ logger.info(f"🔧 FILTER BUILDING: Final filters: {filters}")
950
+ return filters
951
+
952
+ def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any]) -> str:
953
+ """Generate conversational response from RAG results"""
954
+ logger.info("💬 RESPONSE GENERATION: Starting conversational response generation")
955
+ logger.info(f"💬 RESPONSE GENERATION: Processing {len(documents)} documents")
956
+ logger.info(f"💬 RESPONSE GENERATION: Query: '{query[:50]}...'")
957
+
958
+ # Create response prompt
959
+ logger.info(f"💬 RESPONSE GENERATION: Building response prompt")
960
+ response_prompt = ChatPromptTemplate.from_messages([
961
+ SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
962
+
963
+ RULES:
964
+ 1. Answer the user's question directly and clearly
965
+ 2. Use the retrieved documents as evidence
966
+ 3. Be conversational, not technical
967
+ 4. Don't mention scores, retrieval details, or technical implementation
968
+ 5. If relevant documents were found, reference them naturally
969
+ 6. If no relevant documents, explain based on your knowledge (if you have it) or just say you do not have enough information.
970
+ 7. If the passages have useful facts or numbers, use them in your answer.
971
+ 8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
972
+ 9. Do not use the sentence 'Doc i says ...' to say where information came from.
973
+ 10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
974
+ 11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
975
+ 12. If it makes sense, use bullet points and lists to make your answers easier to understand.
976
+ 13. You do not need to use every passage. Only use the ones that help answer the question.
977
+ 14. If the documents do not have the information needed to answer the question, just say you do not have enough information.
978
+
979
+
980
+ TONE: Professional but friendly, like talking to a colleague."""),
981
+ HumanMessage(content=f"""User Question: {query}
982
+
983
+ Retrieved Documents: {len(documents)} documents found
984
+
985
+ RAG Answer: {rag_answer}
986
+
987
+ Generate a conversational response:""")
988
+ ])
989
+
990
+ try:
991
+ logger.info(f"💬 RESPONSE GENERATION: Calling LLM for final response")
992
+ response = self.llm.invoke(response_prompt.format_messages())
993
+ logger.info(f"💬 RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
994
+ return response.content.strip()
995
+ except Exception as e:
996
+ logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
997
+ logger.info(f"💬 RESPONSE GENERATION: Using RAG answer as fallback")
998
+ return rag_answer # Fallback to RAG answer
999
+
1000
+ def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
1001
+ """Generate conversational response using only LLM knowledge and conversation history"""
1002
+ logger.info("💬 RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
1003
+ logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Query: '{query[:50]}...'")
1004
+
1005
+ # Build conversation context
1006
+ conversation_context = ""
1007
+ for i, msg in enumerate(messages[-6:]): # Last 6 messages for context
1008
+ if isinstance(msg, HumanMessage):
1009
+ conversation_context += f"User: {msg.content}\n"
1010
+ elif isinstance(msg, AIMessage):
1011
+ conversation_context += f"Assistant: {msg.content}\n"
1012
+
1013
+ # Create response prompt
1014
+ logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Building response prompt")
1015
+ response_prompt = ChatPromptTemplate.from_messages([
1016
+ SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
1017
+
1018
+ RULES:
1019
+ 1. Answer the user's question directly and clearly based on your knowledge
1020
+ 2. Use conversation history for context
1021
+ 3. Be conversational, not technical
1022
+ 4. Acknowledge if the answer is based on general knowledge rather than specific documents
1023
+ 5. Stay professional but friendly
1024
+
1025
+ TONE: Professional but friendly, like talking to a colleague."""),
1026
+ HumanMessage(content=f"""Current Question: {query}
1027
+
1028
+ Conversation History:
1029
+ {conversation_context}
1030
+
1031
+ Generate a conversational response based on your knowledge:""")
1032
+ ])
1033
+
1034
+ try:
1035
+ logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Calling LLM")
1036
+ response = self.llm.invoke(response_prompt.format_messages())
1037
+ logger.info(f"💬 RESPONSE GENERATION (NO DOCS): LLM response received: {response.content[:100]}...")
1038
+ return response.content.strip()
1039
+ except Exception as e:
1040
+ logger.error(f"❌ RESPONSE GENERATION (NO DOCS): Error during generation: {e}")
1041
+ return "I apologize, but I encountered an error. Please try asking your question differently."
1042
+
1043
+ def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
1044
+ """Main chat interface"""
1045
+ logger.info(f"💬 MULTI-AGENT CHAT: Processing '{user_input[:50]}...'")
1046
+
1047
+ # Load conversation
1048
+ logger.info(f"💬 MULTI-AGENT CHAT: Loading conversation {conversation_id}")
1049
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
1050
+ conversation = self._load_conversation(conversation_file)
1051
+ logger.info(f"💬 MULTI-AGENT CHAT: Loaded {len(conversation['messages'])} previous messages")
1052
+
1053
+ # Add user message
1054
+ conversation["messages"].append(HumanMessage(content=user_input))
1055
+ logger.info(f"💬 MULTI-AGENT CHAT: Added user message to conversation")
1056
+
1057
+ # Prepare state
1058
+ logger.info(f"💬 MULTI-AGENT CHAT: Preparing state for graph execution")
1059
+ state = MultiAgentState(
1060
+ conversation_id=conversation_id,
1061
+ messages=conversation["messages"],
1062
+ current_query=user_input,
1063
+ query_context=None,
1064
+ rag_query=None,
1065
+ rag_filters=None,
1066
+ retrieved_documents=None,
1067
+ final_response=None,
1068
+ agent_logs=[],
1069
+ conversation_context=conversation.get("context", {}),
1070
+ session_start_time=conversation["session_start_time"],
1071
+ last_ai_message_time=conversation["last_ai_message_time"]
1072
+ )
1073
+
1074
+ # Run multi-agent graph
1075
+ logger.info(f"💬 MULTI-AGENT CHAT: Executing multi-agent graph")
1076
+ final_state = self.graph.invoke(state)
1077
+ logger.info(f"💬 MULTI-AGENT CHAT: Graph execution completed")
1078
+
1079
+ # Add AI response to conversation
1080
+ if final_state["final_response"]:
1081
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
1082
+ logger.info(f"💬 MULTI-AGENT CHAT: Added AI response to conversation")
1083
+
1084
+ # Update conversation
1085
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
1086
+ conversation["context"] = final_state["conversation_context"]
1087
+
1088
+ # Save conversation
1089
+ logger.info(f"💬 MULTI-AGENT CHAT: Saving conversation")
1090
+ self._save_conversation(conversation_file, conversation)
1091
+
1092
+ logger.info("✅ MULTI-AGENT CHAT: Completed")
1093
+
1094
+ # Return response and RAG results
1095
+ return {
1096
+ 'response': final_state["final_response"],
1097
+ 'rag_result': {
1098
+ 'sources': final_state["retrieved_documents"] or [],
1099
+ 'answer': final_state["final_response"]
1100
+ },
1101
+ 'agent_logs': final_state["agent_logs"],
1102
+ 'actual_rag_query': final_state.get("rag_query", "")
1103
+ }
1104
+
1105
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
1106
+ """Load conversation from file"""
1107
+ if conversation_file.exists():
1108
+ try:
1109
+ with open(conversation_file) as f:
1110
+ data = json.load(f)
1111
+ # Convert message dicts back to LangChain messages
1112
+ messages = []
1113
+ for msg_data in data.get("messages", []):
1114
+ if msg_data["type"] == "human":
1115
+ messages.append(HumanMessage(content=msg_data["content"]))
1116
+ elif msg_data["type"] == "ai":
1117
+ messages.append(AIMessage(content=msg_data["content"]))
1118
+ data["messages"] = messages
1119
+ return data
1120
+ except Exception as e:
1121
+ logger.warning(f"Could not load conversation: {e}")
1122
+
1123
+ # Return default conversation
1124
+ return {
1125
+ "messages": [],
1126
+ "session_start_time": time.time(),
1127
+ "last_ai_message_time": time.time(),
1128
+ "context": {}
1129
+ }
1130
+
1131
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
1132
+ """Save conversation to file"""
1133
+ try:
1134
+ # Convert messages to serializable format
1135
+ messages_data = []
1136
+ for msg in conversation["messages"]:
1137
+ if isinstance(msg, HumanMessage):
1138
+ messages_data.append({"type": "human", "content": msg.content})
1139
+ elif isinstance(msg, AIMessage):
1140
+ messages_data.append({"type": "ai", "content": msg.content})
1141
+
1142
+ conversation_data = {
1143
+ "messages": messages_data,
1144
+ "session_start_time": conversation["session_start_time"],
1145
+ "last_ai_message_time": conversation["last_ai_message_time"],
1146
+ "context": conversation.get("context", {})
1147
+ }
1148
+
1149
+ with open(conversation_file, 'w') as f:
1150
+ json.dump(conversation_data, f, indent=2)
1151
+
1152
+ except Exception as e:
1153
+ logger.error(f"Could not save conversation: {e}")
1154
+
1155
+
1156
+ def get_multi_agent_chatbot():
1157
+ """Get multi-agent chatbot instance"""
1158
+ return MultiAgentRAGChatbot()
1159
+
1160
+ if __name__ == "__main__":
1161
+ # Test the multi-agent system
1162
+ chatbot = MultiAgentRAGChatbot()
1163
+
1164
+ # Test conversation
1165
+ result = chatbot.chat("List me top 10 challenges in budget allocation for the last 3 years")
1166
+ print("Response:", result['response'])
1167
+ print("Agent Logs:", result['agent_logs'])
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ langchain>=0.1.0
3
+ langchain-core>=0.1.0
4
+ langgraph>=0.0.20
5
+ qdrant-client>=1.7.0
6
+ python-dotenv>=1.0.0
7
+ openai>=1.0.0
8
+ snowflake-connector-python>=4.0.0
9
+ pydantic>=2.0.0
smart_chatbot.py ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent RAG Chatbot with Smart Query Analysis and Conversation Management
3
+
4
+ This chatbot provides intelligent conversation flow with:
5
+ - Smart query analysis and expansion
6
+ - Single LangSmith conversation traces
7
+ - Local conversation logging
8
+ - Context-aware RAG retrieval
9
+ - Natural conversation without technical jargon
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import time
15
+ import logging
16
+ from pathlib import Path
17
+ from dataclasses import dataclass
18
+ from datetime import datetime, timedelta
19
+ from typing import Dict, List, Any, Optional, TypedDict
20
+
21
+
22
+ import re
23
+ from langgraph.graph import StateGraph, END
24
+ from langchain_core.prompts import ChatPromptTemplate
25
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
26
+
27
+ from src.pipeline import PipelineManager
28
+ from src.config.loader import load_config
29
+
30
+
31
+ @dataclass
32
+ class QueryAnalysis:
33
+ """Analysis result of a user query"""
34
+ has_district: bool
35
+ has_source: bool
36
+ has_year: bool
37
+ extracted_district: Optional[str]
38
+ extracted_source: Optional[str]
39
+ extracted_year: Optional[str]
40
+ confidence_score: float
41
+ can_answer_directly: bool
42
+ missing_filters: List[str]
43
+ suggested_follow_up: Optional[str]
44
+ expanded_query: Optional[str] = None # Query expansion for better RAG
45
+
46
+
47
+ class ConversationState(TypedDict):
48
+ """State for the conversation flow"""
49
+ conversation_id: str
50
+ messages: List[Any]
51
+ current_query: str
52
+ query_analysis: Optional[QueryAnalysis]
53
+ rag_query: Optional[str]
54
+ rag_result: Optional[Any]
55
+ final_response: Optional[str]
56
+ conversation_context: Dict[str, Any] # Store conversation context
57
+ session_start_time: float
58
+ last_ai_message_time: float
59
+
60
+
61
+ class IntelligentRAGChatbot:
62
+ """Intelligent chatbot with smart query analysis and conversation management"""
63
+
64
+ def __init__(self, suppress_logs=False):
65
+ """Initialize the intelligent chatbot"""
66
+ # Setup logger to avoid cluttering UI
67
+ self.logger = logging.getLogger(__name__)
68
+ if suppress_logs:
69
+ self.logger.setLevel(logging.CRITICAL) # Suppress all logs
70
+ else:
71
+ self.logger.setLevel(logging.INFO)
72
+ if not self.logger.handlers:
73
+ handler = logging.StreamHandler()
74
+ formatter = logging.Formatter('%(message)s')
75
+ handler.setFormatter(formatter)
76
+ self.logger.addHandler(handler)
77
+
78
+ self.logger.info("🤖 INITIALIZING: Intelligent RAG Chatbot")
79
+
80
+ # Load configuration first
81
+ self.config = load_config()
82
+
83
+ # Use the same LLM configuration as the existing system
84
+ from auditqa.llm.adapters import get_llm_client
85
+
86
+ # Get LLM client using the same configuration
87
+ reader_config = self.config.get("reader", {})
88
+ default_type = reader_config.get("default_type", "INF_PROVIDERS")
89
+
90
+ # Convert to lowercase as that's how it's registered
91
+ provider_name = default_type.lower()
92
+
93
+ self.llm_adapter = get_llm_client(provider_name, self.config)
94
+
95
+ # Create a simple wrapper for LangChain compatibility
96
+ class LLMWrapper:
97
+ def __init__(self, adapter):
98
+ self.adapter = adapter
99
+
100
+ def invoke(self, messages):
101
+ # Convert LangChain messages to the format expected by the adapter
102
+ if isinstance(messages, list):
103
+ # Convert LangChain messages to dict format
104
+ message_dicts = []
105
+ for msg in messages:
106
+ if hasattr(msg, 'content'):
107
+ role = "user" if isinstance(msg, HumanMessage) else "assistant"
108
+ message_dicts.append({"role": role, "content": msg.content})
109
+ else:
110
+ message_dicts.append({"role": "user", "content": str(msg)})
111
+ else:
112
+ # Single message
113
+ message_dicts = [{"role": "user", "content": str(messages)}]
114
+
115
+ # Use the adapter to generate response
116
+ llm_response = self.adapter.generate(message_dicts)
117
+
118
+ # Return in LangChain format
119
+ class MockResponse:
120
+ def __init__(self, content):
121
+ self.content = content
122
+
123
+ return MockResponse(llm_response.content)
124
+
125
+ self.llm = LLMWrapper(self.llm_adapter)
126
+
127
+ # Initialize pipeline manager for RAG
128
+ self.logger.info("🔧 PIPELINE: Initializing PipelineManager...")
129
+ self.pipeline_manager = PipelineManager(self.config)
130
+
131
+ # Ensure vectorstore is connected
132
+ self.logger.info("🔗 VECTORSTORE: Connecting to Qdrant...")
133
+ try:
134
+ vectorstore = self.pipeline_manager.vectorstore_manager.connect_to_existing()
135
+ self.logger.info("✅ VECTORSTORE: Connected successfully")
136
+ except Exception as e:
137
+ self.logger.error(f"❌ VECTORSTORE: Connection failed: {e}")
138
+
139
+ # Fix LLM client to use the same provider as chatbot
140
+ self.logger.info("🔧 LLM: Fixing PipelineManager LLM client...")
141
+ self.pipeline_manager.llm_client = self.llm_adapter
142
+ self.logger.info("✅ LLM: PipelineManager now uses same LLM as chatbot")
143
+
144
+ self.logger.info("✅ PIPELINE: PipelineManager initialized")
145
+
146
+ # Available metadata for filtering
147
+ self.available_metadata = {
148
+ 'sources': [
149
+ 'KCCA', 'MAAIF', 'MWTS', 'Gulu DLG', 'Kalangala DLG', 'Namutumba DLG',
150
+ 'Lwengo DLG', 'Kiboga DLG', 'Annual Consolidated OAG', 'Consolidated',
151
+ 'Hospital', 'Local Government', 'Ministry, Department and Agency',
152
+ 'Project', 'Thematic', 'Value for Money'
153
+ ],
154
+ 'years': ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025'],
155
+ 'districts': [
156
+ 'Gulu', 'Kalangala', 'Kampala', 'Namutumba', 'Lwengo', 'Kiboga',
157
+ 'Fort Portal', 'Arua', 'Kasese', 'Kabale', 'Masindi', 'Mbale', 'Jinja', 'Masaka', 'Mbarara',
158
+ 'KCCA'
159
+ ]
160
+ }
161
+
162
+ # Try to load district whitelist from filter_options.json
163
+ try:
164
+ fo = Path("filter_options.json")
165
+ if fo.exists():
166
+ with open(fo) as f:
167
+ data = json.load(f)
168
+ if isinstance(data, dict) and data.get("districts"):
169
+ self.district_whitelist = [d.strip() for d in data["districts"] if d]
170
+ else:
171
+ self.district_whitelist = self.available_metadata['districts']
172
+ else:
173
+ self.district_whitelist = self.available_metadata['districts']
174
+ except Exception:
175
+ self.district_whitelist = self.available_metadata['districts']
176
+
177
+ # Enrich whitelist from add_district_metadata.py if available
178
+ try:
179
+ from add_district_metadata import DistrictMetadataProcessor
180
+ proc = DistrictMetadataProcessor()
181
+ names = set()
182
+ for key, mapping in proc.district_mappings.items():
183
+ if getattr(mapping, 'is_district', True):
184
+ names.add(mapping.name)
185
+ if names:
186
+ # Merge while preserving order: existing first, then new ones not present
187
+ merged = list(self.district_whitelist)
188
+ for n in sorted(names):
189
+ if n not in merged:
190
+ merged.append(n)
191
+ self.district_whitelist = merged
192
+ self.logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
193
+ except Exception as e:
194
+ self.logger.info(f"ℹ️ Could not enrich districts from add_district_metadata: {e}")
195
+
196
+ # Get dynamic year list from filter_options.json
197
+ try:
198
+ fo = Path("filter_options.json")
199
+ if fo.exists():
200
+ with open(fo) as f:
201
+ data = json.load(f)
202
+ if isinstance(data, dict) and data.get("years"):
203
+ self.year_whitelist = [str(y).strip() for y in data["years"] if y]
204
+ else:
205
+ self.year_whitelist = self.available_metadata['years']
206
+ else:
207
+ self.year_whitelist = self.available_metadata['years']
208
+ except Exception:
209
+ self.year_whitelist = self.available_metadata['years']
210
+
211
+ # Calculate current year dynamically
212
+ from datetime import datetime
213
+ self.current_year = str(datetime.now().year)
214
+ self.previous_year = str(datetime.now().year - 1)
215
+
216
+ # Data context for system prompt
217
+ self.data_context = self._load_data_context()
218
+
219
+ # Build the LangGraph
220
+ self.graph = self._build_graph()
221
+
222
+ # Conversation logging
223
+ self.conversations_dir = Path("conversations")
224
+ self.conversations_dir.mkdir(exist_ok=True)
225
+
226
+ def _load_data_context(self) -> str:
227
+ """Load and analyze data context for system prompt"""
228
+ try:
229
+ # Try to load from generated context file
230
+ context_file = Path("data_context.md")
231
+ if context_file.exists():
232
+ with open(context_file) as f:
233
+ return f.read()
234
+
235
+ # Fallback to basic analysis
236
+ reports_dir = Path("reports")
237
+ testset_dir = Path("outputs/datasets/testset")
238
+
239
+ context_parts = []
240
+
241
+ # Report analysis
242
+ if reports_dir.exists():
243
+ report_folders = [d for d in reports_dir.iterdir() if d.is_dir()]
244
+ context_parts.append(f"📊 Available Reports: {len(report_folders)} audit report folders")
245
+
246
+ # Get year range
247
+ years = []
248
+ for folder in report_folders:
249
+ if "2018" in folder.name:
250
+ years.append("2018")
251
+ elif "2019" in folder.name:
252
+ years.append("2019")
253
+ elif "2020" in folder.name:
254
+ years.append("2020")
255
+ elif "2021" in folder.name:
256
+ years.append("2021")
257
+ elif "2022" in folder.name:
258
+ years.append("2022")
259
+ elif "2023" in folder.name:
260
+ years.append("2023")
261
+
262
+ if years:
263
+ context_parts.append(f"📅 Years covered: {', '.join(sorted(set(years)))}")
264
+
265
+ # Test dataset analysis
266
+ if testset_dir.exists():
267
+ test_files = list(testset_dir.glob("*.json"))
268
+ context_parts.append(f"🧪 Test dataset: {len(test_files)} files with sample questions")
269
+
270
+ return "\n".join(context_parts) if context_parts else "📊 Audit report database with comprehensive coverage"
271
+
272
+ except Exception as e:
273
+ self.logger.warning(f"⚠️ Could not load data context: {e}")
274
+ return "📊 Comprehensive audit report database"
275
+
276
+ def _build_graph(self) -> StateGraph:
277
+ """Build the LangGraph for intelligent conversation flow"""
278
+
279
+ # Define the graph
280
+ workflow = StateGraph(ConversationState)
281
+
282
+ # Add nodes
283
+ workflow.add_node("analyze_query", self._analyze_query)
284
+ workflow.add_node("decide_action", self._decide_action)
285
+ workflow.add_node("perform_rag", self._perform_rag)
286
+ workflow.add_node("ask_follow_up", self._ask_follow_up)
287
+ workflow.add_node("generate_response", self._generate_response)
288
+
289
+ # Add edges
290
+ workflow.add_edge("analyze_query", "decide_action")
291
+
292
+ # Conditional edges from decide_action
293
+ workflow.add_conditional_edges(
294
+ "decide_action",
295
+ self._should_perform_rag,
296
+ {
297
+ "rag": "perform_rag",
298
+ "follow_up": "ask_follow_up"
299
+ }
300
+ )
301
+
302
+ # From perform_rag, go to generate_response
303
+ workflow.add_edge("perform_rag", "generate_response")
304
+
305
+ # From ask_follow_up, end
306
+ workflow.add_edge("ask_follow_up", END)
307
+
308
+ # From generate_response, end
309
+ workflow.add_edge("generate_response", END)
310
+
311
+ # Set entry point
312
+ workflow.set_entry_point("analyze_query")
313
+
314
+ return workflow.compile()
315
+
316
+ def _extract_districts_list(self, text: str) -> List[str]:
317
+ """Extract one or more districts from free text using whitelist matching.
318
+ - Case-insensitive substring match for each known district name
319
+ - Handles multi-district inputs like "Lwengo Kiboga District & Namutumba"
320
+ """
321
+ if not text:
322
+ return []
323
+ q = text.lower()
324
+ found: List[str] = []
325
+ for name in self.district_whitelist:
326
+ n = name.lower()
327
+ if n in q:
328
+ # Map Kampala -> KCCA canonical
329
+ canonical = 'KCCA' if name.lower() == 'kampala' else name
330
+ if canonical not in found:
331
+ found.append(canonical)
332
+ return found
333
+
334
+ def _extract_years_list(self, text: str) -> List[str]:
335
+ """Extract year list from text, supporting forms like '2022 / 23', '2022-2023', '2022–23'."""
336
+ if not text:
337
+ return []
338
+ years: List[str] = []
339
+ q = text
340
+ # Full 4-digit years
341
+ for y in re.findall(r"\b(20\d{2})\b", q):
342
+ if y not in years:
343
+ years.append(y)
344
+ # Shorthand like 2022/23 or 2022-23
345
+ for m in re.finditer(r"\b(20\d{2})\s*[\-/–]\s*(\d{2})\b", q):
346
+ y1 = m.group(1)
347
+ y2_short = int(m.group(2))
348
+ y2 = f"20{y2_short:02d}"
349
+ for y in [y1, y2]:
350
+ if y not in years:
351
+ years.append(y)
352
+ return years
353
+
354
+ def _analyze_query(self, state: ConversationState) -> ConversationState:
355
+ """Analyze the user query with conversation context"""
356
+
357
+ query = state["current_query"]
358
+ conversation_context = state.get("conversation_context", {})
359
+
360
+ self.logger.info(f"🧠 QUERY ANALYSIS: Starting analysis for: '{query[:50]}...'")
361
+
362
+ # Build conversation context for analysis
363
+ context_info = ""
364
+ if conversation_context:
365
+ context_info = f"\n\nConversation context:\n"
366
+ for key, value in conversation_context.items():
367
+ if value:
368
+ context_info += f"- {key}: {value}\n"
369
+
370
+ # Also include recent conversation messages for better context
371
+ recent_messages = state.get("messages", [])
372
+ if recent_messages and len(recent_messages) > 1:
373
+ context_info += f"\nRecent conversation:\n"
374
+ # Get last 3 messages for context
375
+ for msg in recent_messages[-3:]:
376
+ if hasattr(msg, 'content'):
377
+ role = "User" if isinstance(msg, HumanMessage) else "Assistant"
378
+ context_info += f"- {role}: {msg.content[:100]}...\n"
379
+
380
+ # Create analysis prompt with data context
381
+ analysis_prompt = ChatPromptTemplate.from_messages([
382
+ SystemMessage(content=f"""You are an expert at analyzing audit report queries. Your job is to extract specific information and determine if a query can be answered directly.
383
+
384
+ {self.data_context}
385
+
386
+ DISTRICT RECOGNITION RULES:
387
+ - Kampala = KCCA (Kampala Capital City Authority)
388
+ - Available districts: {', '.join(self.district_whitelist[:15])}... (and {len(self.district_whitelist)-15} more)
389
+ - DLG = District Local Government
390
+ - Uganda has {len(self.district_whitelist)} districts - recognize common ones
391
+
392
+ SOURCE RECOGNITION RULES:
393
+ - KCCA = Kampala Capital City Authority
394
+ - MAAIF = Ministry of Agriculture, Animal Industry and Fisheries
395
+ - MWTS = Ministry of Works and Transport
396
+ - OAG = Office of the Auditor General
397
+ - Consolidated = Annual Consolidated reports
398
+
399
+ YEAR RECOGNITION RULES:
400
+ - Available years: {', '.join(self.year_whitelist)}
401
+ - Current year is {self.current_year} - use this to reason about relative years
402
+ - If user mentions "last year", "previous year" - infer {self.previous_year}
403
+ - If user mentions "this year", "current year" - infer {self.current_year}
404
+
405
+ Analysis rules:
406
+ 1. Be SMART - if you have enough context to search, do it
407
+ 2. Use conversation context to fill in missing information
408
+ 3. For budget/expenditure queries, try to infer missing details from context
409
+ 4. Current year is {self.current_year} - use this to reason about relative years
410
+ 5. If user mentions "last year", "previous year" - infer {self.previous_year}
411
+ 6. If user mentions "this year", "current year" - infer {self.current_year}
412
+ 7. If user mentions a department/ministry, infer the source
413
+ 8. If user is getting frustrated or asking for results, proceed with RAG even if not perfect
414
+ 9. Recognize Kampala as a district (KCCA)
415
+
416
+ IMPORTANT: You must respond with ONLY valid JSON. No additional text.
417
+
418
+ Return your analysis as JSON with these exact fields:
419
+ {{
420
+ "has_district": boolean,
421
+ "has_source": boolean,
422
+ "has_year": boolean,
423
+ "extracted_district": "string or null",
424
+ "extracted_source": "string or null",
425
+ "extracted_year": "string or null",
426
+ "confidence_score": 0.0-1.0,
427
+ "can_answer_directly": boolean,
428
+ "missing_filters": ["list", "of", "missing", "filters"],
429
+ "suggested_follow_up": "string or null",
430
+ "expanded_query": "string or null"
431
+ }}
432
+
433
+ The expanded_query should be a natural language query that combines the original question with any inferred context for better RAG retrieval."""),
434
+ HumanMessage(content=f"Analyze this query: '{query}'{context_info}")
435
+ ])
436
+
437
+ # Get analysis from LLM
438
+ response = self.llm.invoke(analysis_prompt.format_messages())
439
+
440
+ try:
441
+ # Clean the response content to extract JSON
442
+ content = response.content.strip()
443
+
444
+ # Try to find JSON in the response
445
+ if content.startswith('{') and content.endswith('}'):
446
+ json_content = content
447
+ else:
448
+ # Try to extract JSON from the response
449
+ import re
450
+ json_match = re.search(r'\{.*\}', content, re.DOTALL)
451
+ if json_match:
452
+ json_content = json_match.group()
453
+ else:
454
+ raise json.JSONDecodeError("No JSON found in response", content, 0)
455
+
456
+ # Parse JSON response
457
+ analysis_data = json.loads(json_content)
458
+
459
+ query_analysis = QueryAnalysis(
460
+ has_district=analysis_data.get("has_district", False),
461
+ has_source=analysis_data.get("has_source", False),
462
+ has_year=analysis_data.get("has_year", False),
463
+ extracted_district=analysis_data.get("extracted_district"),
464
+ extracted_source=analysis_data.get("extracted_source"),
465
+ extracted_year=analysis_data.get("extracted_year"),
466
+ confidence_score=analysis_data.get("confidence_score", 0.0),
467
+ can_answer_directly=analysis_data.get("can_answer_directly", False),
468
+ missing_filters=analysis_data.get("missing_filters", []),
469
+ suggested_follow_up=analysis_data.get("suggested_follow_up"),
470
+ expanded_query=analysis_data.get("expanded_query")
471
+ )
472
+
473
+ except (json.JSONDecodeError, KeyError, AttributeError) as e:
474
+ self.logger.info(f"⚠️ JSON parsing failed: {e}")
475
+ # Fallback analysis - be more permissive
476
+ query_lower = query.lower()
477
+
478
+ # Simple keyword matching - improved district recognition
479
+ has_district = any(district.lower() in query_lower for district in [
480
+ 'gulu', 'kalangala', 'kampala', 'namutumba', 'lwengo', 'kiboga', 'kcca', 'maaif', 'mwts'
481
+ ])
482
+
483
+ # Special case: Kampala = KCCA
484
+ if 'kampala' in query_lower and not has_district:
485
+ has_district = True
486
+
487
+ has_source = any(source.lower() in query_lower for source in [
488
+ 'kcca', 'maaif', 'mwts', 'gulu', 'kalangala', 'consolidated', 'oag', 'government'
489
+ ])
490
+
491
+ # Check for year mentions using dynamic year list
492
+ has_year = any(year in query_lower for year in self.year_whitelist)
493
+
494
+ # Also check for explicit relative year terms
495
+ has_year = has_year or any(term in query_lower for term in [
496
+ 'this year', 'last year', 'previous year', 'current year'
497
+ ])
498
+
499
+ # Extract specific values
500
+ extracted_district = None
501
+ extracted_source = None
502
+ extracted_year = None
503
+
504
+ # Extract districts using comprehensive whitelist
505
+ for district_name in self.district_whitelist:
506
+ if district_name.lower() in query_lower:
507
+ extracted_district = district_name
508
+ break
509
+
510
+ # Also check common aliases
511
+ district_aliases = {
512
+ 'kampala': 'Kampala',
513
+ 'kcca': 'Kampala',
514
+ 'gulu': 'Gulu',
515
+ 'kalangala': 'Kalangala'
516
+ }
517
+ for alias, full_name in district_aliases.items():
518
+ if alias in query_lower and not extracted_district:
519
+ extracted_district = full_name
520
+ break
521
+
522
+ for source in ['kcca', 'maaif', 'mwts', 'consolidated', 'oag']:
523
+ if source in query_lower:
524
+ extracted_source = source.upper()
525
+ break
526
+
527
+ # Extract year using dynamic year list
528
+ for year in self.year_whitelist:
529
+ if year in query_lower:
530
+ extracted_year = year
531
+ has_year = True
532
+ break
533
+
534
+ # Only handle relative year terms if explicitly mentioned
535
+ if not extracted_year:
536
+ if 'last year' in query_lower or 'previous year' in query_lower:
537
+ extracted_year = self.previous_year
538
+ has_year = True
539
+ elif 'this year' in query_lower or 'current year' in query_lower:
540
+ extracted_year = self.current_year
541
+ has_year = True
542
+ elif 'recent' in query_lower and 'year' in query_lower:
543
+ # Use the most recent year from available data
544
+ extracted_year = max(self.year_whitelist) if self.year_whitelist else self.previous_year
545
+ has_year = True
546
+
547
+ # Be more permissive - if we have some context, try to answer
548
+ missing_filters = []
549
+ if not has_district:
550
+ missing_filters.append("district")
551
+ if not has_source:
552
+ missing_filters.append("source")
553
+ if not has_year:
554
+ missing_filters.append("year")
555
+
556
+ # If user seems frustrated or asking for results, be more permissive
557
+ frustration_indicators = ['already', 'just said', 'specified', 'provided', 'crazy', 'answer']
558
+ is_frustrated = any(indicator in query_lower for indicator in frustration_indicators)
559
+
560
+ can_answer_directly = len(missing_filters) <= 1 or is_frustrated # More permissive
561
+ confidence_score = 0.8 if can_answer_directly else 0.3
562
+
563
+ # Generate follow-up suggestion
564
+ if missing_filters and not is_frustrated:
565
+ if "district" in missing_filters and "source" in missing_filters:
566
+ suggested_follow_up = "I'd be happy to help you with that information! Could you please specify which district and department/ministry you're asking about?"
567
+ elif "district" in missing_filters:
568
+ suggested_follow_up = "Thanks for your question! Could you please specify which district you're asking about?"
569
+ elif "source" in missing_filters:
570
+ suggested_follow_up = "I can help you with that! Could you please specify which department or ministry you're asking about?"
571
+ elif "year" in missing_filters:
572
+ suggested_follow_up = "Great question! Could you please specify which year you're interested in?"
573
+ else:
574
+ suggested_follow_up = "Could you please provide more specific details to help me give you a precise answer?"
575
+ else:
576
+ suggested_follow_up = None
577
+
578
+ # Create expanded query
579
+ expanded_query = query
580
+ if extracted_district:
581
+ expanded_query += f" for {extracted_district} district"
582
+ if extracted_source:
583
+ expanded_query += f" from {extracted_source}"
584
+ if extracted_year:
585
+ expanded_query += f" in {extracted_year}"
586
+
587
+ query_analysis = QueryAnalysis(
588
+ has_district=has_district,
589
+ has_source=has_source,
590
+ has_year=has_year,
591
+ extracted_district=extracted_district,
592
+ extracted_source=extracted_source,
593
+ extracted_year=extracted_year,
594
+ confidence_score=confidence_score,
595
+ can_answer_directly=can_answer_directly,
596
+ missing_filters=missing_filters,
597
+ suggested_follow_up=suggested_follow_up,
598
+ expanded_query=expanded_query
599
+ )
600
+
601
+ # Update conversation context
602
+ if query_analysis.extracted_district:
603
+ conversation_context["district"] = query_analysis.extracted_district
604
+ if query_analysis.extracted_source:
605
+ conversation_context["source"] = query_analysis.extracted_source
606
+ if query_analysis.extracted_year:
607
+ conversation_context["year"] = query_analysis.extracted_year
608
+
609
+ state["query_analysis"] = query_analysis
610
+ state["conversation_context"] = conversation_context
611
+
612
+ self.logger.info(f"✅ ANALYSIS COMPLETE: district={query_analysis.has_district}, source={query_analysis.has_source}, year={query_analysis.has_year}")
613
+ self.logger.info(f"📈 Confidence: {query_analysis.confidence_score:.2f}, Can answer directly: {query_analysis.can_answer_directly}")
614
+ if query_analysis.expanded_query:
615
+ self.logger.info(f"🔄 Expanded query: {query_analysis.expanded_query}")
616
+
617
+ return state
618
+
619
+ def _decide_action(self, state: ConversationState) -> ConversationState:
620
+ """Decide what action to take based on query analysis"""
621
+
622
+ analysis = state["query_analysis"]
623
+
624
+ # Add decision reasoning
625
+ if analysis.can_answer_directly and analysis.confidence_score > 0.7:
626
+ self.logger.info(f"🚀 DECISION: Query is complete, proceeding with RAG")
627
+ self.logger.info(f"📊 REASONING: Confidence={analysis.confidence_score:.2f}, Missing filters={len(analysis.missing_filters or [])}")
628
+ if analysis.missing_filters:
629
+ self.logger.info(f"📋 Missing: {', '.join(analysis.missing_filters)}")
630
+ else:
631
+ self.logger.info(f"✅ All required information available")
632
+ else:
633
+ self.logger.info(f"❓ DECISION: Query incomplete, asking follow-up")
634
+ self.logger.info(f"📊 REASONING: Confidence={analysis.confidence_score:.2f}, Missing filters={len(analysis.missing_filters or [])}")
635
+ if analysis.missing_filters:
636
+ self.logger.info(f"📋 Missing: {', '.join(analysis.missing_filters)}")
637
+ self.logger.info(f"💡 Follow-up needed: {analysis.suggested_follow_up}")
638
+
639
+ return state
640
+
641
+ def _should_perform_rag(self, state: ConversationState) -> str:
642
+ """Determine whether to perform RAG or ask follow-up"""
643
+
644
+ analysis = state["query_analysis"]
645
+ conversation_context = state.get("conversation_context", {})
646
+ recent_messages = state.get("messages", [])
647
+
648
+ # Check if we have enough context from conversation history
649
+ has_district_context = analysis.has_district or conversation_context.get("district")
650
+ has_source_context = analysis.has_source or conversation_context.get("source")
651
+ has_year_context = analysis.has_year or conversation_context.get("year")
652
+
653
+ # Count how many context pieces we have
654
+ context_count = sum([bool(has_district_context), bool(has_source_context), bool(has_year_context)])
655
+
656
+ # For PDM queries, we need more specific information
657
+ current_query = state["current_query"].lower()
658
+ recent_messages = state.get("messages", [])
659
+
660
+ # Check if this is a PDM query by looking at current query OR recent conversation
661
+ is_pdm_query = "pdm" in current_query or "parish development" in current_query
662
+
663
+ # Also check recent messages for PDM context
664
+ if not is_pdm_query and recent_messages:
665
+ for msg in recent_messages[-3:]: # Check last 3 messages
666
+ if isinstance(msg, HumanMessage) and ("pdm" in msg.content.lower() or "parish development" in msg.content.lower()):
667
+ is_pdm_query = True
668
+ break
669
+
670
+ if is_pdm_query:
671
+ # For PDM queries, we need district AND year to be specific enough
672
+ # But we need them to be explicitly provided in the current conversation, not just inferred
673
+ if has_district_context and has_year_context:
674
+ # Check if both district and year are explicitly mentioned in recent messages
675
+ explicit_district = False
676
+ explicit_year = False
677
+
678
+ for msg in recent_messages[-3:]: # Check last 3 messages
679
+ if isinstance(msg, HumanMessage):
680
+ content = msg.content.lower()
681
+ if any(district in content for district in ["gulu", "kalangala", "kampala", "namutumba"]):
682
+ explicit_district = True
683
+ if any(year in content for year in ["2022", "2023", "2022/23", "2023/24"]):
684
+ explicit_year = True
685
+
686
+ if explicit_district and explicit_year:
687
+ self.logger.info(f"🚀 DECISION: PDM query with explicit district and year, proceeding with RAG")
688
+ self.logger.info(f"📊 REASONING: PDM query - explicit_district={explicit_district}, explicit_year={explicit_year}")
689
+ return "rag"
690
+ else:
691
+ self.logger.info(f"❓ DECISION: PDM query needs explicit district and year, asking follow-up")
692
+ self.logger.info(f"📊 REASONING: PDM query - explicit_district={explicit_district}, explicit_year={explicit_year}")
693
+ return "follow_up"
694
+ else:
695
+ self.logger.info(f"❓ DECISION: PDM query needs more specific info, asking follow-up")
696
+ self.logger.info(f"📊 REASONING: PDM query - district={has_district_context}, year={has_year_context}")
697
+ return "follow_up"
698
+
699
+ # For general queries, be more conservative - need at least 2 pieces AND high confidence
700
+ if context_count >= 2 and analysis.confidence_score > 0.8:
701
+ self.logger.info(f"🚀 DECISION: Sufficient context with high confidence, proceeding with RAG")
702
+ self.logger.info(f"📊 REASONING: Context pieces: district={has_district_context}, source={has_source_context}, year={has_year_context}, confidence={analysis.confidence_score}")
703
+ return "rag"
704
+
705
+ # If user seems frustrated (short responses like "no"), proceed with RAG
706
+ if recent_messages and len(recent_messages) >= 3: # Need more messages to detect frustration
707
+ last_user_message = None
708
+ for msg in reversed(recent_messages):
709
+ if isinstance(msg, HumanMessage):
710
+ last_user_message = msg.content.lower().strip()
711
+ break
712
+
713
+ if last_user_message and len(last_user_message) < 10 and any(word in last_user_message for word in ["no", "yes", "ok", "sure"]):
714
+ self.logger.info(f"🚀 DECISION: User seems frustrated with short response, proceeding with RAG")
715
+ return "rag"
716
+
717
+ # Original logic for direct answers
718
+ if analysis.can_answer_directly and analysis.confidence_score > 0.7:
719
+ return "rag"
720
+ else:
721
+ return "follow_up"
722
+
723
+ def _ask_follow_up(self, state: ConversationState) -> ConversationState:
724
+ """Generate a follow-up question to clarify missing information"""
725
+
726
+ analysis = state["query_analysis"]
727
+ current_query = state["current_query"].lower()
728
+ conversation_context = state.get("conversation_context", {})
729
+
730
+ # Check if this is a PDM query
731
+ is_pdm_query = "pdm" in current_query or "parish development" in current_query
732
+
733
+ if is_pdm_query:
734
+ # Generate PDM-specific follow-up questions
735
+ missing_info = []
736
+
737
+ if not analysis.has_district and not conversation_context.get("district"):
738
+ missing_info.append("district (e.g., Gulu, Kalangala)")
739
+
740
+ if not analysis.has_year and not conversation_context.get("year"):
741
+ missing_info.append("year (e.g., 2022, 2023)")
742
+
743
+ if missing_info:
744
+ follow_up_message = f"For PDM administrative costs information, I need to know the {', '.join(missing_info)}. Could you please specify these details?"
745
+ else:
746
+ follow_up_message = "Could you please provide more specific details about the PDM administrative costs you're looking for?"
747
+ else:
748
+ # Use the original follow-up logic
749
+ if analysis.suggested_follow_up:
750
+ follow_up_message = analysis.suggested_follow_up
751
+ else:
752
+ follow_up_message = "Could you please provide more specific details to help me give you a precise answer?"
753
+
754
+ state["final_response"] = follow_up_message
755
+ state["last_ai_message_time"] = time.time()
756
+
757
+ return state
758
+
759
+ def _build_comprehensive_query(self, current_query: str, analysis, conversation_context: dict, recent_messages: list) -> str:
760
+ """Build a better RAG query from conversation.
761
+ - If latest message is a short modifier (e.g., "financial"), merge it into the last substantive question.
762
+ - If latest message looks like filters (district/year), keep the last question unchanged.
763
+ - Otherwise, use the current message.
764
+ """
765
+
766
+ def is_interrogative(text: str) -> bool:
767
+ t = text.lower().strip()
768
+ return any(t.startswith(w) for w in ["what", "how", "why", "when", "where", "which", "who"]) or t.endswith("?")
769
+
770
+ def is_filter_like(text: str) -> bool:
771
+ t = text.lower()
772
+ if "district" in t:
773
+ return True
774
+ if re.search(r"\b20\d{2}\b", t) or re.search(r"20\d{2}\s*[\-/–]\s*\d{2}\b", t):
775
+ return True
776
+ if self._extract_districts_list(text):
777
+ return True
778
+ return False
779
+
780
+ # Find last substantive user question
781
+ last_question = None
782
+ for msg in reversed(recent_messages[:-1] if recent_messages else []):
783
+ if isinstance(msg, HumanMessage):
784
+ if is_interrogative(msg.content) and len(msg.content.strip()) > 15:
785
+ last_question = msg.content.strip()
786
+ break
787
+
788
+ cq = current_query.strip()
789
+ words = cq.split()
790
+ is_short_modifier = (not is_interrogative(cq)) and (len(words) <= 3)
791
+
792
+ if is_filter_like(cq) and last_question:
793
+ comprehensive_query = last_question
794
+ elif is_short_modifier and last_question:
795
+ modifier = cq
796
+ if modifier.lower() in last_question.lower():
797
+ comprehensive_query = last_question
798
+ else:
799
+ if last_question.endswith('?'):
800
+ comprehensive_query = last_question[:-1] + f" for {modifier}?"
801
+ else:
802
+ comprehensive_query = last_question + f" for {modifier}"
803
+ else:
804
+ comprehensive_query = current_query
805
+
806
+ self.logger.info(f"🔄 COMPREHENSIVE QUERY: '{comprehensive_query}'")
807
+ return comprehensive_query
808
+
809
+ def _rewrite_query_with_llm(self, recent_messages: list, draft_query: str) -> str:
810
+ """Use the LLM to rewrite a clean, focused RAG query from the conversation.
811
+ Rules enforced in prompt:
812
+ - Keep the user's main information need from the last substantive question
813
+ - Integrate short modifiers (e.g., "financial") into that question when appropriate
814
+ - Do NOT include filter text (years/districts/sources) in the query; those are handled separately
815
+ - Return a single plain sentence only (no quotes, no markdown)
816
+ """
817
+ try:
818
+ # Build a compact conversation transcript (last 6 messages max)
819
+ convo_lines = []
820
+ for msg in recent_messages[-6:]:
821
+ if isinstance(msg, HumanMessage):
822
+ convo_lines.append(f"User: {msg.content}")
823
+ elif isinstance(msg, AIMessage):
824
+ convo_lines.append(f"Assistant: {msg.content}")
825
+
826
+ convo_text = "\n".join(convo_lines)
827
+
828
+ """
829
+ "DECISION GUIDANCE:\n"
830
+ "- If the latest user message looks like a modifier (e.g., 'financial'), merge it into the best prior question.\n"
831
+ "- If the latest message provides filters (e.g., districts, years), DO NOT embed them; keep the base question.\n"
832
+ "- If the latest message itself is a full, clear question, use it.\n"
833
+ "- If the draft query is already good, you may refine its clarity but keep the same intent.\n\n"
834
+ """
835
+
836
+
837
+ prompt = ChatPromptTemplate.from_messages([
838
+ SystemMessage(content=(
839
+ "ROLE: Query Rewriter for a RAG system.\n\n"
840
+ "PRIMARY OBJECTIVE:\n- Produce ONE retrieval-focused sentence that best represents the user's information need.\n"
841
+ "- Maximize recall of relevant evidence; be specific but not overconstrained.\n\n"
842
+ "INPUTS:\n- Conversation with User and Assistant turns (latest last).\n- A draft query (heuristic).\n\n"
843
+ "OPERATING PRINCIPLES:\n"
844
+ "1) Use the last substantive USER question as the backbone of intent.\n"
845
+ "2) Merge helpful domain modifiers from any USER turns (financial, procurement, risk) when they sharpen focus; ignore if not helpful.\n"
846
+ "3) Treat Assistant messages as guidance only; if the user later provided filters (years, districts, sources), DO NOT embed them in the query (filters are applied separately).\n"
847
+ "4) Remove meta-verbs like 'summarize', 'list', 'explain', 'compare' from the query.\n"
848
+ "5) Prefer content-bearing terms (topics, programs, outcomes) over task phrasing.\n"
849
+ "6) If the latest user message is filters-only, keep the prior substantive question unchanged.\n"
850
+ "7) If the draft query is already strong, refine wording for clarity but keep the same intent.\n\n"
851
+ "EXAMPLES (multi-turn):\n"
852
+ "A)\nUser: What are the top 5 priorities for improving audit procedures?\nAssistant: Could you specify the scope (e.g., financial, procurement)?\nUser: Financial\n→ Output: Top priorities for improving financial audit procedures.\n\n"
853
+ "B)\nUser: How were PDM administrative costs utilized and what was the impact of shortfalls?\nAssistant: Please specify district/year for precision.\nUser: Namutumba and Lwengo Districts (2022/23)\n→ Output: How were PDM administrative costs utilized and what was the impact of shortfalls.\n(Exclude districts/years; they are filters.)\n\n"
854
+ "C)\nUser: Summarize risk management issues in audit reports.\n→ Output: Key risk management issues in audit reports.\n\n"
855
+ "CONSTRAINTS:\n- Do NOT include filters (years, districts, sources, filenames).\n- Do NOT include quotes/markdown/bullets or multiple sentences.\n- Return exactly one plain sentence."
856
+ )),
857
+ HumanMessage(content=(
858
+ f"Conversation (most recent last):\n{convo_text}\n\n"
859
+ f"Draft query: {draft_query}\n\n"
860
+ "Rewrite the single best retrieval query sentence now:"
861
+ )),
862
+ ])
863
+
864
+ # Add timeout for LLM call
865
+ import signal
866
+
867
+ def timeout_handler(signum, frame):
868
+ raise TimeoutError("LLM rewrite timeout")
869
+
870
+ # Set 10 second timeout
871
+ signal.signal(signal.SIGALRM, timeout_handler)
872
+ signal.alarm(10)
873
+
874
+ try:
875
+ resp = self.llm.invoke(prompt.format_messages())
876
+ signal.alarm(0) # Cancel timeout
877
+
878
+ rewritten = getattr(resp, 'content', '').strip()
879
+ # Basic sanitization: keep it one line
880
+ rewritten = rewritten.replace('\n', ' ').strip()
881
+ if rewritten and len(rewritten) > 5: # Basic quality check
882
+ self.logger.info(f"🛠️ LLM REWRITER: '{rewritten}'")
883
+ return rewritten
884
+ else:
885
+ self.logger.info(f"⚠️ LLM rewrite too short/empty, using draft query")
886
+ return draft_query
887
+ except TimeoutError:
888
+ signal.alarm(0)
889
+ self.logger.info(f"⚠️ LLM rewrite timeout after 10s, using draft query")
890
+ return draft_query
891
+ except Exception as e:
892
+ signal.alarm(0)
893
+ self.logger.info(f"⚠️ LLM rewrite failed, using draft query. Error: {e}")
894
+ return draft_query
895
+ except Exception as e:
896
+ self.logger.info(f"⚠️ LLM rewrite setup failed, using draft query. Error: {e}")
897
+ return draft_query
898
+
899
+ def _perform_rag(self, state: ConversationState) -> ConversationState:
900
+ """Perform RAG retrieval with smart query expansion"""
901
+
902
+ query = state["current_query"]
903
+ analysis = state["query_analysis"]
904
+ conversation_context = state.get("conversation_context", {})
905
+ recent_messages = state.get("messages", [])
906
+
907
+ # Build comprehensive query from conversation history
908
+ draft_query = self._build_comprehensive_query(query, analysis, conversation_context, recent_messages)
909
+ # Let LLM rewrite a clean, focused search query
910
+ search_query = self._rewrite_query_with_llm(recent_messages, draft_query)
911
+
912
+ self.logger.info(f"🔍 RAG RETRIEVAL: Starting for query: '{search_query[:50]}...'")
913
+ self.logger.info(f"📊 Analysis: district={analysis.has_district}, source={analysis.has_source}, year={analysis.has_year}")
914
+
915
+ try:
916
+ # Build filters from analysis and conversation context
917
+ filters = {}
918
+
919
+ # Use conversation context to fill in missing filters
920
+ source = analysis.extracted_source or conversation_context.get("source")
921
+ district = analysis.extracted_district or conversation_context.get("district")
922
+ year = analysis.extracted_year or conversation_context.get("year")
923
+
924
+ if source:
925
+ filters["source"] = [source] # Qdrant expects lists
926
+ self.logger.info(f"🎯 Filter: source={source}")
927
+
928
+ if year:
929
+ filters["year"] = [year]
930
+ self.logger.info(f"🎯 Filter: year={year}")
931
+
932
+ if district:
933
+ # Map district to source if needed
934
+ if district.upper() == "KAMPALA":
935
+ filters["source"] = ["KCCA"]
936
+ self.logger.info(f"🎯 Filter: district={district} -> source=KCCA")
937
+ elif district.upper() in ["GULU", "KALANGALA"]:
938
+ filters["source"] = [f"{district.upper()} DLG"]
939
+ self.logger.info(f"🎯 Filter: district={district} -> source={district.upper()} DLG")
940
+
941
+ # Run RAG pipeline with correct parameters
942
+ result = self.pipeline_manager.run(
943
+ query=search_query, # Use expanded query
944
+ sources=filters.get("source") if filters.get("source") else None,
945
+ auto_infer_filters=False, # Our agent already handled filter inference
946
+ filters=filters if filters else None
947
+ )
948
+
949
+ self.logger.info(f"✅ RAG completed: Found {len(result.sources)} sources")
950
+ self.logger.info(f"⏱️ Execution time: {result.execution_time:.2f}s")
951
+
952
+ # Store RAG result in state
953
+ state["rag_result"] = result
954
+ state["rag_query"] = search_query
955
+
956
+ except Exception as e:
957
+ self.logger.info(f"❌ RAG retrieval failed: {e}")
958
+ state["rag_result"] = None
959
+
960
+ return state
961
+
962
+ def _generate_response(self, state: ConversationState) -> ConversationState:
963
+ """Generate final response using RAG results"""
964
+
965
+ rag_result = state["rag_result"]
966
+
967
+ self.logger.info(f"📝 RESPONSE: Using RAG result ({len(rag_result.answer)} chars)")
968
+
969
+ # Store the final response directly from RAG
970
+ state["final_response"] = rag_result.answer
971
+ state["last_ai_message_time"] = time.time()
972
+
973
+ return state
974
+
975
+ def chat(self, user_input: str, conversation_id: str = "default") -> str:
976
+ """Main chat interface with conversation management"""
977
+
978
+ self.logger.info(f"💬 CHAT: Processing user input: '{user_input[:50]}...'")
979
+ self.logger.info(f"📊 Session: {conversation_id}")
980
+
981
+ # Load conversation history
982
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
983
+ conversation = self._load_conversation(conversation_file)
984
+
985
+ # Add user message to conversation
986
+ conversation["messages"].append(HumanMessage(content=user_input))
987
+
988
+ self.logger.info(f"🔄 LANGGRAPH: Starting graph execution")
989
+
990
+ # Prepare state for LangGraph with conversation context
991
+ state = ConversationState(
992
+ conversation_id=conversation_id,
993
+ messages=conversation["messages"],
994
+ current_query=user_input,
995
+ query_analysis=None,
996
+ conversation_context=conversation.get("context", {}),
997
+ rag_result=None,
998
+ final_response=None,
999
+ session_start_time=conversation["session_start_time"],
1000
+ last_ai_message_time=conversation["last_ai_message_time"]
1001
+ )
1002
+
1003
+ # Run the graph
1004
+ final_state = self.graph.invoke(state)
1005
+
1006
+ # Add the AI response to conversation
1007
+ if final_state["final_response"]:
1008
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
1009
+
1010
+ # Update conversation state
1011
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
1012
+ conversation["context"] = final_state["conversation_context"]
1013
+
1014
+ # Save conversation
1015
+ self._save_conversation(conversation_file, conversation)
1016
+
1017
+ self.logger.info(f"✅ LANGGRAPH: Graph execution completed")
1018
+ self.logger.info(f"🎯 CHAT COMPLETE: Response ready")
1019
+
1020
+ # Return both response and RAG result for UI
1021
+ return {
1022
+ 'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
1023
+ 'rag_result': final_state["rag_result"],
1024
+ 'actual_rag_query': final_state.get("rag_query", "")
1025
+ }
1026
+
1027
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
1028
+ """Load conversation from file"""
1029
+ if conversation_file.exists():
1030
+ try:
1031
+ with open(conversation_file) as f:
1032
+ data = json.load(f)
1033
+ # Convert message dicts back to LangChain messages
1034
+ messages = []
1035
+ for msg_data in data.get("messages", []):
1036
+ if msg_data["type"] == "human":
1037
+ messages.append(HumanMessage(content=msg_data["content"]))
1038
+ elif msg_data["type"] == "ai":
1039
+ messages.append(AIMessage(content=msg_data["content"]))
1040
+ data["messages"] = messages
1041
+ return data
1042
+ except Exception as e:
1043
+ self.logger.info(f"⚠️ Could not load conversation: {e}")
1044
+
1045
+ # Return default conversation
1046
+ return {
1047
+ "messages": [],
1048
+ "session_start_time": time.time(),
1049
+ "last_ai_message_time": time.time(),
1050
+ "context": {}
1051
+ }
1052
+
1053
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
1054
+ """Save conversation to file"""
1055
+ try:
1056
+ # Convert LangChain messages to serializable format
1057
+ messages_data = []
1058
+ for msg in conversation["messages"]:
1059
+ if isinstance(msg, HumanMessage):
1060
+ messages_data.append({"type": "human", "content": msg.content})
1061
+ elif isinstance(msg, AIMessage):
1062
+ messages_data.append({"type": "ai", "content": msg.content})
1063
+
1064
+ data = {
1065
+ "messages": messages_data,
1066
+ "session_start_time": conversation["session_start_time"],
1067
+ "last_ai_message_time": conversation["last_ai_message_time"],
1068
+ "context": conversation.get("context", {}),
1069
+ "last_updated": datetime.now().isoformat()
1070
+ }
1071
+
1072
+ with open(conversation_file, "w") as f:
1073
+ json.dump(data, f, indent=2)
1074
+
1075
+ except Exception as e:
1076
+ self.logger.info(f"⚠️ Could not save conversation: {e}")
1077
+
1078
+
1079
+ def get_chatbot():
1080
+ """Get chatbot instance"""
1081
+ return IntelligentRAGChatbot()
1082
+
1083
+ if __name__ == "__main__":
1084
+ # Test the chatbot
1085
+ chatbot = IntelligentRAGChatbot()
1086
+
1087
+ # Test conversation
1088
+ test_queries = [
1089
+ "How much was the budget allocation for government salary payroll management?",
1090
+ "Namutumba district in 2023",
1091
+ "KCCA"
1092
+ ]
1093
+
1094
+ for query in test_queries:
1095
+ self.logger.info(f"\n{'='*50}")
1096
+ self.logger.info(f"User: {query}")
1097
+ response = chatbot.chat(query)
1098
+ self.logger.info(f"Bot: {response}")
src/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audit QA Refactored Module
3
+ A modular and maintainable RAG pipeline for audit report analysis.
4
+ """
5
+
6
+ from .pipeline import PipelineManager
7
+ from .config.loader import load_config
8
+
9
+ __version__ = "2.0.0"
10
+ __all__ = ["PipelineManager", "load_config"]
src/config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Configuration management for Audit QA."""
2
+
3
+ from .loader import load_config, get_nested_config
4
+
5
+ __all__ = ["load_config", "get_nested_config"]
src/config/collections.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "docling": {
3
+ "model": "BAAI/bge-m3",
4
+ "description": "Default collection with BGE-M3 embedding model"
5
+ },
6
+ "modernbert-embed-base-akryl-matryoshka": {
7
+ "model": "Akryl/modernbert-embed-base-akryl-matryoshka",
8
+ "description": "ModernBERT embedding model with matryoshka representation"
9
+ },
10
+ "sentence-transformers-all-MiniLM-L6-v2": {
11
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
12
+ "description": "Sentence transformers MiniLM model"
13
+ },
14
+ "sentence-transformers-all-mpnet-base-v2": {
15
+ "model": "sentence-transformers/all-mpnet-base-v2",
16
+ "description": "Sentence transformers MPNet model"
17
+ },
18
+ "BAAI-bge-m3": {
19
+ "model": "BAAI/bge-m3",
20
+ "description": "BAAI BGE-M3 multilingual embedding model"
21
+ }
22
+ }
src/config/loader.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration loader for YAML settings."""
2
+
3
+ import yaml
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Dict, Any, Optional
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ load_dotenv()
11
+
12
+ def load_config(config_path: str = None) -> Dict[str, Any]:
13
+ """
14
+ Load configuration from YAML file.
15
+
16
+ Args:
17
+ config_path: Path to config file. If None, uses default settings.yaml
18
+
19
+ Returns:
20
+ Dictionary containing configuration settings
21
+ """
22
+ if config_path is None:
23
+ # Default to settings.yaml in the same directory as this file
24
+ config_path = Path(__file__).parent / "settings.yaml"
25
+
26
+ config_path = Path(config_path)
27
+
28
+ if not config_path.exists():
29
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
30
+
31
+ with open(config_path, 'r', encoding='utf-8') as f:
32
+ content = f.read()
33
+
34
+ # Replace environment variables in the content
35
+ import os
36
+ import re
37
+
38
+ def replace_env_vars(match):
39
+ env_var = match.group(1)
40
+ return os.getenv(env_var, match.group(0)) # Return original if env var not found
41
+
42
+ # Replace ${VAR} patterns with environment variables
43
+ content = re.sub(r'\$\{([^}]+)\}', replace_env_vars, content)
44
+
45
+ config = yaml.safe_load(content)
46
+
47
+ # Override with environment variables if they exist
48
+ config = _override_with_env_vars(config)
49
+
50
+ return config
51
+
52
+
53
+ def _override_with_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
54
+ """Override config values with environment variables where available."""
55
+
56
+ # Map environment variables to config paths
57
+ env_mappings = {
58
+ 'QDRANT_URL': ['qdrant', 'url'],
59
+ 'QDRANT_COLLECTION': ['qdrant', 'collection_name'],
60
+ 'QDRANT_API_KEY': ['qdrant', 'api_key'],
61
+ 'RETRIEVER_MODEL': ['retriever', 'model'],
62
+ 'RANKER_MODEL': ['ranker', 'model'],
63
+ 'READER_TYPE': ['reader', 'default_type'],
64
+ 'MAX_TOKENS': ['reader', 'max_tokens'],
65
+ 'MISTRAL_API_KEY': ['reader', 'MISTRAL', 'api_key'],
66
+ 'OPENAI_API_KEY': ['reader', 'OPENAI', 'api_key'],
67
+ 'NEBIUS_API_KEY': ['reader', 'INF_PROVIDERS', 'api_key'],
68
+ 'NVIDIA_SERVER_API_KEY': ['reader', 'NVIDIA', 'api_key'],
69
+ 'SERVERLESS_API_KEY': ['reader', 'SERVERLESS', 'api_key'],
70
+ 'DEDICATED_API_KEY': ['reader', 'DEDICATED', 'api_key'],
71
+ 'OPENROUTER_API_KEY': ['reader', 'OPENROUTER', 'api_key'],
72
+ }
73
+
74
+ for env_var, config_path in env_mappings.items():
75
+ env_value = os.getenv(env_var)
76
+ if env_value:
77
+ # Navigate to the nested config location
78
+ current = config
79
+ for key in config_path[:-1]:
80
+ if key not in current:
81
+ current[key] = {}
82
+ current = current[key]
83
+
84
+ # Set the final value, converting to appropriate type
85
+ final_key = config_path[-1]
86
+ if final_key in ['top_k', 'max_tokens', 'num_predict']:
87
+ current[final_key] = int(env_value)
88
+ elif final_key in ['normalize', 'prefer_grpc']:
89
+ current[final_key] = env_value.lower() in ('true', '1', 'yes')
90
+ elif final_key == 'temperature':
91
+ current[final_key] = float(env_value)
92
+ else:
93
+ current[final_key] = env_value
94
+
95
+ return config
96
+
97
+
98
+ def get_nested_config(config: Dict[str, Any], path: str, default=None):
99
+ """
100
+ Get a nested configuration value using dot notation.
101
+
102
+ Args:
103
+ config: Configuration dictionary
104
+ path: Dot-separated path (e.g., 'reader.MISTRAL.model')
105
+ default: Default value if path not found
106
+
107
+ Returns:
108
+ Configuration value or default
109
+ """
110
+ keys = path.split('.')
111
+ current = config
112
+
113
+ try:
114
+ for key in keys:
115
+ current = current[key]
116
+ return current
117
+ except (KeyError, TypeError):
118
+ return default
119
+
120
+
121
+ def load_collections_mapping() -> Dict[str, Dict[str, str]]:
122
+ """Load collections mapping from JSON file."""
123
+ collections_file = Path(__file__).parent / "collections.json"
124
+
125
+ if not collections_file.exists():
126
+ # Return default mapping if file doesn't exist
127
+ return {
128
+ "docling": {
129
+ "model": "sentence-transformers/all-MiniLM-L6-v2",
130
+ "description": "Default collection"
131
+ }
132
+ }
133
+
134
+ with open(collections_file, 'r') as f:
135
+ return json.load(f)
136
+
137
+
138
+ def get_embedding_model_for_collection(collection_name: str) -> Optional[str]:
139
+ """Get embedding model for a specific collection name."""
140
+ collections = load_collections_mapping()
141
+
142
+ if collection_name in collections:
143
+ return collections[collection_name]["model"]
144
+
145
+ # Try to infer from collection name patterns
146
+ if "modernbert" in collection_name.lower():
147
+ return "Akryl/modernbert-embed-base-akryl-matryoshka"
148
+ elif "minilm" in collection_name.lower():
149
+ return "sentence-transformers/all-MiniLM-L6-v2"
150
+ elif "mpnet" in collection_name.lower():
151
+ return "sentence-transformers/all-mpnet-base-v2"
152
+ elif "bge" in collection_name.lower():
153
+ return "BAAI/bge-m3"
154
+
155
+ return None
156
+
157
+
158
+ def get_collection_info(collection_name: str) -> Dict[str, str]:
159
+ """Get full collection information including model and description."""
160
+ collections = load_collections_mapping()
161
+
162
+ if collection_name in collections:
163
+ return collections[collection_name]
164
+
165
+ # Return inferred info for unknown collections
166
+ model = get_embedding_model_for_collection(collection_name)
167
+ return {
168
+ "model": model or "unknown",
169
+ "description": f"Auto-inferred collection: {collection_name}"
170
+ }
src/config/settings.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audit QA Configuration
2
+ # Converted from model_params.cfg to YAML format
3
+
4
+ qdrant:
5
+ # url: "http://10.1.4.192:8803"`
6
+ url: "https://2c6d0136-b6ca-4400-bac5-1703f58abc43.europe-west3-0.gcp.cloud.qdrant.io"
7
+ collection_name: "docling"
8
+ prefer_grpc: true
9
+ api_key: "${QDRANT_API_KEY}" # Load from environment variable
10
+
11
+ retriever:
12
+ model: "BAAI/bge-m3"
13
+ normalize: true
14
+ top_k: 20
15
+
16
+ retrieval:
17
+ use_reranking: true
18
+ reranker_model: "BAAI/bge-reranker-v2-m3"
19
+ reranker_top_k: 5
20
+
21
+ ranker:
22
+ model: "BAAI/bge-reranker-v2-m3"
23
+ top_k: 5
24
+
25
+ bm25:
26
+ top_k: 20
27
+
28
+ hybrid:
29
+ default_mode: "vector_only" # Options: vector_only, sparse_only, hybrid
30
+ default_alpha: 0.5 # Weight for vector scores (0.5 = equal weight)
31
+
32
+ reader:
33
+ default_type: "OPENAI"
34
+ max_tokens: 768
35
+
36
+ # Different LLM provider configurations
37
+ INF_PROVIDERS:
38
+ model: "meta-llama/Llama-3.1-8B-Instruct"
39
+ provider: "nebius"
40
+
41
+ # Not working
42
+ NVIDIA:
43
+ model: "meta-llama/Llama-3.1-8B-Instruct"
44
+ endpoint: "https://huggingface.co/api/integrations/dgx/v1"
45
+
46
+ # Not working
47
+ DEDICATED:
48
+ model: "meta-llama/Llama-3.1-8B-Instruct"
49
+ endpoint: "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
50
+
51
+ MISTRAL:
52
+ model: "mistral-medium-latest"
53
+
54
+ OPENAI:
55
+ model: "gpt-4o-mini"
56
+
57
+ OLLAMA:
58
+ model: "mistral-small3.1:24b-instruct-2503-q8_0"
59
+ base_url: "http://10.1.4.192:11434/"
60
+ temperature: 0.8
61
+ num_predict: 256
62
+
63
+ OPENROUTER:
64
+ model: "moonshotai/kimi-k2:free"
65
+ base_url: "https://openrouter.ai/api/v1"
66
+ temperature: 0.7
67
+ max_tokens: 1000
68
+ # site_url: "https://your-site.com" # optional, for OpenRouter ranking
69
+ # site_name: "Your Site Name" # optional, for OpenRouter ranking
70
+
71
+ app:
72
+ dropdown_default: "Annual Consolidated OAG 2024"
73
+
74
+ # File paths
75
+ paths:
76
+ chunks_file: "reports/docling_chunks.json"
77
+ reports_dir: "reports"
78
+
79
+ # Feature toggles
80
+ features:
81
+ enable_session: true
82
+ enable_logging: true
83
+
84
+ # Logging and HuggingFace scheduler configuration
85
+ logging:
86
+ json_dataset_dir: "json_dataset"
87
+ huggingface:
88
+ repo_id: "GIZ/spaces_logs"
89
+ repo_type: "dataset"
90
+ folder_path: "json_dataset"
91
+ path_in_repo: "audit_chatbot"
92
+ token_env_var: "SPACES_LOG"
src/llm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """LLM adapters and utilities."""
2
+
3
+ from .adapters import LLMRegistry, get_llm_client
4
+ from .templates import get_message_template, PromptTemplate, create_audit_prompt
5
+
6
+ __all__ = ["LLMRegistry", "get_llm_client", "get_message_template", "PromptTemplate", "create_audit_prompt"]
src/llm/adapters.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM client adapters for different providers."""
2
+
3
+ from typing import Dict, Any, List, Optional, Union
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+
7
+ # LangChain imports
8
+ from langchain_mistralai.chat_models import ChatMistralAI
9
+ from langchain_openai.chat_models import ChatOpenAI
10
+ from langchain_ollama import ChatOllama
11
+
12
+ # Legacy client dependencies
13
+ from huggingface_hub import InferenceClient
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+ from langchain_community.llms import HuggingFaceEndpoint
16
+ from langchain_community.chat_models.huggingface import ChatHuggingFace
17
+
18
+ # Configuration loader
19
+ from ..config.loader import load_config
20
+
21
+ # Load configuration once at module level
22
+ _config = load_config()
23
+
24
+
25
+ # Legacy client factory functions (inlined from auditqa_old.reader)
26
+ def _create_inf_provider_client():
27
+ """Create INF_PROVIDERS client."""
28
+ reader_config = _config.get("reader", {})
29
+ inf_config = reader_config.get("INF_PROVIDERS", {})
30
+
31
+ api_key = inf_config.get("api_key")
32
+ if not api_key:
33
+ raise ValueError("INF_PROVIDERS api_key not found in configuration")
34
+
35
+ provider = inf_config.get("provider")
36
+ if not provider:
37
+ raise ValueError("INF_PROVIDERS provider not found in configuration")
38
+
39
+ return InferenceClient(
40
+ provider=provider,
41
+ api_key=api_key,
42
+ bill_to="GIZ",
43
+ )
44
+
45
+
46
+ def _create_nvidia_client():
47
+ """Create NVIDIA client."""
48
+ reader_config = _config.get("reader", {})
49
+ nvidia_config = reader_config.get("NVIDIA", {})
50
+
51
+ api_key = nvidia_config.get("api_key")
52
+ if not api_key:
53
+ raise ValueError("NVIDIA api_key not found in configuration")
54
+
55
+ endpoint = nvidia_config.get("endpoint")
56
+ if not endpoint:
57
+ raise ValueError("NVIDIA endpoint not found in configuration")
58
+
59
+ return InferenceClient(
60
+ base_url=endpoint,
61
+ api_key=api_key
62
+ )
63
+
64
+
65
+ def _create_serverless_client():
66
+ """Create serverless API client."""
67
+ reader_config = _config.get("reader", {})
68
+ serverless_config = reader_config.get("SERVERLESS", {})
69
+
70
+ api_key = serverless_config.get("api_key")
71
+ if not api_key:
72
+ raise ValueError("SERVERLESS api_key not found in configuration")
73
+
74
+ model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
75
+
76
+ return InferenceClient(
77
+ model=model_id,
78
+ api_key=api_key,
79
+ )
80
+
81
+
82
+ def _create_dedicated_endpoint_client():
83
+ """Create dedicated endpoint client."""
84
+ reader_config = _config.get("reader", {})
85
+ dedicated_config = reader_config.get("DEDICATED", {})
86
+
87
+ api_key = dedicated_config.get("api_key")
88
+ if not api_key:
89
+ raise ValueError("DEDICATED api_key not found in configuration")
90
+
91
+ endpoint = dedicated_config.get("endpoint")
92
+ if not endpoint:
93
+ raise ValueError("DEDICATED endpoint not found in configuration")
94
+
95
+ max_tokens = dedicated_config.get("max_tokens", 768)
96
+
97
+ # Set up the streaming callback handler
98
+ callback = StreamingStdOutCallbackHandler()
99
+
100
+ # Initialize the HuggingFaceEndpoint with streaming enabled
101
+ llm_qa = HuggingFaceEndpoint(
102
+ endpoint_url=endpoint,
103
+ max_new_tokens=int(max_tokens),
104
+ repetition_penalty=1.03,
105
+ timeout=70,
106
+ huggingfacehub_api_token=api_key,
107
+ streaming=True,
108
+ callbacks=[callback]
109
+ )
110
+
111
+ # Create a ChatHuggingFace instance with the streaming-enabled endpoint
112
+ return ChatHuggingFace(llm=llm_qa)
113
+
114
+
115
+ @dataclass
116
+ class LLMResponse:
117
+ """Standardized LLM response format."""
118
+ content: str
119
+ model: str
120
+ provider: str
121
+ metadata: Dict[str, Any] = None
122
+
123
+
124
+ class BaseLLMAdapter(ABC):
125
+ """Base class for LLM adapters."""
126
+
127
+ def __init__(self, config: Dict[str, Any]):
128
+ self.config = config
129
+
130
+ @abstractmethod
131
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
132
+ """Generate response from messages."""
133
+ pass
134
+
135
+ @abstractmethod
136
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
137
+ """Generate streaming response from messages."""
138
+ pass
139
+
140
+
141
+ class MistralAdapter(BaseLLMAdapter):
142
+ """Adapter for Mistral AI models."""
143
+
144
+ def __init__(self, config: Dict[str, Any]):
145
+ super().__init__(config)
146
+ self.model = ChatMistralAI(
147
+ model=config.get("model", "mistral-medium-latest")
148
+ )
149
+
150
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
151
+ """Generate response using Mistral."""
152
+ response = self.model.invoke(messages)
153
+
154
+ return LLMResponse(
155
+ content=response.content,
156
+ model=self.config.get("model", "mistral-medium-latest"),
157
+ provider="mistral",
158
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
159
+ )
160
+
161
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
162
+ """Generate streaming response using Mistral."""
163
+ for chunk in self.model.stream(messages):
164
+ if chunk.content:
165
+ yield chunk.content
166
+
167
+
168
+ class OpenAIAdapter(BaseLLMAdapter):
169
+ """Adapter for OpenAI models."""
170
+
171
+ def __init__(self, config: Dict[str, Any]):
172
+ super().__init__(config)
173
+ self.model = ChatOpenAI(
174
+ model=config.get("model", "gpt-4o-mini")
175
+ )
176
+
177
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
178
+ """Generate response using OpenAI."""
179
+ response = self.model.invoke(messages)
180
+
181
+ return LLMResponse(
182
+ content=response.content,
183
+ model=self.config.get("model", "gpt-4o-mini"),
184
+ provider="openai",
185
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
186
+ )
187
+
188
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
189
+ """Generate streaming response using OpenAI."""
190
+ for chunk in self.model.stream(messages):
191
+ if chunk.content:
192
+ yield chunk.content
193
+
194
+
195
+ class OllamaAdapter(BaseLLMAdapter):
196
+ """Adapter for Ollama models."""
197
+
198
+ def __init__(self, config: Dict[str, Any]):
199
+ super().__init__(config)
200
+ self.model = ChatOllama(
201
+ model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
202
+ base_url=config.get("base_url", "http://localhost:11434/"),
203
+ temperature=config.get("temperature", 0.8),
204
+ num_predict=config.get("num_predict", 256)
205
+ )
206
+
207
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
208
+ """Generate response using Ollama."""
209
+ response = self.model.invoke(messages)
210
+
211
+ return LLMResponse(
212
+ content=response.content,
213
+ model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
214
+ provider="ollama",
215
+ metadata={}
216
+ )
217
+
218
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
219
+ """Generate streaming response using Ollama."""
220
+ for chunk in self.model.stream(messages):
221
+ if chunk.content:
222
+ yield chunk.content
223
+
224
+
225
+ class OpenRouterAdapter(BaseLLMAdapter):
226
+ """Adapter for OpenRouter models."""
227
+
228
+ def __init__(self, config: Dict[str, Any]):
229
+ super().__init__(config)
230
+
231
+ # Prepare custom headers for OpenRouter (optional)
232
+ headers = {}
233
+ if config.get("site_url"):
234
+ headers["HTTP-Referer"] = config["site_url"]
235
+ if config.get("site_name"):
236
+ headers["X-Title"] = config["site_name"]
237
+
238
+ # Initialize ChatOpenAI with OpenRouter configuration
239
+ self.model = ChatOpenAI(
240
+ model=config.get("model", "openai/gpt-3.5-turbo"),
241
+ api_key=config.get("api_key"),
242
+ base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
243
+ default_headers= headers if headers else {},
244
+ temperature=config.get("temperature", 0.7),
245
+ max_tokens=config.get("max_tokens", 1000)
246
+ )
247
+
248
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
249
+ """Generate response using OpenRouter."""
250
+ response = self.model.invoke(messages)
251
+
252
+ return LLMResponse(
253
+ content=response.content,
254
+ model=self.config.get("model", "openai/gpt-3.5-turbo"),
255
+ provider="openrouter",
256
+ metadata={"usage": getattr(response, 'usage_metadata', {})}
257
+ )
258
+
259
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
260
+ """Generate streaming response using OpenRouter."""
261
+ for chunk in self.model.stream(messages):
262
+ if chunk.content:
263
+ yield chunk.content
264
+
265
+
266
+ class LegacyAdapter(BaseLLMAdapter):
267
+ """Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
268
+
269
+ def __init__(self, config: Dict[str, Any], client_type: str):
270
+ super().__init__(config)
271
+ self.client_type = client_type
272
+ self.client = self._create_client()
273
+
274
+ def _create_client(self):
275
+ """Create legacy client based on type."""
276
+ if self.client_type == "INF_PROVIDERS":
277
+ return _create_inf_provider_client()
278
+ elif self.client_type == "NVIDIA":
279
+ return _create_nvidia_client()
280
+ elif self.client_type == "DEDICATED":
281
+ return _create_dedicated_endpoint_client()
282
+ else: # SERVERLESS
283
+ return _create_serverless_client()
284
+
285
+ def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
286
+ """Generate response using legacy client."""
287
+ max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
288
+
289
+ if self.client_type == "INF_PROVIDERS":
290
+ response = self.client.chat.completions.create(
291
+ model=self.config.get("model"),
292
+ messages=messages,
293
+ max_tokens=max_tokens
294
+ )
295
+ content = response.choices[0].message.content
296
+
297
+ elif self.client_type == "NVIDIA":
298
+ response = self.client.chat_completion(
299
+ model=self.config.get("model"),
300
+ messages=messages,
301
+ max_tokens=max_tokens
302
+ )
303
+ content = response.choices[0].message.content
304
+
305
+ else: # DEDICATED or SERVERLESS
306
+ response = self.client.chat_completion(
307
+ messages=messages,
308
+ max_tokens=max_tokens
309
+ )
310
+ content = response.choices[0].message.content
311
+
312
+ return LLMResponse(
313
+ content=content,
314
+ model=self.config.get("model", "unknown"),
315
+ provider=self.client_type.lower(),
316
+ metadata={}
317
+ )
318
+
319
+ def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
320
+ """Generate streaming response using legacy client."""
321
+ # Legacy clients may not support streaming in the same way
322
+ # This is a simplified implementation
323
+ response = self.generate(messages, **kwargs)
324
+ words = response.content.split()
325
+ for word in words:
326
+ yield word + " "
327
+
328
+
329
+ class LLMRegistry:
330
+ """Registry for managing different LLM adapters."""
331
+
332
+ def __init__(self):
333
+ self.adapters = {}
334
+ self.adapter_configs = {}
335
+
336
+ def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
337
+ """Register an LLM adapter (lazy instantiation)."""
338
+ self.adapter_configs[name] = (adapter_class, config)
339
+
340
+ def get_adapter(self, name: str) -> BaseLLMAdapter:
341
+ """Get an LLM adapter by name (lazy instantiation)."""
342
+ if name not in self.adapter_configs:
343
+ raise ValueError(f"Unknown LLM adapter: {name}")
344
+
345
+ # Lazy instantiation - only create when needed
346
+ if name not in self.adapters:
347
+ adapter_class, config = self.adapter_configs[name]
348
+ self.adapters[name] = adapter_class(config)
349
+
350
+ return self.adapters[name]
351
+
352
+ def list_adapters(self) -> List[str]:
353
+ """List available adapter names."""
354
+ return list(self.adapter_configs.keys())
355
+
356
+
357
+ def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
358
+ """
359
+ Create and populate LLM registry from configuration.
360
+
361
+ Args:
362
+ config: Configuration dictionary
363
+
364
+ Returns:
365
+ Populated LLMRegistry
366
+ """
367
+ registry = LLMRegistry()
368
+ reader_config = config.get("reader", {})
369
+
370
+ # Register simple adapters
371
+ if "MISTRAL" in reader_config:
372
+ registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
373
+
374
+ if "OPENAI" in reader_config:
375
+ registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
376
+
377
+ if "OLLAMA" in reader_config:
378
+ registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
379
+
380
+ if "OPENROUTER" in reader_config:
381
+ registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
382
+
383
+ # Register legacy adapters
384
+ # legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
385
+ legacy_types = ["INF_PROVIDERS"]
386
+ for legacy_type in legacy_types:
387
+ if legacy_type in reader_config:
388
+ registry.register_adapter(
389
+ legacy_type.lower(),
390
+ lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
391
+ reader_config[legacy_type]
392
+ )
393
+
394
+ return registry
395
+
396
+
397
+ def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
398
+ """
399
+ Get LLM client for specified provider.
400
+
401
+ Args:
402
+ provider: Provider name (mistral, openai, ollama, etc.)
403
+ config: Configuration dictionary
404
+
405
+ Returns:
406
+ LLM adapter instance
407
+ """
408
+ registry = create_llm_registry(config)
409
+ return registry.get_adapter(provider)
src/llm/templates.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM prompt templates and message formatting utilities."""
2
+
3
+ from typing import List, Dict, Any, Union
4
+ from dataclasses import dataclass
5
+ from langchain.schema import SystemMessage, HumanMessage
6
+
7
+
8
+ @dataclass
9
+ class PromptTemplate:
10
+ """Template for managing prompts with variables."""
11
+
12
+ system_prompt: str
13
+ user_prompt_template: str
14
+
15
+ def format(self, **kwargs) -> tuple:
16
+ """Format the template with provided variables."""
17
+ formatted_user = self.user_prompt_template.format(**kwargs)
18
+ return self.system_prompt, formatted_user
19
+
20
+
21
+ # Default system prompt for audit Q&A
22
+ DEFAULT_AUDIT_SYSTEM_PROMPT = """
23
+ You are AuditQ&A, an AI Assistant for audit reports. Answer questions directly and factually based on the provided context.
24
+
25
+ Guidelines:
26
+ - Answer directly and concisely (2-3 sentences maximum)
27
+ - Use specific facts and numbers from the context
28
+ - Cite sources using [Doc i] format
29
+ - Be factual, not opinionated
30
+ - Avoid phrases like "From my point of view", "I think", "It seems"
31
+
32
+ Examples:
33
+
34
+ Query: "What challenges arise from contradictory PDM implementation guidelines?"
35
+ Context: [Retrieved documents about PDM guidelines contradictions]
36
+ Answer: "Contradictory PDM implementation guidelines cause challenges during implementation, as entities receive numerous and often conflicting directives from different authorities. For example, guidelines on transfer of funds to PDM SACCOs differ between the PDM Secretariat and PSST, and there are conflicting directives on fund diversion from various authorities."
37
+
38
+ Query: "What was the supplementary funding obtained for the wage budget?"
39
+ Context: [Retrieved documents about wage budget funding]
40
+ Answer: "The supplementary funding obtained for the wage budget was UGX.2,208,040,656."
41
+
42
+ Now answer the following question based on the provided context:
43
+ """
44
+
45
+ # Default user prompt template
46
+ DEFAULT_USER_PROMPT_TEMPLATE = """Passages:
47
+ {context}
48
+ -----------------------
49
+ Question: {question} - Explained to audit expert
50
+ Answer in english with the passages citations:
51
+ """
52
+
53
+
54
+ def create_audit_prompt(context_list: List[str], query: str) -> List[Dict[str, str]]:
55
+ """
56
+ Create audit Q&A prompt messages from context and query.
57
+
58
+ Args:
59
+ context_list: List of context passages
60
+ query: User query
61
+
62
+ Returns:
63
+ List of message dictionaries for LLM
64
+ """
65
+ # Join context passages with numbering
66
+ numbered_context = []
67
+ for i, passage in enumerate(context_list, 1):
68
+ numbered_context.append(f"Doc {i}: {passage}")
69
+
70
+ context_str = "\n\n".join(numbered_context)
71
+
72
+ # Format user prompt
73
+ user_prompt = DEFAULT_USER_PROMPT_TEMPLATE.format(
74
+ context=context_str,
75
+ question=query
76
+ )
77
+
78
+ # Return as message format
79
+ messages = [
80
+ {"role": "system", "content": DEFAULT_AUDIT_SYSTEM_PROMPT},
81
+ {"role": "user", "content": user_prompt}
82
+ ]
83
+
84
+ return messages
85
+
86
+
87
+ def get_message_template(
88
+ provider_type: str,
89
+ system_prompt: str,
90
+ user_prompt: str
91
+ ) -> List[Union[Dict[str, str], SystemMessage, HumanMessage]]:
92
+ """
93
+ Get message template based on LLM provider type.
94
+
95
+ Args:
96
+ provider_type: Type of LLM provider
97
+ system_prompt: System prompt content
98
+ user_prompt: User prompt content
99
+
100
+ Returns:
101
+ List of messages in the appropriate format for the provider
102
+ """
103
+ provider_type = provider_type.upper()
104
+
105
+ if provider_type in ['NVIDIA', 'INF_PROVIDERS', 'MISTRAL', 'OPENAI', 'OPENROUTER']:
106
+ # Dictionary format for API-based providers
107
+ messages = [
108
+ {"role": "system", "content": system_prompt},
109
+ {"role": "user", "content": user_prompt}
110
+ ]
111
+ elif provider_type in ['DEDICATED', 'SERVERLESS', 'OLLAMA']:
112
+ # LangChain message objects for local/dedicated providers
113
+ messages = [
114
+ SystemMessage(content=system_prompt),
115
+ HumanMessage(content=user_prompt)
116
+ ]
117
+ else:
118
+ # Default to dictionary format
119
+ messages = [
120
+ {"role": "system", "content": system_prompt},
121
+ {"role": "user", "content": user_prompt}
122
+ ]
123
+
124
+ return messages
125
+
126
+
127
+ def create_custom_prompt_template(
128
+ system_prompt: str,
129
+ user_template: str
130
+ ) -> PromptTemplate:
131
+ """
132
+ Create a custom prompt template.
133
+
134
+ Args:
135
+ system_prompt: System prompt content
136
+ user_template: User prompt template with placeholders
137
+
138
+ Returns:
139
+ PromptTemplate instance
140
+ """
141
+ return PromptTemplate(
142
+ system_prompt=system_prompt,
143
+ user_prompt_template=user_template
144
+ )
145
+
146
+
147
+ def create_evaluation_prompt(context_list: List[str], query: str, expected_answer: str) -> List[Dict[str, str]]:
148
+ """
149
+ Create prompt for evaluation purposes with expected answer.
150
+
151
+ Args:
152
+ context_list: List of context passages
153
+ query: User query
154
+ expected_answer: Expected/ground truth answer
155
+
156
+ Returns:
157
+ List of message dictionaries for evaluation
158
+ """
159
+ # Join context passages
160
+ context_str = "\n\n".join([f"Doc {i}: {passage}" for i, passage in enumerate(context_list, 1)])
161
+
162
+ evaluation_system_prompt = """
163
+ You are an evaluation assistant. Given context passages, a question, and an expected answer,
164
+ evaluate how well the provided context supports answering the question accurately.
165
+
166
+ Provide your evaluation focusing on:
167
+ 1. Relevance of the context to the question
168
+ 2. Completeness of information needed to answer
169
+ 3. Quality and accuracy of supporting details
170
+ """
171
+
172
+ user_prompt = f"""Context Passages:
173
+ {context_str}
174
+
175
+ Question: {query}
176
+ Expected Answer: {expected_answer}
177
+
178
+ Evaluation:"""
179
+
180
+ return [
181
+ {"role": "system", "content": evaluation_system_prompt},
182
+ {"role": "user", "content": user_prompt}
183
+ ]
184
+
185
+
186
+ def get_prompt_variants() -> Dict[str, PromptTemplate]:
187
+ """
188
+ Get different prompt template variants for testing.
189
+
190
+ Returns:
191
+ Dictionary of named prompt templates
192
+ """
193
+ variants = {
194
+ "standard": create_custom_prompt_template(
195
+ DEFAULT_AUDIT_SYSTEM_PROMPT,
196
+ DEFAULT_USER_PROMPT_TEMPLATE
197
+ ),
198
+
199
+ "concise": create_custom_prompt_template(
200
+ """You are an audit report AI assistant. Provide clear, concise answers based on the given context passages. Always cite sources using [Doc i] format.""",
201
+ """Context:\n{context}\n\nQuestion: {question}\nAnswer:"""
202
+ ),
203
+
204
+ "detailed": create_custom_prompt_template(
205
+ DEFAULT_AUDIT_SYSTEM_PROMPT + """\n\nAdditional Instructions:
206
+ - Provide detailed explanations with specific examples
207
+ - Include relevant numbers, dates, and financial figures when available
208
+ - Structure your response with clear headings when appropriate
209
+ - Explain the significance of findings in the context of governance and accountability""",
210
+ DEFAULT_USER_PROMPT_TEMPLATE
211
+ )
212
+ }
213
+
214
+ return variants
215
+
216
+
217
+ # Backward compatibility function
218
+ def format_context_with_citations(context_list: List[str]) -> str:
219
+ """
220
+ Format context list with document citations.
221
+
222
+ Args:
223
+ context_list: List of context passages
224
+
225
+ Returns:
226
+ Formatted context string with citations
227
+ """
228
+ formatted_passages = []
229
+ for i, passage in enumerate(context_list, 1):
230
+ formatted_passages.append(f"Doc {i}: {passage}")
231
+
232
+ return "\n\n".join(formatted_passages)
src/loader.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading utilities for chunks and JSON files."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import List, Dict, Any
6
+ from langchain.docstore.document import Document
7
+
8
+
9
+ def load_json(filepath: Path | str) -> List[Dict[str, Any]]:
10
+ """
11
+ Load JSON data from file.
12
+
13
+ Args:
14
+ filepath: Path to JSON file
15
+
16
+ Returns:
17
+ List of dictionaries containing the JSON data
18
+ """
19
+ filepath = Path(filepath)
20
+
21
+ if not filepath.exists():
22
+ raise FileNotFoundError(f"JSON file not found: {filepath}")
23
+
24
+ with open(filepath, 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+
27
+ return data
28
+
29
+
30
+ def open_file(filepath: Path | str) -> str:
31
+ """
32
+ Open and read a text file.
33
+
34
+ Args:
35
+ filepath: Path to text file
36
+
37
+ Returns:
38
+ File contents as string
39
+ """
40
+ filepath = Path(filepath)
41
+
42
+ if not filepath.exists():
43
+ raise FileNotFoundError(f"File not found: {filepath}")
44
+
45
+ with open(filepath, 'r', encoding='utf-8') as f:
46
+ content = f.read()
47
+
48
+ return content
49
+
50
+
51
+ def load_chunks(chunks_file: Path | str = None) -> List[Dict[str, Any]]:
52
+ """
53
+ Load document chunks from JSON file.
54
+
55
+ Args:
56
+ chunks_file: Path to chunks JSON file. If None, uses default path.
57
+
58
+ Returns:
59
+ List of chunk dictionaries
60
+ """
61
+ if chunks_file is None:
62
+ chunks_file = Path("reports/docling_chunks.json")
63
+
64
+ return load_json(chunks_file)
65
+
66
+
67
+ def chunks_to_documents(chunks: List[Dict[str, Any]]) -> List[Document]:
68
+ """
69
+ Convert chunk dictionaries to LangChain Document objects.
70
+
71
+ Args:
72
+ chunks: List of chunk dictionaries
73
+
74
+ Returns:
75
+ List of Document objects
76
+ """
77
+ documents = []
78
+
79
+ for chunk in chunks:
80
+ doc = Document(
81
+ page_content=chunk.get("content", ""),
82
+ metadata=chunk.get("metadata", {})
83
+ )
84
+ documents.append(doc)
85
+
86
+ return documents
87
+
88
+
89
+ def validate_chunks(chunks: List[Dict[str, Any]]) -> bool:
90
+ """
91
+ Validate that chunks have required fields.
92
+
93
+ Args:
94
+ chunks: List of chunk dictionaries
95
+
96
+ Returns:
97
+ True if valid, raises ValueError if invalid
98
+ """
99
+ required_fields = ["content", "metadata"]
100
+
101
+ for i, chunk in enumerate(chunks):
102
+ for field in required_fields:
103
+ if field not in chunk:
104
+ raise ValueError(f"Chunk {i} missing required field: {field}")
105
+
106
+ # Validate metadata has required fields
107
+ metadata = chunk["metadata"]
108
+ if not isinstance(metadata, dict):
109
+ raise ValueError(f"Chunk {i} metadata must be a dictionary")
110
+
111
+ # Check for common metadata fields
112
+ if "filename" not in metadata:
113
+ raise ValueError(f"Chunk {i} metadata missing 'filename' field")
114
+
115
+ return True
src/logging.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging utilities (placeholder for legacy compatibility)."""
2
+ import json
3
+ import logging
4
+ from uuid import uuid4
5
+ from pathlib import Path
6
+ from threading import Lock
7
+ from datetime import datetime
8
+ from typing import Dict, Any, Optional
9
+
10
+ from .config import load_config
11
+
12
+ def save_logs(
13
+ scheduler=None,
14
+ json_dataset_path: Path = None,
15
+ logs_data: Dict[str, Any] = None,
16
+ feedback: str = None
17
+ ) -> None:
18
+ """
19
+ Save logs (placeholder for legacy compatibility).
20
+
21
+ Args:
22
+ scheduler: HuggingFace scheduler (not used in refactored version)
23
+ json_dataset_path: Path to JSON dataset
24
+ logs_data: Log data dictionary
25
+ feedback: User feedback
26
+
27
+ Note:
28
+ This is a placeholder function for backward compatibility.
29
+ In the refactored version, logging would be handled differently.
30
+ """
31
+ if not is_logging_enabled():
32
+ return
33
+ try:
34
+ current_time = datetime.now().timestamp()
35
+ logs_data["time"] = str(current_time)
36
+ if feedback:
37
+ logs_data["feedback"] = feedback
38
+ logs_data["record_id"] = str(uuid4())
39
+ field_order = [
40
+ "record_id",
41
+ "session_id",
42
+ "time",
43
+ "session_duration_seconds",
44
+ "client_location",
45
+ "platform",
46
+ "system_prompt",
47
+ "sources",
48
+ "reports",
49
+ "subtype",
50
+ "year",
51
+ "question",
52
+ "retriever",
53
+ "endpoint_type",
54
+ "reader",
55
+ "docs",
56
+ "answer",
57
+ "feedback"
58
+ ]
59
+ ordered_logs = {k: logs_data.get(k) for k in field_order if k in logs_data}
60
+ lock = getattr(scheduler, "lock", None)
61
+ if lock is None:
62
+ lock = Lock()
63
+ with lock:
64
+ with open(json_dataset_path, 'a') as f:
65
+ json.dump(ordered_logs, f)
66
+ f.write("\n")
67
+ logging.info("logging done")
68
+ except Exception as e:
69
+ logging.error(f"Error saving logs: {e}")
70
+ raise
71
+
72
+
73
+ def setup_logging(log_level: str = "INFO", log_file: str = None) -> None:
74
+ """
75
+ Set up logging configuration.
76
+
77
+ Args:
78
+ log_level: Logging level
79
+ log_file: Optional log file path
80
+ """
81
+ if not is_logging_enabled():
82
+ return
83
+
84
+ # Configure logging
85
+ logging.basicConfig(
86
+ level=getattr(logging, log_level.upper()),
87
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
88
+ handlers=[
89
+ logging.StreamHandler(),
90
+ logging.FileHandler(log_file) if log_file else logging.NullHandler()
91
+ ]
92
+ )
93
+
94
+
95
+ def log_query_response(
96
+ query: str,
97
+ response: str,
98
+ metadata: Dict[str, Any] = None
99
+ ) -> None:
100
+ """
101
+ Log query and response for analysis.
102
+
103
+ Args:
104
+ query: User query
105
+ response: System response
106
+ metadata: Additional metadata
107
+ """
108
+ if not is_logging_enabled():
109
+ return
110
+
111
+ logger = logging.getLogger(__name__)
112
+
113
+ log_entry = {
114
+ "query": query,
115
+ "response_length": len(response),
116
+ "metadata": metadata or {}
117
+ }
118
+
119
+ logger.info(f"Query processed: {log_entry}")
120
+
121
+
122
+ def log_error(error: Exception, context: Dict[str, Any] = None) -> None:
123
+ """
124
+ Log error with context.
125
+
126
+ Args:
127
+ error: Exception that occurred
128
+ context: Additional context information
129
+ """
130
+ if not is_logging_enabled():
131
+ return
132
+
133
+ logger = logging.getLogger(__name__)
134
+
135
+ error_info = {
136
+ "error_type": type(error).__name__,
137
+ "error_message": str(error),
138
+ "context": context or {}
139
+ }
140
+
141
+ logger.error(f"Error occurred: {error_info}")
142
+
143
+
144
+ def log_performance_metrics(
145
+ operation: str,
146
+ duration: float,
147
+ metadata: Dict[str, Any] = None
148
+ ) -> None:
149
+ """
150
+ Log performance metrics.
151
+
152
+ Args:
153
+ operation: Name of the operation
154
+ duration: Duration in seconds
155
+ metadata: Additional metadata
156
+ """
157
+ if not is_logging_enabled():
158
+ return
159
+
160
+ logger = logging.getLogger(__name__)
161
+
162
+ metrics = {
163
+ "operation": operation,
164
+ "duration_seconds": duration,
165
+ "metadata": metadata or {}
166
+ }
167
+
168
+ logger.info(f"Performance metrics: {metrics}")
169
+
170
+
171
+ def is_session_enabled() -> bool:
172
+ """
173
+ Returns True if session management is enabled, False otherwise.
174
+ Checks environment variable ENABLE_SESSION first, then config.
175
+ """
176
+ env = os.getenv("ENABLE_SESSION")
177
+ if env is not None:
178
+ return env.lower() in ("1", "true", "yes", "on")
179
+ config = load_config()
180
+ return config.get("features", {}).get("enable_session", True)
181
+
182
+
183
+ def is_logging_enabled() -> bool:
184
+ """
185
+ Returns True if logging is enabled, False otherwise.
186
+ Checks environment variable ENABLE_LOGGING first, then config.
187
+ """
188
+ env = os.getenv("ENABLE_LOGGING")
189
+ if env is not None:
190
+ return env.lower() in ("1", "true", "yes", "on")
191
+ config = load_config()
192
+ return config.get("features", {}).get("enable_logging", True)
193
+
src/pipeline.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main pipeline orchestrator for the Audit QA system."""
2
+ import time
3
+ from pathlib import Path
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Any, List, Optional
6
+
7
+ from langchain.docstore.document import Document
8
+
9
+ from .logging import log_error
10
+ from .llm.adapters import LLMRegistry
11
+ from .loader import chunks_to_documents
12
+ from .vectorstore import VectorStoreManager
13
+ from .retrieval.context import ContextRetriever
14
+ from .config.loader import get_embedding_model_for_collection
15
+
16
+
17
+
18
+ @dataclass
19
+ class PipelineResult:
20
+ """Result of pipeline execution."""
21
+ answer: str
22
+ sources: List[Document]
23
+ execution_time: float
24
+ metadata: Dict[str, Any]
25
+ query: str = "" # Add default value for query
26
+
27
+ def __post_init__(self):
28
+ """Post-initialization processing."""
29
+ if not self.query:
30
+ self.query = "Unknown query"
31
+
32
+
33
+ class PipelineManager:
34
+ """Main pipeline manager for the RAG system."""
35
+
36
+ def __init__(self, config: dict = None):
37
+ """
38
+ Initialize the pipeline manager.
39
+ """
40
+ self.config = config or {}
41
+ self.vectorstore_manager = None
42
+ self.context_retriever = None # Initialize as None
43
+ self.llm_client = None
44
+ self.report_service = None
45
+ self.chunks = None
46
+
47
+ # Initialize components
48
+ self._initialize_components()
49
+
50
+ def update_config(self, new_config: dict):
51
+ """
52
+ Update the pipeline configuration.
53
+ This is useful for experiments that need different settings.
54
+ """
55
+ if not isinstance(new_config, dict):
56
+ return
57
+
58
+ # Deep merge the new config with existing config
59
+ def deep_merge(base_dict, update_dict):
60
+ for key, value in update_dict.items():
61
+ if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
62
+ deep_merge(base_dict[key], value)
63
+ else:
64
+ base_dict[key] = value
65
+
66
+ deep_merge(self.config, new_config)
67
+
68
+ # Auto-infer embedding model from collection name if not "docling"
69
+ collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
70
+ if collection_name != 'docling':
71
+ inferred_model = get_embedding_model_for_collection(collection_name)
72
+ if inferred_model:
73
+ print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
74
+ if 'retriever' not in self.config:
75
+ self.config['retriever'] = {}
76
+ self.config['retriever']['model'] = inferred_model
77
+ # Set default normalize parameter if not present
78
+ if 'normalize' not in self.config['retriever']:
79
+ self.config['retriever']['normalize'] = True
80
+
81
+ # Also update vectorstore config if it exists
82
+ if 'vectorstore' in self.config:
83
+ self.config['vectorstore']['embedding_model'] = inferred_model
84
+
85
+ print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
86
+
87
+ # Re-initialize vectorstore manager with updated config
88
+ self._reinitialize_vectorstore_manager()
89
+
90
+ def _reinitialize_vectorstore_manager(self):
91
+ """Re-initialize vectorstore manager with current config."""
92
+ try:
93
+ self.vectorstore_manager = VectorStoreManager(self.config)
94
+ print("🔄 VectorStore manager re-initialized with updated config")
95
+ except Exception as e:
96
+ print(f"❌ Error re-initializing vectorstore manager: {e}")
97
+
98
+ def _get_reranker_model_name(self) -> str:
99
+ """
100
+ Get the reranker model name from configuration.
101
+
102
+ Returns:
103
+ Reranker model name or default
104
+ """
105
+ return (
106
+ self.config.get('retrieval', {}).get('reranker_model') or
107
+ self.config.get('ranker', {}).get('model') or
108
+ self.config.get('reranker_model') or
109
+ 'BAAI/bge-reranker-v2-m3'
110
+ )
111
+
112
+ def _initialize_components(self):
113
+ """Initialize pipeline components."""
114
+ try:
115
+ # Load config if not provided
116
+ if not self.config:
117
+ from auditqa.config.loader import load_config
118
+ self.config = load_config()
119
+
120
+ # Auto-infer embedding model from collection name if not "docling"
121
+ collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
122
+ if collection_name != 'docling':
123
+ inferred_model = get_embedding_model_for_collection(collection_name)
124
+ if inferred_model:
125
+ print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
126
+ if 'retriever' not in self.config:
127
+ self.config['retriever'] = {}
128
+ self.config['retriever']['model'] = inferred_model
129
+ # Set default normalize parameter if not present
130
+ if 'normalize' not in self.config['retriever']:
131
+ self.config['retriever']['normalize'] = True
132
+
133
+ # Also update vectorstore config if it exists
134
+ if 'vectorstore' in self.config:
135
+ self.config['vectorstore']['embedding_model'] = inferred_model
136
+
137
+ self.vectorstore_manager = VectorStoreManager(self.config)
138
+
139
+ self.llm_manager = LLMRegistry()
140
+
141
+ # Try to get LLM client using the correct method
142
+ self.llm_client = None
143
+ try:
144
+ # Try using get_adapter method (most likely correct)
145
+ self.llm_client = self.llm_manager.get_adapter("openai")
146
+ print("✅ LLM CLIENT: Initialized using get_adapter method")
147
+ except Exception as e:
148
+ try:
149
+ # Try direct instantiation with config
150
+ from auditqa.llm.adapters import get_llm_client
151
+ self.llm_client = get_llm_client("openai", self.config)
152
+ print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
153
+ except Exception as e2:
154
+ print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
155
+ # Try to create a simple LLM client directly
156
+ try:
157
+ from langchain_openai import ChatOpenAI
158
+ import os
159
+ api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
160
+ if api_key:
161
+ self.llm_client = ChatOpenAI(
162
+ model="gpt-3.5-turbo",
163
+ api_key=api_key,
164
+ temperature=0.1,
165
+ max_tokens=1000
166
+ )
167
+ print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
168
+ else:
169
+ print("❌ LLM CLIENT: No API key available")
170
+ except Exception as e3:
171
+ print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
172
+ self.llm_client = None
173
+
174
+ # Load system prompt
175
+ from auditqa.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
176
+ self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
177
+
178
+ # Initialize report service
179
+ try:
180
+ from auditqa.reporting.service import ReportService
181
+ self.report_service = ReportService()
182
+ except Exception as e:
183
+ print(f"Warning: Could not initialize report service: {e}")
184
+ self.report_service = None
185
+
186
+ except Exception as e:
187
+ print(f"Warning: Error initializing components: {e}")
188
+
189
+ def test_retrieval(
190
+ self,
191
+ query: str,
192
+ reports: List[str] = None,
193
+ sources: str = None,
194
+ subtype: List[str] = None,
195
+ k: int = None,
196
+ search_mode: str = None,
197
+ search_alpha: float = None,
198
+ use_reranking: bool = True
199
+ ) -> Dict[str, Any]:
200
+ """
201
+ Test retrieval only without LLM inference.
202
+
203
+ Args:
204
+ query: User query
205
+ reports: List of specific report filenames
206
+ sources: Source category
207
+ subtype: List of subtypes
208
+ k: Number of documents to retrieve
209
+ search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
210
+ search_alpha: Weight for vector scores in hybrid mode
211
+ use_reranking: Whether to use reranking
212
+
213
+ Returns:
214
+ Dictionary with retrieval results and metadata
215
+ """
216
+ start_time = time.time()
217
+
218
+ try:
219
+ # Set default search parameters if not provided
220
+ if search_mode is None:
221
+ search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
222
+ if search_alpha is None:
223
+ search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
224
+
225
+ # Get vector store
226
+ vectorstore = self.vectorstore_manager.get_vectorstore()
227
+ if not vectorstore:
228
+ raise ValueError(
229
+ "Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
230
+ )
231
+
232
+ # Retrieve context with scores for test retrieval
233
+ context_docs_with_scores = self.context_retriever.retrieve_with_scores(
234
+ vectorstore=vectorstore,
235
+ query=query,
236
+ reports=reports,
237
+ sources=sources,
238
+ subtype=subtype,
239
+ k=k,
240
+ search_mode=search_mode,
241
+ alpha=search_alpha,
242
+ )
243
+
244
+ # Extract documents and scores
245
+ context_docs = [doc for doc, score in context_docs_with_scores]
246
+ context_scores = [score for doc, score in context_docs_with_scores]
247
+
248
+ execution_time = time.time() - start_time
249
+
250
+ # Format results with actual scores
251
+ results = []
252
+ for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
253
+ results.append({
254
+ "rank": i + 1,
255
+ "content": doc.page_content, # Return full content without truncation
256
+ "metadata": doc.metadata,
257
+ "score": score if score is not None else 0.0
258
+ })
259
+
260
+ return {
261
+ "results": results,
262
+ "num_results": len(results),
263
+ "execution_time": execution_time,
264
+ "search_mode": search_mode,
265
+ "search_alpha": search_alpha,
266
+ "query": query
267
+ }
268
+
269
+ except Exception as e:
270
+ print(f"❌ Error during retrieval test: {e}")
271
+ log_error(e, {"component": "retrieval_test", "query": query})
272
+ return {
273
+ "results": [],
274
+ "num_results": 0,
275
+ "execution_time": time.time() - start_time,
276
+ "error": str(e),
277
+ "search_mode": search_mode or "unknown",
278
+ "search_alpha": search_alpha or 0.5,
279
+ "query": query
280
+ }
281
+
282
+ def connect_vectorstore(self, force_recreate: bool = False) -> bool:
283
+ """
284
+ Connect to existing vector store.
285
+
286
+ Args:
287
+ force_recreate: If True, recreate the collection if dimension mismatch occurs
288
+
289
+ Returns:
290
+ True if successful, False otherwise
291
+ """
292
+ try:
293
+ vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
294
+ if vectorstore:
295
+ print("✅ Connected to vector store")
296
+ return True
297
+ else:
298
+ print("❌ Failed to connect to vector store")
299
+ return False
300
+ except Exception as e:
301
+ print(f"❌ Error connecting to vector store: {e}")
302
+ log_error(e, {"component": "vectorstore_connection"})
303
+
304
+ # If it's a dimension mismatch error, try with force_recreate
305
+ if "dimensions" in str(e).lower() and not force_recreate:
306
+ print("🔄 Dimension mismatch detected, attempting to recreate collection...")
307
+ try:
308
+ vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
309
+ if vectorstore:
310
+ print("✅ Connected to vector store (recreated)")
311
+ return True
312
+ except Exception as recreate_error:
313
+ print(f"❌ Failed to recreate vector store: {recreate_error}")
314
+ log_error(recreate_error, {"component": "vectorstore_recreation"})
315
+
316
+ return False
317
+
318
+ def create_vectorstore(self) -> bool:
319
+ """
320
+ Create new vector store from chunks.
321
+
322
+ Returns:
323
+ True if successful, False otherwise
324
+ """
325
+ try:
326
+ if not self.chunks:
327
+ raise ValueError("No chunks available for vector store creation")
328
+
329
+ documents = chunks_to_documents(self.chunks)
330
+ self.vectorstore_manager.create_from_documents(documents)
331
+ print("✅ Vector store created successfully")
332
+ return True
333
+ except Exception as e:
334
+ print(f"❌ Error creating vector store: {e}")
335
+ log_error(e, {"component": "vectorstore_creation"})
336
+ return False
337
+
338
+ def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
339
+ """Create a prompt for the LLM to generate an answer."""
340
+ try:
341
+ # Ensure query is not None
342
+ if not query or not isinstance(query, str) or query.strip() == "":
343
+ return "Error: No query provided"
344
+
345
+ # Ensure context_docs is not None and is a list
346
+ if context_docs is None:
347
+ context_docs = []
348
+
349
+ # Filter out None documents and ensure they have content
350
+ valid_docs = []
351
+ for doc in context_docs:
352
+ if doc is not None:
353
+ if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
354
+ valid_docs.append(doc)
355
+ elif isinstance(doc, str) and doc.strip():
356
+ valid_docs.append(doc)
357
+
358
+ # Create context string
359
+ if valid_docs:
360
+ context_parts = []
361
+ for i, doc in enumerate(valid_docs, 1):
362
+ if hasattr(doc, 'page_content') and doc.page_content:
363
+ context_parts.append(f"Doc {i}: {doc.page_content}")
364
+ elif isinstance(doc, str) and doc.strip():
365
+ context_parts.append(f"Doc {i}: {doc}")
366
+
367
+ context_string = "\n\n".join(context_parts)
368
+ else:
369
+ context_string = "No relevant context found."
370
+
371
+ # Create the prompt
372
+ prompt = f"""
373
+ {self.system_prompt}
374
+
375
+ Context:
376
+ {context_string}
377
+
378
+ Query: {query}
379
+
380
+ Answer:"""
381
+
382
+ return prompt
383
+
384
+ except Exception as e:
385
+ print(f"Error creating audit prompt: {e}")
386
+ return f"Error creating prompt: {e}"
387
+
388
+ def _generate_answer(self, prompt: str) -> str:
389
+ """Generate answer using the LLM."""
390
+ try:
391
+ if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
392
+ return "Error: No prompt provided"
393
+
394
+ # Ensure LLM client is available
395
+ if not self.llm_client:
396
+ return "Error: LLM client not available"
397
+
398
+ # Generate response using the correct method
399
+ if hasattr(self.llm_client, 'generate'):
400
+ # Use the generate method (for adapters)
401
+ response = self.llm_client.generate([{"role": "user", "content": prompt}])
402
+
403
+ # Extract content from LLMResponse
404
+ if hasattr(response, 'content'):
405
+ answer = response.content
406
+ else:
407
+ answer = str(response)
408
+
409
+ elif hasattr(self.llm_client, 'invoke'):
410
+ # Use the invoke method (for direct LangChain models)
411
+ response = self.llm_client.invoke(prompt)
412
+
413
+ # Extract content safely
414
+ if hasattr(response, 'content') and response.content is not None:
415
+ answer = response.content
416
+ elif isinstance(response, str) and response.strip():
417
+ answer = response
418
+ else:
419
+ answer = str(response) if response is not None else "Error: LLM returned None response"
420
+ else:
421
+ return "Error: LLM client has no generate or invoke method"
422
+
423
+ # Ensure answer is not None and is a string
424
+ if answer is None or not isinstance(answer, str):
425
+ return "Error: LLM returned invalid response"
426
+
427
+ return answer.strip()
428
+
429
+ except Exception as e:
430
+ print(f"Error generating answer: {e}")
431
+ return f"Error generating answer: {e}"
432
+
433
+ def run(
434
+ self,
435
+ query: str,
436
+ reports: List[str] = None,
437
+ sources: List[str] = None,
438
+ subtype: List[str] = None,
439
+ llm_provider: str = None,
440
+ use_reranking: bool = True,
441
+ search_mode: str = None,
442
+ search_alpha: float = None,
443
+ auto_infer_filters: bool = True,
444
+ filters: Dict[str, Any] = None,
445
+ ) -> PipelineResult:
446
+ """
447
+ Run the complete RAG pipeline.
448
+
449
+ Args:
450
+ query: User query
451
+ reports: List of specific report filenames
452
+ sources: Source category filter
453
+ subtype: List of subtypes/filenames
454
+ llm_provider: LLM provider to use
455
+ use_reranking: Whether to use reranking
456
+ search_mode: Search mode (vector, sparse, hybrid)
457
+ search_alpha: Alpha value for hybrid search
458
+ auto_infer_filters: Whether to auto-infer filters from query
459
+
460
+ Returns:
461
+ PipelineResult object
462
+ """
463
+ try:
464
+ # Validate input
465
+ if not query or not isinstance(query, str) or query.strip() == "":
466
+ return PipelineResult(
467
+ answer="Error: Invalid query provided",
468
+ sources=[],
469
+ execution_time=0.0,
470
+ metadata={'error': 'Invalid query'},
471
+ query=query
472
+ )
473
+
474
+ # Ensure lists are not None
475
+ if reports is None:
476
+ reports = []
477
+ if subtype is None:
478
+ subtype = []
479
+
480
+ start_time = time.time()
481
+
482
+ # Auto-infer filters if enabled and no explicit filters provided
483
+ inferred_filters = {}
484
+ filters_applied = False
485
+ qdrant_filter = None # Add this
486
+
487
+ if auto_infer_filters and not any([reports, sources, subtype]):
488
+ print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
489
+ try:
490
+ # Import get_available_metadata here to avoid circular imports
491
+ from auditqa.retrieval.filter import get_available_metadata, infer_filters_from_query
492
+
493
+ # Get available metadata
494
+ available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
495
+
496
+ # Infer filters from query - this returns a Qdrant filter
497
+ qdrant_filter, filter_summary = infer_filters_from_query(
498
+ query=query,
499
+ available_metadata=available_metadata,
500
+ llm_client=self.llm_client
501
+ )
502
+
503
+ if qdrant_filter:
504
+ print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
505
+ filters_applied = True
506
+ # Don't set sources/reports/subtype - use the Qdrant filter directly
507
+ else:
508
+ print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
509
+
510
+ except Exception as e:
511
+ print(f"❌ AUTO-INFERENCE FAILED: {e}")
512
+ qdrant_filter = None
513
+ else:
514
+ # Check if any explicit filters were provided
515
+ filters_applied = any([reports, sources, subtype])
516
+ if filters_applied:
517
+ print(f"✅ EXPLICIT FILTERS: Using provided filters")
518
+ else:
519
+ print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
520
+
521
+ # Extract filter parameters from the filters parameter
522
+ reports = filters.get('reports', []) if filters else []
523
+ sources = filters.get('sources', []) if filters else []
524
+ subtype = filters.get('subtype', []) if filters else []
525
+ year = filters.get('year', []) if filters else []
526
+ district = filters.get('district', []) if filters else []
527
+ filenames = filters.get('filenames', []) if filters else [] # Support mutually exclusive filename filtering
528
+
529
+ # Get vectorstore
530
+ vectorstore = self.vectorstore_manager.get_vectorstore()
531
+ if not vectorstore:
532
+ return PipelineResult(
533
+ answer="Error: Vector store not available",
534
+ sources=[],
535
+ execution_time=0.0,
536
+ metadata={'error': 'Vector store not available'},
537
+ query=query
538
+ )
539
+
540
+ # Initialize context retriever if not already done
541
+ if not hasattr(self, 'context_retriever') or self.context_retriever is None:
542
+ # Get the actual vectorstore object
543
+ vectorstore_obj = self.vectorstore_manager.get_vectorstore()
544
+ if vectorstore_obj is None:
545
+ print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
546
+ return None
547
+ self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
548
+ print("✅ ContextRetriever initialized successfully")
549
+
550
+ # Debug config access
551
+ print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
552
+ print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
553
+ print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
554
+ print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
555
+
556
+ # Get the correct top_k value
557
+ # Priority: experiment config > retriever config > default
558
+ top_k = (
559
+ self.config.get('retrieval', {}).get('top_k') or
560
+ self.config.get('retriever', {}).get('top_k') or
561
+ 5
562
+ )
563
+
564
+ # Get reranking setting
565
+ use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
566
+
567
+ print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
568
+ print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
569
+
570
+ # Retrieve context using the context retriever
571
+ context_docs = self.context_retriever.retrieve_context(
572
+ query=query,
573
+ k=top_k,
574
+ reports=reports,
575
+ sources=sources,
576
+ subtype=subtype,
577
+ year=year,
578
+ district=district,
579
+ filenames=filenames,
580
+ use_reranking=use_reranking,
581
+ qdrant_filter=qdrant_filter
582
+ )
583
+
584
+ # Ensure context_docs is not None
585
+ if context_docs is None:
586
+ context_docs = []
587
+
588
+ # Generate answer
589
+ answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
590
+
591
+ execution_time = time.time() - start_time
592
+
593
+ # Create result with comprehensive metadata
594
+ result = PipelineResult(
595
+ answer=answer,
596
+ sources=context_docs,
597
+ execution_time=execution_time,
598
+ metadata={
599
+ 'llm_provider': llm_provider,
600
+ 'use_reranking': use_reranking,
601
+ 'search_mode': search_mode,
602
+ 'search_alpha': search_alpha,
603
+ 'auto_infer_filters': auto_infer_filters,
604
+ 'filters_applied': filters_applied,
605
+ 'with_filtering': filters_applied,
606
+ 'filter_conditions': {
607
+ 'reports': reports,
608
+ 'sources': sources,
609
+ 'subtype': subtype
610
+ },
611
+ 'inferred_filters': inferred_filters,
612
+ 'applied_filters': {
613
+ 'reports': reports,
614
+ 'sources': sources,
615
+ 'subtype': subtype
616
+ },
617
+ # Store filter and reranking metadata
618
+ 'filter_details': {
619
+ 'explicit_filters': {
620
+ 'reports': reports,
621
+ 'sources': sources,
622
+ 'subtype': subtype,
623
+ 'year': year
624
+ },
625
+ 'inferred_filters': inferred_filters if auto_infer_filters else {},
626
+ 'auto_inference_enabled': auto_infer_filters,
627
+ 'qdrant_filter_applied': qdrant_filter is not None,
628
+ 'filter_summary': filter_summary if 'filter_summary' in locals() else None
629
+ },
630
+ 'reranker_model': self._get_reranker_model_name() if use_reranking else None,
631
+ 'reranker_applied': use_reranking,
632
+ 'reranking_info': {
633
+ 'model': self._get_reranker_model_name(),
634
+ 'applied': use_reranking,
635
+ 'top_k': len(context_docs) if context_docs else 0,
636
+ # 'original_documents': [
637
+ # {
638
+ # 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
639
+ # 'metadata': doc.metadata,
640
+ # 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
641
+ # } for doc in context_docs
642
+ # ] if use_reranking else None,
643
+ 'reranked_documents': [
644
+ {
645
+ 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
646
+ 'metadata': doc.metadata,
647
+ 'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
648
+ 'original_rank': doc.metadata.get('original_rank', None),
649
+ 'final_rank': doc.metadata.get('final_rank', None),
650
+ 'reranked_score': doc.metadata.get('reranked_score', None)
651
+ } for doc in context_docs
652
+ ] if use_reranking else None
653
+ }
654
+ },
655
+ query=query
656
+ )
657
+
658
+ return result
659
+
660
+ except Exception as e:
661
+ print(f"Error in pipeline run: {e}")
662
+ return PipelineResult(
663
+ answer=f"Error processing query: {e}",
664
+ sources=[],
665
+ execution_time=0.0,
666
+ metadata={'error': str(e)},
667
+ query=query
668
+ )
669
+
670
+
671
+
672
+ def get_system_status(self) -> Dict[str, Any]:
673
+ """
674
+ Get system status information.
675
+
676
+ Returns:
677
+ Dictionary with system status
678
+ """
679
+ status = {
680
+ "config_loaded": bool(self.config),
681
+ "chunks_loaded": bool(self.chunks),
682
+ "vectorstore_connected": bool(
683
+ self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
684
+ ),
685
+ "components_initialized": bool(
686
+ self.context_retriever and self.report_service
687
+ ),
688
+ }
689
+
690
+ if self.chunks:
691
+ status["num_chunks"] = len(self.chunks)
692
+
693
+ if self.report_service:
694
+ status["available_sources"] = self.report_service.get_available_sources()
695
+ status["available_reports"] = len(
696
+ self.report_service.get_available_reports()
697
+ )
698
+
699
+ status["overall_status"] = (
700
+ "ready"
701
+ if all(
702
+ [
703
+ status["config_loaded"],
704
+ status["chunks_loaded"],
705
+ status["vectorstore_connected"],
706
+ status["components_initialized"],
707
+ ]
708
+ )
709
+ else "not_ready"
710
+ )
711
+
712
+ return status
713
+
714
+ def get_available_llm_providers(self) -> List[str]:
715
+ """Get list of available LLM providers."""
716
+ providers = []
717
+ reader_config = self.config.get("reader", {})
718
+
719
+ for provider in [
720
+ "MISTRAL",
721
+ "OPENAI",
722
+ "OLLAMA",
723
+ "INF_PROVIDERS",
724
+ "NVIDIA",
725
+ "DEDICATED",
726
+ "OPENROUTER",
727
+ ]:
728
+ if provider in reader_config:
729
+ providers.append(provider.lower())
730
+
731
+ return providers
src/reporting/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Report metadata and utilities."""
2
+
3
+ from .metadata import get_report_metadata, get_available_sources
4
+ from .service import ReportService
5
+
6
+ __all__ = ["get_report_metadata", "get_available_sources", "ReportService"]
src/reporting/feedback_schema.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Schema for RAG Chatbot
3
+
4
+ This module defines dataclasses for feedback data structures
5
+ and provides Snowflake schema generation.
6
+ """
7
+
8
+ from dataclasses import dataclass, asdict, field
9
+ from typing import List, Optional, Dict, Any, Union
10
+ from datetime import datetime
11
+
12
+
13
+ @dataclass
14
+ class RetrievedDocument:
15
+ """Single retrieved document metadata"""
16
+ doc_id: str
17
+ filename: str
18
+ page: int
19
+ score: float
20
+ content: str
21
+ metadata: Dict[str, Any]
22
+
23
+
24
+ @dataclass
25
+ class RetrievalEntry:
26
+ """Single retrieval operation metadata"""
27
+ rag_query: str
28
+ documents_retrieved: List[RetrievedDocument]
29
+ conversation_length: int
30
+ filters_applied: Optional[Dict[str, Any]] = None
31
+ timestamp: Optional[float] = None
32
+ _raw_data: Optional[Dict[str, Any]] = None
33
+
34
+
35
+ @dataclass
36
+ class UserFeedback:
37
+ """User feedback submission data"""
38
+ feedback_id: str
39
+ open_ended_feedback: Optional[str]
40
+ score: int
41
+ is_feedback_about_last_retrieval: bool
42
+ retrieved_data: List[RetrievalEntry]
43
+ conversation_id: str
44
+ timestamp: float
45
+ message_count: int
46
+ has_retrievals: bool
47
+ retrieval_count: int
48
+ user_query: Optional[str] = None
49
+ bot_response: Optional[str] = None
50
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
51
+
52
+ def to_dict(self) -> Dict[str, Any]:
53
+ """Convert to dictionary with nested data structures"""
54
+ result = asdict(self)
55
+ # Handle nested objects
56
+ if self.retrieved_data:
57
+ result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
58
+ return result
59
+
60
+ def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
61
+ """Serialize retrieval entry to dict"""
62
+ # If raw data exists, use it (it's already properly formatted)
63
+ if hasattr(entry, '_raw_data') and entry._raw_data:
64
+ return entry._raw_data
65
+
66
+ # Otherwise, serialize the dataclass
67
+ result = asdict(entry)
68
+ if entry.documents_retrieved:
69
+ result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
70
+ return result
71
+
72
+ def to_snowflake_schema(self) -> Dict[str, Any]:
73
+ """Generate Snowflake schema for this dataclass"""
74
+ schema = {
75
+ "feedback_id": "VARCHAR(255)",
76
+ "open_ended_feedback": "VARCHAR(16777216)", # Large text
77
+ "score": "INTEGER",
78
+ "is_feedback_about_last_retrieval": "BOOLEAN",
79
+ "conversation_id": "VARCHAR(255)",
80
+ "timestamp": "NUMBER(20, 0)",
81
+ "message_count": "INTEGER",
82
+ "has_retrievals": "BOOLEAN",
83
+ "retrieval_count": "INTEGER",
84
+ "user_query": "VARCHAR(16777216)",
85
+ "bot_response": "VARCHAR(16777216)",
86
+ "created_at": "TIMESTAMP_NTZ",
87
+ "retrieved_data": "VARIANT", # Array of retrieval entries
88
+ # retrieved_data structure:
89
+ # [
90
+ # {
91
+ # "rag_query": "...",
92
+ # "conversation_length": 5,
93
+ # "timestamp": 1234567890,
94
+ # "docs_retrieved": [
95
+ # {"filename": "...", "page": 14, "score": 0.95, ...},
96
+ # ...
97
+ # ]
98
+ # },
99
+ # ...
100
+ # ]
101
+ }
102
+ return schema
103
+
104
+ @classmethod
105
+ def get_snowflake_create_table_sql(cls, table_name: str = "user_feedback") -> str:
106
+ """Generate CREATE TABLE SQL for Snowflake"""
107
+ schema = cls.to_snowflake_schema(None)
108
+
109
+ columns = []
110
+ for col_name, col_type in schema.items():
111
+ nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
112
+ columns.append(f" {col_name} {col_type} {nullable}")
113
+
114
+ # Build SQL string properly
115
+ columns_str = ",\n".join(columns)
116
+
117
+ sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
118
+ {columns_str},
119
+ PRIMARY KEY (feedback_id)
120
+ );
121
+
122
+ -- Create index on timestamp for querying by time
123
+ CREATE INDEX IF NOT EXISTS idx_feedback_timestamp ON {table_name} (timestamp);
124
+
125
+ -- Create index on conversation_id for querying by conversation
126
+ CREATE INDEX IF NOT EXISTS idx_feedback_conversation ON {table_name} (conversation_id);
127
+
128
+ -- Create index on score for feedback analysis
129
+ CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
130
+ """
131
+ return sql
132
+
133
+
134
+ # Snowflake variant schema for retrieved_data array
135
+ RETRIEVAL_ENTRY_SCHEMA = {
136
+ "rag_query": "VARCHAR",
137
+ "documents_retrieved": "ARRAY", # Array of document objects
138
+ "conversation_length": "INTEGER",
139
+ "filters_applied": "OBJECT",
140
+ "timestamp": "NUMBER"
141
+ }
142
+
143
+ DOCUMENT_SCHEMA = {
144
+ "doc_id": "VARCHAR",
145
+ "filename": "VARCHAR",
146
+ "page": "INTEGER",
147
+ "score": "DOUBLE",
148
+ "content": "VARCHAR(16777216)",
149
+ "metadata": "OBJECT"
150
+ }
151
+
152
+
153
+ def generate_snowflake_schema_sql() -> str:
154
+ """Generate complete Snowflake schema SQL for feedback system"""
155
+ return UserFeedback.get_snowflake_create_table_sql("user_feedback")
156
+
157
+
158
+ def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
159
+ """Create UserFeedback instance from dictionary"""
160
+ # Parse retrieved_data if present
161
+ retrieved_data = []
162
+ if "retrieved_data" in data and data["retrieved_data"]:
163
+ for entry_dict in data.get("retrieved_data", []):
164
+ # Map the actual structure from rag_retrieval_history
165
+ # Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
166
+ try:
167
+ # Try to map to expected structure
168
+ entry = RetrievalEntry(
169
+ rag_query=entry_dict.get("rag_query_expansion", ""),
170
+ documents_retrieved=[], # Empty for now, will store as raw data
171
+ conversation_length=len(entry_dict.get("conversation_up_to", [])),
172
+ filters_applied=None,
173
+ timestamp=entry_dict.get("timestamp", None)
174
+ )
175
+ # Store raw data in the entry
176
+ entry._raw_data = entry_dict # Store original for preservation
177
+ retrieved_data.append(entry)
178
+ except Exception as e:
179
+ # If mapping fails, store as-is without strict typing
180
+ pass
181
+
182
+ return UserFeedback(
183
+ feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
184
+ open_ended_feedback=data.get("open_ended_feedback"),
185
+ score=data["score"],
186
+ is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
187
+ retrieved_data=retrieved_data,
188
+ conversation_id=data["conversation_id"],
189
+ timestamp=data["timestamp"],
190
+ message_count=data["message_count"],
191
+ has_retrievals=data["has_retrievals"],
192
+ retrieval_count=data["retrieval_count"],
193
+ user_query=data.get("user_query"),
194
+ bot_response=data.get("bot_response")
195
+ )
196
+
src/reporting/metadata.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Report metadata management."""
2
+
3
+ from typing import Dict, List, Any, Set
4
+ from pathlib import Path
5
+
6
+
7
+ def get_report_metadata(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
8
+ """
9
+ Extract metadata from chunks.
10
+
11
+ Args:
12
+ chunks: List of chunk dictionaries
13
+
14
+ Returns:
15
+ Dictionary with report metadata
16
+ """
17
+ if not chunks:
18
+ return {}
19
+
20
+ sources = set()
21
+ filenames = set()
22
+ years = set()
23
+
24
+ for chunk in chunks:
25
+ metadata = chunk.get("metadata", {})
26
+
27
+ if "source" in metadata:
28
+ sources.add(metadata["source"])
29
+
30
+ if "filename" in metadata:
31
+ filenames.add(metadata["filename"])
32
+
33
+ if "year" in metadata:
34
+ years.add(metadata["year"])
35
+
36
+ return {
37
+ "sources": sorted(list(sources)),
38
+ "filenames": sorted(list(filenames)),
39
+ "years": sorted(list(years)),
40
+ "total_chunks": len(chunks)
41
+ }
42
+
43
+
44
+ def get_available_sources() -> List[str]:
45
+ """
46
+ Get list of available report sources (legacy compatibility).
47
+
48
+ Returns:
49
+ List of source categories
50
+ """
51
+ # This would typically come from the original auditqa_old.reports module
52
+ # For now, return common categories
53
+ return [
54
+ "Consolidated",
55
+ "Ministry, Department, Agency and Projects",
56
+ "Local Government",
57
+ "Value for Money",
58
+ "Thematic",
59
+ "Hospital",
60
+ "Project"
61
+ ]
62
+
63
+
64
+ def get_source_subtypes() -> Dict[str, List[str]]:
65
+ """
66
+ Get mapping of sources to their subtypes (placeholder).
67
+
68
+ Returns:
69
+ Dictionary mapping sources to subtypes
70
+ """
71
+ # This was originally imported from auditqa_old.reports.new_files
72
+ # For now, return a placeholder structure
73
+ return {
74
+ "Consolidated": ["Annual Consolidated OAG 2024", "Annual Consolidated OAG 2023"],
75
+ "Local Government": ["District Reports", "Municipal Reports"],
76
+ "Ministry, Department, Agency and Projects": ["Ministry Reports", "Agency Reports"],
77
+ "Value for Money": ["VFM Reports 2024", "VFM Reports 2023"],
78
+ "Thematic": ["Thematic Reports 2024", "Thematic Reports 2023"],
79
+ "Hospital": ["Hospital Reports 2024", "Hospital Reports 2023"],
80
+ "Project": ["Project Reports 2024", "Project Reports 2023"]
81
+ }
82
+
83
+
84
+ def validate_report_filters(
85
+ reports: List[str] = None,
86
+ sources: str = None,
87
+ subtype: List[str] = None,
88
+ available_metadata: Dict[str, Any] = None
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ Validate report filter parameters.
92
+
93
+ Args:
94
+ reports: List of specific report filenames
95
+ sources: Source category
96
+ subtype: List of subtypes
97
+ available_metadata: Available metadata for validation
98
+
99
+ Returns:
100
+ Dictionary with validation results
101
+ """
102
+ validation_result = {
103
+ "valid": True,
104
+ "warnings": [],
105
+ "errors": []
106
+ }
107
+
108
+ if not available_metadata:
109
+ validation_result["warnings"].append("No metadata available for validation")
110
+ return validation_result
111
+
112
+ available_sources = available_metadata.get("sources", [])
113
+ available_filenames = available_metadata.get("filenames", [])
114
+
115
+ # Validate sources
116
+ if sources and sources not in available_sources:
117
+ validation_result["errors"].append(f"Source '{sources}' not found in available sources")
118
+ validation_result["valid"] = False
119
+
120
+ # Validate reports
121
+ if reports:
122
+ for report in reports:
123
+ if report not in available_filenames:
124
+ validation_result["warnings"].append(f"Report '{report}' not found in available reports")
125
+
126
+ # Validate subtypes
127
+ if subtype:
128
+ for sub in subtype:
129
+ if sub not in available_filenames:
130
+ validation_result["warnings"].append(f"Subtype '{sub}' not found in available reports")
131
+
132
+ return validation_result
133
+
134
+
135
+ def get_report_statistics(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
136
+ """
137
+ Get statistics about reports in chunks.
138
+
139
+ Args:
140
+ chunks: List of chunk dictionaries
141
+
142
+ Returns:
143
+ Dictionary with report statistics
144
+ """
145
+ if not chunks:
146
+ return {}
147
+
148
+ stats = {
149
+ "total_chunks": len(chunks),
150
+ "sources": {},
151
+ "years": {},
152
+ "avg_chunk_length": 0,
153
+ "total_content_length": 0
154
+ }
155
+
156
+ total_length = 0
157
+
158
+ for chunk in chunks:
159
+ content = chunk.get("content", "")
160
+ total_length += len(content)
161
+
162
+ metadata = chunk.get("metadata", {})
163
+
164
+ # Count by source
165
+ source = metadata.get("source", "Unknown")
166
+ stats["sources"][source] = stats["sources"].get(source, 0) + 1
167
+
168
+ # Count by year
169
+ year = metadata.get("year", "Unknown")
170
+ stats["years"][year] = stats["years"].get(year, 0) + 1
171
+
172
+ stats["total_content_length"] = total_length
173
+ stats["avg_chunk_length"] = total_length / len(chunks) if chunks else 0
174
+
175
+ return stats
176
+
177
+
178
+ def filter_chunks_by_metadata(
179
+ chunks: List[Dict[str, Any]],
180
+ source_filter: str = None,
181
+ filename_filter: List[str] = None,
182
+ year_filter: List[str] = None
183
+ ) -> List[Dict[str, Any]]:
184
+ """
185
+ Filter chunks by metadata criteria.
186
+
187
+ Args:
188
+ chunks: List of chunk dictionaries
189
+ source_filter: Source to filter by
190
+ filename_filter: List of filenames to filter by
191
+ year_filter: List of years to filter by
192
+
193
+ Returns:
194
+ Filtered list of chunks
195
+ """
196
+ filtered_chunks = chunks
197
+
198
+ if source_filter:
199
+ filtered_chunks = [
200
+ chunk for chunk in filtered_chunks
201
+ if chunk.get("metadata", {}).get("source") == source_filter
202
+ ]
203
+
204
+ if filename_filter:
205
+ filtered_chunks = [
206
+ chunk for chunk in filtered_chunks
207
+ if chunk.get("metadata", {}).get("filename") in filename_filter
208
+ ]
209
+
210
+ if year_filter:
211
+ filtered_chunks = [
212
+ chunk for chunk in filtered_chunks
213
+ if chunk.get("metadata", {}).get("year") in year_filter
214
+ ]
215
+
216
+ return filtered_chunks
src/reporting/service.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Report service for managing report operations."""
2
+
3
+ from typing import Dict, List, Any, Optional
4
+ from .metadata import get_report_metadata, get_available_sources, get_source_subtypes
5
+
6
+
7
+ class ReportService:
8
+ """Service class for report operations."""
9
+
10
+ def __init__(self, chunks: List[Dict[str, Any]] = None):
11
+ """
12
+ Initialize report service.
13
+
14
+ Args:
15
+ chunks: List of chunk dictionaries
16
+ """
17
+ self.chunks = chunks or []
18
+ self.metadata = get_report_metadata(self.chunks) if self.chunks else {}
19
+
20
+ def get_available_sources(self) -> List[str]:
21
+ """Get available report sources."""
22
+ if self.metadata:
23
+ return self.metadata.get("sources", [])
24
+ return get_available_sources()
25
+
26
+ def get_available_reports(self) -> List[str]:
27
+ """Get available report filenames."""
28
+ return self.metadata.get("filenames", [])
29
+
30
+ def get_source_subtypes(self) -> Dict[str, List[str]]:
31
+ """Get source to subtype mapping."""
32
+ # For now, use the placeholder function
33
+ # In a full implementation, this would be derived from actual data
34
+ return get_source_subtypes()
35
+
36
+ def get_reports_by_source(self, source: str) -> List[str]:
37
+ """
38
+ Get reports filtered by source.
39
+
40
+ Args:
41
+ source: Source category
42
+
43
+ Returns:
44
+ List of report filenames
45
+ """
46
+ if not self.chunks:
47
+ return []
48
+
49
+ reports = set()
50
+ for chunk in self.chunks:
51
+ metadata = chunk.get("metadata", {})
52
+ if metadata.get("source") == source:
53
+ filename = metadata.get("filename")
54
+ if filename:
55
+ reports.add(filename)
56
+
57
+ return sorted(list(reports))
58
+
59
+ def get_years_by_source(self, source: str) -> List[str]:
60
+ """
61
+ Get years available for a specific source.
62
+
63
+ Args:
64
+ source: Source category
65
+
66
+ Returns:
67
+ List of years
68
+ """
69
+ if not self.chunks:
70
+ return []
71
+
72
+ years = set()
73
+ for chunk in self.chunks:
74
+ metadata = chunk.get("metadata", {})
75
+ if metadata.get("source") == source:
76
+ year = metadata.get("year")
77
+ if year:
78
+ years.add(year)
79
+
80
+ return sorted(list(years))
81
+
82
+ def search_reports(self, query: str) -> List[str]:
83
+ """
84
+ Search for reports by name.
85
+
86
+ Args:
87
+ query: Search query
88
+
89
+ Returns:
90
+ List of matching report filenames
91
+ """
92
+ if not self.chunks:
93
+ return []
94
+
95
+ query_lower = query.lower()
96
+ matching_reports = set()
97
+
98
+ for chunk in self.chunks:
99
+ metadata = chunk.get("metadata", {})
100
+ filename = metadata.get("filename", "")
101
+
102
+ if query_lower in filename.lower():
103
+ matching_reports.add(filename)
104
+
105
+ return sorted(list(matching_reports))
106
+
107
+ def get_report_info(self, filename: str) -> Dict[str, Any]:
108
+ """
109
+ Get information about a specific report.
110
+
111
+ Args:
112
+ filename: Report filename
113
+
114
+ Returns:
115
+ Dictionary with report information
116
+ """
117
+ if not self.chunks:
118
+ return {}
119
+
120
+ report_info = {
121
+ "filename": filename,
122
+ "chunk_count": 0,
123
+ "sources": set(),
124
+ "years": set(),
125
+ "total_content_length": 0
126
+ }
127
+
128
+ for chunk in self.chunks:
129
+ metadata = chunk.get("metadata", {})
130
+ if metadata.get("filename") == filename:
131
+ report_info["chunk_count"] += 1
132
+ report_info["total_content_length"] += len(chunk.get("content", ""))
133
+
134
+ if "source" in metadata:
135
+ report_info["sources"].add(metadata["source"])
136
+
137
+ if "year" in metadata:
138
+ report_info["years"].add(metadata["year"])
139
+
140
+ # Convert sets to lists
141
+ report_info["sources"] = list(report_info["sources"])
142
+ report_info["years"] = list(report_info["years"])
143
+
144
+ return report_info
src/reporting/snowflake_connector.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Snowflake Connector for Feedback System
3
+
4
+ This module handles inserting user feedback into Snowflake.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from typing import Dict, Any, Optional
11
+ from src.reporting.feedback_schema import UserFeedback
12
+
13
+ # Try to import snowflake connector
14
+ try:
15
+ import snowflake.connector
16
+ SNOWFLAKE_AVAILABLE = True
17
+ except ImportError:
18
+ SNOWFLAKE_AVAILABLE = False
19
+ logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SnowflakeFeedbackConnector:
27
+ """Connector for inserting feedback into Snowflake"""
28
+
29
+ def __init__(
30
+ self,
31
+ user: str,
32
+ password: str,
33
+ account: str,
34
+ warehouse: str,
35
+ database: str = "SNOWFLAKE_LEARNING",
36
+ schema: str = "PUBLIC"
37
+ ):
38
+ self.user = user
39
+ self.password = password
40
+ self.account = account
41
+ self.warehouse = warehouse
42
+ self.database = database
43
+ self.schema = schema
44
+ self._connection = None
45
+
46
+ def connect(self):
47
+ """Establish Snowflake connection"""
48
+ if not SNOWFLAKE_AVAILABLE:
49
+ raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
50
+
51
+ logger.info("=" * 80)
52
+ logger.info("🔌 SNOWFLAKE CONNECTION: Attempting to connect...")
53
+ logger.info(f" - Account: {self.account}")
54
+ logger.info(f" - Warehouse: {self.warehouse}")
55
+ logger.info(f" - Database: {self.database}")
56
+ logger.info(f" - Schema: {self.schema}")
57
+ logger.info(f" - User: {self.user}")
58
+
59
+ try:
60
+ self._connection = snowflake.connector.connect(
61
+ user=self.user,
62
+ password=self.password,
63
+ account=self.account,
64
+ warehouse=self.warehouse
65
+ # Don't set database/schema in connection - we'll do it per query
66
+ )
67
+ logger.info("✅ SNOWFLAKE CONNECTION: Successfully connected")
68
+ logger.info("=" * 80)
69
+ print(f"✅ Connected to Snowflake: {self.database}.{self.schema}")
70
+ except Exception as e:
71
+ logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
72
+ logger.error("=" * 80)
73
+ print(f"❌ Failed to connect to Snowflake: {e}")
74
+ raise
75
+
76
+ def disconnect(self):
77
+ """Close Snowflake connection"""
78
+ if self._connection:
79
+ self._connection.close()
80
+ print("✅ Disconnected from Snowflake")
81
+
82
+ def insert_feedback(self, feedback: UserFeedback) -> bool:
83
+ """Insert a single feedback record into Snowflake"""
84
+ logger.info("=" * 80)
85
+ logger.info("🔄 SNOWFLAKE INSERT: Starting feedback insertion process")
86
+ logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
87
+
88
+ if not self._connection:
89
+ logger.error("❌ Not connected to Snowflake. Call connect() first.")
90
+ raise RuntimeError("Not connected to Snowflake. Call connect() first.")
91
+
92
+ try:
93
+ logger.info("📊 VALIDATION: Validating feedback data structure...")
94
+
95
+ # Validate feedback object
96
+ validation_errors = []
97
+ if not feedback.feedback_id:
98
+ validation_errors.append("Missing feedback_id")
99
+ if feedback.score is None:
100
+ validation_errors.append("Missing score")
101
+ if feedback.timestamp is None:
102
+ validation_errors.append("Missing timestamp")
103
+
104
+ if validation_errors:
105
+ logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
106
+ return False
107
+ else:
108
+ logger.info("✅ VALIDATION PASSED: All required fields present")
109
+
110
+ logger.info("📋 Data Summary:")
111
+ logger.info(f" - Feedback ID: {feedback.feedback_id}")
112
+ logger.info(f" - Score: {feedback.score}")
113
+ logger.info(f" - Conversation ID: {feedback.conversation_id}")
114
+ logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
115
+ logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
116
+ logger.info(f" - Message Count: {feedback.message_count}")
117
+ logger.info(f" - Timestamp: {feedback.timestamp}")
118
+
119
+ cursor = self._connection.cursor()
120
+ logger.info("✅ SNOWFLAKE CONNECTION: Cursor created")
121
+
122
+ # Set database and schema context
123
+ logger.info(f"🔧 SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
124
+ try:
125
+ cursor.execute(f'USE DATABASE "{self.database}"')
126
+ cursor.execute(f'USE SCHEMA "{self.schema}"')
127
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
128
+ current_db, current_schema = cursor.fetchone()
129
+ logger.info(f"✅ Current context verified: Database={current_db}, Schema={current_schema}")
130
+ except Exception as e:
131
+ logger.error(f"❌ Could not set context: {e}")
132
+ raise
133
+
134
+ # Prepare data
135
+ logger.info("🔧 DATA PREPARATION: Preparing retrieved_data...")
136
+ retrieved_data_raw = feedback.to_dict()['retrieved_data']
137
+
138
+ logger.info(f" - Retrieved data type (raw): {type(retrieved_data_raw).__name__}")
139
+ logger.info(f" - Retrieved data: {repr(retrieved_data_raw)[:200]}")
140
+
141
+ # If retrieved_data is already a string (from UI), parse it
142
+ if isinstance(retrieved_data_raw, str):
143
+ logger.info(" - Parsing string to Python object")
144
+ retrieved_data = json.loads(retrieved_data_raw)
145
+ elif retrieved_data_raw is None:
146
+ retrieved_data = None
147
+ else:
148
+ # It's already a Python object (list/dict)
149
+ logger.info(" - Data is already a Python object")
150
+ retrieved_data = retrieved_data_raw
151
+
152
+ logger.info(f" - Retrieved data size: {len(str(retrieved_data)) if retrieved_data else 0} characters")
153
+ logger.info(f" - Retrieved data type: {type(retrieved_data).__name__}")
154
+
155
+ # Convert to JSON string for TEXT column
156
+ if retrieved_data:
157
+ retrieved_data_for_db = json.dumps(retrieved_data)
158
+ logger.info(f" - Converting to JSON string for TEXT column")
159
+ logger.info(f" - JSON string length: {len(retrieved_data_for_db)}")
160
+ else:
161
+ logger.info(f" - Retrieved data is None, using NULL")
162
+ retrieved_data_for_db = None
163
+
164
+ # Build SQL with retrieved_data as a TEXT column parameter
165
+ sql = f"""INSERT INTO user_feedback (
166
+ feedback_id,
167
+ open_ended_feedback,
168
+ score,
169
+ is_feedback_about_last_retrieval,
170
+ conversation_id,
171
+ timestamp,
172
+ message_count,
173
+ has_retrievals,
174
+ retrieval_count,
175
+ user_query,
176
+ bot_response,
177
+ created_at,
178
+ retrieved_data
179
+ ) VALUES (
180
+ %(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
181
+ %(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
182
+ %(retrieval_count)s, %(user_query)s, %(bot_response)s, %(created_at)s,
183
+ %(retrieved_data)s
184
+ )"""
185
+
186
+ logger.info("📝 SQL PREPARATION: Building INSERT statement...")
187
+ logger.info(f" - Target table: user_feedback")
188
+ logger.info(f" - Database: {self.database}")
189
+ logger.info(f" - Schema: {self.schema}")
190
+
191
+ # Prepare parameters
192
+ params = {
193
+ 'feedback_id': feedback.feedback_id,
194
+ 'open_ended_feedback': feedback.open_ended_feedback,
195
+ 'score': feedback.score,
196
+ 'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
197
+ 'conversation_id': feedback.conversation_id,
198
+ 'timestamp': int(feedback.timestamp),
199
+ 'message_count': feedback.message_count,
200
+ 'has_retrievals': feedback.has_retrievals,
201
+ 'retrieval_count': feedback.retrieval_count,
202
+ 'user_query': feedback.user_query,
203
+ 'bot_response': feedback.bot_response,
204
+ 'created_at': feedback.created_at,
205
+ 'retrieved_data': retrieved_data_for_db
206
+ }
207
+
208
+ # Execute insert
209
+ logger.info("🚀 SQL EXECUTION: Executing INSERT query...")
210
+ cursor.execute(sql, params)
211
+
212
+ logger.info("✅ SQL EXECUTION: Query executed successfully")
213
+ logger.info(f" - Rows affected: 1")
214
+ logger.info(f" - Status: SUCCESS")
215
+
216
+ cursor.close()
217
+ logger.info("✅ SNOWFLAKE INSERT: Feedback inserted successfully")
218
+ logger.info(f"📝 Inserted feedback: {feedback.feedback_id}")
219
+ logger.info("=" * 80)
220
+ return True
221
+
222
+ except Exception as e:
223
+ # Check if it's a Snowflake error
224
+ if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
225
+ logger.error(f"❌ SQL EXECUTION ERROR: {e}")
226
+ logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
227
+ logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
228
+ else:
229
+ logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
230
+ logger.error(f" - Error: {e}")
231
+ logger.error("=" * 80)
232
+ return False
233
+
234
+ def __enter__(self):
235
+ """Context manager entry"""
236
+ self.connect()
237
+ return self
238
+
239
+ def __exit__(self, exc_type, exc_val, exc_tb):
240
+ """Context manager exit"""
241
+ self.disconnect()
242
+
243
+
244
+ def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
245
+ """Create Snowflake connector from environment variables"""
246
+ user = os.getenv("SNOWFLAKE_USER")
247
+ password = os.getenv("SNOWFLAKE_PASSWORD")
248
+ account = os.getenv("SNOWFLAKE_ACCOUNT")
249
+ warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
250
+ database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
251
+ schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
252
+
253
+ if not all([user, password, account, warehouse]):
254
+ print("⚠️ Snowflake credentials not found in environment variables")
255
+ print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
256
+ return None
257
+
258
+ return SnowflakeFeedbackConnector(
259
+ user=user,
260
+ password=password,
261
+ account=account,
262
+ warehouse=warehouse,
263
+ database=database,
264
+ schema=schema
265
+ )
266
+
267
+
268
+ def save_to_snowflake(feedback: UserFeedback) -> bool:
269
+ """Helper function to save feedback to Snowflake"""
270
+ logger.info("=" * 80)
271
+ logger.info("🔵 SNOWFLAKE SAVE: Starting save process")
272
+ logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
273
+
274
+ connector = get_snowflake_connector_from_env()
275
+
276
+ if not connector:
277
+ logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
278
+ logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
279
+ logger.info("=" * 80)
280
+ return False
281
+
282
+ try:
283
+ logger.info("📡 SNOWFLAKE SAVE: Establishing connection...")
284
+ connector.connect()
285
+ logger.info("✅ SNOWFLAKE SAVE: Connection established")
286
+
287
+ logger.info("📥 SNOWFLAKE SAVE: Attempting to insert feedback...")
288
+ success = connector.insert_feedback(feedback)
289
+
290
+ logger.info("🔌 SNOWFLAKE SAVE: Disconnecting...")
291
+ connector.disconnect()
292
+
293
+ if success:
294
+ logger.info("✅ SNOWFLAKE SAVE: Successfully saved feedback")
295
+ else:
296
+ logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
297
+
298
+ logger.info("=" * 80)
299
+ return success
300
+ except Exception as e:
301
+ logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
302
+ logger.error(f" - Error: {e}")
303
+ logger.info("=" * 80)
304
+ return False
305
+
src/retrieval/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retrieval and filtering utilities."""
2
+
3
+ from .filter import create_filter, FilterBuilder
4
+ from .context import ContextRetriever, get_context
5
+ from .hybrid import HybridRetriever, get_available_search_modes, get_search_mode_description
6
+
7
+ __all__ = [
8
+ "create_filter",
9
+ "FilterBuilder",
10
+ "ContextRetriever",
11
+ "get_context",
12
+ "HybridRetriever",
13
+ "get_available_search_modes",
14
+ "get_search_mode_description"
15
+ ]
src/retrieval/colbert_cache.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ColBERT embeddings cache for test set documents.
3
+ Provides O(1) lookup for ColBERT embeddings during late interaction.
4
+ """
5
+
6
+ import json
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from typing import Dict, Optional, Any
10
+
11
+
12
+ class ColBERTCache:
13
+ """Cache for ColBERT embeddings of test set documents."""
14
+
15
+ def __init__(self, cache_file: str = "test_set_colbert_cache.json"):
16
+ self.cache_file = Path("outputs/caches") / cache_file
17
+ self.embeddings_cache: Dict[str, np.ndarray] = {}
18
+ self._load_cache()
19
+
20
+ def _load_cache(self):
21
+ """Load embeddings from cache file."""
22
+ if not self.cache_file.exists():
23
+ print(f"⚠️ ColBERT cache not found: {self.cache_file}")
24
+ print("💡 Run 'python precalculate_test_set_colbert.py' to create cache")
25
+ return
26
+
27
+ print(f"📂 Loading ColBERT cache from {self.cache_file}...")
28
+
29
+ try:
30
+ with open(self.cache_file, 'r') as f:
31
+ cache_data = json.load(f)
32
+
33
+ # Reconstruct embeddings from compressed format
34
+ for doc_id, data in cache_data.items():
35
+ embedding_min = data['min']
36
+ embedding_max = data['max']
37
+ quantized_embedding = np.array(data['embedding'], dtype=np.uint8)
38
+
39
+ # Reconstruct original embedding
40
+ reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min
41
+ self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape'])
42
+
43
+ print(f"✅ Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache")
44
+
45
+ except Exception as e:
46
+ print(f"❌ Error loading ColBERT cache: {e}")
47
+ self.embeddings_cache = {}
48
+
49
+ def get_embedding(self, document_text: str) -> Optional[np.ndarray]:
50
+ """Get ColBERT embedding for a document (O(1) lookup)."""
51
+ return self.embeddings_cache.get(document_text)
52
+
53
+ def has_embedding(self, document_text: str) -> bool:
54
+ """Check if embedding exists for document."""
55
+ return document_text in self.embeddings_cache
56
+
57
+ def get_cache_stats(self) -> Dict[str, Any]:
58
+ """Get cache statistics."""
59
+ return {
60
+ 'total_embeddings': len(self.embeddings_cache),
61
+ 'cache_file': str(self.cache_file),
62
+ 'cache_exists': self.cache_file.exists()
63
+ }
64
+
65
+
66
+ # Global cache instance
67
+ _colbert_cache = None
68
+
69
+ def get_colbert_cache() -> ColBERTCache:
70
+ """Get global ColBERT cache instance."""
71
+ global _colbert_cache
72
+ if _colbert_cache is None:
73
+ _colbert_cache = ColBERTCache()
74
+ return _colbert_cache
src/retrieval/context.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Context retrieval with reranking capabilities."""
2
+
3
+ import os
4
+ from typing import List, Optional, Tuple, Dict, Any
5
+ from langchain.schema import Document
6
+ from langchain_community.vectorstores import Qdrant
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from sentence_transformers import CrossEncoder
9
+ import numpy as np
10
+ import torch
11
+ from qdrant_client.http import models as rest
12
+ import traceback
13
+
14
+ from .filter import create_filter
15
+
16
+ class ContextRetriever:
17
+ """
18
+ Context retriever for hybrid search with optional filtering and reranking.
19
+ """
20
+
21
+ def __init__(self, vectorstore: Qdrant, config: dict = None):
22
+ """
23
+ Initialize the context retriever.
24
+
25
+ Args:
26
+ vectorstore: Qdrant vector store instance
27
+ config: Configuration dictionary
28
+ """
29
+ self.vectorstore = vectorstore
30
+ self.config = config or {}
31
+ self.reranker = None
32
+
33
+ # BM25 attributes
34
+ self.bm25_vectorizer = None
35
+ self.bm25_matrix = None
36
+ self.bm25_documents = None
37
+
38
+ # Initialize reranker if available
39
+ # Try to get reranker model from different config paths
40
+ self.reranker_model_name = (
41
+ config.get('retrieval', {}).get('reranker_model') or
42
+ config.get('ranker', {}).get('model') or
43
+ config.get('reranker_model') or
44
+ 'BAAI/bge-reranker-v2-m3'
45
+ )
46
+ self.reranker_type = self._detect_reranker_type(self.reranker_model_name)
47
+
48
+ try:
49
+ if self.reranker_type == 'colbert':
50
+ from colbert.infra import Run, ColBERTConfig
51
+ from colbert.modeling.checkpoint import Checkpoint
52
+ # ColBERT uses late interaction - different implementation needed
53
+ print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})")
54
+ print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)")
55
+
56
+ # Create ColBERT config for CPU mode
57
+ colbert_config = ColBERTConfig(
58
+ doc_maxlen=300,
59
+ query_maxlen=32,
60
+ nbits=2,
61
+ kmeans_niters=4,
62
+ root="./colbert_data"
63
+ )
64
+
65
+ # Load checkpoint (e.g. "colbert-ir/colbertv2.0")
66
+ self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
67
+ self.colbert_model = self.colbert_checkpoint.model
68
+ self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
69
+ self.reranker = self._colbert_rerank # attach wrapper function
70
+ print(f"✅ COLBERT: Model and tokenizer loaded successfully")
71
+
72
+ else:
73
+ # Standard CrossEncoder for BGE and other models
74
+ from sentence_transformers import CrossEncoder
75
+ self.reranker = CrossEncoder(self.reranker_model_name)
76
+ print(f"✅ RERANKER: Initialized {self.reranker_model_name}")
77
+ print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)")
78
+ except Exception as e:
79
+ print(f"⚠️ Reranker initialization failed: {e}")
80
+ self.reranker = None
81
+
82
+ def _detect_reranker_type(self, model_name: str) -> str:
83
+ """
84
+ Detect the type of reranker based on model name.
85
+
86
+ Args:
87
+ model_name: Name of the reranker model
88
+
89
+ Returns:
90
+ 'colbert' for ColBERT models, 'crossencoder' for others
91
+ """
92
+ model_name_lower = model_name.lower()
93
+
94
+ # ColBERT model patterns
95
+ colbert_patterns = [
96
+ 'colbert',
97
+ 'colbert-ir',
98
+ 'colbertv2',
99
+ 'colbert-v2'
100
+ ]
101
+
102
+ for pattern in colbert_patterns:
103
+ if pattern in model_name_lower:
104
+ return 'colbert'
105
+
106
+ # Default to cross-encoder for BGE and other models
107
+ return 'crossencoder'
108
+
109
+ def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]:
110
+ """
111
+ Perform similarity search and fetch ColBERT embeddings for documents.
112
+
113
+ Args:
114
+ query: Search query
115
+ k: Number of documents to retrieve
116
+ **kwargs: Additional search parameters (filter, etc.)
117
+
118
+ Returns:
119
+ List of (Document, score) tuples with ColBERT embeddings in metadata
120
+ """
121
+ try:
122
+ print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings")
123
+
124
+ # Use the vectorstore's similarity_search_with_score method instead of direct client
125
+ # This ensures proper filter handling
126
+ if 'filter' in kwargs and kwargs['filter']:
127
+ # Use the vectorstore method with filter
128
+ result = self.vectorstore.similarity_search_with_score(
129
+ query,
130
+ k=k,
131
+ filter=kwargs['filter']
132
+ )
133
+ else:
134
+ # Use the vectorstore method without filter
135
+ result = self.vectorstore.similarity_search_with_score(query, k=k)
136
+
137
+ # Convert to the format we need
138
+ if isinstance(result, tuple) and len(result) == 2:
139
+ documents, scores = result
140
+ elif isinstance(result, list):
141
+ documents = []
142
+ scores = []
143
+ for item in result:
144
+ if isinstance(item, tuple) and len(item) == 2:
145
+ doc, score = item
146
+ documents.append(doc)
147
+ scores.append(score)
148
+ else:
149
+ documents.append(item)
150
+ scores.append(0.0)
151
+ else:
152
+ documents = []
153
+ scores = []
154
+
155
+ # Now we need to fetch the ColBERT embeddings for these documents
156
+ # We'll use the Qdrant client directly for this part since we need specific payload fields
157
+ from qdrant_client.http import models as rest
158
+
159
+ collection_name = self.vectorstore.collection_name
160
+
161
+ # Get document IDs from the retrieved documents
162
+ doc_ids = []
163
+ for doc in documents:
164
+ # Extract ID from document metadata or use page_content hash as fallback
165
+ doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
166
+ if not doc_id:
167
+ # Use a hash of the content as ID
168
+ import hashlib
169
+ doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
170
+ doc_ids.append(doc_id)
171
+
172
+ # Fetch documents with ColBERT embeddings from Qdrant
173
+ search_result = self.vectorstore.client.retrieve(
174
+ collection_name=collection_name,
175
+ ids=doc_ids,
176
+ with_payload=True,
177
+ with_vectors=False
178
+ )
179
+
180
+ # Convert results to Document objects with ColBERT embeddings
181
+ enhanced_documents = []
182
+ enhanced_scores = []
183
+
184
+ # Create a mapping from doc_id to original score
185
+ doc_id_to_score = {}
186
+ for i, doc in enumerate(documents):
187
+ doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
188
+ if not doc_id:
189
+ import hashlib
190
+ doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
191
+ doc_id_to_score[doc_id] = scores[i]
192
+
193
+ for point in search_result:
194
+ # Extract payload
195
+ payload = point.payload
196
+
197
+ # Get the original score for this document
198
+ doc_id = str(point.id)
199
+ original_score = doc_id_to_score.get(doc_id, 0.0)
200
+
201
+ # Create Document object with ColBERT embeddings
202
+ doc = Document(
203
+ page_content=payload.get('page_content', ''),
204
+ metadata={
205
+ **payload.get('metadata', {}),
206
+ 'colbert_embedding': payload.get('colbert_embedding'),
207
+ 'colbert_model': payload.get('colbert_model'),
208
+ 'colbert_calculated_at': payload.get('colbert_calculated_at')
209
+ }
210
+ )
211
+
212
+ enhanced_documents.append(doc)
213
+ enhanced_scores.append(original_score)
214
+
215
+ print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings")
216
+
217
+ return list(zip(enhanced_documents, enhanced_scores))
218
+
219
+ except Exception as e:
220
+ print(f"❌ COLBERT RETRIEVAL ERROR: {e}")
221
+ print(f"❌ Falling back to regular similarity search")
222
+
223
+ # Fallback to regular search - handle filter parameter correctly
224
+ if 'filter' in kwargs and kwargs['filter']:
225
+ return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter'])
226
+ else:
227
+ return self.vectorstore.similarity_search_with_score(query, k=k)
228
+
229
+ def retrieve_context(
230
+ self,
231
+ query: str,
232
+ k: int = 5,
233
+ reports: Optional[List[str]] = None,
234
+ sources: Optional[List[str]] = None,
235
+ subtype: Optional[str] = None,
236
+ year: Optional[str] = None,
237
+ district: Optional[List[str]] = None,
238
+ filenames: Optional[List[str]] = None,
239
+ use_reranking: bool = False,
240
+ qdrant_filter: Optional[rest.Filter] = None
241
+ ) -> List[Document]:
242
+ """
243
+ Retrieve context documents using hybrid search with optional filtering and reranking.
244
+
245
+ Args:
246
+ query: User query
247
+ top_k: Number of documents to retrieve
248
+ reports: List of report names to filter by
249
+ sources: List of sources to filter by
250
+ subtype: Document subtype to filter by
251
+ year: Year to filter by
252
+ use_reranking: Whether to apply reranking
253
+ qdrant_filter: Pre-built Qdrant filter to use
254
+
255
+ Returns:
256
+ List of retrieved documents
257
+ """
258
+ try:
259
+ # Determine how many documents to retrieve
260
+ retrieve_k = k #* 3 if use_reranking else k # Retrieve more for reranking
261
+
262
+ # Build search kwargs
263
+ search_kwargs = {}
264
+
265
+ # Use qdrant_filter if provided (this takes precedence)
266
+ if qdrant_filter:
267
+ search_kwargs = {"filter": qdrant_filter}
268
+ print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter")
269
+ else:
270
+ # Build filter from individual parameters
271
+ filter_obj = create_filter(
272
+ reports=reports,
273
+ sources=sources,
274
+ subtype=subtype,
275
+ year=year,
276
+ district=district,
277
+ filenames=filenames
278
+ )
279
+
280
+ if filter_obj:
281
+ search_kwargs = {"filter": filter_obj}
282
+ print(f"✅ FILTERS APPLIED: Using built filter")
283
+ else:
284
+ search_kwargs = {}
285
+ print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
286
+
287
+ # Perform vector search
288
+ try:
289
+ # Check if we need ColBERT embeddings for reranking
290
+ if use_reranking and self.reranker_type == 'colbert':
291
+ result = self._similarity_search_with_colbert_embeddings(
292
+ query,
293
+ k=retrieve_k,
294
+ **search_kwargs
295
+ )
296
+ else:
297
+ result = self.vectorstore.similarity_search_with_score(
298
+ query,
299
+ k=retrieve_k,
300
+ **search_kwargs
301
+ )
302
+
303
+ # Handle different return formats
304
+ if isinstance(result, tuple) and len(result) == 2:
305
+ documents, scores = result
306
+ elif isinstance(result, list) and len(result) > 0:
307
+ # Handle case where result is a list of (Document, score) tuples
308
+ documents = []
309
+ scores = []
310
+ for item in result:
311
+ if isinstance(item, tuple) and len(item) == 2:
312
+ doc, score = item
313
+ documents.append(doc)
314
+ scores.append(score)
315
+ else:
316
+ # Handle case where item is just a Document
317
+ documents.append(item)
318
+ scores.append(0.0) # Default score
319
+ else:
320
+ documents = []
321
+ scores = []
322
+
323
+ print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})")
324
+
325
+ # If we got fewer documents than requested, try without filters
326
+ if len(documents) < retrieve_k and search_kwargs.get('filter'):
327
+ print(f"⚠️ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...")
328
+ try:
329
+ result_no_filter = self.vectorstore.similarity_search_with_score(
330
+ query,
331
+ k=retrieve_k
332
+ )
333
+
334
+ if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2:
335
+ documents_no_filter, scores_no_filter = result_no_filter
336
+ elif isinstance(result_no_filter, list):
337
+ documents_no_filter = []
338
+ scores_no_filter = []
339
+ for item in result_no_filter:
340
+ if isinstance(item, tuple) and len(item) == 2:
341
+ doc, score = item
342
+ documents_no_filter.append(doc)
343
+ scores_no_filter.append(score)
344
+ else:
345
+ documents_no_filter.append(item)
346
+ scores_no_filter.append(0.0)
347
+ else:
348
+ documents_no_filter = []
349
+ scores_no_filter = []
350
+
351
+ if len(documents_no_filter) > len(documents):
352
+ print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters")
353
+ documents = documents_no_filter
354
+ scores = scores_no_filter
355
+ except Exception as e:
356
+ print(f"⚠️ RETRIEVAL: Fallback search failed: {e}")
357
+
358
+ except Exception as e:
359
+ print(f"❌ RETRIEVAL ERROR: {str(e)}")
360
+ return []
361
+
362
+ # Apply reranking if enabled
363
+ reranking_applied = False
364
+ if use_reranking and len(documents) > 1:
365
+ print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...")
366
+ try:
367
+ original_docs = documents.copy()
368
+ original_scores = scores.copy()
369
+
370
+ # Apply reranking
371
+ # print(f"🔍 ORIGINAL DOCS: {documents[0]}")
372
+ reranked_docs = self._apply_reranking(query, documents, scores)
373
+ # print(f"🔍 RERANKED DOCS: {reranked_docs[0]}")
374
+ reranking_applied = len(reranked_docs) > 0
375
+
376
+ if reranking_applied:
377
+ print(f"✅ RERANKING APPLIED: {self.reranker_model_name}")
378
+ documents = reranked_docs
379
+ # Update scores to reflect reranking
380
+ # scores = [0.0] * len(documents) # Reranked scores are not directly comparable
381
+ else:
382
+ print(f"⚠️ RERANKING FAILED: Using original order")
383
+ documents = original_docs
384
+ scores = original_scores
385
+ return documents
386
+
387
+ except Exception as e:
388
+ print(f"❌ RERANKING ERROR: {str(e)}")
389
+ print(f"⚠️ RERANKING FAILED: Using original order")
390
+ reranking_applied = False
391
+ elif use_reranking and len(documents) <= 1:
392
+ print(f"ℹ️ RERANKING: Skipped (only {len(documents)} document(s) retrieved)")
393
+ if use_reranking:
394
+ print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
395
+ # Store original scores in metadata
396
+ for i, (doc, score) in enumerate(zip(documents, scores)):
397
+ doc.metadata['original_score'] = float(score)
398
+ doc.metadata['reranking_applied'] = False
399
+ return documents
400
+ else:
401
+ print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
402
+
403
+ # Limit to requested number of documents
404
+ documents = documents[:k]
405
+ scores = scores[:k] if scores else [0.0] * len(documents)
406
+
407
+ # Add metadata to documents
408
+ for i, (doc, score) in enumerate(zip(documents, scores)):
409
+ if hasattr(doc, 'metadata'):
410
+ doc.metadata.update({
411
+ 'reranking_applied': reranking_applied,
412
+ 'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None,
413
+ 'original_rank': i + 1,
414
+ 'final_rank': i + 1,
415
+ 'original_score': float(score) if score is not None else 0.0
416
+ })
417
+
418
+ return documents
419
+
420
+ except Exception as e:
421
+ print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}")
422
+ return []
423
+
424
+ def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
425
+ """
426
+ Apply reranking to documents using the appropriate reranker.
427
+
428
+ Args:
429
+ query: User query
430
+ documents: List of documents to rerank
431
+ scores: Original scores
432
+
433
+ Returns:
434
+ Reranked list of documents
435
+ """
436
+ if not self.reranker or len(documents) == 0:
437
+ return documents
438
+
439
+ try:
440
+ print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents")
441
+ print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}")
442
+
443
+ if self.reranker_type == 'colbert':
444
+ return self._apply_colbert_reranking(query, documents, scores)
445
+ else:
446
+ return self._apply_crossencoder_reranking(query, documents, scores)
447
+
448
+ except Exception as e:
449
+ print(f"❌ RERANKING ERROR: {str(e)}")
450
+ return documents
451
+
452
+ def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
453
+ """
454
+ Apply reranking using CrossEncoder (BGE and other models).
455
+
456
+ Args:
457
+ query: User query
458
+ documents: List of documents to rerank
459
+ scores: Original scores
460
+
461
+ Returns:
462
+ Reranked list of documents
463
+ """
464
+ # Prepare pairs for reranking
465
+ pairs = []
466
+ for doc in documents:
467
+ pairs.append([query, doc.page_content])
468
+
469
+ print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking")
470
+
471
+ # Get reranking scores using the correct CrossEncoder API
472
+ rerank_scores = self.reranker.predict(pairs)
473
+
474
+ # Handle single score case
475
+ if not isinstance(rerank_scores, (list, np.ndarray)):
476
+ rerank_scores = [rerank_scores]
477
+
478
+ # Ensure we have the right number of scores
479
+ if len(rerank_scores) != len(documents):
480
+ print(f"⚠️ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}")
481
+ return documents
482
+
483
+ print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores")
484
+ print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") # Show first 5 scores
485
+
486
+ # Combine documents with their rerank scores
487
+ doc_scores = list(zip(documents, rerank_scores))
488
+
489
+ # Sort by rerank score (descending)
490
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
491
+
492
+ # Extract reranked documents and store scores in metadata
493
+ reranked_docs = []
494
+ for i, (doc, rerank_score) in enumerate(doc_scores):
495
+ # Find original index for original score
496
+ original_idx = documents.index(doc)
497
+ original_score = scores[original_idx] if original_idx < len(scores) else 0.0
498
+
499
+ # Create new document with reranking metadata
500
+ new_doc = Document(
501
+ page_content=doc.page_content,
502
+ metadata={
503
+ **doc.metadata,
504
+ 'reranking_applied': True,
505
+ 'reranker_model': self.reranker_model_name,
506
+ 'reranker_type': self.reranker_type,
507
+ 'original_rank': original_idx + 1,
508
+ 'final_rank': i + 1,
509
+ 'original_score': float(original_score),
510
+ 'reranked_score': float(rerank_score)
511
+ }
512
+ )
513
+ reranked_docs.append(new_doc)
514
+
515
+ print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents")
516
+
517
+ return reranked_docs
518
+
519
+ def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
520
+ """
521
+ Apply reranking using ColBERT late interaction.
522
+
523
+ Args:
524
+ query: User query
525
+ documents: List of documents to rerank
526
+ scores: Original scores
527
+
528
+ Returns:
529
+ Reranked list of documents
530
+ """
531
+ # Use the actual ColBERT reranking implementation
532
+ return self._colbert_rerank(query, documents, scores)
533
+
534
+ def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
535
+ """
536
+ ColBERT reranking using late interaction with pre-calculated embeddings support.
537
+
538
+ Args:
539
+ query: User query
540
+ documents: List of documents to rerank
541
+ scores: Original scores
542
+
543
+ Returns:
544
+ Reranked list of documents
545
+ """
546
+ try:
547
+ print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents")
548
+
549
+ # Check if documents have pre-calculated ColBERT embeddings
550
+ pre_calculated_embeddings = []
551
+ documents_without_embeddings = []
552
+ documents_without_indices = []
553
+
554
+ for i, doc in enumerate(documents):
555
+ if (hasattr(doc, 'metadata') and
556
+ 'colbert_embedding' in doc.metadata and
557
+ doc.metadata['colbert_embedding'] is not None):
558
+ # Use pre-calculated embedding
559
+ colbert_embedding = doc.metadata['colbert_embedding']
560
+ if isinstance(colbert_embedding, list):
561
+ colbert_embedding = torch.tensor(colbert_embedding)
562
+ pre_calculated_embeddings.append(colbert_embedding)
563
+ else:
564
+ # Need to calculate embedding
565
+ documents_without_embeddings.append(doc)
566
+ documents_without_indices.append(i)
567
+
568
+ # Calculate query embedding
569
+ query_embeddings = self.colbert_checkpoint.queryFromText([query])
570
+
571
+ # Calculate embeddings for documents without pre-calculated ones
572
+ if documents_without_embeddings:
573
+ print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings")
574
+ doc_texts = [doc.page_content for doc in documents_without_embeddings]
575
+ doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts)
576
+
577
+ # Insert calculated embeddings into the right positions
578
+ for i, embedding in enumerate(doc_embeddings):
579
+ idx = documents_without_indices[i]
580
+ pre_calculated_embeddings.insert(idx, embedding)
581
+ else:
582
+ print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents")
583
+
584
+ # Calculate late interaction scores
585
+ # ColBERT uses MaxSim: for each query token, find max similarity with document tokens
586
+ colbert_scores = []
587
+ for i, doc_embedding in enumerate(pre_calculated_embeddings):
588
+ # Calculate similarity matrix between query and document i
589
+ sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2))
590
+
591
+ # MaxSim: for each query token, take max similarity with document
592
+ max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0]
593
+
594
+ # Sum over query tokens to get final score
595
+ final_score = torch.sum(max_sim_per_query_token).item()
596
+ colbert_scores.append(final_score)
597
+
598
+ # Sort documents by ColBERT scores
599
+ doc_scores = list(zip(documents, colbert_scores))
600
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
601
+
602
+ # Create reranked documents with metadata
603
+ reranked_docs = []
604
+ for i, (doc, colbert_score) in enumerate(doc_scores):
605
+ original_idx = documents.index(doc)
606
+ original_score = scores[original_idx] if original_idx < len(scores) else 0.0
607
+
608
+ new_doc = Document(
609
+ page_content=doc.page_content,
610
+ metadata={
611
+ **doc.metadata,
612
+ 'reranking_applied': True,
613
+ 'reranker_model': self.reranker_model_name,
614
+ 'reranker_type': self.reranker_type,
615
+ 'original_rank': original_idx + 1,
616
+ 'final_rank': i + 1,
617
+ 'original_score': float(original_score),
618
+ 'reranked_score': float(colbert_score),
619
+ 'colbert_score': float(colbert_score),
620
+ 'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata
621
+ }
622
+ )
623
+ reranked_docs.append(new_doc)
624
+
625
+ print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction")
626
+ print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...")
627
+
628
+ return reranked_docs
629
+
630
+ except Exception as e:
631
+ print(f"❌ COLBERT RERANKING ERROR: {str(e)}")
632
+ print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}")
633
+ # Fallback to original order - return documents as-is
634
+ return documents
635
+
636
+ def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None,
637
+ sources: List[str] = None, subtype: List[str] = None,
638
+ year: List[str] = None, use_reranking: bool = False,
639
+ qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]:
640
+ """
641
+ Retrieve context documents with scores using hybrid search with optional reranking.
642
+
643
+ Args:
644
+ query: User query
645
+ vectorstore: Optional vectorstore instance (for compatibility)
646
+ k: Number of documents to retrieve
647
+ reports: List of report names to filter by
648
+ sources: List of sources to filter by
649
+ subtype: Document subtype to filter by
650
+ year: List of years to filter by
651
+ use_reranking: Whether to apply reranking
652
+ qdrant_filter: Pre-built Qdrant filter
653
+
654
+ Returns:
655
+ Tuple of (documents, scores)
656
+ """
657
+ try:
658
+ # Use the provided vectorstore if available, otherwise use the instance one
659
+ if vectorstore:
660
+ self.vectorstore = vectorstore
661
+
662
+ # Determine search strategy
663
+ search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only')
664
+
665
+ if search_strategy == 'vector_only':
666
+ # Vector search only
667
+ print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...")
668
+
669
+ if qdrant_filter:
670
+ print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
671
+ # Pass filter as positional argument, not keyword argument
672
+ results = self.vectorstore.similarity_search_with_score(
673
+ query,
674
+ k=k,
675
+ filter=qdrant_filter
676
+ )
677
+ else:
678
+ # Build filter from individual parameters
679
+ filter_conditions = self._build_filter_conditions(reports, sources, subtype, year)
680
+ if filter_conditions:
681
+ print(f"✅ FILTER APPLIED: {filter_conditions}")
682
+ results = self.vectorstore.similarity_search_with_score(
683
+ query,
684
+ k=k,
685
+ filter=filter_conditions
686
+ )
687
+ else:
688
+ print(f"ℹ️ NO FILTERS APPLIED: All documents will be searched")
689
+ results = self.vectorstore.similarity_search_with_score(query, k=k)
690
+
691
+ print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}")
692
+ print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}")
693
+
694
+ # Handle different result formats
695
+ if results and isinstance(results[0], tuple):
696
+ documents = [doc for doc, score in results]
697
+ scores = [score for doc, score in results]
698
+ print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}")
699
+ else:
700
+ documents = results
701
+ scores = [0.0] * len(documents)
702
+ print(f"🔍 SEARCH DEBUG: No scores available, using default")
703
+
704
+ print(f"🔧 CONVERTING: Converting {len(documents)} documents")
705
+
706
+ # Convert to Document objects and store original scores
707
+ final_documents = []
708
+ for i, (doc, score) in enumerate(zip(documents, scores)):
709
+ if hasattr(doc, 'page_content'):
710
+ new_doc = Document(
711
+ page_content=doc.page_content,
712
+ metadata=doc.metadata.copy()
713
+ )
714
+ # Store original score in metadata
715
+ new_doc.metadata['original_score'] = float(score) if score is not None else 0.0
716
+ final_documents.append(new_doc)
717
+ else:
718
+ print(f"⚠️ WARNING: Document {i} has no page_content")
719
+
720
+ print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents")
721
+
722
+ # Apply reranking if enabled
723
+ if use_reranking and len(final_documents) > 1:
724
+ print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...")
725
+ final_documents = self._apply_reranking(query, final_documents, scores)
726
+ print(f"✅ RERANKING APPLIED: {self.reranker_model}")
727
+ else:
728
+ print(f"ℹ️ RERANKING: Skipped (disabled or no documents)")
729
+
730
+ return final_documents, scores
731
+
732
+ else:
733
+ print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}")
734
+ return [], []
735
+
736
+ except Exception as e:
737
+ print(f"❌ RETRIEVAL ERROR: {e}")
738
+ print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}")
739
+ return [], []
740
+
741
+ def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None,
742
+ subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]:
743
+ """
744
+ Build Qdrant filter conditions from individual parameters.
745
+
746
+ Args:
747
+ reports: List of report names
748
+ sources: List of sources
749
+ subtype: Document subtype
750
+ year: List of years
751
+
752
+ Returns:
753
+ Qdrant filter or None
754
+ """
755
+ conditions = []
756
+
757
+ if reports:
758
+ conditions.append(rest.FieldCondition(
759
+ key="metadata.filename",
760
+ match=rest.MatchAny(any=reports)
761
+ ))
762
+
763
+ if sources:
764
+ conditions.append(rest.FieldCondition(
765
+ key="metadata.source",
766
+ match=rest.MatchAny(any=sources)
767
+ ))
768
+
769
+ if subtype:
770
+ conditions.append(rest.FieldCondition(
771
+ key="metadata.subtype",
772
+ match=rest.MatchAny(any=subtype)
773
+ ))
774
+
775
+ if year:
776
+ conditions.append(rest.FieldCondition(
777
+ key="metadata.year",
778
+ match=rest.MatchAny(any=year)
779
+ ))
780
+
781
+ if conditions:
782
+ return rest.Filter(must=conditions)
783
+
784
+ return None
785
+
786
+ def get_context(
787
+ query: str,
788
+ vectorstore: Qdrant,
789
+ k: int = 5,
790
+ reports: Optional[List[str]] = None,
791
+ sources: Optional[List[str]] = None,
792
+ subtype: Optional[str] = None,
793
+ year: Optional[str] = None,
794
+ use_reranking: bool = False,
795
+ qdrant_filter: Optional[rest.Filter] = None
796
+ ) -> List[Document]:
797
+ """
798
+ Convenience function to get context documents.
799
+
800
+ Args:
801
+ query: User query
802
+ vectorstore: Qdrant vector store instance
803
+ k: Number of documents to retrieve
804
+ reports: Optional list of report names to filter by
805
+ sources: Optional list of source categories to filter by
806
+ subtype: Optional subtype to filter by
807
+ year: Optional year to filter by
808
+ use_reranking: Whether to apply reranking
809
+ qdrant_filter: Optional pre-built Qdrant filter
810
+
811
+ Returns:
812
+ List of retrieved documents
813
+ """
814
+ retriever = ContextRetriever(vectorstore)
815
+ return retriever.retrieve_context(
816
+ query=query,
817
+ k=k,
818
+ reports=reports,
819
+ sources=sources,
820
+ subtype=subtype,
821
+ year=year,
822
+ use_reranking=use_reranking,
823
+ qdrant_filter=qdrant_filter
824
+ )
825
+
826
+
827
+ def format_context_for_llm(documents: List[Document]) -> str:
828
+ """
829
+ Format retrieved documents for LLM input.
830
+
831
+ Args:
832
+ documents: List of Document objects
833
+
834
+ Returns:
835
+ Formatted string for LLM
836
+ """
837
+ if not documents:
838
+ return ""
839
+
840
+ formatted_parts = []
841
+ for i, doc in enumerate(documents, 1):
842
+ content = doc.page_content.strip()
843
+ source = doc.metadata.get('filename', 'Unknown')
844
+
845
+ formatted_parts.append(f"Document {i} (Source: {source}):\n{content}")
846
+
847
+ return "\n\n".join(formatted_parts)
848
+
849
+
850
+ def get_context_metadata(documents: List[Document]) -> Dict[str, Any]:
851
+ """
852
+ Extract metadata summary from retrieved documents.
853
+
854
+ Args:
855
+ documents: List of Document objects
856
+
857
+ Returns:
858
+ Dictionary with metadata summary
859
+ """
860
+ if not documents:
861
+ return {}
862
+
863
+ sources = set()
864
+ years = set()
865
+ doc_types = set()
866
+
867
+ for doc in documents:
868
+ metadata = doc.metadata
869
+ if 'filename' in metadata:
870
+ sources.add(metadata['filename'])
871
+ if 'year' in metadata:
872
+ years.add(metadata['year'])
873
+ if 'source' in metadata:
874
+ doc_types.add(metadata['source'])
875
+
876
+ return {
877
+ "num_documents": len(documents),
878
+ "sources": list(sources),
879
+ "years": list(years),
880
+ "document_types": list(doc_types)
881
+ }
src/retrieval/filter.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document filtering utilities for Qdrant vector store."""
2
+
3
+ from typing import List, Optional, Union, Dict, Tuple, Any
4
+ from qdrant_client.http import models as rest
5
+ import time
6
+
7
+
8
+ class FilterBuilder:
9
+ """Builder class for creating Qdrant filters."""
10
+
11
+ def __init__(self):
12
+ self.conditions = []
13
+
14
+ def add_source_filter(self, source: Union[str, List[str]]) -> 'FilterBuilder':
15
+ """Add source filter condition."""
16
+ if source:
17
+ if isinstance(source, list):
18
+ condition = rest.FieldCondition(
19
+ key="metadata.source",
20
+ match=rest.MatchAny(any=source)
21
+ )
22
+ print(f"🔧 FilterBuilder: Added source filter for {source}")
23
+ else:
24
+ condition = rest.FieldCondition(
25
+ key="metadata.source",
26
+ match=rest.MatchValue(value=source)
27
+ )
28
+ print(f"🔧 FilterBuilder: Added source filter for '{source}'")
29
+ self.conditions.append(condition)
30
+ return self
31
+
32
+ def add_filename_filter(self, filenames: List[str]) -> 'FilterBuilder':
33
+ """Add filename filter condition."""
34
+ if filenames:
35
+ condition = rest.FieldCondition(
36
+ key="metadata.filename",
37
+ match=rest.MatchAny(any=filenames)
38
+ )
39
+ self.conditions.append(condition)
40
+ print(f"🔧 FilterBuilder: Added filename filter for {filenames}")
41
+ return self
42
+
43
+ def add_year_filter(self, years: List[str]) -> 'FilterBuilder':
44
+ """Add year filter condition."""
45
+ if years:
46
+ condition = rest.FieldCondition(
47
+ key="metadata.year",
48
+ match=rest.MatchAny(any=years)
49
+ )
50
+ self.conditions.append(condition)
51
+ print(f"🔧 FilterBuilder: Added year filter for {years}")
52
+ return self
53
+
54
+ def add_district_filter(self, districts: List[str]) -> 'FilterBuilder':
55
+ """Add district filter condition."""
56
+ if districts:
57
+ condition = rest.FieldCondition(
58
+ key="metadata.district",
59
+ match=rest.MatchAny(any=districts)
60
+ )
61
+ self.conditions.append(condition)
62
+ print(f"🔧 FilterBuilder: Added district filter for {districts}")
63
+ return self
64
+
65
+ def add_custom_filter(self, key: str, value: Union[str, List[str]]) -> 'FilterBuilder':
66
+ """Add custom filter condition."""
67
+ if isinstance(value, list):
68
+ condition = rest.FieldCondition(
69
+ key=key,
70
+ match=rest.MatchAny(any=value)
71
+ )
72
+ else:
73
+ condition = rest.FieldCondition(
74
+ key=key,
75
+ match=rest.MatchValue(value=value)
76
+ )
77
+ self.conditions.append(condition)
78
+ return self
79
+
80
+ def build(self) -> rest.Filter:
81
+ """Build the final filter."""
82
+ if not self.conditions:
83
+ return None
84
+
85
+ return rest.Filter(must=self.conditions)
86
+
87
+
88
+ def create_filter(
89
+ reports: List[str] = None,
90
+ sources: Union[str, List[str]] = None,
91
+ subtype: List[str] = None,
92
+ year: List[str] = None,
93
+ district: List[str] = None,
94
+ filenames: List[str] = None
95
+ ) -> rest.Filter:
96
+ """
97
+ Create a search filter for Qdrant (legacy function for compatibility).
98
+
99
+ Args:
100
+ reports: List of specific report filenames
101
+ sources: Source category
102
+ subtype: List of subtypes/filenames
103
+ year: List of years
104
+ district: List of districts
105
+ filenames: List of specific filenames (mutually exclusive with other filters)
106
+
107
+ Returns:
108
+ Qdrant Filter object
109
+
110
+ Note:
111
+ If filenames are provided, ONLY filename filtering is applied (mutually exclusive)
112
+ """
113
+ builder = FilterBuilder()
114
+
115
+ # Check if filename filtering is requested (mutually exclusive)
116
+ # Both filenames and reports serve the same purpose (backward compatibility)
117
+ # Prefer filenames, fallback to reports for legacy support
118
+ target_filenames = filenames if filenames else reports
119
+
120
+ if target_filenames and len(target_filenames) > 0:
121
+ # ONLY apply filename filter, ignore all other filters
122
+ print(f"🔍 FILTER APPLIED: Filenames = {target_filenames} (mutually exclusive mode)")
123
+ builder.add_filename_filter(target_filenames)
124
+ else:
125
+ # Otherwise, filter by source and subtype
126
+ print(f"🔍 FILTER APPLIED: Sources = {sources}, Subtype = {subtype}, Year = {year}, District = {district}")
127
+ if sources:
128
+ print(f"✅ Adding source filter: metadata.source = '{sources}'")
129
+ builder.add_source_filter(sources)
130
+ if subtype:
131
+ print(f"✅ Adding subtype filter: metadata.filename IN {subtype}")
132
+ builder.add_filename_filter(subtype)
133
+ if year:
134
+ print(f"✅ Adding year filter: metadata.year IN {year}")
135
+ builder.add_year_filter(year)
136
+
137
+ if district:
138
+ print(f"✅ Adding district filter: metadata.district IN {district}")
139
+ builder.add_district_filter(district)
140
+
141
+ filter_obj = builder.build()
142
+
143
+ if filter_obj:
144
+ print(f"�� FINAL FILTER: {len(filter_obj.must)} condition(s) applied")
145
+ for i, condition in enumerate(filter_obj.must, 1):
146
+ print(f" Condition {i}: {condition.key} = {condition.match}")
147
+ else:
148
+ print("⚠️ NO FILTERS APPLIED: All documents will be searched")
149
+
150
+ return filter_obj
151
+
152
+
153
+ def create_advanced_filter(
154
+ must_conditions: List[dict] = None,
155
+ should_conditions: List[dict] = None,
156
+ must_not_conditions: List[dict] = None
157
+ ) -> rest.Filter:
158
+ """
159
+ Create advanced filter with multiple condition types.
160
+
161
+ Args:
162
+ must_conditions: Conditions that must match
163
+ should_conditions: Conditions that should match (OR logic)
164
+ must_not_conditions: Conditions that must not match
165
+
166
+ Returns:
167
+ Qdrant Filter object
168
+ """
169
+ filter_dict = {}
170
+
171
+ if must_conditions:
172
+ filter_dict["must"] = [
173
+ _dict_to_field_condition(cond) for cond in must_conditions
174
+ ]
175
+
176
+ if should_conditions:
177
+ filter_dict["should"] = [
178
+ _dict_to_field_condition(cond) for cond in should_conditions
179
+ ]
180
+
181
+ if must_not_conditions:
182
+ filter_dict["must_not"] = [
183
+ _dict_to_field_condition(cond) for cond in must_not_conditions
184
+ ]
185
+
186
+ if not filter_dict:
187
+ return None
188
+
189
+ return rest.Filter(**filter_dict)
190
+
191
+
192
+ def _dict_to_field_condition(condition_dict: dict) -> rest.FieldCondition:
193
+ """Convert dictionary to FieldCondition."""
194
+ key = condition_dict["key"]
195
+ value = condition_dict["value"]
196
+
197
+ if isinstance(value, list):
198
+ match = rest.MatchAny(any=value)
199
+ else:
200
+ match = rest.MatchValue(value=value)
201
+
202
+ return rest.FieldCondition(key=key, match=match)
203
+
204
+
205
+ def validate_filter(filter_obj: rest.Filter) -> bool:
206
+ """
207
+ Validate that a filter object is properly constructed.
208
+
209
+ Args:
210
+ filter_obj: Qdrant Filter object
211
+
212
+ Returns:
213
+ True if valid, raises ValueError if invalid
214
+ """
215
+ if filter_obj is None:
216
+ return True
217
+
218
+ if not isinstance(filter_obj, rest.Filter):
219
+ raise ValueError("Filter must be a rest.Filter object")
220
+
221
+ # Check that at least one condition type is present
222
+ has_conditions = any([
223
+ hasattr(filter_obj, 'must') and filter_obj.must,
224
+ hasattr(filter_obj, 'should') and filter_obj.should,
225
+ hasattr(filter_obj, 'must_not') and filter_obj.must_not
226
+ ])
227
+
228
+ if not has_conditions:
229
+ raise ValueError("Filter must have at least one condition")
230
+
231
+ return True
232
+
233
+
234
+ def infer_filters_from_query(
235
+ query: str,
236
+ available_metadata: dict,
237
+ llm_client=None
238
+ ) -> Tuple[rest.Filter, Union[dict, None]]:
239
+ """
240
+ Automatically infer filters from a query using LLM analysis.
241
+
242
+ Args:
243
+ query: User query to analyze
244
+ available_metadata: Available metadata values in the vectorstore
245
+ llm_client: LLM client for analysis (optional)
246
+
247
+ Returns:
248
+ Qdrant Filter object with inferred conditions
249
+ """
250
+ print(f"�� AUTO-INFERRING FILTERS from query: '{query[:50]}...'")
251
+
252
+ # Check if LLM client is available
253
+ if not llm_client:
254
+ print(f"❌ LLM CLIENT MISSING: Cannot use LLM analysis, falling back to rule-based")
255
+ return _infer_filters_rule_based(query, available_metadata), None
256
+
257
+ # Extract available options
258
+ available_sources = available_metadata.get('sources', [])
259
+ available_years = available_metadata.get('years', [])
260
+ available_filenames = available_metadata.get('filenames', [])
261
+
262
+ print(f"📊 Available metadata: sources={len(available_sources)}, years={len(available_years)}, filenames={len(available_filenames)}")
263
+
264
+ # Try LLM analysis first
265
+ print(f" LLM ANALYSIS: Attempting LLM-based filter inference...")
266
+ llm_result = _analyze_query_with_llm(
267
+ query=query,
268
+ available_metadata=available_metadata,
269
+ llm_client=llm_client
270
+ )
271
+
272
+ if llm_result:
273
+ print(f"✅ LLM SUCCESS: LLM successfully inferred filters")
274
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
275
+ qdrant_filter, filter_summary = _build_qdrant_filter(llm_result)
276
+ if qdrant_filter:
277
+ print(f"✅ QDRANT FILTER: Successfully built Qdrant filter")
278
+ # print(f"✅ INFERRED FILTERS: {qdrant_filter}")
279
+ return qdrant_filter, filter_summary
280
+ else:
281
+ print(f"❌ QDRANT FILTER: Failed to build Qdrant filter, trying rule-based fallback")
282
+ rule_based_result = _infer_filters_rule_based(query, available_metadata)
283
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
284
+ qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
285
+ if qdrant_filter:
286
+ print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
287
+ return qdrant_filter, filter_summary
288
+ else:
289
+ print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
290
+ return None, None
291
+ else:
292
+ print(f"⚠️ LLM FAILED: LLM could not infer filters, trying rule-based fallback")
293
+ rule_based_result = _infer_filters_rule_based(query, available_metadata)
294
+ # Use the _build_qdrant_filter function to properly build the Qdrant filter
295
+ qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
296
+ if qdrant_filter:
297
+ print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
298
+ return qdrant_filter, filter_summary
299
+ else:
300
+ print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
301
+ return None, None
302
+
303
+
304
+ def _analyze_query_with_llm(
305
+ query: str,
306
+ available_metadata: Dict[str, List[str]],
307
+ llm_client=None
308
+ ) -> dict:
309
+
310
+
311
+ """
312
+ - Filenames: {available_metadata.get('filenames', [])}
313
+
314
+ 📁 FILENAME FILTERING (Use Sparingly):
315
+ - Only if specific filename explicitly mentioned
316
+ - Prefer source/subtype over filename
317
+ - Be very conservative
318
+
319
+
320
+ "filenames": ["filename1", "filename2"] or [],
321
+ - For filenames: Only use if you have high confidence and can identify specific files
322
+ """
323
+
324
+
325
+ """
326
+ Use LLM to analyze query and infer appropriate filters.
327
+
328
+ Args:
329
+ query: User query to analyze
330
+ available_metadata: Available metadata values in the vectorstore
331
+ llm_client: LLM client for analysis
332
+
333
+ Returns:
334
+ Dictionary with inferred filters or empty dict if failed
335
+ """
336
+ if not llm_client:
337
+ print("❌ LLM CLIENT MISSING: Cannot analyze query without LLM client")
338
+ return {}
339
+
340
+ try:
341
+ print(f" LLM ANALYSIS: Analyzing query with LLM...")
342
+
343
+
344
+ """
345
+ For example: "What is the expected ... in 2024" - this refference to a future statement, so retrieving documents for 2023, 2022 and 2021 can be relevant too
346
+ Another example: "What is the GDP increase now compared to 2022" - this is a relative statement, refferring to past data, so both Year 2022, and now - 2025 needs to be detected/marked
347
+ """
348
+
349
+ # Create prompt for LLM analysis
350
+ prompt = f"""
351
+ You are a filter inference system. Analyze this query and return ONLY a JSON object.
352
+
353
+ Query: "{query}"
354
+
355
+ Available metadata:
356
+ - Sources: {available_metadata.get('sources', [])}
357
+ - Years: {available_metadata.get('years', [])}
358
+
359
+ FILTER INFERENCE GUIDELINES:
360
+
361
+ YEAR FILTERING (Be VERY Conservative):
362
+ ✅ INFER YEARS ONLY IF:
363
+ - Explicit 4-digit years: "2022", "2023", "2021"
364
+ - Clear relative terms: "last year", "this year", "recent", "current year" (for the context, now is 2025)
365
+ - Temporal context: "annual report 2022", "audit for 2023"
366
+ - Give multiple years for complex queries.
367
+
368
+
369
+ ❌ DO NOT INFER YEARS FOR:
370
+ - Vague terms: "implementation", "activities", "costs", "challenges", "issues"
371
+ - General concepts: "PDM", "administrative", "budget", "staff"
372
+ - Process descriptions: "how were", "what challenges", "management of"
373
+
374
+ 🏛️ SOURCE FILTERING (Context-Based):
375
+ - "Ministry, Department and Agency" → Central government, ministries, departments, PS/ST
376
+ - "Local Government" → Districts, municipalities, local authorities, DLG
377
+ - "Consolidated" → Annual consolidated reports, OAG reports
378
+ - "Thematic" → Special studies, thematic reports
379
+
380
+ �� SUBTYPE FILTERING (Document Type):
381
+ - "audit" → Audit reports, reviews, examinations
382
+ - "report" → General reports, annual reports
383
+ - "guidance" → Guidelines, directives, circulars
384
+
385
+ CONFIDENCE SCORING:
386
+ - 0.9-1.0: Crystal clear indicators (explicit years, specific sources)
387
+ - 0.7-0.8: Good indicators (relative years, clear context)
388
+ - 0.5-0.6: Moderate indicators (some context clues)
389
+ - 0.0-0.4: Low confidence (vague or unclear)
390
+
391
+ EXAMPLES:
392
+ ✅ "What challenges arose in 2022?" → years: ["2022"], confidence: 1
393
+ ✅ "How were administrative costs managed in our government?" → sources: ["Local Government"], confidence: 0.75
394
+ ✅ "PDM implementation guidelines from last year" → years: ["2024"], confidence: 0.9
395
+ ❌ "What issues arose with budget execution?" → NO FILTERS, confidence: 0.2
396
+ ❌ "How were tools related to administrative costs?" → NO FILTERS, confidence: 0.1
397
+
398
+ RESPONSE FORMAT (JSON only):
399
+ {{
400
+ "years": ["2022", "2023"] or [],
401
+ "sources": ["Ministry, Department and Agency", "Local Government"] or [],
402
+ "subtype": ["audit", "report"] or [],
403
+ "confidence": 0.8,
404
+ "reasoning": "Very brief explanation of filter choices"
405
+ }}
406
+
407
+ Rules:
408
+ - Use OR logic (SHOULD) for multiple values
409
+ - Prefer sources over filenames
410
+ - Only include years if clearly mentioned
411
+ - Return null for unclear fields
412
+ - For sources/subtypes: Include at least 3 candidates unless confidence is high and you can identify exactly one source (MUST)
413
+ - For years: If you want to include, then include at least 2 candidates unless confidence is high and you can identify exactly one year (MUST)
414
+ """
415
+
416
+ print(f"🔄 LLM CALL: Sending prompt to LLM...")
417
+ try:
418
+ # Try different methods to call the LLM
419
+ if hasattr(llm_client, 'invoke'):
420
+ response = llm_client.invoke(prompt)
421
+ elif hasattr(llm_client, 'generate'):
422
+ response = llm_client.generate([{"role": "user", "content": prompt}])
423
+ elif hasattr(llm_client, 'call'):
424
+ response = llm_client.call(prompt)
425
+ elif hasattr(llm_client, 'predict'):
426
+ response = llm_client.predict(prompt)
427
+ else:
428
+ # Try to call it directly
429
+ response = llm_client(prompt)
430
+
431
+ print(f"✅ LLM CALL SUCCESS: Received response from LLM")
432
+
433
+ # Extract content from response
434
+ if hasattr(response, 'content'):
435
+ response_content = response.content
436
+ elif hasattr(response, 'text'):
437
+ response_content = response.text
438
+ elif isinstance(response, str):
439
+ response_content = response
440
+ else:
441
+ response_content = str(response)
442
+
443
+ print(f"🔄 LLM RESPONSE: {response_content[:200]}...")
444
+
445
+ except Exception as e:
446
+ print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
447
+ return {}
448
+
449
+ # Parse JSON response
450
+ import json
451
+ import re
452
+ try:
453
+ print(f"🔄 JSON PARSING: Attempting to parse LLM response...")
454
+
455
+ # Clean the response to extract JSON from markdown
456
+ response_text = response_content.strip()
457
+
458
+ # Remove markdown formatting if present
459
+ if "```json" in response_text:
460
+ # Extract JSON from markdown code block
461
+ start_marker = "```json"
462
+ end_marker = "```"
463
+ start_idx = response_text.find(start_marker)
464
+ if start_idx != -1:
465
+ start_idx += len(start_marker)
466
+ end_idx = response_text.find(end_marker, start_idx)
467
+ if end_idx != -1:
468
+ response_text = response_text[start_idx:end_idx].strip()
469
+
470
+ # Try to find JSON object in the response
471
+ json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
472
+ if json_match:
473
+ response_text = json_match.group(0)
474
+
475
+ print(f"🔄 JSON PARSING: Cleaned response: {response_text[:200]}...")
476
+
477
+ # Parse JSON
478
+ filters = json.loads(response_text)
479
+ print(f"✅ JSON PARSING SUCCESS: Parsed filters: {filters}")
480
+
481
+ # Validate filters
482
+ if not isinstance(filters, dict):
483
+ print(f"❌ JSON VALIDATION FAILED: Response is not a dictionary")
484
+ return {}
485
+
486
+ # Check if any filters were inferred
487
+ has_filters = any(filters.get(key) for key in ['sources', 'years', 'filenames'])
488
+ if not has_filters:
489
+ print(f"⚠️ QUERY DIFFICULT: LLM could not determine appropriate filters from query")
490
+ return {}
491
+
492
+ # print(f"✅ FILTER INFERENCE SUCCESS: Inferred filters: {filters}")
493
+ return filters
494
+
495
+ except json.JSONDecodeError as e:
496
+ print(f"❌ JSON PARSING FAILED: Invalid JSON format - {e}")
497
+ print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
498
+ return {}
499
+ except Exception as e:
500
+ print(f"❌ JSON PARSING FAILED: Unexpected error - {e}")
501
+ print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
502
+ return {}
503
+
504
+ except Exception as e:
505
+ print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
506
+ return {}
507
+
508
+
509
+ def _infer_filters_rule_based(
510
+ query: str,
511
+ available_metadata: dict
512
+ ) -> dict:
513
+ """
514
+ Rule-based fallback for filter inference with improved logic.
515
+
516
+ Args:
517
+ query: User query
518
+ available_metadata: Available metadata values in the vectorstore
519
+
520
+ Returns:
521
+ Dictionary of inferred filters
522
+ """
523
+ print(f" RULE-BASED ANALYSIS: Starting rule-based inference for query: '{query[:50]}...'")
524
+
525
+ inferred = {}
526
+ query_lower = query.lower()
527
+
528
+ # SEMANTIC SOURCE INFERENCE - Use semantic understanding
529
+ source_matches = []
530
+
531
+ # Define semantic mappings for better source inference
532
+ source_keywords = {
533
+ 'consolidated': ['consolidated', 'annual', 'oag', 'auditor general', 'government', 'financial statements', 'budget', 'expenditure', 'revenue'],
534
+ 'military': ['military', 'defence', 'defense', 'army', 'navy', 'air force', 'security', 'defense ministry'],
535
+ 'departmental': ['department', 'ministry', 'agency', 'authority', 'commission', 'board', 'directorate'],
536
+ 'thematic': ['thematic', 'sector', 'program', 'project', 'initiative', 'development', 'infrastructure']
537
+ }
538
+
539
+ for source in available_metadata.get('sources', []):
540
+ source_lower = source.lower()
541
+
542
+ # Direct keyword match
543
+ if source_lower in query_lower:
544
+ source_matches.append(source)
545
+ print(f"✅ DIRECT MATCH: Found direct keyword match for '{source}'")
546
+ else:
547
+ # Semantic keyword matching
548
+ if source_lower in source_keywords:
549
+ keywords = source_keywords[source_lower]
550
+ matches = sum(1 for keyword in keywords if keyword in query_lower)
551
+ if matches >= 2: # Require at least 2 keyword matches for semantic inference
552
+ source_matches.append(source)
553
+ print(f"✅ SEMANTIC MATCH: Found {matches} semantic keywords for '{source}': {[k for k in keywords if k in query_lower]}")
554
+
555
+ if source_matches:
556
+ # Use SHOULD (OR logic) for multiple sources
557
+ inferred['sources_should'] = source_matches
558
+ print(f"✅ SOURCE INFERENCE: Found {len(source_matches)} sources with OR logic: {source_matches}")
559
+ else:
560
+ print("❌ SOURCE INFERENCE: No source keywords found in query")
561
+
562
+ # Infer year filters - use SHOULD (OR logic) for multiple years
563
+ import re
564
+ year_matches = []
565
+ for year in available_metadata.get('years', []):
566
+ if year in query or f"'{year}" in query:
567
+ year_matches.append(year)
568
+
569
+ if year_matches:
570
+ # Use SHOULD (OR logic) for multiple years
571
+ inferred['years_should'] = year_matches
572
+ print(f"✅ YEAR INFERENCE: Found {len(year_matches)} years with OR logic: {year_matches}")
573
+ else:
574
+ print("❌ YEAR INFERENCE: No year references found in query")
575
+
576
+ # Only infer filename filters if no year filter was found (to avoid conflicts)
577
+ if not year_matches:
578
+ filename_matches = []
579
+ for filename in available_metadata.get('filenames', []):
580
+ # Only match if multiple words from filename appear in query
581
+ filename_words = filename.lower().split()
582
+ matches = sum(1 for word in filename_words if word in query_lower)
583
+ if matches >= 2: # High confidence threshold
584
+ filename_matches.append(filename)
585
+
586
+ if filename_matches:
587
+ # Use SHOULD (OR logic) for multiple filenames
588
+ inferred['filenames_should'] = filename_matches
589
+ print(f"✅ FILENAME INFERENCE: Found {len(filename_matches)} filenames with OR logic: {filename_matches}")
590
+ else:
591
+ print("❌ FILENAME INFERENCE: No high-confidence filename matches found")
592
+ else:
593
+ print("ℹ️ FILENAME INFERENCE: Skipped (year filter already applied to avoid conflicts)")
594
+
595
+ print(f" RULE-BASED RESULT: {inferred}")
596
+ return inferred
597
+
598
+
599
+ def _validate_inferred_filters(inferred_filters: dict) -> dict:
600
+ """
601
+ Validate and normalize inferred filters to ensure they're in the expected format.
602
+
603
+ Args:
604
+ inferred_filters: Raw inferred filters dictionary
605
+
606
+ Returns:
607
+ Validated and normalized filters dictionary
608
+ """
609
+ if not isinstance(inferred_filters, dict):
610
+ print(f"⚠️ FILTER VALIDATION: Inferred filters is not a dict: {type(inferred_filters)}")
611
+ return {}
612
+
613
+ validated = {}
614
+
615
+ # Normalize field names and validate values
616
+ for field_name in ['sources', 'sources_should', 'years', 'years_should', 'filenames', 'filenames_should']:
617
+ if field_name in inferred_filters and inferred_filters[field_name]:
618
+ value = inferred_filters[field_name]
619
+ if isinstance(value, list) and len(value) > 0:
620
+ # Remove any None or empty string values
621
+ clean_value = [v for v in value if v is not None and str(v).strip()]
622
+ if clean_value:
623
+ validated[field_name] = clean_value
624
+ print(f"✅ FILTER VALIDATION: {field_name} = {clean_value}")
625
+ elif isinstance(value, str) and value.strip():
626
+ validated[field_name] = [value.strip()]
627
+ print(f"✅ FILTER VALIDATION: {field_name} = [{value.strip()}]")
628
+
629
+ return validated
630
+
631
+
632
+ def _build_qdrant_filter(inferred_filters: dict) -> rest.Filter:
633
+ """
634
+ Build Qdrant filter from inferred filters.
635
+
636
+ Args:
637
+ inferred_filters: Dictionary with inferred filter values
638
+
639
+ Returns:
640
+ Qdrant Filter object
641
+ """
642
+ try:
643
+ from qdrant_client.http import models as rest
644
+
645
+ # Validate and normalize the inferred filters first
646
+ validated_filters = _validate_inferred_filters(inferred_filters)
647
+ if not validated_filters:
648
+ print(f"⚠️ NO VALID FILTERS: All filters were invalid or empty")
649
+ return None, {}
650
+
651
+ conditions = []
652
+ filter_summary = {}
653
+
654
+ # Handle sources (use OR logic for multiple values)
655
+ # Support both 'sources' and 'sources_should' field names
656
+ source_values = None
657
+ if 'sources' in validated_filters and validated_filters['sources']:
658
+ source_values = validated_filters['sources']
659
+ elif 'sources_should' in validated_filters and validated_filters['sources_should']:
660
+ source_values = validated_filters['sources_should']
661
+
662
+ if source_values and isinstance(source_values, list) and len(source_values) > 0:
663
+ if len(source_values) == 1:
664
+ conditions.append(rest.FieldCondition(
665
+ key="metadata.source",
666
+ match=rest.MatchValue(value=source_values[0])
667
+ ))
668
+ else:
669
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
670
+ conditions.append(rest.FieldCondition(
671
+ key="metadata.source",
672
+ match=rest.MatchAny(any=source_values)
673
+ ))
674
+ filter_summary['sources'] = f"SHOULD: {source_values}"
675
+
676
+ # Handle years (use OR logic for multiple values)
677
+ # Support both 'years' and 'years_should' field names
678
+ year_values = None
679
+ if 'years' in validated_filters and validated_filters['years']:
680
+ year_values = validated_filters['years']
681
+ elif 'years_should' in validated_filters and validated_filters['years_should']:
682
+ year_values = validated_filters['years_should']
683
+
684
+ if year_values and isinstance(year_values, list) and len(year_values) > 0:
685
+ if len(year_values) == 1:
686
+ conditions.append(rest.FieldCondition(
687
+ key="metadata.year",
688
+ match=rest.MatchValue(value=year_values[0])
689
+ ))
690
+ else:
691
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
692
+ conditions.append(rest.FieldCondition(
693
+ key="metadata.year",
694
+ match=rest.MatchAny(any=year_values)
695
+ ))
696
+ filter_summary['years'] = f"SHOULD: {year_values}"
697
+
698
+ # Handle filenames (use OR logic for multiple values)
699
+ # Support both 'filenames' and 'filenames_should' field names
700
+ filename_values = None
701
+ if 'filenames' in validated_filters and validated_filters['filenames']:
702
+ filename_values = validated_filters['filenames']
703
+ elif 'filenames_should' in validated_filters and validated_filters['filenames_should']:
704
+ filename_values = validated_filters['filenames_should']
705
+
706
+ if filename_values and isinstance(filename_values, list) and len(filename_values) > 0:
707
+ if len(filename_values) == 1:
708
+ conditions.append(rest.FieldCondition(
709
+ key="metadata.filename",
710
+ match=rest.MatchValue(value=filename_values[0])
711
+ ))
712
+ else:
713
+ # Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
714
+ conditions.append(rest.FieldCondition(
715
+ key="metadata.filename",
716
+ match=rest.MatchAny(any=filename_values)
717
+ ))
718
+ filter_summary['filenames'] = f"SHOULD: {filename_values}"
719
+
720
+ # Build final filter
721
+ if conditions:
722
+ # Always wrap conditions in a Filter object, even for single conditions
723
+ result_filter = rest.Filter(must=conditions)
724
+
725
+ # Print clean filter summary
726
+ print(f"✅ APPLIED FILTERS: {filter_summary}")
727
+ return result_filter, filter_summary
728
+ else:
729
+ print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
730
+ return None, {}
731
+
732
+ except Exception as e:
733
+ print(f"❌ FILTER BUILD ERROR: {str(e)}")
734
+ print(f"🔍 DEBUG: Original inferred filters keys: {list(inferred_filters.keys()) if isinstance(inferred_filters, dict) else 'Not a dict'}")
735
+ print(f"🔍 DEBUG: Original inferred filters content: {inferred_filters}")
736
+ print(f"🔍 DEBUG: Validated filters keys: {list(validated_filters.keys()) if isinstance(validated_filters, dict) else 'Not a dict'}")
737
+ print(f"🔍 DEBUG: Validated filters content: {validated_filters}")
738
+ # Return a safe fallback - no filter (search all documents)
739
+ return None, {}
740
+
741
+
742
+ class MetadataCache:
743
+ """Cache for vectorstore metadata to avoid repeated queries."""
744
+
745
+ def __init__(self):
746
+ self._cache = None
747
+ self._last_updated = None
748
+ self._cache_ttl = 3600 # 1 hour TTL
749
+
750
+ def get_metadata(self, vectorstore) -> dict:
751
+ """
752
+ Get metadata from cache or load it if not available/expired.
753
+
754
+ Args:
755
+ vectorstore: QdrantVectorStore instance
756
+
757
+ Returns:
758
+ Dictionary of available metadata values
759
+ """
760
+ import time
761
+
762
+ # Check if cache is valid
763
+ if (self._cache is not None and
764
+ self._last_updated is not None and
765
+ time.time() - self._last_updated < self._cache_ttl):
766
+ print(f"✅ METADATA CACHE: Using cached metadata")
767
+ return self._cache
768
+
769
+ try:
770
+ print(f"🔄 METADATA CACHE: Loading metadata from vectorstore...")
771
+
772
+ # Get collection info
773
+ try:
774
+ collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
775
+ print(f"✅ Collection info retrieved: {getattr(collection_info, 'name', 'unknown')}")
776
+ except Exception as e:
777
+ print(f"⚠️ Could not get collection info: {e}")
778
+
779
+ # Get ALL documents to extract complete metadata
780
+ print(f"📄 Scanning entire corpus for complete metadata extraction...")
781
+
782
+ # Get collection info to determine total size
783
+ try:
784
+ collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
785
+ total_points = getattr(collection_info, 'points_count', 0)
786
+ print(f"📊 Total documents in corpus: {total_points}")
787
+ except Exception as e:
788
+ print(f"⚠️ Could not get collection size: {e}")
789
+ total_points = 0
790
+
791
+ # Extract unique metadata values from ALL documents
792
+ sources = set()
793
+ years = set()
794
+ filenames = set()
795
+
796
+ # Try to use scroll to get all documents in batches
797
+ batch_size = 1000 # Process in batches to avoid memory issues
798
+ offset = None
799
+ processed_count = 0
800
+ scroll_success = False
801
+
802
+ try:
803
+ while True:
804
+ # Scroll through all documents
805
+ scroll_result = vectorstore._client.scroll(
806
+ collection_name=vectorstore.collection_name,
807
+ limit=batch_size,
808
+ offset=offset,
809
+ with_payload=True,
810
+ with_vectors=False # We only need metadata
811
+ )
812
+
813
+ points = scroll_result[0] # Get the points
814
+ if not points:
815
+ break # No more documents
816
+
817
+ # Process each document
818
+ for i, point in enumerate(points):
819
+ if hasattr(point, 'payload') and point.payload:
820
+ payload = point.payload
821
+
822
+ # Debug: Log structure of first few documents
823
+ if processed_count + i < 2: # Only log first 2 documents
824
+ print(f"🔍 DEBUG Document {processed_count + i + 1} payload structure:")
825
+ print(f" Payload keys: {list(payload.keys()) if isinstance(payload, dict) else 'Not a dict'}")
826
+ if isinstance(payload, dict) and 'metadata' in payload:
827
+ print(f" Metadata keys: {list(payload['metadata'].keys()) if isinstance(payload['metadata'], dict) else 'Not a dict'}")
828
+ elif isinstance(payload, dict):
829
+ print(f" Top-level keys: {list(payload.keys())}")
830
+ print(f" Payload type: {type(payload)}")
831
+ print(f" Payload sample: {str(payload)[:200]}...")
832
+ print()
833
+
834
+ # Try different metadata structures
835
+ found_metadata = False
836
+
837
+ # Structure 1: payload['metadata']['source']
838
+ if isinstance(payload, dict) and 'metadata' in payload:
839
+ metadata = payload['metadata']
840
+ if isinstance(metadata, dict):
841
+ if 'source' in metadata:
842
+ sources.add(metadata['source'])
843
+ found_metadata = True
844
+ if 'year' in metadata:
845
+ years.add(metadata['year'])
846
+ found_metadata = True
847
+ if 'filename' in metadata:
848
+ filenames.add(metadata['filename'])
849
+ found_metadata = True
850
+
851
+ # Structure 2: payload['source'] (direct)
852
+ if isinstance(payload, dict):
853
+ if 'source' in payload:
854
+ sources.add(payload['source'])
855
+ found_metadata = True
856
+ if 'year' in payload:
857
+ years.add(payload['year'])
858
+ found_metadata = True
859
+ if 'filename' in payload:
860
+ filenames.add(payload['filename'])
861
+ found_metadata = True
862
+
863
+ # Structure 3: Check for nested structures
864
+ if not found_metadata and isinstance(payload, dict):
865
+ # Look for any nested dict that might contain metadata
866
+ for key, value in payload.items():
867
+ if isinstance(value, dict):
868
+ if 'source' in value:
869
+ sources.add(value['source'])
870
+ found_metadata = True
871
+ if 'year' in value:
872
+ years.add(value['year'])
873
+ found_metadata = True
874
+ if 'filename' in value:
875
+ filenames.add(value['filename'])
876
+ found_metadata = True
877
+
878
+ processed_count += len(points)
879
+ progress_pct = (processed_count / total_points * 100) if total_points > 0 else 0
880
+ print(f"📄 Processed {processed_count}/{total_points} documents ({progress_pct:.1f}%)... (sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
881
+
882
+ # Update offset for next batch
883
+ offset = scroll_result[1] # Next offset
884
+ if offset is None:
885
+ break # No more documents
886
+
887
+ scroll_success = True
888
+ print(f"✅ Scroll method successful - processed {processed_count} documents")
889
+
890
+ except Exception as e:
891
+ print(f"❌ Scroll method failed: {e}")
892
+ print(f"🔄 Falling back to similarity search method...")
893
+
894
+ # Fallback: Use similarity search with multiple queries to get more coverage
895
+ fallback_queries = [
896
+ "", # Empty query
897
+ "audit", "report", "government", "ministry", "department",
898
+ "local", "consolidated", "annual", "financial", "budget",
899
+ "2020", "2021", "2022", "2023", "2024" # Year queries
900
+ ]
901
+
902
+ processed_count = 0
903
+ for query in fallback_queries:
904
+ try:
905
+ # Get documents for this query
906
+ docs = vectorstore.similarity_search(query, k=1000) # Get more per query
907
+
908
+ for j, doc in enumerate(docs):
909
+ if hasattr(doc, 'metadata') and doc.metadata:
910
+ # Debug: Log structure of first few documents in fallback
911
+ if processed_count + j < 3: # Only log first 3 documents per query
912
+ print(f"🔍 DEBUG Fallback Document {processed_count + j + 1} (query: '{query}') metadata structure:")
913
+ print(f" Metadata keys: {list(doc.metadata.keys()) if isinstance(doc.metadata, dict) else 'Not a dict'}")
914
+ print(f" Metadata type: {type(doc.metadata)}")
915
+ print(f" Metadata sample: {str(doc.metadata)[:200]}...")
916
+ print()
917
+
918
+ if 'source' in doc.metadata:
919
+ sources.add(doc.metadata['source'])
920
+ if 'year' in doc.metadata:
921
+ years.add(doc.metadata['year'])
922
+ if 'filename' in doc.metadata:
923
+ filenames.add(doc.metadata['filename'])
924
+
925
+ processed_count += len(docs)
926
+ print(f"📄 Fallback query '{query}': {len(docs)} docs (total: {processed_count}, sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
927
+
928
+ except Exception as query_error:
929
+ print(f"⚠️ Fallback query '{query}' failed: {query_error}")
930
+ continue
931
+
932
+ print(f"✅ Fallback method completed - processed {processed_count} documents")
933
+
934
+ print(f"✅ Completed scanning {processed_count} documents from entire corpus")
935
+
936
+ # Convert to sorted lists
937
+ metadata = {
938
+ 'sources': sorted(list(sources)),
939
+ 'years': sorted(list(years)),
940
+ 'filenames': sorted(list(filenames))
941
+ }
942
+
943
+ # Cache the results
944
+ self._cache = metadata
945
+ self._last_updated = time.time()
946
+
947
+ print(f"✅ Complete metadata extracted from entire corpus: {len(sources)} sources, {len(years)} years, {len(filenames)} files")
948
+
949
+ # Debug: Show what was actually found
950
+ if sources:
951
+ print(f"📁 Sources found: {sorted(list(sources))}")
952
+ else:
953
+ print(f"❌ No sources found - check metadata structure")
954
+
955
+ if years:
956
+ print(f"📅 Years found: {sorted(list(years))}")
957
+ else:
958
+ print(f"❌ No years found - check metadata structure")
959
+
960
+ if filenames:
961
+ print(f"📄 Filenames found: {sorted(list(filenames))[:10]}{'...' if len(filenames) > 10 else ''}")
962
+ else:
963
+ print(f"❌ No filenames found - check metadata structure")
964
+ return metadata
965
+
966
+ except Exception as e:
967
+ print(f"❌ Error extracting metadata: {e}")
968
+ return {'sources': [], 'years': [], 'filenames': []}
969
+
970
+ # Global metadata cache
971
+ _metadata_cache = MetadataCache()
972
+
973
+ def get_available_metadata(vectorstore) -> dict:
974
+ """Get available metadata values from the vectorstore efficiently."""
975
+ return _metadata_cache.get_metadata(vectorstore)
src/retrieval/hybrid.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hybrid search implementation combining vector and sparse retrieval."""
2
+
3
+ import json
4
+ import numpy as np
5
+ from typing import List, Dict, Any, Tuple
6
+ from pathlib import Path
7
+ from langchain.docstore.document import Document
8
+ from langchain_qdrant import QdrantVectorStore
9
+ from langchain_community.retrievers import BM25Retriever
10
+ from .filter import create_filter
11
+ import pickle
12
+ import os
13
+
14
+
15
+ class HybridRetriever:
16
+ """
17
+ Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search.
18
+ Supports configurable search modes: vector_only, sparse_only, or hybrid.
19
+ """
20
+
21
+ def __init__(self, config: Dict[str, Any]):
22
+ """
23
+ Initialize hybrid retriever.
24
+
25
+ Args:
26
+ config: Configuration dictionary with hybrid search settings
27
+ """
28
+ self.config = config
29
+ self.bm25_retriever = None
30
+ self.documents = []
31
+ self._bm25_cache_file = None
32
+
33
+ def _get_bm25_cache_path(self) -> str:
34
+ """Get path for BM25 cache file."""
35
+ cache_dir = Path("cache/bm25")
36
+ cache_dir.mkdir(parents=True, exist_ok=True)
37
+ return str(cache_dir / "bm25_retriever.pkl")
38
+
39
+ def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None:
40
+ """
41
+ Initialize BM25 retriever with documents.
42
+
43
+ Args:
44
+ documents: List of Document objects to index
45
+ force_rebuild: Whether to force rebuilding the BM25 index
46
+ """
47
+ self.documents = documents
48
+ self._bm25_cache_file = self._get_bm25_cache_path()
49
+
50
+ # Try to load cached BM25 retriever
51
+ if not force_rebuild and os.path.exists(self._bm25_cache_file):
52
+ try:
53
+ print("Loading cached BM25 retriever...")
54
+ with open(self._bm25_cache_file, 'rb') as f:
55
+ self.bm25_retriever = pickle.load(f)
56
+ print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents")
57
+ return
58
+ except Exception as e:
59
+ print(f"⚠️ Failed to load cached BM25 retriever: {e}")
60
+ print("Building new BM25 index...")
61
+
62
+ # Build new BM25 retriever
63
+ print("Building BM25 index...")
64
+ try:
65
+ # Use langchain's BM25Retriever
66
+ self.bm25_retriever = BM25Retriever.from_documents(documents)
67
+
68
+ # Configure BM25 parameters
69
+ bm25_config = self.config.get("bm25", {})
70
+ k = bm25_config.get("top_k", 20)
71
+ self.bm25_retriever.k = k
72
+
73
+ # Cache the BM25 retriever
74
+ with open(self._bm25_cache_file, 'wb') as f:
75
+ pickle.dump(self.bm25_retriever, f)
76
+ print(f"✅ Built and cached BM25 retriever with {len(documents)} documents")
77
+
78
+ except Exception as e:
79
+ print(f"❌ Failed to build BM25 retriever: {e}")
80
+ print("BM25 search will be disabled")
81
+ self.bm25_retriever = None
82
+
83
+ def _filter_documents_by_metadata(
84
+ self,
85
+ documents: List[Document],
86
+ reports: List[str] = None,
87
+ sources: str = None,
88
+ subtype: List[str] = None,
89
+ year: List[str] = None
90
+ ) -> List[Document]:
91
+ """
92
+ Filter documents by metadata criteria.
93
+
94
+ Args:
95
+ documents: List of documents to filter
96
+ reports: List of specific report filenames
97
+ sources: Source category
98
+ subtype: List of subtypes
99
+ year: List of years
100
+
101
+ Returns:
102
+ Filtered list of documents
103
+ """
104
+ if not any([reports, sources, subtype, year]):
105
+ return documents
106
+
107
+ filtered_docs = []
108
+ for doc in documents:
109
+ metadata = doc.metadata
110
+
111
+ # Filter by reports
112
+ if reports:
113
+ filename = metadata.get('filename', '')
114
+ if not any(report in filename for report in reports):
115
+ continue
116
+
117
+ # Filter by sources
118
+ if sources:
119
+ doc_source = metadata.get('source', '')
120
+ if sources != doc_source:
121
+ continue
122
+
123
+ # Filter by subtype
124
+ if subtype:
125
+ doc_subtype = metadata.get('subtype', '')
126
+ if doc_subtype not in subtype:
127
+ continue
128
+
129
+ # Filter by year
130
+ if year:
131
+ doc_year = str(metadata.get('year', ''))
132
+ if doc_year not in year:
133
+ continue
134
+
135
+ filtered_docs.append(doc)
136
+
137
+ return filtered_docs
138
+
139
+ def _bm25_search(
140
+ self,
141
+ query: str,
142
+ k: int = 20,
143
+ reports: List[str] = None,
144
+ sources: str = None,
145
+ subtype: List[str] = None,
146
+ year: List[str] = None
147
+ ) -> List[Tuple[Document, float]]:
148
+ """
149
+ Perform BM25 sparse search.
150
+
151
+ Args:
152
+ query: Search query
153
+ k: Number of documents to retrieve
154
+ reports: List of specific report filenames
155
+ sources: Source category
156
+ subtype: List of subtypes
157
+ year: List of years
158
+
159
+ Returns:
160
+ List of (Document, score) tuples
161
+ """
162
+ if not self.bm25_retriever:
163
+ print("⚠️ BM25 retriever not available")
164
+ return []
165
+
166
+ try:
167
+ # Get BM25 results
168
+ self.bm25_retriever.k = k
169
+ bm25_docs = self.bm25_retriever.invoke(query)
170
+
171
+ # Apply metadata filtering
172
+ if any([reports, sources, subtype, year]):
173
+ bm25_docs = self._filter_documents_by_metadata(
174
+ bm25_docs, reports, sources, subtype, year
175
+ )
176
+
177
+ # BM25Retriever doesn't return scores directly, so we'll use placeholder scores
178
+ # In a production system, you'd want to access the actual BM25 scores
179
+ results = []
180
+ for i, doc in enumerate(bm25_docs):
181
+ # Assign decreasing scores based on rank (higher rank = higher score)
182
+ # Normalize to [0, 1] range for consistency with vector search
183
+ score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1)))
184
+ results.append((doc, score))
185
+
186
+ return results
187
+
188
+ except Exception as e:
189
+ print(f"❌ BM25 search failed: {e}")
190
+ return []
191
+
192
+ def _vector_search(
193
+ self,
194
+ vectorstore: QdrantVectorStore,
195
+ query: str,
196
+ k: int = 20,
197
+ reports: List[str] = None,
198
+ sources: str = None,
199
+ subtype: List[str] = None,
200
+ year: List[str] = None
201
+ ) -> List[Tuple[Document, float]]:
202
+ """
203
+ Perform vector similarity search.
204
+
205
+ Args:
206
+ vectorstore: QdrantVectorStore instance
207
+ query: Search query
208
+ k: Number of documents to retrieve
209
+ reports: List of specific report filenames
210
+ sources: Source category
211
+ subtype: List of subtypes
212
+ year: List of years
213
+
214
+ Returns:
215
+ List of (Document, score) tuples
216
+ """
217
+ try:
218
+ # Create filter
219
+ filter_obj = create_filter(
220
+ reports=reports,
221
+ sources=sources,
222
+ subtype=subtype,
223
+ year=year
224
+ )
225
+
226
+ # Perform vector search
227
+ if filter_obj:
228
+ results = vectorstore.similarity_search_with_score(
229
+ query, k=k, filter=filter_obj
230
+ )
231
+ else:
232
+ results = vectorstore.similarity_search_with_score(query, k=k)
233
+
234
+ return results
235
+
236
+ except Exception as e:
237
+ print(f"❌ Vector search failed: {e}")
238
+ return []
239
+
240
+ def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]:
241
+ """
242
+ Normalize scores to [0, 1] range.
243
+
244
+ Args:
245
+ results: List of (Document, score) tuples
246
+ method: Normalization method ('min_max' or 'z_score')
247
+
248
+ Returns:
249
+ List of (Document, normalized_score) tuples
250
+ """
251
+ if not results:
252
+ return results
253
+
254
+ scores = [score for _, score in results]
255
+
256
+ if method == "min_max":
257
+ min_score = min(scores)
258
+ max_score = max(scores)
259
+ if max_score == min_score:
260
+ normalized_results = [(doc, 1.0) for doc, _ in results]
261
+ else:
262
+ normalized_results = [
263
+ (doc, (score - min_score) / (max_score - min_score))
264
+ for doc, score in results
265
+ ]
266
+ elif method == "z_score":
267
+ mean_score = np.mean(scores)
268
+ std_score = np.std(scores)
269
+ if std_score == 0:
270
+ normalized_results = [(doc, 1.0) for doc, _ in results]
271
+ else:
272
+ normalized_results = [
273
+ (doc, max(0, (score - mean_score) / std_score))
274
+ for doc, score in results
275
+ ]
276
+ else:
277
+ normalized_results = results
278
+
279
+ return normalized_results
280
+
281
+ def _combine_results(
282
+ self,
283
+ vector_results: List[Tuple[Document, float]],
284
+ bm25_results: List[Tuple[Document, float]],
285
+ alpha: float = 0.5
286
+ ) -> List[Tuple[Document, float]]:
287
+ """
288
+ Combine vector and BM25 results with weighted scoring.
289
+
290
+ Args:
291
+ vector_results: Vector search results
292
+ bm25_results: BM25 search results
293
+ alpha: Weight for vector scores (1-alpha for BM25 scores)
294
+
295
+ Returns:
296
+ Combined and ranked results
297
+ """
298
+ # Normalize scores
299
+ vector_results = self._normalize_scores(vector_results)
300
+ bm25_results = self._normalize_scores(bm25_results)
301
+
302
+ # Create document ID mapping for both result sets
303
+ vector_docs = {id(doc): (doc, score) for doc, score in vector_results}
304
+ bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results}
305
+
306
+ # Combine scores
307
+ combined_scores = {}
308
+ all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys())
309
+
310
+ for doc_id in all_doc_ids:
311
+ vector_score = vector_docs.get(doc_id, (None, 0.0))[1]
312
+ bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1]
313
+
314
+ # Weighted combination
315
+ combined_score = alpha * vector_score + (1 - alpha) * bm25_score
316
+
317
+ # Get document object
318
+ doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0]
319
+ combined_scores[doc_id] = (doc, combined_score)
320
+
321
+ # Sort by combined score (descending)
322
+ sorted_results = sorted(
323
+ combined_scores.values(),
324
+ key=lambda x: x[1],
325
+ reverse=True
326
+ )
327
+
328
+ return sorted_results
329
+
330
+ def retrieve(
331
+ self,
332
+ vectorstore: QdrantVectorStore,
333
+ query: str,
334
+ mode: str = "hybrid",
335
+ reports: List[str] = None,
336
+ sources: str = None,
337
+ subtype: List[str] = None,
338
+ year: List[str] = None,
339
+ alpha: float = 0.5,
340
+ k: int = None
341
+ ) -> List[Document]:
342
+ """
343
+ Retrieve documents using the specified search mode.
344
+
345
+ Args:
346
+ vectorstore: QdrantVectorStore instance
347
+ query: Search query
348
+ mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
349
+ reports: List of specific report filenames
350
+ sources: Source category
351
+ subtype: List of subtypes
352
+ year: List of years
353
+ alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
354
+ k: Number of documents to retrieve
355
+
356
+ Returns:
357
+ List of relevant Document objects
358
+ """
359
+ if k is None:
360
+ k = self.config.get("retriever", {}).get("top_k", 20)
361
+
362
+ results = []
363
+
364
+ if mode == "vector_only":
365
+ # Vector search only
366
+ vector_results = self._vector_search(
367
+ vectorstore, query, k, reports, sources, subtype, year
368
+ )
369
+ results = [(doc, score) for doc, score in vector_results]
370
+
371
+ elif mode == "sparse_only":
372
+ # BM25 search only
373
+ bm25_results = self._bm25_search(
374
+ query, k, reports, sources, subtype, year
375
+ )
376
+ results = [(doc, score) for doc, score in bm25_results]
377
+
378
+ elif mode == "hybrid":
379
+ # Hybrid search - combine both
380
+ # Get more results from each method to have better fusion
381
+ retrieval_k = min(k * 2, 50) # Get more candidates for fusion
382
+
383
+ vector_results = self._vector_search(
384
+ vectorstore, query, retrieval_k, reports, sources, subtype, year
385
+ )
386
+ bm25_results = self._bm25_search(
387
+ query, retrieval_k, reports, sources, subtype, year
388
+ )
389
+
390
+ results = self._combine_results(vector_results, bm25_results, alpha)
391
+
392
+ else:
393
+ raise ValueError(f"Unknown search mode: {mode}")
394
+
395
+ # Limit to top k results
396
+ results = results[:k]
397
+
398
+ # Return just the documents
399
+ return [doc for doc, score in results]
400
+
401
+ def retrieve_with_scores(
402
+ self,
403
+ vectorstore: QdrantVectorStore,
404
+ query: str,
405
+ mode: str = "hybrid",
406
+ reports: List[str] = None,
407
+ sources: str = None,
408
+ subtype: List[str] = None,
409
+ year: List[str] = None,
410
+ alpha: float = 0.5,
411
+ k: int = None
412
+ ) -> List[Tuple[Document, float]]:
413
+ """
414
+ Retrieve documents with scores using the specified search mode.
415
+
416
+ Args:
417
+ vectorstore: QdrantVectorStore instance
418
+ query: Search query
419
+ mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
420
+ reports: List of specific report filenames
421
+ sources: Source category
422
+ subtype: List of subtypes
423
+ year: List of years
424
+ alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
425
+ k: Number of documents to retrieve
426
+
427
+ Returns:
428
+ List of (Document, score) tuples
429
+ """
430
+ if k is None:
431
+ k = self.config.get("retriever", {}).get("top_k", 20)
432
+
433
+ results = []
434
+
435
+ if mode == "vector_only":
436
+ # Vector search only
437
+ results = self._vector_search(
438
+ vectorstore, query, k, reports, sources, subtype, year
439
+ )
440
+
441
+ elif mode == "sparse_only":
442
+ # BM25 search only
443
+ results = self._bm25_search(
444
+ query, k, reports, sources, subtype, year
445
+ )
446
+
447
+ elif mode == "hybrid":
448
+ # Hybrid search - combine both
449
+ # Get more results from each method to have better fusion
450
+ retrieval_k = min(k * 2, 50) # Get more candidates for fusion
451
+
452
+ vector_results = self._vector_search(
453
+ vectorstore, query, retrieval_k, reports, sources, subtype, year
454
+ )
455
+ bm25_results = self._bm25_search(
456
+ query, retrieval_k, reports, sources, subtype, year
457
+ )
458
+
459
+ results = self._combine_results(vector_results, bm25_results, alpha)
460
+
461
+ else:
462
+ raise ValueError(f"Unknown search mode: {mode}")
463
+
464
+ # Limit to top k results
465
+ return results[:k]
466
+
467
+
468
+ def get_available_search_modes() -> List[str]:
469
+ """Get list of available search modes."""
470
+ return ["vector_only", "sparse_only", "hybrid"]
471
+
472
+
473
+ def get_search_mode_description() -> Dict[str, str]:
474
+ """Get descriptions for each search mode."""
475
+ return {
476
+ "vector_only": "Semantic search using dense embeddings - good for conceptual matching",
477
+ "sparse_only": "Keyword search using BM25 - good for exact term matching",
478
+ "hybrid": "Combined semantic and keyword search - balanced approach"
479
+ }
src/vectorstore.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vector store management and operations."""
2
+ from pathlib import Path
3
+ from typing import Dict, Any, List, Optional
4
+
5
+
6
+ import torch
7
+ from langchain_qdrant import QdrantVectorStore
8
+ from langchain.docstore.document import Document
9
+ from langchain_core.embeddings import Embeddings
10
+ from sentence_transformers import SentenceTransformer
11
+ from langchain_huggingface import HuggingFaceEmbeddings
12
+
13
+
14
+ class MatryoshkaEmbeddings(Embeddings):
15
+ """Custom embeddings class that supports Matryoshka dimension truncation."""
16
+
17
+ def __init__(self, model_name: str, truncate_dim: int = None, **kwargs):
18
+ """
19
+ Initialize Matryoshka embeddings.
20
+
21
+ Args:
22
+ model_name: Name of the model
23
+ truncate_dim: Dimension to truncate to (for Matryoshka models)
24
+ **kwargs: Additional arguments (ignored for Matryoshka models)
25
+ """
26
+ self.model_name = model_name
27
+ self.truncate_dim = truncate_dim
28
+
29
+ if truncate_dim and "matryoshka" in model_name.lower():
30
+ # Use SentenceTransformer directly for Matryoshka models
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
33
+ print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
34
+ else:
35
+ # Use standard HuggingFaceEmbeddings
36
+ self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
37
+
38
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
39
+ """Embed documents."""
40
+ if self.truncate_dim and "matryoshka" in self.model_name.lower():
41
+ embeddings = self.model.encode(texts, normalize_embeddings=True)
42
+ return embeddings.tolist()
43
+ else:
44
+ return self.model.embed_documents(texts)
45
+
46
+ def embed_query(self, text: str) -> List[float]:
47
+ """Embed query."""
48
+ if self.truncate_dim and "matryoshka" in self.model_name.lower():
49
+ embedding = self.model.encode([text], normalize_embeddings=True)
50
+ return embedding[0].tolist()
51
+ else:
52
+ return self.model.embed_query(text)
53
+
54
+
55
+ class VectorStoreManager:
56
+ """Manages vector store operations and connections."""
57
+
58
+ def __init__(self, config: Dict[str, Any]):
59
+ """
60
+ Initialize vector store manager.
61
+
62
+ Args:
63
+ config: Configuration dictionary
64
+ """
65
+ self.config = config
66
+ self.embeddings = self._create_embeddings()
67
+ self.vectorstore = None
68
+
69
+ # Define metadata fields that need payload indexes for filtering
70
+ self.metadata_fields = [
71
+ ("metadata.year", "keyword"),
72
+ ("metadata.source", "keyword"),
73
+ ("metadata.filename", "keyword"),
74
+ # Add more metadata fields as needed
75
+ ]
76
+
77
+ def _create_embeddings(self) -> HuggingFaceEmbeddings:
78
+ """Create embeddings model from configuration."""
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+
81
+ model_name = self.config["retriever"]["model"]
82
+ normalize = self.config["retriever"]["normalize"]
83
+
84
+ model_kwargs = {"device": device}
85
+ encode_kwargs = {
86
+ "normalize_embeddings": normalize,
87
+ "batch_size": 100,
88
+ }
89
+
90
+ # For Matryoshka models, check if we need to truncate dimensions
91
+ if "matryoshka" in model_name.lower():
92
+ # Check if we have a specific dimension requirement
93
+ collection_name = self.config.get("qdrant", {}).get("collection_name", "")
94
+
95
+ if "modernbert-embed-base-akryl-matryoshka" in collection_name:
96
+ # This collection expects 768 dimensions
97
+ truncate_dim = 768
98
+ print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
99
+
100
+ # Use custom MatryoshkaEmbeddings
101
+ embeddings = MatryoshkaEmbeddings(
102
+ model_name=model_name,
103
+ truncate_dim=truncate_dim,
104
+ model_kwargs=model_kwargs,
105
+ encode_kwargs=encode_kwargs,
106
+ show_progress=True,
107
+ )
108
+ return embeddings
109
+
110
+ # Use standard HuggingFaceEmbeddings for non-Matryoshka models
111
+ embeddings = HuggingFaceEmbeddings(
112
+ model_name=model_name,
113
+ model_kwargs=model_kwargs,
114
+ encode_kwargs=encode_kwargs,
115
+ show_progress=True,
116
+ )
117
+
118
+ return embeddings
119
+
120
+ def ensure_metadata_indexes(self) -> None:
121
+ """
122
+ Create payload indexes for all required metadata fields.
123
+ This ensures filtering works properly, especially in Qdrant Cloud.
124
+ """
125
+ if not self.vectorstore:
126
+ return
127
+
128
+ qdrant_config = self.config["qdrant"]
129
+ collection_name = qdrant_config["collection_name"]
130
+
131
+ for field_name, field_type in self.metadata_fields:
132
+ try:
133
+ self.vectorstore.client.create_payload_index(
134
+ collection_name=collection_name,
135
+ field_name=field_name,
136
+ field_type=field_type
137
+ )
138
+ print(f"Created payload index for {field_name} ({field_type})")
139
+ except Exception as e:
140
+ # Index might already exist or other error - log but continue
141
+ print(f"Index creation for {field_name} ({field_type}): {str(e)}")
142
+
143
+ def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore:
144
+ """
145
+ Connect to existing Qdrant collection.
146
+
147
+ Args:
148
+ force_recreate: If True, recreate the collection if dimension mismatch occurs
149
+
150
+ Returns:
151
+ QdrantVectorStore instance
152
+ """
153
+ qdrant_config = self.config["qdrant"]
154
+
155
+ kwargs_qdrant = {
156
+ "url": qdrant_config["url"],
157
+ "collection_name": qdrant_config["collection_name"],
158
+ "prefer_grpc": qdrant_config.get("prefer_grpc", True),
159
+ "api_key": qdrant_config.get("api_key", None),
160
+ }
161
+
162
+ if force_recreate:
163
+ kwargs_qdrant["force_recreate"] = True
164
+
165
+ self.vectorstore = QdrantVectorStore.from_existing_collection(
166
+ embedding=self.embeddings,
167
+ **kwargs_qdrant
168
+ )
169
+
170
+ # Ensure payload indexes exist for metadata filtering
171
+ self.ensure_metadata_indexes()
172
+
173
+ return self.vectorstore
174
+
175
+ def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore:
176
+ """
177
+ Create new Qdrant collection from documents.
178
+
179
+ Args:
180
+ documents: List of Document objects
181
+
182
+ Returns:
183
+ QdrantVectorStore instance
184
+ """
185
+ qdrant_config = self.config["qdrant"]
186
+
187
+ kwargs_qdrant = {
188
+ "url": qdrant_config["url"],
189
+ "collection_name": qdrant_config["collection_name"],
190
+ "prefer_grpc": qdrant_config.get("prefer_grpc", True),
191
+ "api_key": qdrant_config.get("api_key", None),
192
+ }
193
+
194
+ self.vectorstore = QdrantVectorStore.from_documents(
195
+ documents=documents,
196
+ embedding=self.embeddings,
197
+ **kwargs_qdrant
198
+ )
199
+
200
+ # Ensure payload indexes exist for metadata filtering
201
+ self.ensure_metadata_indexes()
202
+
203
+ return self.vectorstore
204
+
205
+ def delete_collection(self) -> None:
206
+ """
207
+ Delete the current Qdrant collection.
208
+
209
+ Returns:
210
+ QdrantVectorStore instance
211
+ """
212
+ qdrant_config = self.config["qdrant"]
213
+ collection_name = qdrant_config.get("collection_name")
214
+
215
+ self.vectorstore.client.delete_collection(
216
+ collection_name=collection_name
217
+ )
218
+
219
+ return self.vectorstore
220
+
221
+ def get_vectorstore(self) -> Optional[QdrantVectorStore]:
222
+ """Get current vectorstore instance."""
223
+ return self.vectorstore
224
+
225
+
226
+ def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore:
227
+ """
228
+ Get local Qdrant vector store (legacy function for compatibility).
229
+
230
+ Args:
231
+ config: Configuration dictionary
232
+
233
+ Returns:
234
+ QdrantVectorStore instance
235
+ """
236
+ manager = VectorStoreManager(config)
237
+ return manager.connect_to_existing()
238
+
239
+
240
+ def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore:
241
+ """
242
+ Create new vector store from documents.
243
+
244
+ Args:
245
+ config: Configuration dictionary
246
+ documents: List of Document objects
247
+
248
+ Returns:
249
+ QdrantVectorStore instance
250
+ """
251
+ manager = VectorStoreManager(config)
252
+ return manager.create_from_documents(documents)
253
+
254
+
255
+ def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings:
256
+ """
257
+ Create embeddings model from configuration (legacy function).
258
+
259
+ Args:
260
+ config: Configuration dictionary
261
+
262
+ Returns:
263
+ HuggingFaceEmbeddings instance
264
+ """
265
+ manager = VectorStoreManager(config)
266
+ return manager.embeddings
utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import dataclasses
3
+ from uuid import UUID
4
+ from typing import Any
5
+ from datetime import datetime, date
6
+
7
+
8
+ import configparser
9
+ from torch import cuda
10
+ from qdrant_client.http import models as rest
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
13
+
14
+
15
+ def get_config(fp):
16
+ config = configparser.ConfigParser()
17
+ config.read_file(open(fp))
18
+ return config
19
+
20
+
21
+ def get_embeddings_model(config):
22
+ device = "cuda" if cuda.is_available() else "cpu"
23
+
24
+ # Define embedding model
25
+ model_name = config.get("retriever", "MODEL")
26
+ model_kwargs = {"device": device}
27
+ normalize_embeddings = bool(int(config.get("retriever", "NORMALIZE")))
28
+ encode_kwargs = {
29
+ "normalize_embeddings": normalize_embeddings,
30
+ "batch_size": 100,
31
+ }
32
+
33
+ embeddings = HuggingFaceEmbeddings(
34
+ show_progress=True,
35
+ model_name=model_name,
36
+ model_kwargs=model_kwargs,
37
+ encode_kwargs=encode_kwargs,
38
+ )
39
+
40
+ return embeddings
41
+
42
+ # Create a search filter for Qdrant
43
+ def create_filter(
44
+ reports: list = [], sources: str = None, subtype: str = None, year: str = None
45
+ ):
46
+ if len(reports) == 0:
47
+ print(f"defining filter for sources:{sources}, subtype:{subtype}")
48
+ filter = rest.Filter(
49
+ must=[
50
+ rest.FieldCondition(
51
+ key="metadata.source", match=rest.MatchValue(value=sources)
52
+ ),
53
+ rest.FieldCondition(
54
+ key="metadata.filename", match=rest.MatchAny(any=subtype)
55
+ ),
56
+ # rest.FieldCondition(
57
+ # key="metadata.year",
58
+ # match=rest.MatchAny(any=year)
59
+ ]
60
+ )
61
+ else:
62
+ print(f"defining filter for allreports:{reports}")
63
+ filter = rest.Filter(
64
+ must=[
65
+ rest.FieldCondition(
66
+ key="metadata.filename", match=rest.MatchAny(any=reports)
67
+ )
68
+ ]
69
+ )
70
+
71
+ return filter
72
+
73
+
74
+ def load_json(fp):
75
+ with open(fp, "r") as f:
76
+ docs = json.load(f)
77
+ return docs
78
+
79
+ def get_timestamp():
80
+ now = datetime.datetime.now()
81
+ timestamp = now.strftime("%Y%m%d%H%M%S")
82
+ return timestamp
83
+
84
+
85
+
86
+ # A custom class to help with recursive serialization.
87
+ # This approach avoids modifying the original object.
88
+ class _RecursiveSerializer(json.JSONEncoder):
89
+ """A custom JSONEncoder that handles complex types by converting them to dicts or strings."""
90
+ def default(self, obj):
91
+ # Prefer the pydantic method if it exists for the most robust serialization.
92
+ if hasattr(obj, 'model_dump'):
93
+ return obj.model_dump()
94
+
95
+ # Handle dataclasses
96
+ if dataclasses.is_dataclass(obj):
97
+ return dataclasses.asdict(obj)
98
+
99
+ # Handle other non-serializable but common types.
100
+ if isinstance(obj, (datetime, date, UUID)):
101
+ return str(obj)
102
+
103
+ # Fallback for general objects with a __dict__
104
+ if hasattr(obj, '__dict__'):
105
+ return obj.__dict__
106
+
107
+ # Default fallback to JSONEncoder's behavior
108
+ return super().default(obj)
109
+
110
+ def to_json_string(obj: Any, **kwargs) -> str:
111
+ """
112
+ Serializes a Python object into a JSON-formatted string.
113
+
114
+ This function is a comprehensive utility that can handle:
115
+ - Standard Python types (lists, dicts, strings, numbers, bools, None).
116
+ - Pydantic models (using `model_dump()`).
117
+ - Dataclasses (using `dataclasses.asdict()`).
118
+ - Standard library types not natively JSON-serializable (e.g., datetime, UUID).
119
+ - Custom classes with a `__dict__`.
120
+
121
+ Args:
122
+ obj (Any): The Python object to serialize.
123
+ **kwargs: Additional keyword arguments to pass to `json.dumps`.
124
+
125
+ Returns:
126
+ str: A JSON-formatted string.
127
+
128
+ Example:
129
+ >>> from datetime import datetime
130
+ >>> from pydantic import BaseModel
131
+ >>> from dataclasses import dataclass
132
+
133
+ >>> class Address(BaseModel):
134
+ ... street: str
135
+ ... city: str
136
+
137
+ >>> @dataclass
138
+ ... class Product:
139
+ ... id: int
140
+ ... name: str
141
+
142
+ >>> class Order(BaseModel):
143
+ ... user_address: Address
144
+ ... item: Product
145
+
146
+ >>> order_obj = Order(
147
+ ... user_address=Address(street="123 Main St", city="Example City"),
148
+ ... item=Product(id=1, name="Laptop")
149
+ ... )
150
+
151
+ >>> print(to_json_string(order_obj, indent=2))
152
+ {
153
+ "user_address": {
154
+ "street": "123 Main St",
155
+ "city": "Example City"
156
+ },
157
+ "item": {
158
+ "id": 1,
159
+ "name": "Laptop"
160
+ }
161
+ }
162
+ """
163
+ return json.dumps(obj, cls=_RecursiveSerializer, **kwargs)