Spaces:
Sleeping
Sleeping
Pilot (#2)
Browse files- add src (f5df98319255a7ba942909ae0e12791d2a5e78e4)
- update reqs (85f1ebc529a9b7f1f4a999dc8843b337bc89beb2)
- add single smart chatbot (aafcd0db8782557c9e599432d751e03b1f373c0f)
- add multi-agent system (caeff10c241774d0831c3aa69fd71a546afe5964)
- add utils (fab49c5c2eb911969d7d515898d3c8ce8a459178)
- create UI (ce77124199dfb6e08e9e7bcc82276dc965626947)
- adjust Dockerfile accordingly (87edf9841faa8a7bdc563ff4ae6f3e5c0306ea9a)
- Dockerfile +14 -5
- app.py +694 -0
- multi_agent_chatbot.py +1167 -0
- requirements.txt +9 -3
- smart_chatbot.py +1098 -0
- src/__init__.py +10 -0
- src/config/__init__.py +5 -0
- src/config/collections.json +22 -0
- src/config/loader.py +170 -0
- src/config/settings.yaml +92 -0
- src/llm/__init__.py +6 -0
- src/llm/adapters.py +409 -0
- src/llm/templates.py +232 -0
- src/loader.py +115 -0
- src/logging.py +193 -0
- src/pipeline.py +731 -0
- src/reporting/__init__.py +6 -0
- src/reporting/feedback_schema.py +196 -0
- src/reporting/metadata.py +216 -0
- src/reporting/service.py +144 -0
- src/reporting/snowflake_connector.py +305 -0
- src/retrieval/__init__.py +15 -0
- src/retrieval/colbert_cache.py +74 -0
- src/retrieval/context.py +881 -0
- src/retrieval/filter.py +975 -0
- src/retrieval/hybrid.py +479 -0
- src/vectorstore.py +266 -0
- utils.py +163 -0
Dockerfile
CHANGED
|
@@ -1,20 +1,29 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
| 5 |
RUN apt-get update && apt-get install -y \
|
| 6 |
build-essential \
|
| 7 |
curl \
|
| 8 |
git \
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
|
|
|
| 11 |
COPY requirements.txt ./
|
| 12 |
-
COPY src/ ./src/
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
EXPOSE 8501
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
# Install system dependencies
|
| 6 |
RUN apt-get update && apt-get install -y \
|
| 7 |
build-essential \
|
| 8 |
curl \
|
| 9 |
git \
|
| 10 |
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
|
| 12 |
+
# Copy requirements first (for better Docker layer caching)
|
| 13 |
COPY requirements.txt ./
|
|
|
|
| 14 |
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 17 |
|
| 18 |
+
# Copy all application files (excluding .dockerignore patterns)
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Expose Streamlit port (HF Spaces maps to 7860 automatically)
|
| 22 |
EXPOSE 8501
|
| 23 |
|
| 24 |
+
# Health check for Streamlit
|
| 25 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
| 26 |
+
CMD curl --fail http://localhost:8501/_stcore/health || exit 1
|
| 27 |
|
| 28 |
+
# Run Streamlit app
|
| 29 |
+
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0", "--server.headless", "true"]
|
app.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent Audit Report Chatbot UI
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import uuid
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import streamlit as st
|
| 16 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 17 |
+
|
| 18 |
+
from multi_agent_chatbot import get_multi_agent_chatbot
|
| 19 |
+
from smart_chatbot import get_chatbot as get_smart_chatbot
|
| 20 |
+
from src.reporting.feedback_schema import create_feedback_from_dict
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# Page config
|
| 27 |
+
st.set_page_config(
|
| 28 |
+
layout="wide",
|
| 29 |
+
page_icon="🤖",
|
| 30 |
+
initial_sidebar_state="expanded",
|
| 31 |
+
page_title="Intelligent Audit Report Chatbot"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Custom CSS
|
| 35 |
+
st.markdown("""
|
| 36 |
+
<style>
|
| 37 |
+
.main-header {
|
| 38 |
+
font-size: 2.5rem;
|
| 39 |
+
font-weight: bold;
|
| 40 |
+
color: #1f77b4;
|
| 41 |
+
text-align: center;
|
| 42 |
+
margin-bottom: 1rem;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.subtitle {
|
| 46 |
+
font-size: 1.2rem;
|
| 47 |
+
color: #666;
|
| 48 |
+
text-align: center;
|
| 49 |
+
margin-bottom: 2rem;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
.session-info {
|
| 53 |
+
background-color: #f0f2f6;
|
| 54 |
+
padding: 10px;
|
| 55 |
+
border-radius: 5px;
|
| 56 |
+
margin-bottom: 20px;
|
| 57 |
+
font-size: 0.9rem;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
.user-message {
|
| 61 |
+
background-color: #007bff;
|
| 62 |
+
color: white;
|
| 63 |
+
padding: 12px 16px;
|
| 64 |
+
border-radius: 18px 18px 4px 18px;
|
| 65 |
+
margin: 8px 0;
|
| 66 |
+
margin-left: 20%;
|
| 67 |
+
word-wrap: break-word;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.bot-message {
|
| 71 |
+
background-color: #f1f3f4;
|
| 72 |
+
color: #333;
|
| 73 |
+
padding: 12px 16px;
|
| 74 |
+
border-radius: 18px 18px 18px 4px;
|
| 75 |
+
margin: 8px 0;
|
| 76 |
+
margin-right: 20%;
|
| 77 |
+
word-wrap: break-word;
|
| 78 |
+
border: 1px solid #e0e0e0;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.filter-section {
|
| 82 |
+
margin-bottom: 20px;
|
| 83 |
+
padding: 15px;
|
| 84 |
+
background-color: #f8f9fa;
|
| 85 |
+
border-radius: 8px;
|
| 86 |
+
border: 1px solid #e9ecef;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.filter-title {
|
| 90 |
+
font-weight: bold;
|
| 91 |
+
margin-bottom: 10px;
|
| 92 |
+
color: #495057;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.feedback-section {
|
| 96 |
+
background-color: #f8f9fa;
|
| 97 |
+
padding: 20px;
|
| 98 |
+
border-radius: 10px;
|
| 99 |
+
margin-top: 30px;
|
| 100 |
+
border: 2px solid #dee2e6;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.retrieval-history {
|
| 104 |
+
background-color: #ffffff;
|
| 105 |
+
padding: 15px;
|
| 106 |
+
border-radius: 5px;
|
| 107 |
+
margin: 10px 0;
|
| 108 |
+
border-left: 4px solid #007bff;
|
| 109 |
+
}
|
| 110 |
+
</style>
|
| 111 |
+
""", unsafe_allow_html=True)
|
| 112 |
+
|
| 113 |
+
def get_system_type():
|
| 114 |
+
"""Get the current system type"""
|
| 115 |
+
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 116 |
+
if system == 'smart':
|
| 117 |
+
return "Smart Chatbot System"
|
| 118 |
+
else:
|
| 119 |
+
return "Multi-Agent System"
|
| 120 |
+
|
| 121 |
+
def get_chatbot():
|
| 122 |
+
"""Initialize and return the chatbot based on system type"""
|
| 123 |
+
# Check environment variable for system type
|
| 124 |
+
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
|
| 125 |
+
if system == 'smart':
|
| 126 |
+
return get_smart_chatbot()
|
| 127 |
+
else:
|
| 128 |
+
return get_multi_agent_chatbot()
|
| 129 |
+
|
| 130 |
+
def serialize_messages(messages):
|
| 131 |
+
"""Serialize LangChain messages to dictionaries"""
|
| 132 |
+
serialized = []
|
| 133 |
+
for msg in messages:
|
| 134 |
+
if hasattr(msg, 'content'):
|
| 135 |
+
serialized.append({
|
| 136 |
+
"type": type(msg).__name__,
|
| 137 |
+
"content": str(msg.content)
|
| 138 |
+
})
|
| 139 |
+
return serialized
|
| 140 |
+
|
| 141 |
+
def serialize_documents(sources):
|
| 142 |
+
"""Serialize document objects to dictionaries with deduplication"""
|
| 143 |
+
serialized = []
|
| 144 |
+
seen_content = set()
|
| 145 |
+
|
| 146 |
+
for doc in sources:
|
| 147 |
+
content = getattr(doc, 'page_content', getattr(doc, 'content', ''))
|
| 148 |
+
|
| 149 |
+
# Skip if we've seen this exact content before
|
| 150 |
+
if content in seen_content:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
seen_content.add(content)
|
| 154 |
+
|
| 155 |
+
doc_dict = {
|
| 156 |
+
"content": content,
|
| 157 |
+
"metadata": getattr(doc, 'metadata', {}),
|
| 158 |
+
"score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)),
|
| 159 |
+
"id": getattr(doc, 'metadata', {}).get('_id', 'unknown'),
|
| 160 |
+
"source": getattr(doc, 'metadata', {}).get('source', 'unknown'),
|
| 161 |
+
"year": getattr(doc, 'metadata', {}).get('year', 'unknown'),
|
| 162 |
+
"district": getattr(doc, 'metadata', {}).get('district', 'unknown'),
|
| 163 |
+
"page": getattr(doc, 'metadata', {}).get('page', 'unknown'),
|
| 164 |
+
"chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'),
|
| 165 |
+
"page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'),
|
| 166 |
+
"original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0),
|
| 167 |
+
"reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None)
|
| 168 |
+
}
|
| 169 |
+
serialized.append(doc_dict)
|
| 170 |
+
|
| 171 |
+
return serialized
|
| 172 |
+
|
| 173 |
+
@st.cache_data
|
| 174 |
+
def load_filter_options():
|
| 175 |
+
try:
|
| 176 |
+
with open("filter_options.json", "r") as f:
|
| 177 |
+
return json.load(f)
|
| 178 |
+
except FileNotFoundError:
|
| 179 |
+
st.info([x for x in os.listdir() if x.endswith('.json')])
|
| 180 |
+
st.error("filter_options.json not found. Please run the metadata analysis script.")
|
| 181 |
+
return {"sources": [], "years": [], "districts": [], 'filenames': []}
|
| 182 |
+
|
| 183 |
+
def main():
|
| 184 |
+
# Initialize session state
|
| 185 |
+
if 'messages' not in st.session_state:
|
| 186 |
+
st.session_state.messages = []
|
| 187 |
+
if 'conversation_id' not in st.session_state:
|
| 188 |
+
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
|
| 189 |
+
if 'session_start_time' not in st.session_state:
|
| 190 |
+
st.session_state.session_start_time = time.time()
|
| 191 |
+
if 'active_filters' not in st.session_state:
|
| 192 |
+
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
|
| 193 |
+
# Track RAG retrieval history for feedback
|
| 194 |
+
if 'rag_retrieval_history' not in st.session_state:
|
| 195 |
+
st.session_state.rag_retrieval_history = []
|
| 196 |
+
# Initialize chatbot only once per app session (cached)
|
| 197 |
+
if 'chatbot' not in st.session_state:
|
| 198 |
+
with st.spinner("🔄 Loading AI models and connecting to database..."):
|
| 199 |
+
st.session_state.chatbot = get_chatbot()
|
| 200 |
+
st.success("✅ AI system ready!")
|
| 201 |
+
|
| 202 |
+
# Reset conversation history if needed (but keep chatbot cached)
|
| 203 |
+
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
|
| 204 |
+
st.session_state.messages = []
|
| 205 |
+
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
|
| 206 |
+
st.session_state.session_start_time = time.time()
|
| 207 |
+
st.session_state.rag_retrieval_history = []
|
| 208 |
+
st.session_state.feedback_submitted = False
|
| 209 |
+
st.session_state.reset_conversation = False
|
| 210 |
+
st.rerun()
|
| 211 |
+
|
| 212 |
+
# Header with system indicator
|
| 213 |
+
col1, col2 = st.columns([3, 1])
|
| 214 |
+
with col1:
|
| 215 |
+
st.markdown('<h1 class="main-header">🤖 Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
|
| 216 |
+
with col2:
|
| 217 |
+
system_type = get_system_type()
|
| 218 |
+
if "Multi-Agent" in system_type:
|
| 219 |
+
st.success(f"🔧 {system_type}")
|
| 220 |
+
else:
|
| 221 |
+
st.info(f"🔧 {system_type}")
|
| 222 |
+
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
|
| 223 |
+
|
| 224 |
+
# Session info
|
| 225 |
+
duration = int(time.time() - st.session_state.session_start_time)
|
| 226 |
+
duration_str = f"{duration // 60}m {duration % 60}s"
|
| 227 |
+
st.markdown(f'''
|
| 228 |
+
<div class="session-info">
|
| 229 |
+
<strong>Session Info:</strong> Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
|
| 230 |
+
</div>
|
| 231 |
+
''', unsafe_allow_html=True)
|
| 232 |
+
|
| 233 |
+
# Load filter options
|
| 234 |
+
filter_options = load_filter_options()
|
| 235 |
+
|
| 236 |
+
# Sidebar for filters
|
| 237 |
+
with st.sidebar:
|
| 238 |
+
st.markdown("### 🔍 Search Filters")
|
| 239 |
+
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
|
| 240 |
+
|
| 241 |
+
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 242 |
+
st.markdown('<div class="filter-title">📄 Specific Reports (Filename Filter)</div>', unsafe_allow_html=True)
|
| 243 |
+
st.markdown('<p style="font-size: 0.85em; color: #666;">⚠️ Selecting specific reports will ignore all other filters</p>', unsafe_allow_html=True)
|
| 244 |
+
selected_filenames = st.multiselect(
|
| 245 |
+
"Select specific reports:",
|
| 246 |
+
options=filter_options.get('filenames', []),
|
| 247 |
+
default=st.session_state.active_filters.get('filenames', []),
|
| 248 |
+
key="filenames_filter",
|
| 249 |
+
help="Choose specific reports to search. When enabled, all other filters are ignored."
|
| 250 |
+
)
|
| 251 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 252 |
+
|
| 253 |
+
# Determine if filename filter is active
|
| 254 |
+
filename_mode = len(selected_filenames) > 0
|
| 255 |
+
# Sources filter
|
| 256 |
+
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 257 |
+
st.markdown('<div class="filter-title">📊 Sources</div>', unsafe_allow_html=True)
|
| 258 |
+
selected_sources = st.multiselect(
|
| 259 |
+
"Select sources:",
|
| 260 |
+
options=filter_options['sources'],
|
| 261 |
+
default=st.session_state.active_filters['sources'],
|
| 262 |
+
disabled = filename_mode,
|
| 263 |
+
key="sources_filter",
|
| 264 |
+
help="Choose which types of reports to search"
|
| 265 |
+
)
|
| 266 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 267 |
+
|
| 268 |
+
# Years filter
|
| 269 |
+
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 270 |
+
st.markdown('<div class="filter-title">📅 Years</div>', unsafe_allow_html=True)
|
| 271 |
+
selected_years = st.multiselect(
|
| 272 |
+
"Select years:",
|
| 273 |
+
options=filter_options['years'],
|
| 274 |
+
default=st.session_state.active_filters['years'],
|
| 275 |
+
disabled = filename_mode,
|
| 276 |
+
key="years_filter",
|
| 277 |
+
help="Choose which years to search"
|
| 278 |
+
)
|
| 279 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 280 |
+
|
| 281 |
+
# Districts filter
|
| 282 |
+
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
|
| 283 |
+
st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True)
|
| 284 |
+
selected_districts = st.multiselect(
|
| 285 |
+
"Select districts:",
|
| 286 |
+
options=filter_options['districts'],
|
| 287 |
+
default=st.session_state.active_filters['districts'],
|
| 288 |
+
disabled = filename_mode,
|
| 289 |
+
key="districts_filter",
|
| 290 |
+
help="Choose which districts to search"
|
| 291 |
+
)
|
| 292 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 293 |
+
|
| 294 |
+
# Update active filters
|
| 295 |
+
st.session_state.active_filters = {
|
| 296 |
+
'sources': selected_sources if not filename_mode else [],
|
| 297 |
+
'years': selected_years if not filename_mode else [],
|
| 298 |
+
'districts': selected_districts if not filename_mode else [],
|
| 299 |
+
'filenames': selected_filenames
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
# Clear filters button
|
| 303 |
+
if st.button("🗑️ Clear All Filters", key="clear_filters_button"):
|
| 304 |
+
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
|
| 305 |
+
st.rerun()
|
| 306 |
+
|
| 307 |
+
# Main content area with tabs
|
| 308 |
+
tab1, tab2 = st.tabs(["💬 Chat", "📄 Retrieved Documents"])
|
| 309 |
+
|
| 310 |
+
with tab1:
|
| 311 |
+
# Chat container
|
| 312 |
+
chat_container = st.container()
|
| 313 |
+
|
| 314 |
+
with chat_container:
|
| 315 |
+
# Display conversation history
|
| 316 |
+
for message in st.session_state.messages:
|
| 317 |
+
if isinstance(message, HumanMessage):
|
| 318 |
+
st.markdown(f'<div class="user-message">{message.content}</div>', unsafe_allow_html=True)
|
| 319 |
+
elif isinstance(message, AIMessage):
|
| 320 |
+
st.markdown(f'<div class="bot-message">{message.content}</div>', unsafe_allow_html=True)
|
| 321 |
+
|
| 322 |
+
# Input area
|
| 323 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 324 |
+
|
| 325 |
+
# Create two columns for input and button
|
| 326 |
+
col1, col2 = st.columns([4, 1])
|
| 327 |
+
|
| 328 |
+
with col1:
|
| 329 |
+
# Use a counter to force input clearing
|
| 330 |
+
if 'input_counter' not in st.session_state:
|
| 331 |
+
st.session_state.input_counter = 0
|
| 332 |
+
|
| 333 |
+
user_input = st.text_input(
|
| 334 |
+
"Type your message here...",
|
| 335 |
+
placeholder="Ask about budget allocations, expenditures, or audit findings...",
|
| 336 |
+
key=f"user_input_{st.session_state.input_counter}",
|
| 337 |
+
label_visibility="collapsed"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
with col2:
|
| 341 |
+
send_button = st.button("Send", key="send_button", use_container_width=True)
|
| 342 |
+
|
| 343 |
+
# Clear chat button
|
| 344 |
+
if st.button("🗑️ Clear Chat", key="clear_chat_button"):
|
| 345 |
+
st.session_state.reset_conversation = True
|
| 346 |
+
# Clear all conversation files
|
| 347 |
+
import os
|
| 348 |
+
conversations_dir = "conversations"
|
| 349 |
+
if os.path.exists(conversations_dir):
|
| 350 |
+
for file in os.listdir(conversations_dir):
|
| 351 |
+
if file.endswith('.json'):
|
| 352 |
+
os.remove(os.path.join(conversations_dir, file))
|
| 353 |
+
st.rerun()
|
| 354 |
+
|
| 355 |
+
# Handle user input
|
| 356 |
+
if send_button and user_input:
|
| 357 |
+
# Construct filter context string
|
| 358 |
+
filter_context_str = ""
|
| 359 |
+
if selected_filenames:
|
| 360 |
+
filter_context_str += "FILTER CONTEXT:\n"
|
| 361 |
+
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
|
| 362 |
+
filter_context_str += "USER QUERY:\n"
|
| 363 |
+
elif selected_sources or selected_years or selected_districts:
|
| 364 |
+
filter_context_str += "FILTER CONTEXT:\n"
|
| 365 |
+
if selected_sources:
|
| 366 |
+
filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
|
| 367 |
+
if selected_years:
|
| 368 |
+
filter_context_str += f"Years: {', '.join(selected_years)}\n"
|
| 369 |
+
if selected_districts:
|
| 370 |
+
filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
|
| 371 |
+
filter_context_str += "USER QUERY:\n"
|
| 372 |
+
|
| 373 |
+
full_query = filter_context_str + user_input
|
| 374 |
+
|
| 375 |
+
# Add user message to history
|
| 376 |
+
st.session_state.messages.append(HumanMessage(content=user_input))
|
| 377 |
+
|
| 378 |
+
# Get chatbot response
|
| 379 |
+
with st.spinner("🤔 Thinking..."):
|
| 380 |
+
try:
|
| 381 |
+
# Pass the full query with filter context
|
| 382 |
+
chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id)
|
| 383 |
+
|
| 384 |
+
# Handle both old format (string) and new format (dict)
|
| 385 |
+
if isinstance(chat_result, dict):
|
| 386 |
+
response = chat_result['response']
|
| 387 |
+
rag_result = chat_result.get('rag_result')
|
| 388 |
+
st.session_state.last_rag_result = rag_result
|
| 389 |
+
|
| 390 |
+
# Track RAG retrieval for feedback
|
| 391 |
+
if rag_result:
|
| 392 |
+
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
|
| 393 |
+
|
| 394 |
+
# Get the actual RAG query
|
| 395 |
+
actual_rag_query = chat_result.get('actual_rag_query', '')
|
| 396 |
+
if actual_rag_query:
|
| 397 |
+
# Format it like the log message
|
| 398 |
+
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
| 399 |
+
formatted_query = f"{timestamp} - INFO - 🔍 ACTUAL RAG QUERY: '{actual_rag_query}'"
|
| 400 |
+
else:
|
| 401 |
+
formatted_query = "No RAG query available"
|
| 402 |
+
|
| 403 |
+
retrieval_entry = {
|
| 404 |
+
"conversation_up_to": serialize_messages(st.session_state.messages),
|
| 405 |
+
"rag_query_expansion": formatted_query,
|
| 406 |
+
"docs_retrieved": serialize_documents(sources)
|
| 407 |
+
}
|
| 408 |
+
st.session_state.rag_retrieval_history.append(retrieval_entry)
|
| 409 |
+
else:
|
| 410 |
+
response = chat_result
|
| 411 |
+
st.session_state.last_rag_result = None
|
| 412 |
+
|
| 413 |
+
# Add bot response to history
|
| 414 |
+
st.session_state.messages.append(AIMessage(content=response))
|
| 415 |
+
|
| 416 |
+
except Exception as e:
|
| 417 |
+
error_msg = f"Sorry, I encountered an error: {str(e)}"
|
| 418 |
+
st.session_state.messages.append(AIMessage(content=error_msg))
|
| 419 |
+
|
| 420 |
+
# Clear input and rerun
|
| 421 |
+
st.session_state.input_counter += 1 # This will clear the input
|
| 422 |
+
st.rerun()
|
| 423 |
+
|
| 424 |
+
with tab2:
|
| 425 |
+
# Document retrieval panel
|
| 426 |
+
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
|
| 427 |
+
rag_result = st.session_state.last_rag_result
|
| 428 |
+
|
| 429 |
+
# Handle both PipelineResult object and dictionary formats
|
| 430 |
+
sources = None
|
| 431 |
+
if hasattr(rag_result, 'sources'):
|
| 432 |
+
# PipelineResult object format
|
| 433 |
+
sources = rag_result.sources
|
| 434 |
+
elif isinstance(rag_result, dict) and 'sources' in rag_result:
|
| 435 |
+
# Dictionary format from multi-agent system
|
| 436 |
+
sources = rag_result['sources']
|
| 437 |
+
|
| 438 |
+
if sources and len(sources) > 0:
|
| 439 |
+
# Count unique filenames
|
| 440 |
+
unique_filenames = set()
|
| 441 |
+
for doc in sources:
|
| 442 |
+
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
|
| 443 |
+
unique_filenames.add(filename)
|
| 444 |
+
|
| 445 |
+
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 10):**")
|
| 446 |
+
if len(unique_filenames) < len(sources):
|
| 447 |
+
st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
|
| 448 |
+
|
| 449 |
+
for i, doc in enumerate(sources[:10]): # Show top 10
|
| 450 |
+
# Get relevance score and ID if available
|
| 451 |
+
metadata = getattr(doc, 'metadata', {})
|
| 452 |
+
score = metadata.get('reranked_score', metadata.get('original_score', None))
|
| 453 |
+
chunk_id = metadata.get('_id', 'Unknown')
|
| 454 |
+
score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
|
| 455 |
+
|
| 456 |
+
with st.expander(f"📄 Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
|
| 457 |
+
# Display document metadata with emojis
|
| 458 |
+
metadata = getattr(doc, 'metadata', {})
|
| 459 |
+
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
|
| 460 |
+
|
| 461 |
+
with col1:
|
| 462 |
+
st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}")
|
| 463 |
+
with col2:
|
| 464 |
+
st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}")
|
| 465 |
+
with col3:
|
| 466 |
+
st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}")
|
| 467 |
+
with col4:
|
| 468 |
+
# Display page number and chunk ID
|
| 469 |
+
page = metadata.get('page_label', metadata.get('page', 'Unknown'))
|
| 470 |
+
chunk_id = metadata.get('_id', 'Unknown')
|
| 471 |
+
st.write(f"📖 **Page:** {page}")
|
| 472 |
+
st.write(f"🆔 **ID:** {chunk_id}")
|
| 473 |
+
|
| 474 |
+
# Display full content (no truncation)
|
| 475 |
+
content = getattr(doc, 'page_content', 'No content available')
|
| 476 |
+
st.write(f"**Full Content:**")
|
| 477 |
+
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
|
| 478 |
+
else:
|
| 479 |
+
st.info("No documents were retrieved for the last query.")
|
| 480 |
+
else:
|
| 481 |
+
st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.")
|
| 482 |
+
|
| 483 |
+
# Feedback Dashboard Section
|
| 484 |
+
st.markdown("---")
|
| 485 |
+
st.markdown("### 💬 Feedback Dashboard")
|
| 486 |
+
|
| 487 |
+
# Check if there's any conversation to provide feedback on
|
| 488 |
+
has_conversation = len(st.session_state.messages) > 0
|
| 489 |
+
has_retrievals = len(st.session_state.rag_retrieval_history) > 0
|
| 490 |
+
|
| 491 |
+
if not has_conversation:
|
| 492 |
+
st.info("💡 Start a conversation to provide feedback!")
|
| 493 |
+
st.markdown("The feedback dashboard will be enabled once you begin chatting.")
|
| 494 |
+
else:
|
| 495 |
+
st.markdown("Help us improve by providing feedback on this conversation.")
|
| 496 |
+
|
| 497 |
+
# Initialize feedback state if not exists
|
| 498 |
+
if 'feedback_submitted' not in st.session_state:
|
| 499 |
+
st.session_state.feedback_submitted = False
|
| 500 |
+
|
| 501 |
+
# Feedback form
|
| 502 |
+
with st.form("feedback_form", clear_on_submit=False):
|
| 503 |
+
col1, col2 = st.columns([1, 1])
|
| 504 |
+
|
| 505 |
+
with col1:
|
| 506 |
+
feedback_score = st.slider(
|
| 507 |
+
"Rate this conversation (1-5)",
|
| 508 |
+
min_value=1,
|
| 509 |
+
max_value=5,
|
| 510 |
+
help="How satisfied are you with the conversation?"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
with col2:
|
| 514 |
+
is_feedback_about_last_retrieval = st.checkbox(
|
| 515 |
+
"Feedback about last retrieval only",
|
| 516 |
+
value=True,
|
| 517 |
+
help="If checked, feedback applies to the most recent document retrieval"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
open_ended_feedback = st.text_area(
|
| 521 |
+
"Your feedback (optional)",
|
| 522 |
+
placeholder="Tell us what went well or what could be improved...",
|
| 523 |
+
height=100
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# Disable submit if no score selected
|
| 527 |
+
submit_disabled = feedback_score is None
|
| 528 |
+
|
| 529 |
+
submitted = st.form_submit_button(
|
| 530 |
+
"📤 Submit Feedback",
|
| 531 |
+
use_container_width=True,
|
| 532 |
+
disabled=submit_disabled
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if submitted and not st.session_state.feedback_submitted:
|
| 536 |
+
# Log the feedback data being submitted
|
| 537 |
+
print("=" * 80)
|
| 538 |
+
print("🔄 FEEDBACK SUBMISSION: Starting...")
|
| 539 |
+
print("=" * 80)
|
| 540 |
+
st.write("🔍 **Debug: Feedback Data Being Submitted:**")
|
| 541 |
+
|
| 542 |
+
# Create feedback data dictionary
|
| 543 |
+
feedback_dict = {
|
| 544 |
+
"open_ended_feedback": open_ended_feedback,
|
| 545 |
+
"score": feedback_score,
|
| 546 |
+
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
|
| 547 |
+
"retrieved_data": st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
|
| 548 |
+
"conversation_id": st.session_state.conversation_id,
|
| 549 |
+
"timestamp": time.time(),
|
| 550 |
+
"message_count": len(st.session_state.messages),
|
| 551 |
+
"has_retrievals": has_retrievals,
|
| 552 |
+
"retrieval_count": len(st.session_state.rag_retrieval_history)
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
print(f"📝 FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
|
| 556 |
+
|
| 557 |
+
# Create UserFeedback dataclass instance
|
| 558 |
+
feedback_obj = None # Initialize outside try block
|
| 559 |
+
try:
|
| 560 |
+
feedback_obj = create_feedback_from_dict(feedback_dict)
|
| 561 |
+
print(f"✅ FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
|
| 562 |
+
st.write(f"✅ **Feedback Object Created**")
|
| 563 |
+
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
|
| 564 |
+
st.write(f"- Score: {feedback_obj.score}/5")
|
| 565 |
+
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
|
| 566 |
+
|
| 567 |
+
# Convert back to dict for JSON serialization
|
| 568 |
+
feedback_data = feedback_obj.to_dict()
|
| 569 |
+
except Exception as e:
|
| 570 |
+
print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
|
| 571 |
+
st.error(f"Failed to create feedback object: {e}")
|
| 572 |
+
feedback_data = feedback_dict
|
| 573 |
+
|
| 574 |
+
# Display the data being submitted
|
| 575 |
+
st.json(feedback_data)
|
| 576 |
+
|
| 577 |
+
# Save feedback to file
|
| 578 |
+
feedback_dir = Path("feedback")
|
| 579 |
+
feedback_dir.mkdir(exist_ok=True)
|
| 580 |
+
|
| 581 |
+
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
|
| 582 |
+
|
| 583 |
+
try:
|
| 584 |
+
# Save to local file
|
| 585 |
+
print(f"💾 FEEDBACK SAVE: Saving to local file: {feedback_file}")
|
| 586 |
+
with open(feedback_file, 'w') as f:
|
| 587 |
+
json.dump(feedback_data, f, indent=2, default=str)
|
| 588 |
+
|
| 589 |
+
print(f"✅ FEEDBACK SAVE: Local file saved successfully")
|
| 590 |
+
st.success("✅ Thank you for your feedback! It has been saved locally.")
|
| 591 |
+
st.balloons()
|
| 592 |
+
|
| 593 |
+
# Save to Snowflake if enabled and credentials available
|
| 594 |
+
logger.info("🔄 FEEDBACK SAVE: Starting Snowflake save process...")
|
| 595 |
+
logger.info(f"📊 FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
|
| 596 |
+
|
| 597 |
+
try:
|
| 598 |
+
import os
|
| 599 |
+
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
|
| 600 |
+
logger.info(f"🔍 SNOWFLAKE CHECK: enabled={snowflake_enabled}")
|
| 601 |
+
|
| 602 |
+
if snowflake_enabled:
|
| 603 |
+
if feedback_obj:
|
| 604 |
+
try:
|
| 605 |
+
from auditqa.reporting.snowflake_connector import save_to_snowflake
|
| 606 |
+
logger.info("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
|
| 607 |
+
print("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
|
| 608 |
+
|
| 609 |
+
if save_to_snowflake(feedback_obj):
|
| 610 |
+
logger.info("✅ SNOWFLAKE UI: Successfully saved to Snowflake")
|
| 611 |
+
print("✅ SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
|
| 612 |
+
st.success("✅ Feedback also saved to Snowflake!")
|
| 613 |
+
else:
|
| 614 |
+
logger.warning("⚠️ SNOWFLAKE UI: Save failed")
|
| 615 |
+
print("⚠️ SNOWFLAKE UI: Save failed") # Also print to terminal
|
| 616 |
+
st.warning("⚠️ Snowflake save failed, but local save succeeded")
|
| 617 |
+
except Exception as e:
|
| 618 |
+
logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
|
| 619 |
+
print(f"❌ SNOWFLAKE UI ERROR: {e}") # Also print to terminal
|
| 620 |
+
import traceback
|
| 621 |
+
traceback.print_exc() # Print full traceback to terminal
|
| 622 |
+
st.warning(f"⚠️ Could not save to Snowflake: {e}")
|
| 623 |
+
else:
|
| 624 |
+
logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
|
| 625 |
+
print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
|
| 626 |
+
st.warning("⚠️ Skipping Snowflake save (feedback object not created)")
|
| 627 |
+
else:
|
| 628 |
+
logger.info("💡 SNOWFLAKE UI: Integration disabled")
|
| 629 |
+
print("💡 SNOWFLAKE UI: Integration disabled") # Also print to terminal
|
| 630 |
+
st.info("💡 Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
|
| 631 |
+
except NameError as e:
|
| 632 |
+
import traceback
|
| 633 |
+
traceback.print_exc()
|
| 634 |
+
logger.error(f"❌ NameError in Snowflake save: {e}")
|
| 635 |
+
print(f"❌ NameError in Snowflake save: {e}") # Also print to terminal
|
| 636 |
+
st.warning(f"⚠️ Snowflake save error: {e}")
|
| 637 |
+
except Exception as e:
|
| 638 |
+
logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
|
| 639 |
+
print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") # Also print to terminal
|
| 640 |
+
st.warning(f"⚠️ Snowflake save error: {e}")
|
| 641 |
+
|
| 642 |
+
# Mark feedback as submitted to prevent resubmission
|
| 643 |
+
st.session_state.feedback_submitted = True
|
| 644 |
+
|
| 645 |
+
print("=" * 80)
|
| 646 |
+
print(f"✅ FEEDBACK SUBMISSION: Completed successfully")
|
| 647 |
+
print("=" * 80)
|
| 648 |
+
|
| 649 |
+
# Log file location
|
| 650 |
+
st.info(f"📁 Feedback saved to: {feedback_file}")
|
| 651 |
+
|
| 652 |
+
except Exception as e:
|
| 653 |
+
print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
|
| 654 |
+
print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
|
| 655 |
+
import traceback
|
| 656 |
+
traceback.print_exc()
|
| 657 |
+
st.error(f"❌ Error saving feedback: {e}")
|
| 658 |
+
st.write(f"Debug error: {str(e)}")
|
| 659 |
+
|
| 660 |
+
elif st.session_state.feedback_submitted:
|
| 661 |
+
st.success("✅ Feedback already submitted for this conversation!")
|
| 662 |
+
if st.button("🔄 Submit New Feedback", key="new_feedback_button"):
|
| 663 |
+
st.session_state.feedback_submitted = False
|
| 664 |
+
st.rerun()
|
| 665 |
+
|
| 666 |
+
# Display retrieval history stats
|
| 667 |
+
if st.session_state.rag_retrieval_history:
|
| 668 |
+
st.markdown("---")
|
| 669 |
+
st.markdown("#### 📊 Retrieval History")
|
| 670 |
+
|
| 671 |
+
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
|
| 672 |
+
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
|
| 673 |
+
st.markdown(f"**Retrieval #{idx}**")
|
| 674 |
+
|
| 675 |
+
# Display the actual RAG query
|
| 676 |
+
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
|
| 677 |
+
st.code(rag_query_expansion, language="text")
|
| 678 |
+
|
| 679 |
+
# Display summary stats
|
| 680 |
+
st.json({
|
| 681 |
+
"conversation_length": len(entry.get("conversation_up_to", [])),
|
| 682 |
+
"documents_retrieved": len(entry.get("docs_retrieved", []))
|
| 683 |
+
})
|
| 684 |
+
st.markdown("---")
|
| 685 |
+
|
| 686 |
+
# Auto-scroll to bottom
|
| 687 |
+
st.markdown("""
|
| 688 |
+
<script>
|
| 689 |
+
window.scrollTo(0, document.body.scrollHeight);
|
| 690 |
+
</script>
|
| 691 |
+
""", unsafe_allow_html=True)
|
| 692 |
+
|
| 693 |
+
if __name__ == "__main__":
|
| 694 |
+
main()
|
multi_agent_chatbot.py
ADDED
|
@@ -0,0 +1,1167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Agent RAG Chatbot using LangGraph
|
| 3 |
+
|
| 4 |
+
This system implements a 3-agent architecture:
|
| 5 |
+
1. Main Agent: Handles conversation flow, follow-ups, and determines when to call RAG
|
| 6 |
+
2. RAG Agent: Rewrites queries and applies filters for document retrieval
|
| 7 |
+
3. Response Agent: Generates final answers from retrieved documents
|
| 8 |
+
|
| 9 |
+
Each agent has specialized prompts and responsibilities.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, List, Any, Optional, TypedDict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
import re
|
| 23 |
+
from langchain_core.tools import tool
|
| 24 |
+
from langgraph.graph import StateGraph, END
|
| 25 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 26 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
from src.pipeline import PipelineManager
|
| 30 |
+
from src.config.loader import load_config
|
| 31 |
+
from src.llm.adapters import get_llm_client
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class QueryContext:
|
| 40 |
+
"""Context extracted from conversation"""
|
| 41 |
+
has_district: bool = False
|
| 42 |
+
has_source: bool = False
|
| 43 |
+
has_year: bool = False
|
| 44 |
+
extracted_district: Optional[str] = None
|
| 45 |
+
extracted_source: Optional[str] = None
|
| 46 |
+
extracted_year: Optional[str] = None
|
| 47 |
+
ui_filters: Dict[str, List[str]] = None
|
| 48 |
+
confidence_score: float = 0.0
|
| 49 |
+
needs_follow_up: bool = False
|
| 50 |
+
follow_up_question: Optional[str] = None
|
| 51 |
+
|
| 52 |
+
class MultiAgentState(TypedDict):
|
| 53 |
+
"""State for the multi-agent conversation flow"""
|
| 54 |
+
conversation_id: str
|
| 55 |
+
messages: List[Any]
|
| 56 |
+
current_query: str
|
| 57 |
+
query_context: Optional[QueryContext]
|
| 58 |
+
rag_query: Optional[str]
|
| 59 |
+
rag_filters: Optional[Dict[str, Any]]
|
| 60 |
+
retrieved_documents: Optional[List[Any]]
|
| 61 |
+
final_response: Optional[str]
|
| 62 |
+
agent_logs: List[str]
|
| 63 |
+
conversation_context: Dict[str, Any]
|
| 64 |
+
session_start_time: float
|
| 65 |
+
last_ai_message_time: float
|
| 66 |
+
|
| 67 |
+
class MultiAgentRAGChatbot:
|
| 68 |
+
"""Multi-agent RAG chatbot with specialized agents"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, config_path: str = "auditqa/config/settings.yaml"):
|
| 71 |
+
"""Initialize the multi-agent chatbot"""
|
| 72 |
+
self.config = load_config(config_path)
|
| 73 |
+
|
| 74 |
+
# Get LLM provider from config
|
| 75 |
+
reader_config = self.config.get("reader", {})
|
| 76 |
+
default_type = reader_config.get("default_type", "INF_PROVIDERS")
|
| 77 |
+
provider_name = default_type.lower()
|
| 78 |
+
|
| 79 |
+
self.llm_adapter = get_llm_client(provider_name, self.config)
|
| 80 |
+
|
| 81 |
+
# Create a simple wrapper for LangChain compatibility
|
| 82 |
+
class LLMWrapper:
|
| 83 |
+
def __init__(self, adapter):
|
| 84 |
+
self.adapter = adapter
|
| 85 |
+
|
| 86 |
+
def invoke(self, messages):
|
| 87 |
+
# Convert LangChain messages to the format expected by the adapter
|
| 88 |
+
if isinstance(messages, list):
|
| 89 |
+
formatted_messages = []
|
| 90 |
+
for msg in messages:
|
| 91 |
+
if hasattr(msg, 'content'):
|
| 92 |
+
role = "user" if msg.__class__.__name__ == "HumanMessage" else "assistant"
|
| 93 |
+
formatted_messages.append({"role": role, "content": msg.content})
|
| 94 |
+
else:
|
| 95 |
+
formatted_messages.append({"role": "user", "content": str(msg)})
|
| 96 |
+
else:
|
| 97 |
+
formatted_messages = [{"role": "user", "content": str(messages)}]
|
| 98 |
+
|
| 99 |
+
# Use the adapter to get response
|
| 100 |
+
response = self.adapter.generate(formatted_messages)
|
| 101 |
+
|
| 102 |
+
# Return a mock response object
|
| 103 |
+
class MockResponse:
|
| 104 |
+
def __init__(self, content):
|
| 105 |
+
self.content = content
|
| 106 |
+
|
| 107 |
+
return MockResponse(response.content)
|
| 108 |
+
|
| 109 |
+
self.llm = LLMWrapper(self.llm_adapter)
|
| 110 |
+
|
| 111 |
+
# Initialize pipeline manager early to load models
|
| 112 |
+
logger.info("🔄 Initializing pipeline manager and loading models...")
|
| 113 |
+
self.pipeline_manager = PipelineManager(self.config)
|
| 114 |
+
logger.info("✅ Pipeline manager initialized and models loaded")
|
| 115 |
+
|
| 116 |
+
# Connect to vector store
|
| 117 |
+
logger.info("🔄 Connecting to vector store...")
|
| 118 |
+
if not self.pipeline_manager.connect_vectorstore():
|
| 119 |
+
logger.error("❌ Failed to connect to vector store")
|
| 120 |
+
raise RuntimeError("Vector store connection failed")
|
| 121 |
+
logger.info("✅ Vector store connected successfully")
|
| 122 |
+
|
| 123 |
+
# Load dynamic data
|
| 124 |
+
self._load_dynamic_data()
|
| 125 |
+
|
| 126 |
+
# Build the multi-agent graph
|
| 127 |
+
self.graph = self._build_graph()
|
| 128 |
+
|
| 129 |
+
# Conversations directory
|
| 130 |
+
self.conversations_dir = Path("conversations")
|
| 131 |
+
self.conversations_dir.mkdir(exist_ok=True)
|
| 132 |
+
|
| 133 |
+
logger.info("🤖 Multi-Agent RAG Chatbot initialized")
|
| 134 |
+
|
| 135 |
+
def _load_dynamic_data(self):
|
| 136 |
+
"""Load dynamic data from filter_options.json and add_district_metadata.py"""
|
| 137 |
+
# Load filter options
|
| 138 |
+
try:
|
| 139 |
+
fo = Path("filter_options.json")
|
| 140 |
+
if fo.exists():
|
| 141 |
+
with open(fo) as f:
|
| 142 |
+
data = json.load(f)
|
| 143 |
+
self.year_whitelist = [str(y).strip() for y in data.get("years", [])]
|
| 144 |
+
self.source_whitelist = [str(s).strip() for s in data.get("sources", [])]
|
| 145 |
+
self.district_whitelist = [str(d).strip() for d in data.get("districts", [])]
|
| 146 |
+
else:
|
| 147 |
+
# Fallback to default values
|
| 148 |
+
self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
|
| 149 |
+
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
|
| 150 |
+
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.warning(f"Could not load filter options: {e}")
|
| 153 |
+
self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
|
| 154 |
+
self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
|
| 155 |
+
self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
|
| 156 |
+
|
| 157 |
+
# Enrich district list from add_district_metadata.py
|
| 158 |
+
try:
|
| 159 |
+
from add_district_metadata import DistrictMetadataProcessor
|
| 160 |
+
proc = DistrictMetadataProcessor()
|
| 161 |
+
names = set()
|
| 162 |
+
for key, mapping in proc.district_mappings.items():
|
| 163 |
+
if getattr(mapping, 'is_district', True):
|
| 164 |
+
names.add(mapping.name)
|
| 165 |
+
if names:
|
| 166 |
+
merged = list(self.district_whitelist)
|
| 167 |
+
for n in sorted(names):
|
| 168 |
+
if n not in merged:
|
| 169 |
+
merged.append(n)
|
| 170 |
+
self.district_whitelist = merged
|
| 171 |
+
logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.info(f"ℹ️ Could not enrich districts: {e}")
|
| 174 |
+
|
| 175 |
+
# Calculate current year dynamically
|
| 176 |
+
self.current_year = str(datetime.now().year)
|
| 177 |
+
self.previous_year = str(datetime.now().year - 1)
|
| 178 |
+
|
| 179 |
+
# Log the actual filter values for debugging
|
| 180 |
+
logger.info(f"📊 ACTUAL FILTER VALUES:")
|
| 181 |
+
logger.info(f" Years: {self.year_whitelist}")
|
| 182 |
+
logger.info(f" Sources: {self.source_whitelist}")
|
| 183 |
+
logger.info(f" Districts: {len(self.district_whitelist)} districts (first 10: {self.district_whitelist[:10]})")
|
| 184 |
+
|
| 185 |
+
def _build_graph(self) -> StateGraph:
|
| 186 |
+
"""Build the multi-agent LangGraph"""
|
| 187 |
+
graph = StateGraph(MultiAgentState)
|
| 188 |
+
|
| 189 |
+
# Add nodes for each agent
|
| 190 |
+
graph.add_node("main_agent", self._main_agent)
|
| 191 |
+
graph.add_node("rag_agent", self._rag_agent)
|
| 192 |
+
graph.add_node("response_agent", self._response_agent)
|
| 193 |
+
|
| 194 |
+
# Define the flow
|
| 195 |
+
graph.set_entry_point("main_agent")
|
| 196 |
+
|
| 197 |
+
# Main agent decides next step
|
| 198 |
+
graph.add_conditional_edges(
|
| 199 |
+
"main_agent",
|
| 200 |
+
self._should_call_rag,
|
| 201 |
+
{
|
| 202 |
+
"follow_up": END,
|
| 203 |
+
"call_rag": "rag_agent"
|
| 204 |
+
}
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# RAG agent calls response agent
|
| 208 |
+
graph.add_edge("rag_agent", "response_agent")
|
| 209 |
+
|
| 210 |
+
# Response agent returns to main agent for potential follow-ups
|
| 211 |
+
graph.add_edge("response_agent", "main_agent")
|
| 212 |
+
|
| 213 |
+
return graph.compile()
|
| 214 |
+
|
| 215 |
+
def _should_call_rag(self, state: MultiAgentState) -> str:
|
| 216 |
+
"""Determine if we should call RAG or ask follow-up"""
|
| 217 |
+
# If we already have a final response (from response agent), end
|
| 218 |
+
if state.get("final_response"):
|
| 219 |
+
return "follow_up"
|
| 220 |
+
|
| 221 |
+
context = state["query_context"]
|
| 222 |
+
if context and context.needs_follow_up:
|
| 223 |
+
return "follow_up"
|
| 224 |
+
return "call_rag"
|
| 225 |
+
|
| 226 |
+
def _main_agent(self, state: MultiAgentState) -> MultiAgentState:
|
| 227 |
+
"""Main Agent: Handles conversation flow and follow-ups"""
|
| 228 |
+
logger.info("🎯 MAIN AGENT: Starting analysis")
|
| 229 |
+
|
| 230 |
+
# If we already have a final response from response agent, end gracefully
|
| 231 |
+
if state.get("final_response"):
|
| 232 |
+
logger.info("🎯 MAIN AGENT: Final response already exists, ending conversation flow")
|
| 233 |
+
return state
|
| 234 |
+
|
| 235 |
+
query = state["current_query"]
|
| 236 |
+
messages = state["messages"]
|
| 237 |
+
|
| 238 |
+
logger.info(f"🎯 MAIN AGENT: Extracting UI filters from query")
|
| 239 |
+
ui_filters = self._extract_ui_filters(query)
|
| 240 |
+
logger.info(f"🎯 MAIN AGENT: UI filters extracted: {ui_filters}")
|
| 241 |
+
|
| 242 |
+
# Analyze query context
|
| 243 |
+
logger.info(f"🎯 MAIN AGENT: Analyzing query context")
|
| 244 |
+
context = self._analyze_query_context(query, messages, ui_filters)
|
| 245 |
+
|
| 246 |
+
# Log agent decision
|
| 247 |
+
state["agent_logs"].append(f"MAIN AGENT: Context analyzed - district={context.has_district}, source={context.has_source}, year={context.has_year}")
|
| 248 |
+
logger.info(f"🎯 MAIN AGENT: Context analysis complete - district={context.has_district}, source={context.has_source}, year={context.has_year}")
|
| 249 |
+
|
| 250 |
+
# Store context
|
| 251 |
+
state["query_context"] = context
|
| 252 |
+
|
| 253 |
+
# If follow-up needed, generate response
|
| 254 |
+
if context.needs_follow_up:
|
| 255 |
+
logger.info(f"🎯 MAIN AGENT: Follow-up needed, generating question")
|
| 256 |
+
response = context.follow_up_question
|
| 257 |
+
state["final_response"] = response
|
| 258 |
+
state["last_ai_message_time"] = time.time()
|
| 259 |
+
logger.info(f"🎯 MAIN AGENT: Follow-up question generated: {response[:100]}...")
|
| 260 |
+
else:
|
| 261 |
+
logger.info("🎯 MAIN AGENT: No follow-up needed, proceeding to RAG")
|
| 262 |
+
|
| 263 |
+
return state
|
| 264 |
+
|
| 265 |
+
def _rag_agent(self, state: MultiAgentState) -> MultiAgentState:
|
| 266 |
+
"""RAG Agent: Rewrites queries and applies filters"""
|
| 267 |
+
logger.info("🔍 RAG AGENT: Starting query rewriting and filter preparation")
|
| 268 |
+
|
| 269 |
+
context = state["query_context"]
|
| 270 |
+
messages = state["messages"]
|
| 271 |
+
|
| 272 |
+
logger.info(f"🔍 RAG AGENT: Context received - district={context.has_district}, source={context.has_source}, year={context.has_year}")
|
| 273 |
+
|
| 274 |
+
# Rewrite query for RAG
|
| 275 |
+
logger.info(f"🔍 RAG AGENT: Rewriting query for optimal retrieval")
|
| 276 |
+
rag_query = self._rewrite_query_for_rag(messages, context)
|
| 277 |
+
logger.info(f"🔍 RAG AGENT: Query rewritten: '{rag_query}'")
|
| 278 |
+
|
| 279 |
+
# Build filters
|
| 280 |
+
logger.info(f"🔍 RAG AGENT: Building filters from context")
|
| 281 |
+
filters = self._build_filters(context)
|
| 282 |
+
logger.info(f"🔍 RAG AGENT: Filters built: {filters}")
|
| 283 |
+
|
| 284 |
+
# Log RAG preparation
|
| 285 |
+
state["agent_logs"].append(f"RAG AGENT: Query='{rag_query}', Filters={filters}")
|
| 286 |
+
|
| 287 |
+
# Store for response agent
|
| 288 |
+
state["rag_query"] = rag_query
|
| 289 |
+
state["rag_filters"] = filters
|
| 290 |
+
|
| 291 |
+
logger.info(f"🔍 RAG AGENT: Preparation complete, ready for retrieval")
|
| 292 |
+
|
| 293 |
+
return state
|
| 294 |
+
|
| 295 |
+
def _response_agent(self, state: MultiAgentState) -> MultiAgentState:
|
| 296 |
+
"""Response Agent: Generates final answer from retrieved documents"""
|
| 297 |
+
logger.info("📝 RESPONSE AGENT: Starting document retrieval and answer generation")
|
| 298 |
+
|
| 299 |
+
rag_query = state["rag_query"]
|
| 300 |
+
filters = state["rag_filters"]
|
| 301 |
+
|
| 302 |
+
logger.info(f"📝 RESPONSE AGENT: Starting RAG retrieval with query: '{rag_query}'")
|
| 303 |
+
logger.info(f"📝 RESPONSE AGENT: Using filters: {filters}")
|
| 304 |
+
|
| 305 |
+
# Perform RAG retrieval
|
| 306 |
+
logger.info(f"📝 RESPONSE AGENT: Calling pipeline manager for retrieval")
|
| 307 |
+
logger.info(f"🔍 ACTUAL RAG QUERY: '{rag_query}'")
|
| 308 |
+
logger.info(f"🔍 ACTUAL FILTERS: {filters}")
|
| 309 |
+
try:
|
| 310 |
+
# Extract filenames from filters if present
|
| 311 |
+
filenames = filters.get("filenames") if filters else None
|
| 312 |
+
|
| 313 |
+
result = self.pipeline_manager.run(
|
| 314 |
+
query=rag_query,
|
| 315 |
+
sources=filters.get("sources") if filters else None,
|
| 316 |
+
auto_infer_filters=False,
|
| 317 |
+
filters=filters if filters else None
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
logger.info(f"📝 RESPONSE AGENT: RAG retrieval completed - {len(result.sources)} documents retrieved")
|
| 321 |
+
logger.info(f"🔍 RETRIEVAL DEBUG: Result type: {type(result)}")
|
| 322 |
+
logger.info(f"🔍 RETRIEVAL DEBUG: Result sources type: {type(result.sources)}")
|
| 323 |
+
# logger.info(f"🔍 RETRIEVAL DEBUG: Result metadata: {getattr(result, 'metadata', 'No metadata')}")
|
| 324 |
+
|
| 325 |
+
if len(result.sources) == 0:
|
| 326 |
+
logger.warning(f"⚠️ NO DOCUMENTS RETRIEVED: Query='{rag_query}', Filters={filters}")
|
| 327 |
+
logger.warning(f"⚠️ RETRIEVAL DEBUG: This could be due to:")
|
| 328 |
+
logger.warning(f" - Query too specific for available documents")
|
| 329 |
+
logger.warning(f" - Filters too restrictive")
|
| 330 |
+
logger.warning(f" - Vector store connection issues")
|
| 331 |
+
logger.warning(f" - Embedding model issues")
|
| 332 |
+
else:
|
| 333 |
+
logger.info(f"✅ DOCUMENTS RETRIEVED: {len(result.sources)} documents found")
|
| 334 |
+
for i, doc in enumerate(result.sources[:3]): # Log first 3 docs
|
| 335 |
+
logger.info(f" Doc {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...")
|
| 336 |
+
|
| 337 |
+
state["retrieved_documents"] = result.sources
|
| 338 |
+
state["agent_logs"].append(f"RESPONSE AGENT: Retrieved {len(result.sources)} documents")
|
| 339 |
+
|
| 340 |
+
# Check highest similarity score
|
| 341 |
+
highest_score = 0.0
|
| 342 |
+
if result.sources:
|
| 343 |
+
# Check reranked_score first (more accurate), fallback to original_score
|
| 344 |
+
for doc in result.sources:
|
| 345 |
+
score = doc.metadata.get('reranked_score') or doc.metadata.get('original_score', 0.0)
|
| 346 |
+
if score > highest_score:
|
| 347 |
+
highest_score = score
|
| 348 |
+
|
| 349 |
+
logger.info(f"📝 RESPONSE AGENT: Highest similarity score: {highest_score:.4f}")
|
| 350 |
+
|
| 351 |
+
# If highest score is too low, don't use retrieved documents
|
| 352 |
+
if highest_score <= 0.15:
|
| 353 |
+
logger.warning(f"⚠️ RESPONSE AGENT: Low similarity score ({highest_score:.4f} <= 0.15), using LLM knowledge only")
|
| 354 |
+
response = self._generate_conversational_response_without_docs(
|
| 355 |
+
state["current_query"],
|
| 356 |
+
state["messages"]
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
# Generate conversational response with documents
|
| 360 |
+
logger.info(f"📝 RESPONSE AGENT: Generating conversational response from {len(result.sources)} documents")
|
| 361 |
+
response = self._generate_conversational_response(
|
| 362 |
+
state["current_query"],
|
| 363 |
+
result.sources,
|
| 364 |
+
result.answer,
|
| 365 |
+
state["messages"]
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
logger.info(f"📝 RESPONSE AGENT: Response generated: {response[:100]}...")
|
| 369 |
+
|
| 370 |
+
state["final_response"] = response
|
| 371 |
+
state["last_ai_message_time"] = time.time()
|
| 372 |
+
|
| 373 |
+
logger.info(f"📝 RESPONSE AGENT: Answer generation complete")
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"❌ RESPONSE AGENT ERROR: {e}")
|
| 377 |
+
state["final_response"] = "I apologize, but I encountered an error while retrieving information. Please try again."
|
| 378 |
+
state["last_ai_message_time"] = time.time()
|
| 379 |
+
|
| 380 |
+
return state
|
| 381 |
+
|
| 382 |
+
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
|
| 383 |
+
"""Extract UI filters from query"""
|
| 384 |
+
filters = {}
|
| 385 |
+
|
| 386 |
+
# Look for FILTER CONTEXT in query
|
| 387 |
+
if "FILTER CONTEXT:" in query:
|
| 388 |
+
# Extract the entire filter section (until USER QUERY: or end of query)
|
| 389 |
+
filter_section = query.split("FILTER CONTEXT:")[1]
|
| 390 |
+
if "USER QUERY:" in filter_section:
|
| 391 |
+
filter_section = filter_section.split("USER QUERY:")[0]
|
| 392 |
+
filter_section = filter_section.strip()
|
| 393 |
+
|
| 394 |
+
# Parse sources
|
| 395 |
+
if "Sources:" in filter_section:
|
| 396 |
+
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')][0]
|
| 397 |
+
sources_str = sources_line.split("Sources:")[1].strip()
|
| 398 |
+
if sources_str and sources_str != "None":
|
| 399 |
+
filters["sources"] = [s.strip() for s in sources_str.split(",")]
|
| 400 |
+
|
| 401 |
+
# Parse years
|
| 402 |
+
if "Years:" in filter_section:
|
| 403 |
+
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')][0]
|
| 404 |
+
years_str = years_line.split("Years:")[1].strip()
|
| 405 |
+
if years_str and years_str != "None":
|
| 406 |
+
filters["years"] = [y.strip() for y in years_str.split(",")]
|
| 407 |
+
|
| 408 |
+
# Parse districts
|
| 409 |
+
if "Districts:" in filter_section:
|
| 410 |
+
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')][0]
|
| 411 |
+
districts_str = districts_line.split("Districts:")[1].strip()
|
| 412 |
+
if districts_str and districts_str != "None":
|
| 413 |
+
filters["districts"] = [d.strip() for d in districts_str.split(",")]
|
| 414 |
+
|
| 415 |
+
# Parse filenames
|
| 416 |
+
if "Filenames:" in filter_section:
|
| 417 |
+
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')][0]
|
| 418 |
+
filenames_str = filenames_line.split("Filenames:")[1].strip()
|
| 419 |
+
if filenames_str and filenames_str != "None":
|
| 420 |
+
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
|
| 421 |
+
|
| 422 |
+
return filters
|
| 423 |
+
|
| 424 |
+
def _analyze_query_context(self, query: str, messages: List[Any], ui_filters: Dict[str, List[str]]) -> QueryContext:
|
| 425 |
+
"""Analyze query context using LLM"""
|
| 426 |
+
logger.info(f"🔍 QUERY ANALYSIS: '{query[:50]}...' | UI filters: {ui_filters} | Messages: {len(messages)}")
|
| 427 |
+
|
| 428 |
+
# Build conversation context
|
| 429 |
+
conversation_context = ""
|
| 430 |
+
for i, msg in enumerate(messages[-6:]): # Last 6 messages
|
| 431 |
+
if isinstance(msg, HumanMessage):
|
| 432 |
+
conversation_context += f"User: {msg.content}\n"
|
| 433 |
+
elif isinstance(msg, AIMessage):
|
| 434 |
+
conversation_context += f"Assistant: {msg.content}\n"
|
| 435 |
+
|
| 436 |
+
# Create analysis prompt
|
| 437 |
+
analysis_prompt = ChatPromptTemplate.from_messages([
|
| 438 |
+
SystemMessage(content=f"""You are the Main Agent in an advanced multi-agent RAG system for audit report analysis.
|
| 439 |
+
|
| 440 |
+
🎯 PRIMARY GOAL: Intelligently analyze user queries and determine the optimal conversation flow, whether that's answering directly, asking follow-ups, or proceeding to RAG retrieval.
|
| 441 |
+
|
| 442 |
+
🧠 INTELLIGENCE LEVEL: You are a sophisticated conversational AI that can handle any type of user interaction - from greetings to complex audit queries.
|
| 443 |
+
|
| 444 |
+
📊 YOUR EXPERTISE: You specialize in analyzing audit reports from various sources (Local Government, Ministry, Hospital, etc.) across different years and districts in Uganda.
|
| 445 |
+
|
| 446 |
+
🔍 AVAILABLE FILTERS:
|
| 447 |
+
- Years: {', '.join(self.year_whitelist)}
|
| 448 |
+
- Current year: {self.current_year}, Previous year: {self.previous_year}
|
| 449 |
+
- Sources: {', '.join(self.source_whitelist)}
|
| 450 |
+
- Districts: {', '.join(self.district_whitelist[:50])}... (and {len(self.district_whitelist)-50} more)
|
| 451 |
+
|
| 452 |
+
🎛️ UI FILTERS PROVIDED: {ui_filters}
|
| 453 |
+
|
| 454 |
+
📋 UI FILTER HANDLING:
|
| 455 |
+
- If UI filters contain multiple values (e.g., districts: ['Lwengo', 'Kiboga']), extract ALL values
|
| 456 |
+
- For multiple districts: extract each district separately and validate each one
|
| 457 |
+
- For multiple years: extract each year separately and validate each one
|
| 458 |
+
- For multiple sources: extract each source separately and validate each one
|
| 459 |
+
- UI filters take PRIORITY over conversation context - use them first
|
| 460 |
+
|
| 461 |
+
🧭 CONVERSATION FLOW INTELLIGENCE:
|
| 462 |
+
|
| 463 |
+
1. **GREETINGS & GENERAL CHAT**:
|
| 464 |
+
- If user greets you ("Hi", "Hello", "How are you"), respond warmly and guide them to audit-related questions
|
| 465 |
+
- Example: "Hello! I'm here to help you analyze audit reports. What would you like to know about budget allocations, expenditures, or audit findings?"
|
| 466 |
+
|
| 467 |
+
2. **EDGE CASES**:
|
| 468 |
+
- Handle "What can you do?", "Help", "I don't know what to ask" with helpful guidance
|
| 469 |
+
- Example: "I can help you analyze audit reports! Try asking about budget allocations, salary management, PDM implementation, or any specific audit findings."
|
| 470 |
+
|
| 471 |
+
3. **AUDIT QUERIES**:
|
| 472 |
+
- Extract ONLY values that EXACTLY match the available lists above
|
| 473 |
+
- DO NOT hallucinate or infer values not in the lists
|
| 474 |
+
- If user mentions "salary payroll management" - this is NOT a valid source filter
|
| 475 |
+
|
| 476 |
+
**YEAR EXTRACTION**:
|
| 477 |
+
- If user mentions "2023" and it's in the years list - extract "2023"
|
| 478 |
+
- If user mentions "2022 / 23" - extract ["2022", "2023"] (as a JSON array)
|
| 479 |
+
- If user mentions "2022-2023" - extract ["2022", "2023"] (as a JSON array)
|
| 480 |
+
- If user mentions "latest couple of years" - extract the 2 most recent years from available data as JSON array
|
| 481 |
+
- Always return years as JSON arrays when multiple years are mentioned
|
| 482 |
+
|
| 483 |
+
**DISTRICT EXTRACTION**:
|
| 484 |
+
- If user mentions "Kampala" and it's in the districts list - extract "Kampala"
|
| 485 |
+
- If user mentions "Pader District" - extract "Pader" (remove "District" suffix)
|
| 486 |
+
- If user mentions "Lwengo, Kiboga and Namutumba" - extract ["Lwengo", "Kiboga", "Namutumba"] (as JSON array)
|
| 487 |
+
- If user mentions "Lwengo District and Kiboga District" - extract ["Lwengo", "Kiboga"] (as JSON array, remove "District" suffix)
|
| 488 |
+
- Always return districts as JSON arrays when multiple districts are mentioned
|
| 489 |
+
- If no exact matches found, set extracted values to null
|
| 490 |
+
|
| 491 |
+
4. **FILENAME FILTERING (MUTUALLY EXCLUSIVE)**:
|
| 492 |
+
- If UI provides filenames filter - ONLY use that, ignore all other filters (year, district, source)
|
| 493 |
+
- With filenames filter, no follow-ups needed - proceed directly to RAG
|
| 494 |
+
- When filenames are specified, skip filter inference entirely
|
| 495 |
+
|
| 496 |
+
5. **HALLUCINATION PREVENTION**:
|
| 497 |
+
- If user asks about a specific report but NO filename is selected in UI and NONE is extracted from conversation - DO NOT hallucinate
|
| 498 |
+
- Clearly state: "I don't have any specific report selected. Could you please select a report from the list or tell me which report you'd like to analyze?"
|
| 499 |
+
- DO NOT pretend to know which report they mean
|
| 500 |
+
- DO NOT infer reports from context alone - only use explicitly mentioned reports
|
| 501 |
+
|
| 502 |
+
6. **CONVERSATION CONTEXT AWARENESS**:
|
| 503 |
+
- ALWAYS consider the full conversation context when extracting filters
|
| 504 |
+
- If district was mentioned in previous messages, include it in current analysis
|
| 505 |
+
- If year was mentioned in previous messages, include it in current analysis
|
| 506 |
+
- If source was mentioned in previous messages, include it in current analysis
|
| 507 |
+
- Example: If conversation shows "User: Tell me about Pader District" then "User: 2023", extract both: district="Pader" and year="2023"
|
| 508 |
+
|
| 509 |
+
5. **SMART FOLLOW-UP STRATEGY**:
|
| 510 |
+
- NEVER ask the same question twice in a row
|
| 511 |
+
- If user provides source info, ask for year or district next
|
| 512 |
+
- If user provides year info, ask for source or district next
|
| 513 |
+
- If user provides district info, ask for year or source next
|
| 514 |
+
- If user provides 2+ pieces of info, proceed to RAG instead of asking more
|
| 515 |
+
- Make follow-ups conversational and contextual, not robotic
|
| 516 |
+
|
| 517 |
+
5. **DYNAMIC FOLLOW-UP EXAMPLES**:
|
| 518 |
+
- Budget queries: "What year are you interested in?" or "Which department - Local Government or Ministry?"
|
| 519 |
+
- PDM queries: "Which district are you interested in?" or "What year?"
|
| 520 |
+
- General queries: "Could you be more specific about what you'd like to know?"
|
| 521 |
+
|
| 522 |
+
🎯 DECISION LOGIC:
|
| 523 |
+
- If query is a greeting/general chat → needs_follow_up: true, provide helpful guidance
|
| 524 |
+
- If query has 2+ pieces of info → needs_follow_up: false, proceed to RAG
|
| 525 |
+
- If query has 1 piece of info → needs_follow_up: true, ask for missing piece
|
| 526 |
+
- If query has 0 pieces of info → needs_follow_up: true, ask for clarification
|
| 527 |
+
|
| 528 |
+
RESPOND WITH JSON ONLY:
|
| 529 |
+
{{
|
| 530 |
+
"has_district": boolean,
|
| 531 |
+
"has_source": boolean,
|
| 532 |
+
"has_year": boolean,
|
| 533 |
+
"extracted_district": "single district name or JSON array of districts or null",
|
| 534 |
+
"extracted_source": "single source name or JSON array of sources or null",
|
| 535 |
+
"extracted_year": "single year or JSON array of years or null",
|
| 536 |
+
"confidence_score": 0.0-1.0,
|
| 537 |
+
"needs_follow_up": boolean,
|
| 538 |
+
"follow_up_question": "conversational question or helpful guidance or null"
|
| 539 |
+
}}"""),
|
| 540 |
+
HumanMessage(content=f"""Query: {query}
|
| 541 |
+
|
| 542 |
+
Conversation Context:
|
| 543 |
+
{conversation_context}
|
| 544 |
+
|
| 545 |
+
CRITICAL: You MUST analyze the FULL conversation context above, not just the current query.
|
| 546 |
+
- If ANY district was mentioned in previous messages, extract it
|
| 547 |
+
- If ANY year was mentioned in previous messages, extract it
|
| 548 |
+
- If ANY source was mentioned in previous messages, extract it
|
| 549 |
+
- Combine information from ALL messages in the conversation
|
| 550 |
+
|
| 551 |
+
Analyze this query using ONLY the exact values provided above:""")
|
| 552 |
+
])
|
| 553 |
+
|
| 554 |
+
try:
|
| 555 |
+
response = self.llm.invoke(analysis_prompt.format_messages())
|
| 556 |
+
|
| 557 |
+
# Clean the response to extract JSON
|
| 558 |
+
content = response.content.strip()
|
| 559 |
+
if content.startswith("```json"):
|
| 560 |
+
# Remove markdown formatting
|
| 561 |
+
content = content.replace("```json", "").replace("```", "").strip()
|
| 562 |
+
elif content.startswith("```"):
|
| 563 |
+
# Remove generic markdown formatting
|
| 564 |
+
content = content.replace("```", "").strip()
|
| 565 |
+
|
| 566 |
+
# Clean and parse JSON with better error handling
|
| 567 |
+
try:
|
| 568 |
+
# Remove comments (// and /* */) from JSON
|
| 569 |
+
import re
|
| 570 |
+
# Remove single-line comments
|
| 571 |
+
content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
|
| 572 |
+
# Remove multi-line comments
|
| 573 |
+
content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
|
| 574 |
+
|
| 575 |
+
analysis = json.loads(content)
|
| 576 |
+
logger.info(f"🔍 QUERY ANALYSIS: ✅ Parsed successfully")
|
| 577 |
+
except json.JSONDecodeError as e:
|
| 578 |
+
logger.error(f"❌ JSON parsing failed: {e}")
|
| 579 |
+
logger.error(f"❌ Raw content: {content[:200]}...")
|
| 580 |
+
|
| 581 |
+
# Try to extract JSON from text if embedded
|
| 582 |
+
import re
|
| 583 |
+
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
| 584 |
+
if json_match:
|
| 585 |
+
try:
|
| 586 |
+
# Clean the extracted JSON
|
| 587 |
+
cleaned_json = json_match.group()
|
| 588 |
+
cleaned_json = re.sub(r'//.*?$', '', cleaned_json, flags=re.MULTILINE)
|
| 589 |
+
cleaned_json = re.sub(r'/\*.*?\*/', '', cleaned_json, flags=re.DOTALL)
|
| 590 |
+
analysis = json.loads(cleaned_json)
|
| 591 |
+
logger.info(f"🔍 QUERY ANALYSIS: ✅ Extracted and cleaned JSON from text")
|
| 592 |
+
except json.JSONDecodeError as e2:
|
| 593 |
+
logger.error(f"❌ Failed to extract JSON from text: {e2}")
|
| 594 |
+
# Return fallback context
|
| 595 |
+
context = QueryContext(
|
| 596 |
+
has_district=False,
|
| 597 |
+
has_source=False,
|
| 598 |
+
has_year=False,
|
| 599 |
+
extracted_district=None,
|
| 600 |
+
extracted_source=None,
|
| 601 |
+
extracted_year=None,
|
| 602 |
+
confidence_score=0.0,
|
| 603 |
+
needs_follow_up=True,
|
| 604 |
+
follow_up_question="I apologize, but I'm having trouble processing your request. Could you please rephrase it or ask for help?"
|
| 605 |
+
)
|
| 606 |
+
return context
|
| 607 |
+
else:
|
| 608 |
+
# Return fallback context
|
| 609 |
+
context = QueryContext(
|
| 610 |
+
has_district=False,
|
| 611 |
+
has_source=False,
|
| 612 |
+
has_year=False,
|
| 613 |
+
extracted_district=None,
|
| 614 |
+
extracted_source=None,
|
| 615 |
+
extracted_year=None,
|
| 616 |
+
confidence_score=0.0,
|
| 617 |
+
needs_follow_up=True,
|
| 618 |
+
follow_up_question="I apologize, but I'm having trouble processing your request. Could you please rephrase it or ask for help?"
|
| 619 |
+
)
|
| 620 |
+
return context
|
| 621 |
+
|
| 622 |
+
# Validate extracted values against whitelists
|
| 623 |
+
extracted_district = analysis.get("extracted_district")
|
| 624 |
+
extracted_source = analysis.get("extracted_source")
|
| 625 |
+
extracted_year = analysis.get("extracted_year")
|
| 626 |
+
|
| 627 |
+
logger.info(f"🔍 QUERY ANALYSIS: Raw extracted values - district: {extracted_district}, source: {extracted_source}, year: {extracted_year}")
|
| 628 |
+
|
| 629 |
+
# Validate district (handle both single values and arrays)
|
| 630 |
+
if extracted_district:
|
| 631 |
+
if isinstance(extracted_district, list):
|
| 632 |
+
# Validate each district in the array
|
| 633 |
+
valid_districts = []
|
| 634 |
+
for district in extracted_district:
|
| 635 |
+
if district in self.district_whitelist:
|
| 636 |
+
valid_districts.append(district)
|
| 637 |
+
else:
|
| 638 |
+
# Try removing "District" suffix
|
| 639 |
+
district_name = district.replace(" District", "").replace(" district", "")
|
| 640 |
+
if district_name in self.district_whitelist:
|
| 641 |
+
valid_districts.append(district_name)
|
| 642 |
+
|
| 643 |
+
if valid_districts:
|
| 644 |
+
extracted_district = valid_districts[0] if len(valid_districts) == 1 else valid_districts
|
| 645 |
+
logger.info(f"🔍 QUERY ANALYSIS: Extracted districts: {extracted_district}")
|
| 646 |
+
else:
|
| 647 |
+
logger.warning(f"⚠️ No valid districts found in: '{extracted_district}'")
|
| 648 |
+
extracted_district = None
|
| 649 |
+
else:
|
| 650 |
+
# Single district validation
|
| 651 |
+
if extracted_district not in self.district_whitelist:
|
| 652 |
+
# Try removing "District" suffix
|
| 653 |
+
district_name = extracted_district.replace(" District", "").replace(" district", "")
|
| 654 |
+
if district_name in self.district_whitelist:
|
| 655 |
+
logger.info(f"🔍 QUERY ANALYSIS: Normalized district '{extracted_district}' to '{district_name}'")
|
| 656 |
+
extracted_district = district_name
|
| 657 |
+
else:
|
| 658 |
+
logger.warning(f"⚠️ Invalid district extracted: '{extracted_district}' not in whitelist")
|
| 659 |
+
extracted_district = None
|
| 660 |
+
|
| 661 |
+
# Validate source (handle both single values and arrays)
|
| 662 |
+
if extracted_source:
|
| 663 |
+
if isinstance(extracted_source, list):
|
| 664 |
+
# Validate each source in the array
|
| 665 |
+
valid_sources = []
|
| 666 |
+
for source in extracted_source:
|
| 667 |
+
if source in self.source_whitelist:
|
| 668 |
+
valid_sources.append(source)
|
| 669 |
+
else:
|
| 670 |
+
logger.warning(f"⚠️ Invalid source in array: '{source}' not in whitelist")
|
| 671 |
+
|
| 672 |
+
if valid_sources:
|
| 673 |
+
extracted_source = valid_sources[0] if len(valid_sources) == 1 else valid_sources
|
| 674 |
+
logger.info(f"🔍 QUERY ANALYSIS: Extracted sources: {extracted_source}")
|
| 675 |
+
else:
|
| 676 |
+
logger.warning(f"⚠️ No valid sources found in: '{extracted_source}'")
|
| 677 |
+
extracted_source = None
|
| 678 |
+
else:
|
| 679 |
+
# Single source validation
|
| 680 |
+
if extracted_source not in self.source_whitelist:
|
| 681 |
+
logger.warning(f"⚠️ Invalid source extracted: '{extracted_source}' not in whitelist")
|
| 682 |
+
extracted_source = None
|
| 683 |
+
|
| 684 |
+
# Validate year (handle both single values and arrays)
|
| 685 |
+
if extracted_year:
|
| 686 |
+
if isinstance(extracted_year, list):
|
| 687 |
+
# Validate each year in the array
|
| 688 |
+
valid_years = []
|
| 689 |
+
for year in extracted_year:
|
| 690 |
+
year_str = str(year)
|
| 691 |
+
if year_str in self.year_whitelist:
|
| 692 |
+
valid_years.append(year_str)
|
| 693 |
+
|
| 694 |
+
if valid_years:
|
| 695 |
+
extracted_year = valid_years[0] if len(valid_years) == 1 else valid_years
|
| 696 |
+
logger.info(f"🔍 QUERY ANALYSIS: Extracted years: {extracted_year}")
|
| 697 |
+
else:
|
| 698 |
+
logger.warning(f"⚠️ No valid years found in: '{extracted_year}'")
|
| 699 |
+
extracted_year = None
|
| 700 |
+
else:
|
| 701 |
+
# Single year validation
|
| 702 |
+
year_str = str(extracted_year)
|
| 703 |
+
if year_str not in self.year_whitelist:
|
| 704 |
+
logger.warning(f"⚠️ Invalid year extracted: '{extracted_year}' not in whitelist")
|
| 705 |
+
extracted_year = None
|
| 706 |
+
else:
|
| 707 |
+
extracted_year = year_str
|
| 708 |
+
|
| 709 |
+
logger.info(f"🔍 QUERY ANALYSIS: Validated values - district: {extracted_district}, source: {extracted_source}, year: {extracted_year}")
|
| 710 |
+
|
| 711 |
+
# Create QueryContext object
|
| 712 |
+
context = QueryContext(
|
| 713 |
+
has_district=bool(extracted_district),
|
| 714 |
+
has_source=bool(extracted_source),
|
| 715 |
+
has_year=bool(extracted_year),
|
| 716 |
+
extracted_district=extracted_district,
|
| 717 |
+
extracted_source=extracted_source,
|
| 718 |
+
extracted_year=extracted_year,
|
| 719 |
+
ui_filters=ui_filters,
|
| 720 |
+
confidence_score=analysis.get("confidence_score", 0.0),
|
| 721 |
+
needs_follow_up=analysis.get("needs_follow_up", False),
|
| 722 |
+
follow_up_question=analysis.get("follow_up_question")
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
logger.info(f"🔍 QUERY ANALYSIS: Analysis complete - needs_follow_up: {context.needs_follow_up}, confidence: {context.confidence_score}")
|
| 726 |
+
|
| 727 |
+
# If filenames are provided in UI, skip follow-ups and proceed to RAG
|
| 728 |
+
if ui_filters and ui_filters.get("filenames"):
|
| 729 |
+
logger.info(f"🔍 QUERY ANALYSIS: Filenames provided, skipping follow-ups, proceeding to RAG")
|
| 730 |
+
context.needs_follow_up = False
|
| 731 |
+
context.follow_up_question = None
|
| 732 |
+
|
| 733 |
+
# Additional smart decision logic
|
| 734 |
+
if context.needs_follow_up:
|
| 735 |
+
# Check if we have enough information to proceed
|
| 736 |
+
info_count = sum([
|
| 737 |
+
bool(context.extracted_district),
|
| 738 |
+
bool(context.extracted_source),
|
| 739 |
+
bool(context.extracted_year)
|
| 740 |
+
])
|
| 741 |
+
|
| 742 |
+
# Check if user is asking for more info vs providing it
|
| 743 |
+
query_lower = query.lower()
|
| 744 |
+
is_requesting_info = any(phrase in query_lower for phrase in [
|
| 745 |
+
"please provide", "could you provide", "can you provide",
|
| 746 |
+
"what is", "what are", "how much", "which", "what year",
|
| 747 |
+
"what district", "what source", "tell me about"
|
| 748 |
+
])
|
| 749 |
+
|
| 750 |
+
# If we have 2+ pieces of info AND user is not requesting more info, proceed to RAG
|
| 751 |
+
if info_count >= 2 and not is_requesting_info:
|
| 752 |
+
logger.info(f"🔍 QUERY ANALYSIS: Smart override - have {info_count} pieces of info and user not requesting more, proceeding to RAG")
|
| 753 |
+
context.needs_follow_up = False
|
| 754 |
+
context.follow_up_question = None
|
| 755 |
+
elif info_count >= 2 and is_requesting_info:
|
| 756 |
+
logger.info(f"🔍 QUERY ANALYSIS: User requesting more info despite having {info_count} pieces, proceeding to RAG with comprehensive answer")
|
| 757 |
+
context.needs_follow_up = False
|
| 758 |
+
context.follow_up_question = None
|
| 759 |
+
|
| 760 |
+
return context
|
| 761 |
+
|
| 762 |
+
except Exception as e:
|
| 763 |
+
logger.error(f"❌ Query analysis failed: {e}")
|
| 764 |
+
# Fallback: proceed with RAG
|
| 765 |
+
return QueryContext(
|
| 766 |
+
has_district=bool(ui_filters.get("districts")),
|
| 767 |
+
has_source=bool(ui_filters.get("sources")),
|
| 768 |
+
has_year=bool(ui_filters.get("years")),
|
| 769 |
+
ui_filters=ui_filters,
|
| 770 |
+
confidence_score=0.5,
|
| 771 |
+
needs_follow_up=False
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def _rewrite_query_for_rag(self, messages: List[Any], context: QueryContext) -> str:
|
| 775 |
+
"""Rewrite query for optimal RAG retrieval"""
|
| 776 |
+
logger.info("🔄 QUERY REWRITING: Starting query rewrite for RAG")
|
| 777 |
+
logger.info(f"🔄 QUERY REWRITING: Processing {len(messages)} messages")
|
| 778 |
+
|
| 779 |
+
# Build conversation context
|
| 780 |
+
logger.info(f"🔄 QUERY REWRITING: Building conversation context from last 6 messages")
|
| 781 |
+
conversation_lines = []
|
| 782 |
+
for i, msg in enumerate(messages[-6:]):
|
| 783 |
+
if isinstance(msg, HumanMessage):
|
| 784 |
+
conversation_lines.append(f"User: {msg.content}")
|
| 785 |
+
logger.info(f"🔄 QUERY REWRITING: Message {i+1}: User - {msg.content[:50]}...")
|
| 786 |
+
elif isinstance(msg, AIMessage):
|
| 787 |
+
conversation_lines.append(f"Assistant: {msg.content}")
|
| 788 |
+
logger.info(f"🔄 QUERY REWRITING: Message {i+1}: Assistant - {msg.content[:50]}...")
|
| 789 |
+
|
| 790 |
+
convo_text = "\n".join(conversation_lines)
|
| 791 |
+
logger.info(f"🔄 QUERY REWRITING: Conversation context built ({len(convo_text)} chars)")
|
| 792 |
+
|
| 793 |
+
# Create rewrite prompt
|
| 794 |
+
rewrite_prompt = ChatPromptTemplate.from_messages([
|
| 795 |
+
SystemMessage(content=f"""You are a query rewriter for RAG retrieval.
|
| 796 |
+
|
| 797 |
+
GOAL: Create the best possible search query for document retrieval.
|
| 798 |
+
|
| 799 |
+
CRITICAL RULES:
|
| 800 |
+
1. Focus on the core information need from the conversation
|
| 801 |
+
2. Remove meta-verbs like "summarize", "list", "compare", "how much", "what" - keep the content focus
|
| 802 |
+
3. DO NOT include filter details (years, districts, sources) - these are applied separately as filters
|
| 803 |
+
4. DO NOT include specific years, district names, or source types in the query
|
| 804 |
+
5. Output ONE clear sentence suitable for vector search
|
| 805 |
+
6. Keep it generic and focused on the topic/subject matter
|
| 806 |
+
|
| 807 |
+
EXAMPLES:
|
| 808 |
+
- "What are the top challenges in budget allocation?" → "budget allocation challenges"
|
| 809 |
+
- "How were PDM administrative costs utilized in 2023?" → "PDM administrative costs utilization"
|
| 810 |
+
- "Compare salary management across districts" → "salary management"
|
| 811 |
+
- "How much was budget allocation for Local Government in 2023?" → "budget allocation"
|
| 812 |
+
|
| 813 |
+
OUTPUT FORMAT:
|
| 814 |
+
Provide your response in this exact format:
|
| 815 |
+
|
| 816 |
+
EXPLANATION: [Your reasoning here]
|
| 817 |
+
QUERY: [One clean sentence for retrieval]
|
| 818 |
+
|
| 819 |
+
The QUERY line will be extracted and used directly for RAG retrieval."""),
|
| 820 |
+
HumanMessage(content=f"""Conversation:
|
| 821 |
+
{convo_text}
|
| 822 |
+
|
| 823 |
+
Rewrite the best retrieval query:""")
|
| 824 |
+
])
|
| 825 |
+
|
| 826 |
+
try:
|
| 827 |
+
logger.info(f"🔄 QUERY REWRITING: Calling LLM for query rewrite")
|
| 828 |
+
response = self.llm.invoke(rewrite_prompt.format_messages())
|
| 829 |
+
logger.info(f"🔄 QUERY REWRITING: LLM response received: {response.content[:100]}...")
|
| 830 |
+
|
| 831 |
+
rewritten = response.content.strip()
|
| 832 |
+
|
| 833 |
+
# Extract only the QUERY line from the structured response
|
| 834 |
+
lines = rewritten.split('\n')
|
| 835 |
+
query_line = None
|
| 836 |
+
for line in lines:
|
| 837 |
+
if line.strip().startswith('QUERY:'):
|
| 838 |
+
query_line = line.replace('QUERY:', '').strip()
|
| 839 |
+
break
|
| 840 |
+
|
| 841 |
+
if query_line and len(query_line) > 5:
|
| 842 |
+
logger.info(f"🔄 QUERY REWRITING: Query rewritten successfully: '{query_line[:50]}...'")
|
| 843 |
+
return query_line
|
| 844 |
+
else:
|
| 845 |
+
logger.info(f"🔄 QUERY REWRITING: No QUERY line found or too short, using fallback")
|
| 846 |
+
# Fallback to last user message
|
| 847 |
+
for msg in reversed(messages):
|
| 848 |
+
if isinstance(msg, HumanMessage):
|
| 849 |
+
logger.info(f"🔄 QUERY REWRITING: Using fallback message: '{msg.content[:50]}...'")
|
| 850 |
+
return msg.content
|
| 851 |
+
logger.info(f"🔄 QUERY REWRITING: Using default fallback")
|
| 852 |
+
return "audit report information"
|
| 853 |
+
|
| 854 |
+
except Exception as e:
|
| 855 |
+
logger.error(f"❌ QUERY REWRITING: Error during rewrite: {e}")
|
| 856 |
+
# Fallback
|
| 857 |
+
for msg in reversed(messages):
|
| 858 |
+
if isinstance(msg, HumanMessage):
|
| 859 |
+
logger.info(f"🔄 QUERY REWRITING: Using error fallback message: '{msg.content[:50]}...'")
|
| 860 |
+
return msg.content
|
| 861 |
+
logger.info(f"🔄 QUERY REWRITING: Using default error fallback")
|
| 862 |
+
return "audit report information"
|
| 863 |
+
|
| 864 |
+
def _build_filters(self, context: QueryContext) -> Dict[str, Any]:
|
| 865 |
+
"""Build filters for RAG retrieval"""
|
| 866 |
+
logger.info("🔧 FILTER BUILDING: Starting filter construction")
|
| 867 |
+
filters = {}
|
| 868 |
+
|
| 869 |
+
# Check for filename filtering first (mutually exclusive)
|
| 870 |
+
if context.ui_filters and context.ui_filters.get("filenames"):
|
| 871 |
+
logger.info(f"🔧 FILTER BUILDING: Filename filtering requested (mutually exclusive mode)")
|
| 872 |
+
filters["filenames"] = context.ui_filters["filenames"]
|
| 873 |
+
logger.info(f"🔧 FILTER BUILDING: Added filenames filter: {context.ui_filters['filenames']}")
|
| 874 |
+
logger.info(f"🔧 FILTER BUILDING: Final filters: {filters}")
|
| 875 |
+
return filters # Return early, skip all other filters
|
| 876 |
+
|
| 877 |
+
# UI filters take priority, but merge with extracted context if UI filters are incomplete
|
| 878 |
+
if context.ui_filters:
|
| 879 |
+
logger.info(f"🔧 FILTER BUILDING: UI filters present: {context.ui_filters}")
|
| 880 |
+
|
| 881 |
+
# Add UI filters first
|
| 882 |
+
if context.ui_filters.get("sources"):
|
| 883 |
+
filters["sources"] = context.ui_filters["sources"]
|
| 884 |
+
logger.info(f"🔧 FILTER BUILDING: Added sources filter from UI: {context.ui_filters['sources']}")
|
| 885 |
+
|
| 886 |
+
if context.ui_filters.get("years"):
|
| 887 |
+
filters["year"] = context.ui_filters["years"]
|
| 888 |
+
logger.info(f"🔧 FILTER BUILDING: Added years filter from UI: {context.ui_filters['years']}")
|
| 889 |
+
|
| 890 |
+
if context.ui_filters.get("districts"):
|
| 891 |
+
# Normalize district names to title case (match Qdrant metadata format)
|
| 892 |
+
normalized_districts = [d.title() for d in context.ui_filters['districts']]
|
| 893 |
+
filters["district"] = normalized_districts
|
| 894 |
+
logger.info(f"🔧 FILTER BUILDING: Added districts filter from UI: {context.ui_filters['districts']} → normalized: {normalized_districts}")
|
| 895 |
+
|
| 896 |
+
# Merge with extracted context for missing filters
|
| 897 |
+
if not filters.get("year") and context.extracted_year:
|
| 898 |
+
# Handle both single values and arrays
|
| 899 |
+
if isinstance(context.extracted_year, list):
|
| 900 |
+
filters["year"] = context.extracted_year
|
| 901 |
+
else:
|
| 902 |
+
filters["year"] = [context.extracted_year]
|
| 903 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted year filter (UI missing): {context.extracted_year}")
|
| 904 |
+
|
| 905 |
+
if not filters.get("district") and context.extracted_district:
|
| 906 |
+
# Handle both single values and arrays
|
| 907 |
+
if isinstance(context.extracted_district, list):
|
| 908 |
+
# Normalize district names to title case (match Qdrant metadata format)
|
| 909 |
+
normalized = [d.title() for d in context.extracted_district]
|
| 910 |
+
filters["district"] = normalized
|
| 911 |
+
else:
|
| 912 |
+
filters["district"] = [context.extracted_district.title()]
|
| 913 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted district filter (UI missing): {context.extracted_district}")
|
| 914 |
+
|
| 915 |
+
if not filters.get("sources") and context.extracted_source:
|
| 916 |
+
# Handle both single values and arrays
|
| 917 |
+
if isinstance(context.extracted_source, list):
|
| 918 |
+
filters["sources"] = context.extracted_source
|
| 919 |
+
else:
|
| 920 |
+
filters["sources"] = [context.extracted_source]
|
| 921 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted source filter (UI missing): {context.extracted_source}")
|
| 922 |
+
else:
|
| 923 |
+
logger.info(f"🔧 FILTER BUILDING: No UI filters, using extracted context")
|
| 924 |
+
# Use extracted context
|
| 925 |
+
if context.extracted_source:
|
| 926 |
+
# Handle both single values and arrays
|
| 927 |
+
if isinstance(context.extracted_source, list):
|
| 928 |
+
filters["sources"] = context.extracted_source
|
| 929 |
+
else:
|
| 930 |
+
filters["sources"] = [context.extracted_source]
|
| 931 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted source filter: {context.extracted_source}")
|
| 932 |
+
|
| 933 |
+
if context.extracted_year:
|
| 934 |
+
# Handle both single values and arrays
|
| 935 |
+
if isinstance(context.extracted_year, list):
|
| 936 |
+
filters["year"] = context.extracted_year
|
| 937 |
+
else:
|
| 938 |
+
filters["year"] = [context.extracted_year]
|
| 939 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted year filter: {context.extracted_year}")
|
| 940 |
+
|
| 941 |
+
if context.extracted_district:
|
| 942 |
+
# Handle both single values and arrays
|
| 943 |
+
if isinstance(context.extracted_district, list):
|
| 944 |
+
filters["district"] = context.extracted_district
|
| 945 |
+
else:
|
| 946 |
+
filters["district"] = [context.extracted_district]
|
| 947 |
+
logger.info(f"🔧 FILTER BUILDING: Added extracted district filter: {context.extracted_district}")
|
| 948 |
+
|
| 949 |
+
logger.info(f"🔧 FILTER BUILDING: Final filters: {filters}")
|
| 950 |
+
return filters
|
| 951 |
+
|
| 952 |
+
def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any]) -> str:
|
| 953 |
+
"""Generate conversational response from RAG results"""
|
| 954 |
+
logger.info("💬 RESPONSE GENERATION: Starting conversational response generation")
|
| 955 |
+
logger.info(f"💬 RESPONSE GENERATION: Processing {len(documents)} documents")
|
| 956 |
+
logger.info(f"💬 RESPONSE GENERATION: Query: '{query[:50]}...'")
|
| 957 |
+
|
| 958 |
+
# Create response prompt
|
| 959 |
+
logger.info(f"💬 RESPONSE GENERATION: Building response prompt")
|
| 960 |
+
response_prompt = ChatPromptTemplate.from_messages([
|
| 961 |
+
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 962 |
+
|
| 963 |
+
RULES:
|
| 964 |
+
1. Answer the user's question directly and clearly
|
| 965 |
+
2. Use the retrieved documents as evidence
|
| 966 |
+
3. Be conversational, not technical
|
| 967 |
+
4. Don't mention scores, retrieval details, or technical implementation
|
| 968 |
+
5. If relevant documents were found, reference them naturally
|
| 969 |
+
6. If no relevant documents, explain based on your knowledge (if you have it) or just say you do not have enough information.
|
| 970 |
+
7. If the passages have useful facts or numbers, use them in your answer.
|
| 971 |
+
8. When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 972 |
+
9. Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 973 |
+
10. If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 974 |
+
11. Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 975 |
+
12. If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 976 |
+
13. You do not need to use every passage. Only use the ones that help answer the question.
|
| 977 |
+
14. If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
TONE: Professional but friendly, like talking to a colleague."""),
|
| 981 |
+
HumanMessage(content=f"""User Question: {query}
|
| 982 |
+
|
| 983 |
+
Retrieved Documents: {len(documents)} documents found
|
| 984 |
+
|
| 985 |
+
RAG Answer: {rag_answer}
|
| 986 |
+
|
| 987 |
+
Generate a conversational response:""")
|
| 988 |
+
])
|
| 989 |
+
|
| 990 |
+
try:
|
| 991 |
+
logger.info(f"💬 RESPONSE GENERATION: Calling LLM for final response")
|
| 992 |
+
response = self.llm.invoke(response_prompt.format_messages())
|
| 993 |
+
logger.info(f"💬 RESPONSE GENERATION: LLM response received: {response.content[:100]}...")
|
| 994 |
+
return response.content.strip()
|
| 995 |
+
except Exception as e:
|
| 996 |
+
logger.error(f"❌ RESPONSE GENERATION: Error during generation: {e}")
|
| 997 |
+
logger.info(f"💬 RESPONSE GENERATION: Using RAG answer as fallback")
|
| 998 |
+
return rag_answer # Fallback to RAG answer
|
| 999 |
+
|
| 1000 |
+
def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
|
| 1001 |
+
"""Generate conversational response using only LLM knowledge and conversation history"""
|
| 1002 |
+
logger.info("💬 RESPONSE GENERATION (NO DOCS): Starting response generation without documents")
|
| 1003 |
+
logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Query: '{query[:50]}...'")
|
| 1004 |
+
|
| 1005 |
+
# Build conversation context
|
| 1006 |
+
conversation_context = ""
|
| 1007 |
+
for i, msg in enumerate(messages[-6:]): # Last 6 messages for context
|
| 1008 |
+
if isinstance(msg, HumanMessage):
|
| 1009 |
+
conversation_context += f"User: {msg.content}\n"
|
| 1010 |
+
elif isinstance(msg, AIMessage):
|
| 1011 |
+
conversation_context += f"Assistant: {msg.content}\n"
|
| 1012 |
+
|
| 1013 |
+
# Create response prompt
|
| 1014 |
+
logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Building response prompt")
|
| 1015 |
+
response_prompt = ChatPromptTemplate.from_messages([
|
| 1016 |
+
SystemMessage(content="""You are a helpful audit report assistant. Generate a natural, conversational response.
|
| 1017 |
+
|
| 1018 |
+
RULES:
|
| 1019 |
+
1. Answer the user's question directly and clearly based on your knowledge
|
| 1020 |
+
2. Use conversation history for context
|
| 1021 |
+
3. Be conversational, not technical
|
| 1022 |
+
4. Acknowledge if the answer is based on general knowledge rather than specific documents
|
| 1023 |
+
5. Stay professional but friendly
|
| 1024 |
+
|
| 1025 |
+
TONE: Professional but friendly, like talking to a colleague."""),
|
| 1026 |
+
HumanMessage(content=f"""Current Question: {query}
|
| 1027 |
+
|
| 1028 |
+
Conversation History:
|
| 1029 |
+
{conversation_context}
|
| 1030 |
+
|
| 1031 |
+
Generate a conversational response based on your knowledge:""")
|
| 1032 |
+
])
|
| 1033 |
+
|
| 1034 |
+
try:
|
| 1035 |
+
logger.info(f"💬 RESPONSE GENERATION (NO DOCS): Calling LLM")
|
| 1036 |
+
response = self.llm.invoke(response_prompt.format_messages())
|
| 1037 |
+
logger.info(f"💬 RESPONSE GENERATION (NO DOCS): LLM response received: {response.content[:100]}...")
|
| 1038 |
+
return response.content.strip()
|
| 1039 |
+
except Exception as e:
|
| 1040 |
+
logger.error(f"❌ RESPONSE GENERATION (NO DOCS): Error during generation: {e}")
|
| 1041 |
+
return "I apologize, but I encountered an error. Please try asking your question differently."
|
| 1042 |
+
|
| 1043 |
+
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
|
| 1044 |
+
"""Main chat interface"""
|
| 1045 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Processing '{user_input[:50]}...'")
|
| 1046 |
+
|
| 1047 |
+
# Load conversation
|
| 1048 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Loading conversation {conversation_id}")
|
| 1049 |
+
conversation_file = self.conversations_dir / f"{conversation_id}.json"
|
| 1050 |
+
conversation = self._load_conversation(conversation_file)
|
| 1051 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Loaded {len(conversation['messages'])} previous messages")
|
| 1052 |
+
|
| 1053 |
+
# Add user message
|
| 1054 |
+
conversation["messages"].append(HumanMessage(content=user_input))
|
| 1055 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Added user message to conversation")
|
| 1056 |
+
|
| 1057 |
+
# Prepare state
|
| 1058 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Preparing state for graph execution")
|
| 1059 |
+
state = MultiAgentState(
|
| 1060 |
+
conversation_id=conversation_id,
|
| 1061 |
+
messages=conversation["messages"],
|
| 1062 |
+
current_query=user_input,
|
| 1063 |
+
query_context=None,
|
| 1064 |
+
rag_query=None,
|
| 1065 |
+
rag_filters=None,
|
| 1066 |
+
retrieved_documents=None,
|
| 1067 |
+
final_response=None,
|
| 1068 |
+
agent_logs=[],
|
| 1069 |
+
conversation_context=conversation.get("context", {}),
|
| 1070 |
+
session_start_time=conversation["session_start_time"],
|
| 1071 |
+
last_ai_message_time=conversation["last_ai_message_time"]
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
# Run multi-agent graph
|
| 1075 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Executing multi-agent graph")
|
| 1076 |
+
final_state = self.graph.invoke(state)
|
| 1077 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Graph execution completed")
|
| 1078 |
+
|
| 1079 |
+
# Add AI response to conversation
|
| 1080 |
+
if final_state["final_response"]:
|
| 1081 |
+
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
|
| 1082 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Added AI response to conversation")
|
| 1083 |
+
|
| 1084 |
+
# Update conversation
|
| 1085 |
+
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
|
| 1086 |
+
conversation["context"] = final_state["conversation_context"]
|
| 1087 |
+
|
| 1088 |
+
# Save conversation
|
| 1089 |
+
logger.info(f"💬 MULTI-AGENT CHAT: Saving conversation")
|
| 1090 |
+
self._save_conversation(conversation_file, conversation)
|
| 1091 |
+
|
| 1092 |
+
logger.info("✅ MULTI-AGENT CHAT: Completed")
|
| 1093 |
+
|
| 1094 |
+
# Return response and RAG results
|
| 1095 |
+
return {
|
| 1096 |
+
'response': final_state["final_response"],
|
| 1097 |
+
'rag_result': {
|
| 1098 |
+
'sources': final_state["retrieved_documents"] or [],
|
| 1099 |
+
'answer': final_state["final_response"]
|
| 1100 |
+
},
|
| 1101 |
+
'agent_logs': final_state["agent_logs"],
|
| 1102 |
+
'actual_rag_query': final_state.get("rag_query", "")
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
|
| 1106 |
+
"""Load conversation from file"""
|
| 1107 |
+
if conversation_file.exists():
|
| 1108 |
+
try:
|
| 1109 |
+
with open(conversation_file) as f:
|
| 1110 |
+
data = json.load(f)
|
| 1111 |
+
# Convert message dicts back to LangChain messages
|
| 1112 |
+
messages = []
|
| 1113 |
+
for msg_data in data.get("messages", []):
|
| 1114 |
+
if msg_data["type"] == "human":
|
| 1115 |
+
messages.append(HumanMessage(content=msg_data["content"]))
|
| 1116 |
+
elif msg_data["type"] == "ai":
|
| 1117 |
+
messages.append(AIMessage(content=msg_data["content"]))
|
| 1118 |
+
data["messages"] = messages
|
| 1119 |
+
return data
|
| 1120 |
+
except Exception as e:
|
| 1121 |
+
logger.warning(f"Could not load conversation: {e}")
|
| 1122 |
+
|
| 1123 |
+
# Return default conversation
|
| 1124 |
+
return {
|
| 1125 |
+
"messages": [],
|
| 1126 |
+
"session_start_time": time.time(),
|
| 1127 |
+
"last_ai_message_time": time.time(),
|
| 1128 |
+
"context": {}
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
|
| 1132 |
+
"""Save conversation to file"""
|
| 1133 |
+
try:
|
| 1134 |
+
# Convert messages to serializable format
|
| 1135 |
+
messages_data = []
|
| 1136 |
+
for msg in conversation["messages"]:
|
| 1137 |
+
if isinstance(msg, HumanMessage):
|
| 1138 |
+
messages_data.append({"type": "human", "content": msg.content})
|
| 1139 |
+
elif isinstance(msg, AIMessage):
|
| 1140 |
+
messages_data.append({"type": "ai", "content": msg.content})
|
| 1141 |
+
|
| 1142 |
+
conversation_data = {
|
| 1143 |
+
"messages": messages_data,
|
| 1144 |
+
"session_start_time": conversation["session_start_time"],
|
| 1145 |
+
"last_ai_message_time": conversation["last_ai_message_time"],
|
| 1146 |
+
"context": conversation.get("context", {})
|
| 1147 |
+
}
|
| 1148 |
+
|
| 1149 |
+
with open(conversation_file, 'w') as f:
|
| 1150 |
+
json.dump(conversation_data, f, indent=2)
|
| 1151 |
+
|
| 1152 |
+
except Exception as e:
|
| 1153 |
+
logger.error(f"Could not save conversation: {e}")
|
| 1154 |
+
|
| 1155 |
+
|
| 1156 |
+
def get_multi_agent_chatbot():
|
| 1157 |
+
"""Get multi-agent chatbot instance"""
|
| 1158 |
+
return MultiAgentRAGChatbot()
|
| 1159 |
+
|
| 1160 |
+
if __name__ == "__main__":
|
| 1161 |
+
# Test the multi-agent system
|
| 1162 |
+
chatbot = MultiAgentRAGChatbot()
|
| 1163 |
+
|
| 1164 |
+
# Test conversation
|
| 1165 |
+
result = chatbot.chat("List me top 10 challenges in budget allocation for the last 3 years")
|
| 1166 |
+
print("Response:", result['response'])
|
| 1167 |
+
print("Agent Logs:", result['agent_logs'])
|
requirements.txt
CHANGED
|
@@ -1,3 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
langchain>=0.1.0
|
| 3 |
+
langchain-core>=0.1.0
|
| 4 |
+
langgraph>=0.0.20
|
| 5 |
+
qdrant-client>=1.7.0
|
| 6 |
+
python-dotenv>=1.0.0
|
| 7 |
+
openai>=1.0.0
|
| 8 |
+
snowflake-connector-python>=4.0.0
|
| 9 |
+
pydantic>=2.0.0
|
smart_chatbot.py
ADDED
|
@@ -0,0 +1,1098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent RAG Chatbot with Smart Query Analysis and Conversation Management
|
| 3 |
+
|
| 4 |
+
This chatbot provides intelligent conversation flow with:
|
| 5 |
+
- Smart query analysis and expansion
|
| 6 |
+
- Single LangSmith conversation traces
|
| 7 |
+
- Local conversation logging
|
| 8 |
+
- Context-aware RAG retrieval
|
| 9 |
+
- Natural conversation without technical jargon
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from datetime import datetime, timedelta
|
| 19 |
+
from typing import Dict, List, Any, Optional, TypedDict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
import re
|
| 23 |
+
from langgraph.graph import StateGraph, END
|
| 24 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 25 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 26 |
+
|
| 27 |
+
from src.pipeline import PipelineManager
|
| 28 |
+
from src.config.loader import load_config
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class QueryAnalysis:
|
| 33 |
+
"""Analysis result of a user query"""
|
| 34 |
+
has_district: bool
|
| 35 |
+
has_source: bool
|
| 36 |
+
has_year: bool
|
| 37 |
+
extracted_district: Optional[str]
|
| 38 |
+
extracted_source: Optional[str]
|
| 39 |
+
extracted_year: Optional[str]
|
| 40 |
+
confidence_score: float
|
| 41 |
+
can_answer_directly: bool
|
| 42 |
+
missing_filters: List[str]
|
| 43 |
+
suggested_follow_up: Optional[str]
|
| 44 |
+
expanded_query: Optional[str] = None # Query expansion for better RAG
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ConversationState(TypedDict):
|
| 48 |
+
"""State for the conversation flow"""
|
| 49 |
+
conversation_id: str
|
| 50 |
+
messages: List[Any]
|
| 51 |
+
current_query: str
|
| 52 |
+
query_analysis: Optional[QueryAnalysis]
|
| 53 |
+
rag_query: Optional[str]
|
| 54 |
+
rag_result: Optional[Any]
|
| 55 |
+
final_response: Optional[str]
|
| 56 |
+
conversation_context: Dict[str, Any] # Store conversation context
|
| 57 |
+
session_start_time: float
|
| 58 |
+
last_ai_message_time: float
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class IntelligentRAGChatbot:
|
| 62 |
+
"""Intelligent chatbot with smart query analysis and conversation management"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, suppress_logs=False):
|
| 65 |
+
"""Initialize the intelligent chatbot"""
|
| 66 |
+
# Setup logger to avoid cluttering UI
|
| 67 |
+
self.logger = logging.getLogger(__name__)
|
| 68 |
+
if suppress_logs:
|
| 69 |
+
self.logger.setLevel(logging.CRITICAL) # Suppress all logs
|
| 70 |
+
else:
|
| 71 |
+
self.logger.setLevel(logging.INFO)
|
| 72 |
+
if not self.logger.handlers:
|
| 73 |
+
handler = logging.StreamHandler()
|
| 74 |
+
formatter = logging.Formatter('%(message)s')
|
| 75 |
+
handler.setFormatter(formatter)
|
| 76 |
+
self.logger.addHandler(handler)
|
| 77 |
+
|
| 78 |
+
self.logger.info("🤖 INITIALIZING: Intelligent RAG Chatbot")
|
| 79 |
+
|
| 80 |
+
# Load configuration first
|
| 81 |
+
self.config = load_config()
|
| 82 |
+
|
| 83 |
+
# Use the same LLM configuration as the existing system
|
| 84 |
+
from auditqa.llm.adapters import get_llm_client
|
| 85 |
+
|
| 86 |
+
# Get LLM client using the same configuration
|
| 87 |
+
reader_config = self.config.get("reader", {})
|
| 88 |
+
default_type = reader_config.get("default_type", "INF_PROVIDERS")
|
| 89 |
+
|
| 90 |
+
# Convert to lowercase as that's how it's registered
|
| 91 |
+
provider_name = default_type.lower()
|
| 92 |
+
|
| 93 |
+
self.llm_adapter = get_llm_client(provider_name, self.config)
|
| 94 |
+
|
| 95 |
+
# Create a simple wrapper for LangChain compatibility
|
| 96 |
+
class LLMWrapper:
|
| 97 |
+
def __init__(self, adapter):
|
| 98 |
+
self.adapter = adapter
|
| 99 |
+
|
| 100 |
+
def invoke(self, messages):
|
| 101 |
+
# Convert LangChain messages to the format expected by the adapter
|
| 102 |
+
if isinstance(messages, list):
|
| 103 |
+
# Convert LangChain messages to dict format
|
| 104 |
+
message_dicts = []
|
| 105 |
+
for msg in messages:
|
| 106 |
+
if hasattr(msg, 'content'):
|
| 107 |
+
role = "user" if isinstance(msg, HumanMessage) else "assistant"
|
| 108 |
+
message_dicts.append({"role": role, "content": msg.content})
|
| 109 |
+
else:
|
| 110 |
+
message_dicts.append({"role": "user", "content": str(msg)})
|
| 111 |
+
else:
|
| 112 |
+
# Single message
|
| 113 |
+
message_dicts = [{"role": "user", "content": str(messages)}]
|
| 114 |
+
|
| 115 |
+
# Use the adapter to generate response
|
| 116 |
+
llm_response = self.adapter.generate(message_dicts)
|
| 117 |
+
|
| 118 |
+
# Return in LangChain format
|
| 119 |
+
class MockResponse:
|
| 120 |
+
def __init__(self, content):
|
| 121 |
+
self.content = content
|
| 122 |
+
|
| 123 |
+
return MockResponse(llm_response.content)
|
| 124 |
+
|
| 125 |
+
self.llm = LLMWrapper(self.llm_adapter)
|
| 126 |
+
|
| 127 |
+
# Initialize pipeline manager for RAG
|
| 128 |
+
self.logger.info("🔧 PIPELINE: Initializing PipelineManager...")
|
| 129 |
+
self.pipeline_manager = PipelineManager(self.config)
|
| 130 |
+
|
| 131 |
+
# Ensure vectorstore is connected
|
| 132 |
+
self.logger.info("🔗 VECTORSTORE: Connecting to Qdrant...")
|
| 133 |
+
try:
|
| 134 |
+
vectorstore = self.pipeline_manager.vectorstore_manager.connect_to_existing()
|
| 135 |
+
self.logger.info("✅ VECTORSTORE: Connected successfully")
|
| 136 |
+
except Exception as e:
|
| 137 |
+
self.logger.error(f"❌ VECTORSTORE: Connection failed: {e}")
|
| 138 |
+
|
| 139 |
+
# Fix LLM client to use the same provider as chatbot
|
| 140 |
+
self.logger.info("🔧 LLM: Fixing PipelineManager LLM client...")
|
| 141 |
+
self.pipeline_manager.llm_client = self.llm_adapter
|
| 142 |
+
self.logger.info("✅ LLM: PipelineManager now uses same LLM as chatbot")
|
| 143 |
+
|
| 144 |
+
self.logger.info("✅ PIPELINE: PipelineManager initialized")
|
| 145 |
+
|
| 146 |
+
# Available metadata for filtering
|
| 147 |
+
self.available_metadata = {
|
| 148 |
+
'sources': [
|
| 149 |
+
'KCCA', 'MAAIF', 'MWTS', 'Gulu DLG', 'Kalangala DLG', 'Namutumba DLG',
|
| 150 |
+
'Lwengo DLG', 'Kiboga DLG', 'Annual Consolidated OAG', 'Consolidated',
|
| 151 |
+
'Hospital', 'Local Government', 'Ministry, Department and Agency',
|
| 152 |
+
'Project', 'Thematic', 'Value for Money'
|
| 153 |
+
],
|
| 154 |
+
'years': ['2018', '2019', '2020', '2021', '2022', '2023', '2024', '2025'],
|
| 155 |
+
'districts': [
|
| 156 |
+
'Gulu', 'Kalangala', 'Kampala', 'Namutumba', 'Lwengo', 'Kiboga',
|
| 157 |
+
'Fort Portal', 'Arua', 'Kasese', 'Kabale', 'Masindi', 'Mbale', 'Jinja', 'Masaka', 'Mbarara',
|
| 158 |
+
'KCCA'
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Try to load district whitelist from filter_options.json
|
| 163 |
+
try:
|
| 164 |
+
fo = Path("filter_options.json")
|
| 165 |
+
if fo.exists():
|
| 166 |
+
with open(fo) as f:
|
| 167 |
+
data = json.load(f)
|
| 168 |
+
if isinstance(data, dict) and data.get("districts"):
|
| 169 |
+
self.district_whitelist = [d.strip() for d in data["districts"] if d]
|
| 170 |
+
else:
|
| 171 |
+
self.district_whitelist = self.available_metadata['districts']
|
| 172 |
+
else:
|
| 173 |
+
self.district_whitelist = self.available_metadata['districts']
|
| 174 |
+
except Exception:
|
| 175 |
+
self.district_whitelist = self.available_metadata['districts']
|
| 176 |
+
|
| 177 |
+
# Enrich whitelist from add_district_metadata.py if available
|
| 178 |
+
try:
|
| 179 |
+
from add_district_metadata import DistrictMetadataProcessor
|
| 180 |
+
proc = DistrictMetadataProcessor()
|
| 181 |
+
names = set()
|
| 182 |
+
for key, mapping in proc.district_mappings.items():
|
| 183 |
+
if getattr(mapping, 'is_district', True):
|
| 184 |
+
names.add(mapping.name)
|
| 185 |
+
if names:
|
| 186 |
+
# Merge while preserving order: existing first, then new ones not present
|
| 187 |
+
merged = list(self.district_whitelist)
|
| 188 |
+
for n in sorted(names):
|
| 189 |
+
if n not in merged:
|
| 190 |
+
merged.append(n)
|
| 191 |
+
self.district_whitelist = merged
|
| 192 |
+
self.logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
|
| 193 |
+
except Exception as e:
|
| 194 |
+
self.logger.info(f"ℹ️ Could not enrich districts from add_district_metadata: {e}")
|
| 195 |
+
|
| 196 |
+
# Get dynamic year list from filter_options.json
|
| 197 |
+
try:
|
| 198 |
+
fo = Path("filter_options.json")
|
| 199 |
+
if fo.exists():
|
| 200 |
+
with open(fo) as f:
|
| 201 |
+
data = json.load(f)
|
| 202 |
+
if isinstance(data, dict) and data.get("years"):
|
| 203 |
+
self.year_whitelist = [str(y).strip() for y in data["years"] if y]
|
| 204 |
+
else:
|
| 205 |
+
self.year_whitelist = self.available_metadata['years']
|
| 206 |
+
else:
|
| 207 |
+
self.year_whitelist = self.available_metadata['years']
|
| 208 |
+
except Exception:
|
| 209 |
+
self.year_whitelist = self.available_metadata['years']
|
| 210 |
+
|
| 211 |
+
# Calculate current year dynamically
|
| 212 |
+
from datetime import datetime
|
| 213 |
+
self.current_year = str(datetime.now().year)
|
| 214 |
+
self.previous_year = str(datetime.now().year - 1)
|
| 215 |
+
|
| 216 |
+
# Data context for system prompt
|
| 217 |
+
self.data_context = self._load_data_context()
|
| 218 |
+
|
| 219 |
+
# Build the LangGraph
|
| 220 |
+
self.graph = self._build_graph()
|
| 221 |
+
|
| 222 |
+
# Conversation logging
|
| 223 |
+
self.conversations_dir = Path("conversations")
|
| 224 |
+
self.conversations_dir.mkdir(exist_ok=True)
|
| 225 |
+
|
| 226 |
+
def _load_data_context(self) -> str:
|
| 227 |
+
"""Load and analyze data context for system prompt"""
|
| 228 |
+
try:
|
| 229 |
+
# Try to load from generated context file
|
| 230 |
+
context_file = Path("data_context.md")
|
| 231 |
+
if context_file.exists():
|
| 232 |
+
with open(context_file) as f:
|
| 233 |
+
return f.read()
|
| 234 |
+
|
| 235 |
+
# Fallback to basic analysis
|
| 236 |
+
reports_dir = Path("reports")
|
| 237 |
+
testset_dir = Path("outputs/datasets/testset")
|
| 238 |
+
|
| 239 |
+
context_parts = []
|
| 240 |
+
|
| 241 |
+
# Report analysis
|
| 242 |
+
if reports_dir.exists():
|
| 243 |
+
report_folders = [d for d in reports_dir.iterdir() if d.is_dir()]
|
| 244 |
+
context_parts.append(f"📊 Available Reports: {len(report_folders)} audit report folders")
|
| 245 |
+
|
| 246 |
+
# Get year range
|
| 247 |
+
years = []
|
| 248 |
+
for folder in report_folders:
|
| 249 |
+
if "2018" in folder.name:
|
| 250 |
+
years.append("2018")
|
| 251 |
+
elif "2019" in folder.name:
|
| 252 |
+
years.append("2019")
|
| 253 |
+
elif "2020" in folder.name:
|
| 254 |
+
years.append("2020")
|
| 255 |
+
elif "2021" in folder.name:
|
| 256 |
+
years.append("2021")
|
| 257 |
+
elif "2022" in folder.name:
|
| 258 |
+
years.append("2022")
|
| 259 |
+
elif "2023" in folder.name:
|
| 260 |
+
years.append("2023")
|
| 261 |
+
|
| 262 |
+
if years:
|
| 263 |
+
context_parts.append(f"📅 Years covered: {', '.join(sorted(set(years)))}")
|
| 264 |
+
|
| 265 |
+
# Test dataset analysis
|
| 266 |
+
if testset_dir.exists():
|
| 267 |
+
test_files = list(testset_dir.glob("*.json"))
|
| 268 |
+
context_parts.append(f"🧪 Test dataset: {len(test_files)} files with sample questions")
|
| 269 |
+
|
| 270 |
+
return "\n".join(context_parts) if context_parts else "📊 Audit report database with comprehensive coverage"
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
self.logger.warning(f"⚠️ Could not load data context: {e}")
|
| 274 |
+
return "📊 Comprehensive audit report database"
|
| 275 |
+
|
| 276 |
+
def _build_graph(self) -> StateGraph:
|
| 277 |
+
"""Build the LangGraph for intelligent conversation flow"""
|
| 278 |
+
|
| 279 |
+
# Define the graph
|
| 280 |
+
workflow = StateGraph(ConversationState)
|
| 281 |
+
|
| 282 |
+
# Add nodes
|
| 283 |
+
workflow.add_node("analyze_query", self._analyze_query)
|
| 284 |
+
workflow.add_node("decide_action", self._decide_action)
|
| 285 |
+
workflow.add_node("perform_rag", self._perform_rag)
|
| 286 |
+
workflow.add_node("ask_follow_up", self._ask_follow_up)
|
| 287 |
+
workflow.add_node("generate_response", self._generate_response)
|
| 288 |
+
|
| 289 |
+
# Add edges
|
| 290 |
+
workflow.add_edge("analyze_query", "decide_action")
|
| 291 |
+
|
| 292 |
+
# Conditional edges from decide_action
|
| 293 |
+
workflow.add_conditional_edges(
|
| 294 |
+
"decide_action",
|
| 295 |
+
self._should_perform_rag,
|
| 296 |
+
{
|
| 297 |
+
"rag": "perform_rag",
|
| 298 |
+
"follow_up": "ask_follow_up"
|
| 299 |
+
}
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# From perform_rag, go to generate_response
|
| 303 |
+
workflow.add_edge("perform_rag", "generate_response")
|
| 304 |
+
|
| 305 |
+
# From ask_follow_up, end
|
| 306 |
+
workflow.add_edge("ask_follow_up", END)
|
| 307 |
+
|
| 308 |
+
# From generate_response, end
|
| 309 |
+
workflow.add_edge("generate_response", END)
|
| 310 |
+
|
| 311 |
+
# Set entry point
|
| 312 |
+
workflow.set_entry_point("analyze_query")
|
| 313 |
+
|
| 314 |
+
return workflow.compile()
|
| 315 |
+
|
| 316 |
+
def _extract_districts_list(self, text: str) -> List[str]:
|
| 317 |
+
"""Extract one or more districts from free text using whitelist matching.
|
| 318 |
+
- Case-insensitive substring match for each known district name
|
| 319 |
+
- Handles multi-district inputs like "Lwengo Kiboga District & Namutumba"
|
| 320 |
+
"""
|
| 321 |
+
if not text:
|
| 322 |
+
return []
|
| 323 |
+
q = text.lower()
|
| 324 |
+
found: List[str] = []
|
| 325 |
+
for name in self.district_whitelist:
|
| 326 |
+
n = name.lower()
|
| 327 |
+
if n in q:
|
| 328 |
+
# Map Kampala -> KCCA canonical
|
| 329 |
+
canonical = 'KCCA' if name.lower() == 'kampala' else name
|
| 330 |
+
if canonical not in found:
|
| 331 |
+
found.append(canonical)
|
| 332 |
+
return found
|
| 333 |
+
|
| 334 |
+
def _extract_years_list(self, text: str) -> List[str]:
|
| 335 |
+
"""Extract year list from text, supporting forms like '2022 / 23', '2022-2023', '2022–23'."""
|
| 336 |
+
if not text:
|
| 337 |
+
return []
|
| 338 |
+
years: List[str] = []
|
| 339 |
+
q = text
|
| 340 |
+
# Full 4-digit years
|
| 341 |
+
for y in re.findall(r"\b(20\d{2})\b", q):
|
| 342 |
+
if y not in years:
|
| 343 |
+
years.append(y)
|
| 344 |
+
# Shorthand like 2022/23 or 2022-23
|
| 345 |
+
for m in re.finditer(r"\b(20\d{2})\s*[\-/–]\s*(\d{2})\b", q):
|
| 346 |
+
y1 = m.group(1)
|
| 347 |
+
y2_short = int(m.group(2))
|
| 348 |
+
y2 = f"20{y2_short:02d}"
|
| 349 |
+
for y in [y1, y2]:
|
| 350 |
+
if y not in years:
|
| 351 |
+
years.append(y)
|
| 352 |
+
return years
|
| 353 |
+
|
| 354 |
+
def _analyze_query(self, state: ConversationState) -> ConversationState:
|
| 355 |
+
"""Analyze the user query with conversation context"""
|
| 356 |
+
|
| 357 |
+
query = state["current_query"]
|
| 358 |
+
conversation_context = state.get("conversation_context", {})
|
| 359 |
+
|
| 360 |
+
self.logger.info(f"🧠 QUERY ANALYSIS: Starting analysis for: '{query[:50]}...'")
|
| 361 |
+
|
| 362 |
+
# Build conversation context for analysis
|
| 363 |
+
context_info = ""
|
| 364 |
+
if conversation_context:
|
| 365 |
+
context_info = f"\n\nConversation context:\n"
|
| 366 |
+
for key, value in conversation_context.items():
|
| 367 |
+
if value:
|
| 368 |
+
context_info += f"- {key}: {value}\n"
|
| 369 |
+
|
| 370 |
+
# Also include recent conversation messages for better context
|
| 371 |
+
recent_messages = state.get("messages", [])
|
| 372 |
+
if recent_messages and len(recent_messages) > 1:
|
| 373 |
+
context_info += f"\nRecent conversation:\n"
|
| 374 |
+
# Get last 3 messages for context
|
| 375 |
+
for msg in recent_messages[-3:]:
|
| 376 |
+
if hasattr(msg, 'content'):
|
| 377 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 378 |
+
context_info += f"- {role}: {msg.content[:100]}...\n"
|
| 379 |
+
|
| 380 |
+
# Create analysis prompt with data context
|
| 381 |
+
analysis_prompt = ChatPromptTemplate.from_messages([
|
| 382 |
+
SystemMessage(content=f"""You are an expert at analyzing audit report queries. Your job is to extract specific information and determine if a query can be answered directly.
|
| 383 |
+
|
| 384 |
+
{self.data_context}
|
| 385 |
+
|
| 386 |
+
DISTRICT RECOGNITION RULES:
|
| 387 |
+
- Kampala = KCCA (Kampala Capital City Authority)
|
| 388 |
+
- Available districts: {', '.join(self.district_whitelist[:15])}... (and {len(self.district_whitelist)-15} more)
|
| 389 |
+
- DLG = District Local Government
|
| 390 |
+
- Uganda has {len(self.district_whitelist)} districts - recognize common ones
|
| 391 |
+
|
| 392 |
+
SOURCE RECOGNITION RULES:
|
| 393 |
+
- KCCA = Kampala Capital City Authority
|
| 394 |
+
- MAAIF = Ministry of Agriculture, Animal Industry and Fisheries
|
| 395 |
+
- MWTS = Ministry of Works and Transport
|
| 396 |
+
- OAG = Office of the Auditor General
|
| 397 |
+
- Consolidated = Annual Consolidated reports
|
| 398 |
+
|
| 399 |
+
YEAR RECOGNITION RULES:
|
| 400 |
+
- Available years: {', '.join(self.year_whitelist)}
|
| 401 |
+
- Current year is {self.current_year} - use this to reason about relative years
|
| 402 |
+
- If user mentions "last year", "previous year" - infer {self.previous_year}
|
| 403 |
+
- If user mentions "this year", "current year" - infer {self.current_year}
|
| 404 |
+
|
| 405 |
+
Analysis rules:
|
| 406 |
+
1. Be SMART - if you have enough context to search, do it
|
| 407 |
+
2. Use conversation context to fill in missing information
|
| 408 |
+
3. For budget/expenditure queries, try to infer missing details from context
|
| 409 |
+
4. Current year is {self.current_year} - use this to reason about relative years
|
| 410 |
+
5. If user mentions "last year", "previous year" - infer {self.previous_year}
|
| 411 |
+
6. If user mentions "this year", "current year" - infer {self.current_year}
|
| 412 |
+
7. If user mentions a department/ministry, infer the source
|
| 413 |
+
8. If user is getting frustrated or asking for results, proceed with RAG even if not perfect
|
| 414 |
+
9. Recognize Kampala as a district (KCCA)
|
| 415 |
+
|
| 416 |
+
IMPORTANT: You must respond with ONLY valid JSON. No additional text.
|
| 417 |
+
|
| 418 |
+
Return your analysis as JSON with these exact fields:
|
| 419 |
+
{{
|
| 420 |
+
"has_district": boolean,
|
| 421 |
+
"has_source": boolean,
|
| 422 |
+
"has_year": boolean,
|
| 423 |
+
"extracted_district": "string or null",
|
| 424 |
+
"extracted_source": "string or null",
|
| 425 |
+
"extracted_year": "string or null",
|
| 426 |
+
"confidence_score": 0.0-1.0,
|
| 427 |
+
"can_answer_directly": boolean,
|
| 428 |
+
"missing_filters": ["list", "of", "missing", "filters"],
|
| 429 |
+
"suggested_follow_up": "string or null",
|
| 430 |
+
"expanded_query": "string or null"
|
| 431 |
+
}}
|
| 432 |
+
|
| 433 |
+
The expanded_query should be a natural language query that combines the original question with any inferred context for better RAG retrieval."""),
|
| 434 |
+
HumanMessage(content=f"Analyze this query: '{query}'{context_info}")
|
| 435 |
+
])
|
| 436 |
+
|
| 437 |
+
# Get analysis from LLM
|
| 438 |
+
response = self.llm.invoke(analysis_prompt.format_messages())
|
| 439 |
+
|
| 440 |
+
try:
|
| 441 |
+
# Clean the response content to extract JSON
|
| 442 |
+
content = response.content.strip()
|
| 443 |
+
|
| 444 |
+
# Try to find JSON in the response
|
| 445 |
+
if content.startswith('{') and content.endswith('}'):
|
| 446 |
+
json_content = content
|
| 447 |
+
else:
|
| 448 |
+
# Try to extract JSON from the response
|
| 449 |
+
import re
|
| 450 |
+
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
| 451 |
+
if json_match:
|
| 452 |
+
json_content = json_match.group()
|
| 453 |
+
else:
|
| 454 |
+
raise json.JSONDecodeError("No JSON found in response", content, 0)
|
| 455 |
+
|
| 456 |
+
# Parse JSON response
|
| 457 |
+
analysis_data = json.loads(json_content)
|
| 458 |
+
|
| 459 |
+
query_analysis = QueryAnalysis(
|
| 460 |
+
has_district=analysis_data.get("has_district", False),
|
| 461 |
+
has_source=analysis_data.get("has_source", False),
|
| 462 |
+
has_year=analysis_data.get("has_year", False),
|
| 463 |
+
extracted_district=analysis_data.get("extracted_district"),
|
| 464 |
+
extracted_source=analysis_data.get("extracted_source"),
|
| 465 |
+
extracted_year=analysis_data.get("extracted_year"),
|
| 466 |
+
confidence_score=analysis_data.get("confidence_score", 0.0),
|
| 467 |
+
can_answer_directly=analysis_data.get("can_answer_directly", False),
|
| 468 |
+
missing_filters=analysis_data.get("missing_filters", []),
|
| 469 |
+
suggested_follow_up=analysis_data.get("suggested_follow_up"),
|
| 470 |
+
expanded_query=analysis_data.get("expanded_query")
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
| 474 |
+
self.logger.info(f"⚠️ JSON parsing failed: {e}")
|
| 475 |
+
# Fallback analysis - be more permissive
|
| 476 |
+
query_lower = query.lower()
|
| 477 |
+
|
| 478 |
+
# Simple keyword matching - improved district recognition
|
| 479 |
+
has_district = any(district.lower() in query_lower for district in [
|
| 480 |
+
'gulu', 'kalangala', 'kampala', 'namutumba', 'lwengo', 'kiboga', 'kcca', 'maaif', 'mwts'
|
| 481 |
+
])
|
| 482 |
+
|
| 483 |
+
# Special case: Kampala = KCCA
|
| 484 |
+
if 'kampala' in query_lower and not has_district:
|
| 485 |
+
has_district = True
|
| 486 |
+
|
| 487 |
+
has_source = any(source.lower() in query_lower for source in [
|
| 488 |
+
'kcca', 'maaif', 'mwts', 'gulu', 'kalangala', 'consolidated', 'oag', 'government'
|
| 489 |
+
])
|
| 490 |
+
|
| 491 |
+
# Check for year mentions using dynamic year list
|
| 492 |
+
has_year = any(year in query_lower for year in self.year_whitelist)
|
| 493 |
+
|
| 494 |
+
# Also check for explicit relative year terms
|
| 495 |
+
has_year = has_year or any(term in query_lower for term in [
|
| 496 |
+
'this year', 'last year', 'previous year', 'current year'
|
| 497 |
+
])
|
| 498 |
+
|
| 499 |
+
# Extract specific values
|
| 500 |
+
extracted_district = None
|
| 501 |
+
extracted_source = None
|
| 502 |
+
extracted_year = None
|
| 503 |
+
|
| 504 |
+
# Extract districts using comprehensive whitelist
|
| 505 |
+
for district_name in self.district_whitelist:
|
| 506 |
+
if district_name.lower() in query_lower:
|
| 507 |
+
extracted_district = district_name
|
| 508 |
+
break
|
| 509 |
+
|
| 510 |
+
# Also check common aliases
|
| 511 |
+
district_aliases = {
|
| 512 |
+
'kampala': 'Kampala',
|
| 513 |
+
'kcca': 'Kampala',
|
| 514 |
+
'gulu': 'Gulu',
|
| 515 |
+
'kalangala': 'Kalangala'
|
| 516 |
+
}
|
| 517 |
+
for alias, full_name in district_aliases.items():
|
| 518 |
+
if alias in query_lower and not extracted_district:
|
| 519 |
+
extracted_district = full_name
|
| 520 |
+
break
|
| 521 |
+
|
| 522 |
+
for source in ['kcca', 'maaif', 'mwts', 'consolidated', 'oag']:
|
| 523 |
+
if source in query_lower:
|
| 524 |
+
extracted_source = source.upper()
|
| 525 |
+
break
|
| 526 |
+
|
| 527 |
+
# Extract year using dynamic year list
|
| 528 |
+
for year in self.year_whitelist:
|
| 529 |
+
if year in query_lower:
|
| 530 |
+
extracted_year = year
|
| 531 |
+
has_year = True
|
| 532 |
+
break
|
| 533 |
+
|
| 534 |
+
# Only handle relative year terms if explicitly mentioned
|
| 535 |
+
if not extracted_year:
|
| 536 |
+
if 'last year' in query_lower or 'previous year' in query_lower:
|
| 537 |
+
extracted_year = self.previous_year
|
| 538 |
+
has_year = True
|
| 539 |
+
elif 'this year' in query_lower or 'current year' in query_lower:
|
| 540 |
+
extracted_year = self.current_year
|
| 541 |
+
has_year = True
|
| 542 |
+
elif 'recent' in query_lower and 'year' in query_lower:
|
| 543 |
+
# Use the most recent year from available data
|
| 544 |
+
extracted_year = max(self.year_whitelist) if self.year_whitelist else self.previous_year
|
| 545 |
+
has_year = True
|
| 546 |
+
|
| 547 |
+
# Be more permissive - if we have some context, try to answer
|
| 548 |
+
missing_filters = []
|
| 549 |
+
if not has_district:
|
| 550 |
+
missing_filters.append("district")
|
| 551 |
+
if not has_source:
|
| 552 |
+
missing_filters.append("source")
|
| 553 |
+
if not has_year:
|
| 554 |
+
missing_filters.append("year")
|
| 555 |
+
|
| 556 |
+
# If user seems frustrated or asking for results, be more permissive
|
| 557 |
+
frustration_indicators = ['already', 'just said', 'specified', 'provided', 'crazy', 'answer']
|
| 558 |
+
is_frustrated = any(indicator in query_lower for indicator in frustration_indicators)
|
| 559 |
+
|
| 560 |
+
can_answer_directly = len(missing_filters) <= 1 or is_frustrated # More permissive
|
| 561 |
+
confidence_score = 0.8 if can_answer_directly else 0.3
|
| 562 |
+
|
| 563 |
+
# Generate follow-up suggestion
|
| 564 |
+
if missing_filters and not is_frustrated:
|
| 565 |
+
if "district" in missing_filters and "source" in missing_filters:
|
| 566 |
+
suggested_follow_up = "I'd be happy to help you with that information! Could you please specify which district and department/ministry you're asking about?"
|
| 567 |
+
elif "district" in missing_filters:
|
| 568 |
+
suggested_follow_up = "Thanks for your question! Could you please specify which district you're asking about?"
|
| 569 |
+
elif "source" in missing_filters:
|
| 570 |
+
suggested_follow_up = "I can help you with that! Could you please specify which department or ministry you're asking about?"
|
| 571 |
+
elif "year" in missing_filters:
|
| 572 |
+
suggested_follow_up = "Great question! Could you please specify which year you're interested in?"
|
| 573 |
+
else:
|
| 574 |
+
suggested_follow_up = "Could you please provide more specific details to help me give you a precise answer?"
|
| 575 |
+
else:
|
| 576 |
+
suggested_follow_up = None
|
| 577 |
+
|
| 578 |
+
# Create expanded query
|
| 579 |
+
expanded_query = query
|
| 580 |
+
if extracted_district:
|
| 581 |
+
expanded_query += f" for {extracted_district} district"
|
| 582 |
+
if extracted_source:
|
| 583 |
+
expanded_query += f" from {extracted_source}"
|
| 584 |
+
if extracted_year:
|
| 585 |
+
expanded_query += f" in {extracted_year}"
|
| 586 |
+
|
| 587 |
+
query_analysis = QueryAnalysis(
|
| 588 |
+
has_district=has_district,
|
| 589 |
+
has_source=has_source,
|
| 590 |
+
has_year=has_year,
|
| 591 |
+
extracted_district=extracted_district,
|
| 592 |
+
extracted_source=extracted_source,
|
| 593 |
+
extracted_year=extracted_year,
|
| 594 |
+
confidence_score=confidence_score,
|
| 595 |
+
can_answer_directly=can_answer_directly,
|
| 596 |
+
missing_filters=missing_filters,
|
| 597 |
+
suggested_follow_up=suggested_follow_up,
|
| 598 |
+
expanded_query=expanded_query
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Update conversation context
|
| 602 |
+
if query_analysis.extracted_district:
|
| 603 |
+
conversation_context["district"] = query_analysis.extracted_district
|
| 604 |
+
if query_analysis.extracted_source:
|
| 605 |
+
conversation_context["source"] = query_analysis.extracted_source
|
| 606 |
+
if query_analysis.extracted_year:
|
| 607 |
+
conversation_context["year"] = query_analysis.extracted_year
|
| 608 |
+
|
| 609 |
+
state["query_analysis"] = query_analysis
|
| 610 |
+
state["conversation_context"] = conversation_context
|
| 611 |
+
|
| 612 |
+
self.logger.info(f"✅ ANALYSIS COMPLETE: district={query_analysis.has_district}, source={query_analysis.has_source}, year={query_analysis.has_year}")
|
| 613 |
+
self.logger.info(f"📈 Confidence: {query_analysis.confidence_score:.2f}, Can answer directly: {query_analysis.can_answer_directly}")
|
| 614 |
+
if query_analysis.expanded_query:
|
| 615 |
+
self.logger.info(f"🔄 Expanded query: {query_analysis.expanded_query}")
|
| 616 |
+
|
| 617 |
+
return state
|
| 618 |
+
|
| 619 |
+
def _decide_action(self, state: ConversationState) -> ConversationState:
|
| 620 |
+
"""Decide what action to take based on query analysis"""
|
| 621 |
+
|
| 622 |
+
analysis = state["query_analysis"]
|
| 623 |
+
|
| 624 |
+
# Add decision reasoning
|
| 625 |
+
if analysis.can_answer_directly and analysis.confidence_score > 0.7:
|
| 626 |
+
self.logger.info(f"🚀 DECISION: Query is complete, proceeding with RAG")
|
| 627 |
+
self.logger.info(f"📊 REASONING: Confidence={analysis.confidence_score:.2f}, Missing filters={len(analysis.missing_filters or [])}")
|
| 628 |
+
if analysis.missing_filters:
|
| 629 |
+
self.logger.info(f"📋 Missing: {', '.join(analysis.missing_filters)}")
|
| 630 |
+
else:
|
| 631 |
+
self.logger.info(f"✅ All required information available")
|
| 632 |
+
else:
|
| 633 |
+
self.logger.info(f"❓ DECISION: Query incomplete, asking follow-up")
|
| 634 |
+
self.logger.info(f"📊 REASONING: Confidence={analysis.confidence_score:.2f}, Missing filters={len(analysis.missing_filters or [])}")
|
| 635 |
+
if analysis.missing_filters:
|
| 636 |
+
self.logger.info(f"📋 Missing: {', '.join(analysis.missing_filters)}")
|
| 637 |
+
self.logger.info(f"💡 Follow-up needed: {analysis.suggested_follow_up}")
|
| 638 |
+
|
| 639 |
+
return state
|
| 640 |
+
|
| 641 |
+
def _should_perform_rag(self, state: ConversationState) -> str:
|
| 642 |
+
"""Determine whether to perform RAG or ask follow-up"""
|
| 643 |
+
|
| 644 |
+
analysis = state["query_analysis"]
|
| 645 |
+
conversation_context = state.get("conversation_context", {})
|
| 646 |
+
recent_messages = state.get("messages", [])
|
| 647 |
+
|
| 648 |
+
# Check if we have enough context from conversation history
|
| 649 |
+
has_district_context = analysis.has_district or conversation_context.get("district")
|
| 650 |
+
has_source_context = analysis.has_source or conversation_context.get("source")
|
| 651 |
+
has_year_context = analysis.has_year or conversation_context.get("year")
|
| 652 |
+
|
| 653 |
+
# Count how many context pieces we have
|
| 654 |
+
context_count = sum([bool(has_district_context), bool(has_source_context), bool(has_year_context)])
|
| 655 |
+
|
| 656 |
+
# For PDM queries, we need more specific information
|
| 657 |
+
current_query = state["current_query"].lower()
|
| 658 |
+
recent_messages = state.get("messages", [])
|
| 659 |
+
|
| 660 |
+
# Check if this is a PDM query by looking at current query OR recent conversation
|
| 661 |
+
is_pdm_query = "pdm" in current_query or "parish development" in current_query
|
| 662 |
+
|
| 663 |
+
# Also check recent messages for PDM context
|
| 664 |
+
if not is_pdm_query and recent_messages:
|
| 665 |
+
for msg in recent_messages[-3:]: # Check last 3 messages
|
| 666 |
+
if isinstance(msg, HumanMessage) and ("pdm" in msg.content.lower() or "parish development" in msg.content.lower()):
|
| 667 |
+
is_pdm_query = True
|
| 668 |
+
break
|
| 669 |
+
|
| 670 |
+
if is_pdm_query:
|
| 671 |
+
# For PDM queries, we need district AND year to be specific enough
|
| 672 |
+
# But we need them to be explicitly provided in the current conversation, not just inferred
|
| 673 |
+
if has_district_context and has_year_context:
|
| 674 |
+
# Check if both district and year are explicitly mentioned in recent messages
|
| 675 |
+
explicit_district = False
|
| 676 |
+
explicit_year = False
|
| 677 |
+
|
| 678 |
+
for msg in recent_messages[-3:]: # Check last 3 messages
|
| 679 |
+
if isinstance(msg, HumanMessage):
|
| 680 |
+
content = msg.content.lower()
|
| 681 |
+
if any(district in content for district in ["gulu", "kalangala", "kampala", "namutumba"]):
|
| 682 |
+
explicit_district = True
|
| 683 |
+
if any(year in content for year in ["2022", "2023", "2022/23", "2023/24"]):
|
| 684 |
+
explicit_year = True
|
| 685 |
+
|
| 686 |
+
if explicit_district and explicit_year:
|
| 687 |
+
self.logger.info(f"🚀 DECISION: PDM query with explicit district and year, proceeding with RAG")
|
| 688 |
+
self.logger.info(f"📊 REASONING: PDM query - explicit_district={explicit_district}, explicit_year={explicit_year}")
|
| 689 |
+
return "rag"
|
| 690 |
+
else:
|
| 691 |
+
self.logger.info(f"❓ DECISION: PDM query needs explicit district and year, asking follow-up")
|
| 692 |
+
self.logger.info(f"📊 REASONING: PDM query - explicit_district={explicit_district}, explicit_year={explicit_year}")
|
| 693 |
+
return "follow_up"
|
| 694 |
+
else:
|
| 695 |
+
self.logger.info(f"❓ DECISION: PDM query needs more specific info, asking follow-up")
|
| 696 |
+
self.logger.info(f"📊 REASONING: PDM query - district={has_district_context}, year={has_year_context}")
|
| 697 |
+
return "follow_up"
|
| 698 |
+
|
| 699 |
+
# For general queries, be more conservative - need at least 2 pieces AND high confidence
|
| 700 |
+
if context_count >= 2 and analysis.confidence_score > 0.8:
|
| 701 |
+
self.logger.info(f"🚀 DECISION: Sufficient context with high confidence, proceeding with RAG")
|
| 702 |
+
self.logger.info(f"📊 REASONING: Context pieces: district={has_district_context}, source={has_source_context}, year={has_year_context}, confidence={analysis.confidence_score}")
|
| 703 |
+
return "rag"
|
| 704 |
+
|
| 705 |
+
# If user seems frustrated (short responses like "no"), proceed with RAG
|
| 706 |
+
if recent_messages and len(recent_messages) >= 3: # Need more messages to detect frustration
|
| 707 |
+
last_user_message = None
|
| 708 |
+
for msg in reversed(recent_messages):
|
| 709 |
+
if isinstance(msg, HumanMessage):
|
| 710 |
+
last_user_message = msg.content.lower().strip()
|
| 711 |
+
break
|
| 712 |
+
|
| 713 |
+
if last_user_message and len(last_user_message) < 10 and any(word in last_user_message for word in ["no", "yes", "ok", "sure"]):
|
| 714 |
+
self.logger.info(f"🚀 DECISION: User seems frustrated with short response, proceeding with RAG")
|
| 715 |
+
return "rag"
|
| 716 |
+
|
| 717 |
+
# Original logic for direct answers
|
| 718 |
+
if analysis.can_answer_directly and analysis.confidence_score > 0.7:
|
| 719 |
+
return "rag"
|
| 720 |
+
else:
|
| 721 |
+
return "follow_up"
|
| 722 |
+
|
| 723 |
+
def _ask_follow_up(self, state: ConversationState) -> ConversationState:
|
| 724 |
+
"""Generate a follow-up question to clarify missing information"""
|
| 725 |
+
|
| 726 |
+
analysis = state["query_analysis"]
|
| 727 |
+
current_query = state["current_query"].lower()
|
| 728 |
+
conversation_context = state.get("conversation_context", {})
|
| 729 |
+
|
| 730 |
+
# Check if this is a PDM query
|
| 731 |
+
is_pdm_query = "pdm" in current_query or "parish development" in current_query
|
| 732 |
+
|
| 733 |
+
if is_pdm_query:
|
| 734 |
+
# Generate PDM-specific follow-up questions
|
| 735 |
+
missing_info = []
|
| 736 |
+
|
| 737 |
+
if not analysis.has_district and not conversation_context.get("district"):
|
| 738 |
+
missing_info.append("district (e.g., Gulu, Kalangala)")
|
| 739 |
+
|
| 740 |
+
if not analysis.has_year and not conversation_context.get("year"):
|
| 741 |
+
missing_info.append("year (e.g., 2022, 2023)")
|
| 742 |
+
|
| 743 |
+
if missing_info:
|
| 744 |
+
follow_up_message = f"For PDM administrative costs information, I need to know the {', '.join(missing_info)}. Could you please specify these details?"
|
| 745 |
+
else:
|
| 746 |
+
follow_up_message = "Could you please provide more specific details about the PDM administrative costs you're looking for?"
|
| 747 |
+
else:
|
| 748 |
+
# Use the original follow-up logic
|
| 749 |
+
if analysis.suggested_follow_up:
|
| 750 |
+
follow_up_message = analysis.suggested_follow_up
|
| 751 |
+
else:
|
| 752 |
+
follow_up_message = "Could you please provide more specific details to help me give you a precise answer?"
|
| 753 |
+
|
| 754 |
+
state["final_response"] = follow_up_message
|
| 755 |
+
state["last_ai_message_time"] = time.time()
|
| 756 |
+
|
| 757 |
+
return state
|
| 758 |
+
|
| 759 |
+
def _build_comprehensive_query(self, current_query: str, analysis, conversation_context: dict, recent_messages: list) -> str:
|
| 760 |
+
"""Build a better RAG query from conversation.
|
| 761 |
+
- If latest message is a short modifier (e.g., "financial"), merge it into the last substantive question.
|
| 762 |
+
- If latest message looks like filters (district/year), keep the last question unchanged.
|
| 763 |
+
- Otherwise, use the current message.
|
| 764 |
+
"""
|
| 765 |
+
|
| 766 |
+
def is_interrogative(text: str) -> bool:
|
| 767 |
+
t = text.lower().strip()
|
| 768 |
+
return any(t.startswith(w) for w in ["what", "how", "why", "when", "where", "which", "who"]) or t.endswith("?")
|
| 769 |
+
|
| 770 |
+
def is_filter_like(text: str) -> bool:
|
| 771 |
+
t = text.lower()
|
| 772 |
+
if "district" in t:
|
| 773 |
+
return True
|
| 774 |
+
if re.search(r"\b20\d{2}\b", t) or re.search(r"20\d{2}\s*[\-/–]\s*\d{2}\b", t):
|
| 775 |
+
return True
|
| 776 |
+
if self._extract_districts_list(text):
|
| 777 |
+
return True
|
| 778 |
+
return False
|
| 779 |
+
|
| 780 |
+
# Find last substantive user question
|
| 781 |
+
last_question = None
|
| 782 |
+
for msg in reversed(recent_messages[:-1] if recent_messages else []):
|
| 783 |
+
if isinstance(msg, HumanMessage):
|
| 784 |
+
if is_interrogative(msg.content) and len(msg.content.strip()) > 15:
|
| 785 |
+
last_question = msg.content.strip()
|
| 786 |
+
break
|
| 787 |
+
|
| 788 |
+
cq = current_query.strip()
|
| 789 |
+
words = cq.split()
|
| 790 |
+
is_short_modifier = (not is_interrogative(cq)) and (len(words) <= 3)
|
| 791 |
+
|
| 792 |
+
if is_filter_like(cq) and last_question:
|
| 793 |
+
comprehensive_query = last_question
|
| 794 |
+
elif is_short_modifier and last_question:
|
| 795 |
+
modifier = cq
|
| 796 |
+
if modifier.lower() in last_question.lower():
|
| 797 |
+
comprehensive_query = last_question
|
| 798 |
+
else:
|
| 799 |
+
if last_question.endswith('?'):
|
| 800 |
+
comprehensive_query = last_question[:-1] + f" for {modifier}?"
|
| 801 |
+
else:
|
| 802 |
+
comprehensive_query = last_question + f" for {modifier}"
|
| 803 |
+
else:
|
| 804 |
+
comprehensive_query = current_query
|
| 805 |
+
|
| 806 |
+
self.logger.info(f"🔄 COMPREHENSIVE QUERY: '{comprehensive_query}'")
|
| 807 |
+
return comprehensive_query
|
| 808 |
+
|
| 809 |
+
def _rewrite_query_with_llm(self, recent_messages: list, draft_query: str) -> str:
|
| 810 |
+
"""Use the LLM to rewrite a clean, focused RAG query from the conversation.
|
| 811 |
+
Rules enforced in prompt:
|
| 812 |
+
- Keep the user's main information need from the last substantive question
|
| 813 |
+
- Integrate short modifiers (e.g., "financial") into that question when appropriate
|
| 814 |
+
- Do NOT include filter text (years/districts/sources) in the query; those are handled separately
|
| 815 |
+
- Return a single plain sentence only (no quotes, no markdown)
|
| 816 |
+
"""
|
| 817 |
+
try:
|
| 818 |
+
# Build a compact conversation transcript (last 6 messages max)
|
| 819 |
+
convo_lines = []
|
| 820 |
+
for msg in recent_messages[-6:]:
|
| 821 |
+
if isinstance(msg, HumanMessage):
|
| 822 |
+
convo_lines.append(f"User: {msg.content}")
|
| 823 |
+
elif isinstance(msg, AIMessage):
|
| 824 |
+
convo_lines.append(f"Assistant: {msg.content}")
|
| 825 |
+
|
| 826 |
+
convo_text = "\n".join(convo_lines)
|
| 827 |
+
|
| 828 |
+
"""
|
| 829 |
+
"DECISION GUIDANCE:\n"
|
| 830 |
+
"- If the latest user message looks like a modifier (e.g., 'financial'), merge it into the best prior question.\n"
|
| 831 |
+
"- If the latest message provides filters (e.g., districts, years), DO NOT embed them; keep the base question.\n"
|
| 832 |
+
"- If the latest message itself is a full, clear question, use it.\n"
|
| 833 |
+
"- If the draft query is already good, you may refine its clarity but keep the same intent.\n\n"
|
| 834 |
+
"""
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 838 |
+
SystemMessage(content=(
|
| 839 |
+
"ROLE: Query Rewriter for a RAG system.\n\n"
|
| 840 |
+
"PRIMARY OBJECTIVE:\n- Produce ONE retrieval-focused sentence that best represents the user's information need.\n"
|
| 841 |
+
"- Maximize recall of relevant evidence; be specific but not overconstrained.\n\n"
|
| 842 |
+
"INPUTS:\n- Conversation with User and Assistant turns (latest last).\n- A draft query (heuristic).\n\n"
|
| 843 |
+
"OPERATING PRINCIPLES:\n"
|
| 844 |
+
"1) Use the last substantive USER question as the backbone of intent.\n"
|
| 845 |
+
"2) Merge helpful domain modifiers from any USER turns (financial, procurement, risk) when they sharpen focus; ignore if not helpful.\n"
|
| 846 |
+
"3) Treat Assistant messages as guidance only; if the user later provided filters (years, districts, sources), DO NOT embed them in the query (filters are applied separately).\n"
|
| 847 |
+
"4) Remove meta-verbs like 'summarize', 'list', 'explain', 'compare' from the query.\n"
|
| 848 |
+
"5) Prefer content-bearing terms (topics, programs, outcomes) over task phrasing.\n"
|
| 849 |
+
"6) If the latest user message is filters-only, keep the prior substantive question unchanged.\n"
|
| 850 |
+
"7) If the draft query is already strong, refine wording for clarity but keep the same intent.\n\n"
|
| 851 |
+
"EXAMPLES (multi-turn):\n"
|
| 852 |
+
"A)\nUser: What are the top 5 priorities for improving audit procedures?\nAssistant: Could you specify the scope (e.g., financial, procurement)?\nUser: Financial\n→ Output: Top priorities for improving financial audit procedures.\n\n"
|
| 853 |
+
"B)\nUser: How were PDM administrative costs utilized and what was the impact of shortfalls?\nAssistant: Please specify district/year for precision.\nUser: Namutumba and Lwengo Districts (2022/23)\n→ Output: How were PDM administrative costs utilized and what was the impact of shortfalls.\n(Exclude districts/years; they are filters.)\n\n"
|
| 854 |
+
"C)\nUser: Summarize risk management issues in audit reports.\n→ Output: Key risk management issues in audit reports.\n\n"
|
| 855 |
+
"CONSTRAINTS:\n- Do NOT include filters (years, districts, sources, filenames).\n- Do NOT include quotes/markdown/bullets or multiple sentences.\n- Return exactly one plain sentence."
|
| 856 |
+
)),
|
| 857 |
+
HumanMessage(content=(
|
| 858 |
+
f"Conversation (most recent last):\n{convo_text}\n\n"
|
| 859 |
+
f"Draft query: {draft_query}\n\n"
|
| 860 |
+
"Rewrite the single best retrieval query sentence now:"
|
| 861 |
+
)),
|
| 862 |
+
])
|
| 863 |
+
|
| 864 |
+
# Add timeout for LLM call
|
| 865 |
+
import signal
|
| 866 |
+
|
| 867 |
+
def timeout_handler(signum, frame):
|
| 868 |
+
raise TimeoutError("LLM rewrite timeout")
|
| 869 |
+
|
| 870 |
+
# Set 10 second timeout
|
| 871 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
| 872 |
+
signal.alarm(10)
|
| 873 |
+
|
| 874 |
+
try:
|
| 875 |
+
resp = self.llm.invoke(prompt.format_messages())
|
| 876 |
+
signal.alarm(0) # Cancel timeout
|
| 877 |
+
|
| 878 |
+
rewritten = getattr(resp, 'content', '').strip()
|
| 879 |
+
# Basic sanitization: keep it one line
|
| 880 |
+
rewritten = rewritten.replace('\n', ' ').strip()
|
| 881 |
+
if rewritten and len(rewritten) > 5: # Basic quality check
|
| 882 |
+
self.logger.info(f"🛠️ LLM REWRITER: '{rewritten}'")
|
| 883 |
+
return rewritten
|
| 884 |
+
else:
|
| 885 |
+
self.logger.info(f"⚠️ LLM rewrite too short/empty, using draft query")
|
| 886 |
+
return draft_query
|
| 887 |
+
except TimeoutError:
|
| 888 |
+
signal.alarm(0)
|
| 889 |
+
self.logger.info(f"⚠️ LLM rewrite timeout after 10s, using draft query")
|
| 890 |
+
return draft_query
|
| 891 |
+
except Exception as e:
|
| 892 |
+
signal.alarm(0)
|
| 893 |
+
self.logger.info(f"⚠️ LLM rewrite failed, using draft query. Error: {e}")
|
| 894 |
+
return draft_query
|
| 895 |
+
except Exception as e:
|
| 896 |
+
self.logger.info(f"⚠️ LLM rewrite setup failed, using draft query. Error: {e}")
|
| 897 |
+
return draft_query
|
| 898 |
+
|
| 899 |
+
def _perform_rag(self, state: ConversationState) -> ConversationState:
|
| 900 |
+
"""Perform RAG retrieval with smart query expansion"""
|
| 901 |
+
|
| 902 |
+
query = state["current_query"]
|
| 903 |
+
analysis = state["query_analysis"]
|
| 904 |
+
conversation_context = state.get("conversation_context", {})
|
| 905 |
+
recent_messages = state.get("messages", [])
|
| 906 |
+
|
| 907 |
+
# Build comprehensive query from conversation history
|
| 908 |
+
draft_query = self._build_comprehensive_query(query, analysis, conversation_context, recent_messages)
|
| 909 |
+
# Let LLM rewrite a clean, focused search query
|
| 910 |
+
search_query = self._rewrite_query_with_llm(recent_messages, draft_query)
|
| 911 |
+
|
| 912 |
+
self.logger.info(f"🔍 RAG RETRIEVAL: Starting for query: '{search_query[:50]}...'")
|
| 913 |
+
self.logger.info(f"📊 Analysis: district={analysis.has_district}, source={analysis.has_source}, year={analysis.has_year}")
|
| 914 |
+
|
| 915 |
+
try:
|
| 916 |
+
# Build filters from analysis and conversation context
|
| 917 |
+
filters = {}
|
| 918 |
+
|
| 919 |
+
# Use conversation context to fill in missing filters
|
| 920 |
+
source = analysis.extracted_source or conversation_context.get("source")
|
| 921 |
+
district = analysis.extracted_district or conversation_context.get("district")
|
| 922 |
+
year = analysis.extracted_year or conversation_context.get("year")
|
| 923 |
+
|
| 924 |
+
if source:
|
| 925 |
+
filters["source"] = [source] # Qdrant expects lists
|
| 926 |
+
self.logger.info(f"🎯 Filter: source={source}")
|
| 927 |
+
|
| 928 |
+
if year:
|
| 929 |
+
filters["year"] = [year]
|
| 930 |
+
self.logger.info(f"🎯 Filter: year={year}")
|
| 931 |
+
|
| 932 |
+
if district:
|
| 933 |
+
# Map district to source if needed
|
| 934 |
+
if district.upper() == "KAMPALA":
|
| 935 |
+
filters["source"] = ["KCCA"]
|
| 936 |
+
self.logger.info(f"🎯 Filter: district={district} -> source=KCCA")
|
| 937 |
+
elif district.upper() in ["GULU", "KALANGALA"]:
|
| 938 |
+
filters["source"] = [f"{district.upper()} DLG"]
|
| 939 |
+
self.logger.info(f"🎯 Filter: district={district} -> source={district.upper()} DLG")
|
| 940 |
+
|
| 941 |
+
# Run RAG pipeline with correct parameters
|
| 942 |
+
result = self.pipeline_manager.run(
|
| 943 |
+
query=search_query, # Use expanded query
|
| 944 |
+
sources=filters.get("source") if filters.get("source") else None,
|
| 945 |
+
auto_infer_filters=False, # Our agent already handled filter inference
|
| 946 |
+
filters=filters if filters else None
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
self.logger.info(f"✅ RAG completed: Found {len(result.sources)} sources")
|
| 950 |
+
self.logger.info(f"⏱️ Execution time: {result.execution_time:.2f}s")
|
| 951 |
+
|
| 952 |
+
# Store RAG result in state
|
| 953 |
+
state["rag_result"] = result
|
| 954 |
+
state["rag_query"] = search_query
|
| 955 |
+
|
| 956 |
+
except Exception as e:
|
| 957 |
+
self.logger.info(f"❌ RAG retrieval failed: {e}")
|
| 958 |
+
state["rag_result"] = None
|
| 959 |
+
|
| 960 |
+
return state
|
| 961 |
+
|
| 962 |
+
def _generate_response(self, state: ConversationState) -> ConversationState:
|
| 963 |
+
"""Generate final response using RAG results"""
|
| 964 |
+
|
| 965 |
+
rag_result = state["rag_result"]
|
| 966 |
+
|
| 967 |
+
self.logger.info(f"📝 RESPONSE: Using RAG result ({len(rag_result.answer)} chars)")
|
| 968 |
+
|
| 969 |
+
# Store the final response directly from RAG
|
| 970 |
+
state["final_response"] = rag_result.answer
|
| 971 |
+
state["last_ai_message_time"] = time.time()
|
| 972 |
+
|
| 973 |
+
return state
|
| 974 |
+
|
| 975 |
+
def chat(self, user_input: str, conversation_id: str = "default") -> str:
|
| 976 |
+
"""Main chat interface with conversation management"""
|
| 977 |
+
|
| 978 |
+
self.logger.info(f"💬 CHAT: Processing user input: '{user_input[:50]}...'")
|
| 979 |
+
self.logger.info(f"📊 Session: {conversation_id}")
|
| 980 |
+
|
| 981 |
+
# Load conversation history
|
| 982 |
+
conversation_file = self.conversations_dir / f"{conversation_id}.json"
|
| 983 |
+
conversation = self._load_conversation(conversation_file)
|
| 984 |
+
|
| 985 |
+
# Add user message to conversation
|
| 986 |
+
conversation["messages"].append(HumanMessage(content=user_input))
|
| 987 |
+
|
| 988 |
+
self.logger.info(f"🔄 LANGGRAPH: Starting graph execution")
|
| 989 |
+
|
| 990 |
+
# Prepare state for LangGraph with conversation context
|
| 991 |
+
state = ConversationState(
|
| 992 |
+
conversation_id=conversation_id,
|
| 993 |
+
messages=conversation["messages"],
|
| 994 |
+
current_query=user_input,
|
| 995 |
+
query_analysis=None,
|
| 996 |
+
conversation_context=conversation.get("context", {}),
|
| 997 |
+
rag_result=None,
|
| 998 |
+
final_response=None,
|
| 999 |
+
session_start_time=conversation["session_start_time"],
|
| 1000 |
+
last_ai_message_time=conversation["last_ai_message_time"]
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
# Run the graph
|
| 1004 |
+
final_state = self.graph.invoke(state)
|
| 1005 |
+
|
| 1006 |
+
# Add the AI response to conversation
|
| 1007 |
+
if final_state["final_response"]:
|
| 1008 |
+
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
|
| 1009 |
+
|
| 1010 |
+
# Update conversation state
|
| 1011 |
+
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
|
| 1012 |
+
conversation["context"] = final_state["conversation_context"]
|
| 1013 |
+
|
| 1014 |
+
# Save conversation
|
| 1015 |
+
self._save_conversation(conversation_file, conversation)
|
| 1016 |
+
|
| 1017 |
+
self.logger.info(f"✅ LANGGRAPH: Graph execution completed")
|
| 1018 |
+
self.logger.info(f"🎯 CHAT COMPLETE: Response ready")
|
| 1019 |
+
|
| 1020 |
+
# Return both response and RAG result for UI
|
| 1021 |
+
return {
|
| 1022 |
+
'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
|
| 1023 |
+
'rag_result': final_state["rag_result"],
|
| 1024 |
+
'actual_rag_query': final_state.get("rag_query", "")
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
|
| 1028 |
+
"""Load conversation from file"""
|
| 1029 |
+
if conversation_file.exists():
|
| 1030 |
+
try:
|
| 1031 |
+
with open(conversation_file) as f:
|
| 1032 |
+
data = json.load(f)
|
| 1033 |
+
# Convert message dicts back to LangChain messages
|
| 1034 |
+
messages = []
|
| 1035 |
+
for msg_data in data.get("messages", []):
|
| 1036 |
+
if msg_data["type"] == "human":
|
| 1037 |
+
messages.append(HumanMessage(content=msg_data["content"]))
|
| 1038 |
+
elif msg_data["type"] == "ai":
|
| 1039 |
+
messages.append(AIMessage(content=msg_data["content"]))
|
| 1040 |
+
data["messages"] = messages
|
| 1041 |
+
return data
|
| 1042 |
+
except Exception as e:
|
| 1043 |
+
self.logger.info(f"⚠️ Could not load conversation: {e}")
|
| 1044 |
+
|
| 1045 |
+
# Return default conversation
|
| 1046 |
+
return {
|
| 1047 |
+
"messages": [],
|
| 1048 |
+
"session_start_time": time.time(),
|
| 1049 |
+
"last_ai_message_time": time.time(),
|
| 1050 |
+
"context": {}
|
| 1051 |
+
}
|
| 1052 |
+
|
| 1053 |
+
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
|
| 1054 |
+
"""Save conversation to file"""
|
| 1055 |
+
try:
|
| 1056 |
+
# Convert LangChain messages to serializable format
|
| 1057 |
+
messages_data = []
|
| 1058 |
+
for msg in conversation["messages"]:
|
| 1059 |
+
if isinstance(msg, HumanMessage):
|
| 1060 |
+
messages_data.append({"type": "human", "content": msg.content})
|
| 1061 |
+
elif isinstance(msg, AIMessage):
|
| 1062 |
+
messages_data.append({"type": "ai", "content": msg.content})
|
| 1063 |
+
|
| 1064 |
+
data = {
|
| 1065 |
+
"messages": messages_data,
|
| 1066 |
+
"session_start_time": conversation["session_start_time"],
|
| 1067 |
+
"last_ai_message_time": conversation["last_ai_message_time"],
|
| 1068 |
+
"context": conversation.get("context", {}),
|
| 1069 |
+
"last_updated": datetime.now().isoformat()
|
| 1070 |
+
}
|
| 1071 |
+
|
| 1072 |
+
with open(conversation_file, "w") as f:
|
| 1073 |
+
json.dump(data, f, indent=2)
|
| 1074 |
+
|
| 1075 |
+
except Exception as e:
|
| 1076 |
+
self.logger.info(f"⚠️ Could not save conversation: {e}")
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
def get_chatbot():
|
| 1080 |
+
"""Get chatbot instance"""
|
| 1081 |
+
return IntelligentRAGChatbot()
|
| 1082 |
+
|
| 1083 |
+
if __name__ == "__main__":
|
| 1084 |
+
# Test the chatbot
|
| 1085 |
+
chatbot = IntelligentRAGChatbot()
|
| 1086 |
+
|
| 1087 |
+
# Test conversation
|
| 1088 |
+
test_queries = [
|
| 1089 |
+
"How much was the budget allocation for government salary payroll management?",
|
| 1090 |
+
"Namutumba district in 2023",
|
| 1091 |
+
"KCCA"
|
| 1092 |
+
]
|
| 1093 |
+
|
| 1094 |
+
for query in test_queries:
|
| 1095 |
+
self.logger.info(f"\n{'='*50}")
|
| 1096 |
+
self.logger.info(f"User: {query}")
|
| 1097 |
+
response = chatbot.chat(query)
|
| 1098 |
+
self.logger.info(f"Bot: {response}")
|
src/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audit QA Refactored Module
|
| 3 |
+
A modular and maintainable RAG pipeline for audit report analysis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .pipeline import PipelineManager
|
| 7 |
+
from .config.loader import load_config
|
| 8 |
+
|
| 9 |
+
__version__ = "2.0.0"
|
| 10 |
+
__all__ = ["PipelineManager", "load_config"]
|
src/config/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration management for Audit QA."""
|
| 2 |
+
|
| 3 |
+
from .loader import load_config, get_nested_config
|
| 4 |
+
|
| 5 |
+
__all__ = ["load_config", "get_nested_config"]
|
src/config/collections.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"docling": {
|
| 3 |
+
"model": "BAAI/bge-m3",
|
| 4 |
+
"description": "Default collection with BGE-M3 embedding model"
|
| 5 |
+
},
|
| 6 |
+
"modernbert-embed-base-akryl-matryoshka": {
|
| 7 |
+
"model": "Akryl/modernbert-embed-base-akryl-matryoshka",
|
| 8 |
+
"description": "ModernBERT embedding model with matryoshka representation"
|
| 9 |
+
},
|
| 10 |
+
"sentence-transformers-all-MiniLM-L6-v2": {
|
| 11 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 12 |
+
"description": "Sentence transformers MiniLM model"
|
| 13 |
+
},
|
| 14 |
+
"sentence-transformers-all-mpnet-base-v2": {
|
| 15 |
+
"model": "sentence-transformers/all-mpnet-base-v2",
|
| 16 |
+
"description": "Sentence transformers MPNet model"
|
| 17 |
+
},
|
| 18 |
+
"BAAI-bge-m3": {
|
| 19 |
+
"model": "BAAI/bge-m3",
|
| 20 |
+
"description": "BAAI BGE-M3 multilingual embedding model"
|
| 21 |
+
}
|
| 22 |
+
}
|
src/config/loader.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration loader for YAML settings."""
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Any, Optional
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
def load_config(config_path: str = None) -> Dict[str, Any]:
|
| 13 |
+
"""
|
| 14 |
+
Load configuration from YAML file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
config_path: Path to config file. If None, uses default settings.yaml
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Dictionary containing configuration settings
|
| 21 |
+
"""
|
| 22 |
+
if config_path is None:
|
| 23 |
+
# Default to settings.yaml in the same directory as this file
|
| 24 |
+
config_path = Path(__file__).parent / "settings.yaml"
|
| 25 |
+
|
| 26 |
+
config_path = Path(config_path)
|
| 27 |
+
|
| 28 |
+
if not config_path.exists():
|
| 29 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 30 |
+
|
| 31 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 32 |
+
content = f.read()
|
| 33 |
+
|
| 34 |
+
# Replace environment variables in the content
|
| 35 |
+
import os
|
| 36 |
+
import re
|
| 37 |
+
|
| 38 |
+
def replace_env_vars(match):
|
| 39 |
+
env_var = match.group(1)
|
| 40 |
+
return os.getenv(env_var, match.group(0)) # Return original if env var not found
|
| 41 |
+
|
| 42 |
+
# Replace ${VAR} patterns with environment variables
|
| 43 |
+
content = re.sub(r'\$\{([^}]+)\}', replace_env_vars, content)
|
| 44 |
+
|
| 45 |
+
config = yaml.safe_load(content)
|
| 46 |
+
|
| 47 |
+
# Override with environment variables if they exist
|
| 48 |
+
config = _override_with_env_vars(config)
|
| 49 |
+
|
| 50 |
+
return config
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _override_with_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
|
| 54 |
+
"""Override config values with environment variables where available."""
|
| 55 |
+
|
| 56 |
+
# Map environment variables to config paths
|
| 57 |
+
env_mappings = {
|
| 58 |
+
'QDRANT_URL': ['qdrant', 'url'],
|
| 59 |
+
'QDRANT_COLLECTION': ['qdrant', 'collection_name'],
|
| 60 |
+
'QDRANT_API_KEY': ['qdrant', 'api_key'],
|
| 61 |
+
'RETRIEVER_MODEL': ['retriever', 'model'],
|
| 62 |
+
'RANKER_MODEL': ['ranker', 'model'],
|
| 63 |
+
'READER_TYPE': ['reader', 'default_type'],
|
| 64 |
+
'MAX_TOKENS': ['reader', 'max_tokens'],
|
| 65 |
+
'MISTRAL_API_KEY': ['reader', 'MISTRAL', 'api_key'],
|
| 66 |
+
'OPENAI_API_KEY': ['reader', 'OPENAI', 'api_key'],
|
| 67 |
+
'NEBIUS_API_KEY': ['reader', 'INF_PROVIDERS', 'api_key'],
|
| 68 |
+
'NVIDIA_SERVER_API_KEY': ['reader', 'NVIDIA', 'api_key'],
|
| 69 |
+
'SERVERLESS_API_KEY': ['reader', 'SERVERLESS', 'api_key'],
|
| 70 |
+
'DEDICATED_API_KEY': ['reader', 'DEDICATED', 'api_key'],
|
| 71 |
+
'OPENROUTER_API_KEY': ['reader', 'OPENROUTER', 'api_key'],
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
for env_var, config_path in env_mappings.items():
|
| 75 |
+
env_value = os.getenv(env_var)
|
| 76 |
+
if env_value:
|
| 77 |
+
# Navigate to the nested config location
|
| 78 |
+
current = config
|
| 79 |
+
for key in config_path[:-1]:
|
| 80 |
+
if key not in current:
|
| 81 |
+
current[key] = {}
|
| 82 |
+
current = current[key]
|
| 83 |
+
|
| 84 |
+
# Set the final value, converting to appropriate type
|
| 85 |
+
final_key = config_path[-1]
|
| 86 |
+
if final_key in ['top_k', 'max_tokens', 'num_predict']:
|
| 87 |
+
current[final_key] = int(env_value)
|
| 88 |
+
elif final_key in ['normalize', 'prefer_grpc']:
|
| 89 |
+
current[final_key] = env_value.lower() in ('true', '1', 'yes')
|
| 90 |
+
elif final_key == 'temperature':
|
| 91 |
+
current[final_key] = float(env_value)
|
| 92 |
+
else:
|
| 93 |
+
current[final_key] = env_value
|
| 94 |
+
|
| 95 |
+
return config
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_nested_config(config: Dict[str, Any], path: str, default=None):
|
| 99 |
+
"""
|
| 100 |
+
Get a nested configuration value using dot notation.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
config: Configuration dictionary
|
| 104 |
+
path: Dot-separated path (e.g., 'reader.MISTRAL.model')
|
| 105 |
+
default: Default value if path not found
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Configuration value or default
|
| 109 |
+
"""
|
| 110 |
+
keys = path.split('.')
|
| 111 |
+
current = config
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
for key in keys:
|
| 115 |
+
current = current[key]
|
| 116 |
+
return current
|
| 117 |
+
except (KeyError, TypeError):
|
| 118 |
+
return default
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_collections_mapping() -> Dict[str, Dict[str, str]]:
|
| 122 |
+
"""Load collections mapping from JSON file."""
|
| 123 |
+
collections_file = Path(__file__).parent / "collections.json"
|
| 124 |
+
|
| 125 |
+
if not collections_file.exists():
|
| 126 |
+
# Return default mapping if file doesn't exist
|
| 127 |
+
return {
|
| 128 |
+
"docling": {
|
| 129 |
+
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
| 130 |
+
"description": "Default collection"
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
with open(collections_file, 'r') as f:
|
| 135 |
+
return json.load(f)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_embedding_model_for_collection(collection_name: str) -> Optional[str]:
|
| 139 |
+
"""Get embedding model for a specific collection name."""
|
| 140 |
+
collections = load_collections_mapping()
|
| 141 |
+
|
| 142 |
+
if collection_name in collections:
|
| 143 |
+
return collections[collection_name]["model"]
|
| 144 |
+
|
| 145 |
+
# Try to infer from collection name patterns
|
| 146 |
+
if "modernbert" in collection_name.lower():
|
| 147 |
+
return "Akryl/modernbert-embed-base-akryl-matryoshka"
|
| 148 |
+
elif "minilm" in collection_name.lower():
|
| 149 |
+
return "sentence-transformers/all-MiniLM-L6-v2"
|
| 150 |
+
elif "mpnet" in collection_name.lower():
|
| 151 |
+
return "sentence-transformers/all-mpnet-base-v2"
|
| 152 |
+
elif "bge" in collection_name.lower():
|
| 153 |
+
return "BAAI/bge-m3"
|
| 154 |
+
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_collection_info(collection_name: str) -> Dict[str, str]:
|
| 159 |
+
"""Get full collection information including model and description."""
|
| 160 |
+
collections = load_collections_mapping()
|
| 161 |
+
|
| 162 |
+
if collection_name in collections:
|
| 163 |
+
return collections[collection_name]
|
| 164 |
+
|
| 165 |
+
# Return inferred info for unknown collections
|
| 166 |
+
model = get_embedding_model_for_collection(collection_name)
|
| 167 |
+
return {
|
| 168 |
+
"model": model or "unknown",
|
| 169 |
+
"description": f"Auto-inferred collection: {collection_name}"
|
| 170 |
+
}
|
src/config/settings.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audit QA Configuration
|
| 2 |
+
# Converted from model_params.cfg to YAML format
|
| 3 |
+
|
| 4 |
+
qdrant:
|
| 5 |
+
# url: "http://10.1.4.192:8803"`
|
| 6 |
+
url: "https://2c6d0136-b6ca-4400-bac5-1703f58abc43.europe-west3-0.gcp.cloud.qdrant.io"
|
| 7 |
+
collection_name: "docling"
|
| 8 |
+
prefer_grpc: true
|
| 9 |
+
api_key: "${QDRANT_API_KEY}" # Load from environment variable
|
| 10 |
+
|
| 11 |
+
retriever:
|
| 12 |
+
model: "BAAI/bge-m3"
|
| 13 |
+
normalize: true
|
| 14 |
+
top_k: 20
|
| 15 |
+
|
| 16 |
+
retrieval:
|
| 17 |
+
use_reranking: true
|
| 18 |
+
reranker_model: "BAAI/bge-reranker-v2-m3"
|
| 19 |
+
reranker_top_k: 5
|
| 20 |
+
|
| 21 |
+
ranker:
|
| 22 |
+
model: "BAAI/bge-reranker-v2-m3"
|
| 23 |
+
top_k: 5
|
| 24 |
+
|
| 25 |
+
bm25:
|
| 26 |
+
top_k: 20
|
| 27 |
+
|
| 28 |
+
hybrid:
|
| 29 |
+
default_mode: "vector_only" # Options: vector_only, sparse_only, hybrid
|
| 30 |
+
default_alpha: 0.5 # Weight for vector scores (0.5 = equal weight)
|
| 31 |
+
|
| 32 |
+
reader:
|
| 33 |
+
default_type: "OPENAI"
|
| 34 |
+
max_tokens: 768
|
| 35 |
+
|
| 36 |
+
# Different LLM provider configurations
|
| 37 |
+
INF_PROVIDERS:
|
| 38 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 39 |
+
provider: "nebius"
|
| 40 |
+
|
| 41 |
+
# Not working
|
| 42 |
+
NVIDIA:
|
| 43 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 44 |
+
endpoint: "https://huggingface.co/api/integrations/dgx/v1"
|
| 45 |
+
|
| 46 |
+
# Not working
|
| 47 |
+
DEDICATED:
|
| 48 |
+
model: "meta-llama/Llama-3.1-8B-Instruct"
|
| 49 |
+
endpoint: "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
|
| 50 |
+
|
| 51 |
+
MISTRAL:
|
| 52 |
+
model: "mistral-medium-latest"
|
| 53 |
+
|
| 54 |
+
OPENAI:
|
| 55 |
+
model: "gpt-4o-mini"
|
| 56 |
+
|
| 57 |
+
OLLAMA:
|
| 58 |
+
model: "mistral-small3.1:24b-instruct-2503-q8_0"
|
| 59 |
+
base_url: "http://10.1.4.192:11434/"
|
| 60 |
+
temperature: 0.8
|
| 61 |
+
num_predict: 256
|
| 62 |
+
|
| 63 |
+
OPENROUTER:
|
| 64 |
+
model: "moonshotai/kimi-k2:free"
|
| 65 |
+
base_url: "https://openrouter.ai/api/v1"
|
| 66 |
+
temperature: 0.7
|
| 67 |
+
max_tokens: 1000
|
| 68 |
+
# site_url: "https://your-site.com" # optional, for OpenRouter ranking
|
| 69 |
+
# site_name: "Your Site Name" # optional, for OpenRouter ranking
|
| 70 |
+
|
| 71 |
+
app:
|
| 72 |
+
dropdown_default: "Annual Consolidated OAG 2024"
|
| 73 |
+
|
| 74 |
+
# File paths
|
| 75 |
+
paths:
|
| 76 |
+
chunks_file: "reports/docling_chunks.json"
|
| 77 |
+
reports_dir: "reports"
|
| 78 |
+
|
| 79 |
+
# Feature toggles
|
| 80 |
+
features:
|
| 81 |
+
enable_session: true
|
| 82 |
+
enable_logging: true
|
| 83 |
+
|
| 84 |
+
# Logging and HuggingFace scheduler configuration
|
| 85 |
+
logging:
|
| 86 |
+
json_dataset_dir: "json_dataset"
|
| 87 |
+
huggingface:
|
| 88 |
+
repo_id: "GIZ/spaces_logs"
|
| 89 |
+
repo_type: "dataset"
|
| 90 |
+
folder_path: "json_dataset"
|
| 91 |
+
path_in_repo: "audit_chatbot"
|
| 92 |
+
token_env_var: "SPACES_LOG"
|
src/llm/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM adapters and utilities."""
|
| 2 |
+
|
| 3 |
+
from .adapters import LLMRegistry, get_llm_client
|
| 4 |
+
from .templates import get_message_template, PromptTemplate, create_audit_prompt
|
| 5 |
+
|
| 6 |
+
__all__ = ["LLMRegistry", "get_llm_client", "get_message_template", "PromptTemplate", "create_audit_prompt"]
|
src/llm/adapters.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM client adapters for different providers."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Any, List, Optional, Union
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
# LangChain imports
|
| 8 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
| 9 |
+
from langchain_openai.chat_models import ChatOpenAI
|
| 10 |
+
from langchain_ollama import ChatOllama
|
| 11 |
+
|
| 12 |
+
# Legacy client dependencies
|
| 13 |
+
from huggingface_hub import InferenceClient
|
| 14 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
| 15 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
| 16 |
+
from langchain_community.chat_models.huggingface import ChatHuggingFace
|
| 17 |
+
|
| 18 |
+
# Configuration loader
|
| 19 |
+
from ..config.loader import load_config
|
| 20 |
+
|
| 21 |
+
# Load configuration once at module level
|
| 22 |
+
_config = load_config()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Legacy client factory functions (inlined from auditqa_old.reader)
|
| 26 |
+
def _create_inf_provider_client():
|
| 27 |
+
"""Create INF_PROVIDERS client."""
|
| 28 |
+
reader_config = _config.get("reader", {})
|
| 29 |
+
inf_config = reader_config.get("INF_PROVIDERS", {})
|
| 30 |
+
|
| 31 |
+
api_key = inf_config.get("api_key")
|
| 32 |
+
if not api_key:
|
| 33 |
+
raise ValueError("INF_PROVIDERS api_key not found in configuration")
|
| 34 |
+
|
| 35 |
+
provider = inf_config.get("provider")
|
| 36 |
+
if not provider:
|
| 37 |
+
raise ValueError("INF_PROVIDERS provider not found in configuration")
|
| 38 |
+
|
| 39 |
+
return InferenceClient(
|
| 40 |
+
provider=provider,
|
| 41 |
+
api_key=api_key,
|
| 42 |
+
bill_to="GIZ",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _create_nvidia_client():
|
| 47 |
+
"""Create NVIDIA client."""
|
| 48 |
+
reader_config = _config.get("reader", {})
|
| 49 |
+
nvidia_config = reader_config.get("NVIDIA", {})
|
| 50 |
+
|
| 51 |
+
api_key = nvidia_config.get("api_key")
|
| 52 |
+
if not api_key:
|
| 53 |
+
raise ValueError("NVIDIA api_key not found in configuration")
|
| 54 |
+
|
| 55 |
+
endpoint = nvidia_config.get("endpoint")
|
| 56 |
+
if not endpoint:
|
| 57 |
+
raise ValueError("NVIDIA endpoint not found in configuration")
|
| 58 |
+
|
| 59 |
+
return InferenceClient(
|
| 60 |
+
base_url=endpoint,
|
| 61 |
+
api_key=api_key
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _create_serverless_client():
|
| 66 |
+
"""Create serverless API client."""
|
| 67 |
+
reader_config = _config.get("reader", {})
|
| 68 |
+
serverless_config = reader_config.get("SERVERLESS", {})
|
| 69 |
+
|
| 70 |
+
api_key = serverless_config.get("api_key")
|
| 71 |
+
if not api_key:
|
| 72 |
+
raise ValueError("SERVERLESS api_key not found in configuration")
|
| 73 |
+
|
| 74 |
+
model_id = serverless_config.get("model", "meta-llama/Meta-Llama-3-8B-Instruct")
|
| 75 |
+
|
| 76 |
+
return InferenceClient(
|
| 77 |
+
model=model_id,
|
| 78 |
+
api_key=api_key,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _create_dedicated_endpoint_client():
|
| 83 |
+
"""Create dedicated endpoint client."""
|
| 84 |
+
reader_config = _config.get("reader", {})
|
| 85 |
+
dedicated_config = reader_config.get("DEDICATED", {})
|
| 86 |
+
|
| 87 |
+
api_key = dedicated_config.get("api_key")
|
| 88 |
+
if not api_key:
|
| 89 |
+
raise ValueError("DEDICATED api_key not found in configuration")
|
| 90 |
+
|
| 91 |
+
endpoint = dedicated_config.get("endpoint")
|
| 92 |
+
if not endpoint:
|
| 93 |
+
raise ValueError("DEDICATED endpoint not found in configuration")
|
| 94 |
+
|
| 95 |
+
max_tokens = dedicated_config.get("max_tokens", 768)
|
| 96 |
+
|
| 97 |
+
# Set up the streaming callback handler
|
| 98 |
+
callback = StreamingStdOutCallbackHandler()
|
| 99 |
+
|
| 100 |
+
# Initialize the HuggingFaceEndpoint with streaming enabled
|
| 101 |
+
llm_qa = HuggingFaceEndpoint(
|
| 102 |
+
endpoint_url=endpoint,
|
| 103 |
+
max_new_tokens=int(max_tokens),
|
| 104 |
+
repetition_penalty=1.03,
|
| 105 |
+
timeout=70,
|
| 106 |
+
huggingfacehub_api_token=api_key,
|
| 107 |
+
streaming=True,
|
| 108 |
+
callbacks=[callback]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Create a ChatHuggingFace instance with the streaming-enabled endpoint
|
| 112 |
+
return ChatHuggingFace(llm=llm_qa)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class LLMResponse:
|
| 117 |
+
"""Standardized LLM response format."""
|
| 118 |
+
content: str
|
| 119 |
+
model: str
|
| 120 |
+
provider: str
|
| 121 |
+
metadata: Dict[str, Any] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BaseLLMAdapter(ABC):
|
| 125 |
+
"""Base class for LLM adapters."""
|
| 126 |
+
|
| 127 |
+
def __init__(self, config: Dict[str, Any]):
|
| 128 |
+
self.config = config
|
| 129 |
+
|
| 130 |
+
@abstractmethod
|
| 131 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 132 |
+
"""Generate response from messages."""
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 137 |
+
"""Generate streaming response from messages."""
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class MistralAdapter(BaseLLMAdapter):
|
| 142 |
+
"""Adapter for Mistral AI models."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, config: Dict[str, Any]):
|
| 145 |
+
super().__init__(config)
|
| 146 |
+
self.model = ChatMistralAI(
|
| 147 |
+
model=config.get("model", "mistral-medium-latest")
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 151 |
+
"""Generate response using Mistral."""
|
| 152 |
+
response = self.model.invoke(messages)
|
| 153 |
+
|
| 154 |
+
return LLMResponse(
|
| 155 |
+
content=response.content,
|
| 156 |
+
model=self.config.get("model", "mistral-medium-latest"),
|
| 157 |
+
provider="mistral",
|
| 158 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 162 |
+
"""Generate streaming response using Mistral."""
|
| 163 |
+
for chunk in self.model.stream(messages):
|
| 164 |
+
if chunk.content:
|
| 165 |
+
yield chunk.content
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class OpenAIAdapter(BaseLLMAdapter):
|
| 169 |
+
"""Adapter for OpenAI models."""
|
| 170 |
+
|
| 171 |
+
def __init__(self, config: Dict[str, Any]):
|
| 172 |
+
super().__init__(config)
|
| 173 |
+
self.model = ChatOpenAI(
|
| 174 |
+
model=config.get("model", "gpt-4o-mini")
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 178 |
+
"""Generate response using OpenAI."""
|
| 179 |
+
response = self.model.invoke(messages)
|
| 180 |
+
|
| 181 |
+
return LLMResponse(
|
| 182 |
+
content=response.content,
|
| 183 |
+
model=self.config.get("model", "gpt-4o-mini"),
|
| 184 |
+
provider="openai",
|
| 185 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 189 |
+
"""Generate streaming response using OpenAI."""
|
| 190 |
+
for chunk in self.model.stream(messages):
|
| 191 |
+
if chunk.content:
|
| 192 |
+
yield chunk.content
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class OllamaAdapter(BaseLLMAdapter):
|
| 196 |
+
"""Adapter for Ollama models."""
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: Dict[str, Any]):
|
| 199 |
+
super().__init__(config)
|
| 200 |
+
self.model = ChatOllama(
|
| 201 |
+
model=config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
|
| 202 |
+
base_url=config.get("base_url", "http://localhost:11434/"),
|
| 203 |
+
temperature=config.get("temperature", 0.8),
|
| 204 |
+
num_predict=config.get("num_predict", 256)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 208 |
+
"""Generate response using Ollama."""
|
| 209 |
+
response = self.model.invoke(messages)
|
| 210 |
+
|
| 211 |
+
return LLMResponse(
|
| 212 |
+
content=response.content,
|
| 213 |
+
model=self.config.get("model", "mistral-small3.1:24b-instruct-2503-q8_0"),
|
| 214 |
+
provider="ollama",
|
| 215 |
+
metadata={}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 219 |
+
"""Generate streaming response using Ollama."""
|
| 220 |
+
for chunk in self.model.stream(messages):
|
| 221 |
+
if chunk.content:
|
| 222 |
+
yield chunk.content
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class OpenRouterAdapter(BaseLLMAdapter):
|
| 226 |
+
"""Adapter for OpenRouter models."""
|
| 227 |
+
|
| 228 |
+
def __init__(self, config: Dict[str, Any]):
|
| 229 |
+
super().__init__(config)
|
| 230 |
+
|
| 231 |
+
# Prepare custom headers for OpenRouter (optional)
|
| 232 |
+
headers = {}
|
| 233 |
+
if config.get("site_url"):
|
| 234 |
+
headers["HTTP-Referer"] = config["site_url"]
|
| 235 |
+
if config.get("site_name"):
|
| 236 |
+
headers["X-Title"] = config["site_name"]
|
| 237 |
+
|
| 238 |
+
# Initialize ChatOpenAI with OpenRouter configuration
|
| 239 |
+
self.model = ChatOpenAI(
|
| 240 |
+
model=config.get("model", "openai/gpt-3.5-turbo"),
|
| 241 |
+
api_key=config.get("api_key"),
|
| 242 |
+
base_url=config.get("base_url", "https://openrouter.ai/api/v1"),
|
| 243 |
+
default_headers= headers if headers else {},
|
| 244 |
+
temperature=config.get("temperature", 0.7),
|
| 245 |
+
max_tokens=config.get("max_tokens", 1000)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 249 |
+
"""Generate response using OpenRouter."""
|
| 250 |
+
response = self.model.invoke(messages)
|
| 251 |
+
|
| 252 |
+
return LLMResponse(
|
| 253 |
+
content=response.content,
|
| 254 |
+
model=self.config.get("model", "openai/gpt-3.5-turbo"),
|
| 255 |
+
provider="openrouter",
|
| 256 |
+
metadata={"usage": getattr(response, 'usage_metadata', {})}
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 260 |
+
"""Generate streaming response using OpenRouter."""
|
| 261 |
+
for chunk in self.model.stream(messages):
|
| 262 |
+
if chunk.content:
|
| 263 |
+
yield chunk.content
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class LegacyAdapter(BaseLLMAdapter):
|
| 267 |
+
"""Adapter for legacy LLM clients (INF_PROVIDERS, NVIDIA, etc.)."""
|
| 268 |
+
|
| 269 |
+
def __init__(self, config: Dict[str, Any], client_type: str):
|
| 270 |
+
super().__init__(config)
|
| 271 |
+
self.client_type = client_type
|
| 272 |
+
self.client = self._create_client()
|
| 273 |
+
|
| 274 |
+
def _create_client(self):
|
| 275 |
+
"""Create legacy client based on type."""
|
| 276 |
+
if self.client_type == "INF_PROVIDERS":
|
| 277 |
+
return _create_inf_provider_client()
|
| 278 |
+
elif self.client_type == "NVIDIA":
|
| 279 |
+
return _create_nvidia_client()
|
| 280 |
+
elif self.client_type == "DEDICATED":
|
| 281 |
+
return _create_dedicated_endpoint_client()
|
| 282 |
+
else: # SERVERLESS
|
| 283 |
+
return _create_serverless_client()
|
| 284 |
+
|
| 285 |
+
def generate(self, messages: List[Dict[str, str]], **kwargs) -> LLMResponse:
|
| 286 |
+
"""Generate response using legacy client."""
|
| 287 |
+
max_tokens = kwargs.get('max_tokens', self.config.get('max_tokens', 768))
|
| 288 |
+
|
| 289 |
+
if self.client_type == "INF_PROVIDERS":
|
| 290 |
+
response = self.client.chat.completions.create(
|
| 291 |
+
model=self.config.get("model"),
|
| 292 |
+
messages=messages,
|
| 293 |
+
max_tokens=max_tokens
|
| 294 |
+
)
|
| 295 |
+
content = response.choices[0].message.content
|
| 296 |
+
|
| 297 |
+
elif self.client_type == "NVIDIA":
|
| 298 |
+
response = self.client.chat_completion(
|
| 299 |
+
model=self.config.get("model"),
|
| 300 |
+
messages=messages,
|
| 301 |
+
max_tokens=max_tokens
|
| 302 |
+
)
|
| 303 |
+
content = response.choices[0].message.content
|
| 304 |
+
|
| 305 |
+
else: # DEDICATED or SERVERLESS
|
| 306 |
+
response = self.client.chat_completion(
|
| 307 |
+
messages=messages,
|
| 308 |
+
max_tokens=max_tokens
|
| 309 |
+
)
|
| 310 |
+
content = response.choices[0].message.content
|
| 311 |
+
|
| 312 |
+
return LLMResponse(
|
| 313 |
+
content=content,
|
| 314 |
+
model=self.config.get("model", "unknown"),
|
| 315 |
+
provider=self.client_type.lower(),
|
| 316 |
+
metadata={}
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def stream_generate(self, messages: List[Dict[str, str]], **kwargs):
|
| 320 |
+
"""Generate streaming response using legacy client."""
|
| 321 |
+
# Legacy clients may not support streaming in the same way
|
| 322 |
+
# This is a simplified implementation
|
| 323 |
+
response = self.generate(messages, **kwargs)
|
| 324 |
+
words = response.content.split()
|
| 325 |
+
for word in words:
|
| 326 |
+
yield word + " "
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class LLMRegistry:
|
| 330 |
+
"""Registry for managing different LLM adapters."""
|
| 331 |
+
|
| 332 |
+
def __init__(self):
|
| 333 |
+
self.adapters = {}
|
| 334 |
+
self.adapter_configs = {}
|
| 335 |
+
|
| 336 |
+
def register_adapter(self, name: str, adapter_class: type, config: Dict[str, Any]):
|
| 337 |
+
"""Register an LLM adapter (lazy instantiation)."""
|
| 338 |
+
self.adapter_configs[name] = (adapter_class, config)
|
| 339 |
+
|
| 340 |
+
def get_adapter(self, name: str) -> BaseLLMAdapter:
|
| 341 |
+
"""Get an LLM adapter by name (lazy instantiation)."""
|
| 342 |
+
if name not in self.adapter_configs:
|
| 343 |
+
raise ValueError(f"Unknown LLM adapter: {name}")
|
| 344 |
+
|
| 345 |
+
# Lazy instantiation - only create when needed
|
| 346 |
+
if name not in self.adapters:
|
| 347 |
+
adapter_class, config = self.adapter_configs[name]
|
| 348 |
+
self.adapters[name] = adapter_class(config)
|
| 349 |
+
|
| 350 |
+
return self.adapters[name]
|
| 351 |
+
|
| 352 |
+
def list_adapters(self) -> List[str]:
|
| 353 |
+
"""List available adapter names."""
|
| 354 |
+
return list(self.adapter_configs.keys())
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def create_llm_registry(config: Dict[str, Any]) -> LLMRegistry:
|
| 358 |
+
"""
|
| 359 |
+
Create and populate LLM registry from configuration.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
config: Configuration dictionary
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Populated LLMRegistry
|
| 366 |
+
"""
|
| 367 |
+
registry = LLMRegistry()
|
| 368 |
+
reader_config = config.get("reader", {})
|
| 369 |
+
|
| 370 |
+
# Register simple adapters
|
| 371 |
+
if "MISTRAL" in reader_config:
|
| 372 |
+
registry.register_adapter("mistral", MistralAdapter, reader_config["MISTRAL"])
|
| 373 |
+
|
| 374 |
+
if "OPENAI" in reader_config:
|
| 375 |
+
registry.register_adapter("openai", OpenAIAdapter, reader_config["OPENAI"])
|
| 376 |
+
|
| 377 |
+
if "OLLAMA" in reader_config:
|
| 378 |
+
registry.register_adapter("ollama", OllamaAdapter, reader_config["OLLAMA"])
|
| 379 |
+
|
| 380 |
+
if "OPENROUTER" in reader_config:
|
| 381 |
+
registry.register_adapter("openrouter", OpenRouterAdapter, reader_config["OPENROUTER"])
|
| 382 |
+
|
| 383 |
+
# Register legacy adapters
|
| 384 |
+
# legacy_types = ["INF_PROVIDERS", "NVIDIA", "DEDICATED"]
|
| 385 |
+
legacy_types = ["INF_PROVIDERS"]
|
| 386 |
+
for legacy_type in legacy_types:
|
| 387 |
+
if legacy_type in reader_config:
|
| 388 |
+
registry.register_adapter(
|
| 389 |
+
legacy_type.lower(),
|
| 390 |
+
lambda cfg, lt=legacy_type: LegacyAdapter(cfg, lt),
|
| 391 |
+
reader_config[legacy_type]
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return registry
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def get_llm_client(provider: str, config: Dict[str, Any]) -> BaseLLMAdapter:
|
| 398 |
+
"""
|
| 399 |
+
Get LLM client for specified provider.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
provider: Provider name (mistral, openai, ollama, etc.)
|
| 403 |
+
config: Configuration dictionary
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
LLM adapter instance
|
| 407 |
+
"""
|
| 408 |
+
registry = create_llm_registry(config)
|
| 409 |
+
return registry.get_adapter(provider)
|
src/llm/templates.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM prompt templates and message formatting utilities."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict, Any, Union
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class PromptTemplate:
|
| 10 |
+
"""Template for managing prompts with variables."""
|
| 11 |
+
|
| 12 |
+
system_prompt: str
|
| 13 |
+
user_prompt_template: str
|
| 14 |
+
|
| 15 |
+
def format(self, **kwargs) -> tuple:
|
| 16 |
+
"""Format the template with provided variables."""
|
| 17 |
+
formatted_user = self.user_prompt_template.format(**kwargs)
|
| 18 |
+
return self.system_prompt, formatted_user
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Default system prompt for audit Q&A
|
| 22 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT = """
|
| 23 |
+
You are AuditQ&A, an AI Assistant for audit reports. Answer questions directly and factually based on the provided context.
|
| 24 |
+
|
| 25 |
+
Guidelines:
|
| 26 |
+
- Answer directly and concisely (2-3 sentences maximum)
|
| 27 |
+
- Use specific facts and numbers from the context
|
| 28 |
+
- Cite sources using [Doc i] format
|
| 29 |
+
- Be factual, not opinionated
|
| 30 |
+
- Avoid phrases like "From my point of view", "I think", "It seems"
|
| 31 |
+
|
| 32 |
+
Examples:
|
| 33 |
+
|
| 34 |
+
Query: "What challenges arise from contradictory PDM implementation guidelines?"
|
| 35 |
+
Context: [Retrieved documents about PDM guidelines contradictions]
|
| 36 |
+
Answer: "Contradictory PDM implementation guidelines cause challenges during implementation, as entities receive numerous and often conflicting directives from different authorities. For example, guidelines on transfer of funds to PDM SACCOs differ between the PDM Secretariat and PSST, and there are conflicting directives on fund diversion from various authorities."
|
| 37 |
+
|
| 38 |
+
Query: "What was the supplementary funding obtained for the wage budget?"
|
| 39 |
+
Context: [Retrieved documents about wage budget funding]
|
| 40 |
+
Answer: "The supplementary funding obtained for the wage budget was UGX.2,208,040,656."
|
| 41 |
+
|
| 42 |
+
Now answer the following question based on the provided context:
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
# Default user prompt template
|
| 46 |
+
DEFAULT_USER_PROMPT_TEMPLATE = """Passages:
|
| 47 |
+
{context}
|
| 48 |
+
-----------------------
|
| 49 |
+
Question: {question} - Explained to audit expert
|
| 50 |
+
Answer in english with the passages citations:
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def create_audit_prompt(context_list: List[str], query: str) -> List[Dict[str, str]]:
|
| 55 |
+
"""
|
| 56 |
+
Create audit Q&A prompt messages from context and query.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
context_list: List of context passages
|
| 60 |
+
query: User query
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
List of message dictionaries for LLM
|
| 64 |
+
"""
|
| 65 |
+
# Join context passages with numbering
|
| 66 |
+
numbered_context = []
|
| 67 |
+
for i, passage in enumerate(context_list, 1):
|
| 68 |
+
numbered_context.append(f"Doc {i}: {passage}")
|
| 69 |
+
|
| 70 |
+
context_str = "\n\n".join(numbered_context)
|
| 71 |
+
|
| 72 |
+
# Format user prompt
|
| 73 |
+
user_prompt = DEFAULT_USER_PROMPT_TEMPLATE.format(
|
| 74 |
+
context=context_str,
|
| 75 |
+
question=query
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Return as message format
|
| 79 |
+
messages = [
|
| 80 |
+
{"role": "system", "content": DEFAULT_AUDIT_SYSTEM_PROMPT},
|
| 81 |
+
{"role": "user", "content": user_prompt}
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
return messages
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_message_template(
|
| 88 |
+
provider_type: str,
|
| 89 |
+
system_prompt: str,
|
| 90 |
+
user_prompt: str
|
| 91 |
+
) -> List[Union[Dict[str, str], SystemMessage, HumanMessage]]:
|
| 92 |
+
"""
|
| 93 |
+
Get message template based on LLM provider type.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
provider_type: Type of LLM provider
|
| 97 |
+
system_prompt: System prompt content
|
| 98 |
+
user_prompt: User prompt content
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
List of messages in the appropriate format for the provider
|
| 102 |
+
"""
|
| 103 |
+
provider_type = provider_type.upper()
|
| 104 |
+
|
| 105 |
+
if provider_type in ['NVIDIA', 'INF_PROVIDERS', 'MISTRAL', 'OPENAI', 'OPENROUTER']:
|
| 106 |
+
# Dictionary format for API-based providers
|
| 107 |
+
messages = [
|
| 108 |
+
{"role": "system", "content": system_prompt},
|
| 109 |
+
{"role": "user", "content": user_prompt}
|
| 110 |
+
]
|
| 111 |
+
elif provider_type in ['DEDICATED', 'SERVERLESS', 'OLLAMA']:
|
| 112 |
+
# LangChain message objects for local/dedicated providers
|
| 113 |
+
messages = [
|
| 114 |
+
SystemMessage(content=system_prompt),
|
| 115 |
+
HumanMessage(content=user_prompt)
|
| 116 |
+
]
|
| 117 |
+
else:
|
| 118 |
+
# Default to dictionary format
|
| 119 |
+
messages = [
|
| 120 |
+
{"role": "system", "content": system_prompt},
|
| 121 |
+
{"role": "user", "content": user_prompt}
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
return messages
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def create_custom_prompt_template(
|
| 128 |
+
system_prompt: str,
|
| 129 |
+
user_template: str
|
| 130 |
+
) -> PromptTemplate:
|
| 131 |
+
"""
|
| 132 |
+
Create a custom prompt template.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
system_prompt: System prompt content
|
| 136 |
+
user_template: User prompt template with placeholders
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
PromptTemplate instance
|
| 140 |
+
"""
|
| 141 |
+
return PromptTemplate(
|
| 142 |
+
system_prompt=system_prompt,
|
| 143 |
+
user_prompt_template=user_template
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def create_evaluation_prompt(context_list: List[str], query: str, expected_answer: str) -> List[Dict[str, str]]:
|
| 148 |
+
"""
|
| 149 |
+
Create prompt for evaluation purposes with expected answer.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
context_list: List of context passages
|
| 153 |
+
query: User query
|
| 154 |
+
expected_answer: Expected/ground truth answer
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List of message dictionaries for evaluation
|
| 158 |
+
"""
|
| 159 |
+
# Join context passages
|
| 160 |
+
context_str = "\n\n".join([f"Doc {i}: {passage}" for i, passage in enumerate(context_list, 1)])
|
| 161 |
+
|
| 162 |
+
evaluation_system_prompt = """
|
| 163 |
+
You are an evaluation assistant. Given context passages, a question, and an expected answer,
|
| 164 |
+
evaluate how well the provided context supports answering the question accurately.
|
| 165 |
+
|
| 166 |
+
Provide your evaluation focusing on:
|
| 167 |
+
1. Relevance of the context to the question
|
| 168 |
+
2. Completeness of information needed to answer
|
| 169 |
+
3. Quality and accuracy of supporting details
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
user_prompt = f"""Context Passages:
|
| 173 |
+
{context_str}
|
| 174 |
+
|
| 175 |
+
Question: {query}
|
| 176 |
+
Expected Answer: {expected_answer}
|
| 177 |
+
|
| 178 |
+
Evaluation:"""
|
| 179 |
+
|
| 180 |
+
return [
|
| 181 |
+
{"role": "system", "content": evaluation_system_prompt},
|
| 182 |
+
{"role": "user", "content": user_prompt}
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_prompt_variants() -> Dict[str, PromptTemplate]:
|
| 187 |
+
"""
|
| 188 |
+
Get different prompt template variants for testing.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Dictionary of named prompt templates
|
| 192 |
+
"""
|
| 193 |
+
variants = {
|
| 194 |
+
"standard": create_custom_prompt_template(
|
| 195 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT,
|
| 196 |
+
DEFAULT_USER_PROMPT_TEMPLATE
|
| 197 |
+
),
|
| 198 |
+
|
| 199 |
+
"concise": create_custom_prompt_template(
|
| 200 |
+
"""You are an audit report AI assistant. Provide clear, concise answers based on the given context passages. Always cite sources using [Doc i] format.""",
|
| 201 |
+
"""Context:\n{context}\n\nQuestion: {question}\nAnswer:"""
|
| 202 |
+
),
|
| 203 |
+
|
| 204 |
+
"detailed": create_custom_prompt_template(
|
| 205 |
+
DEFAULT_AUDIT_SYSTEM_PROMPT + """\n\nAdditional Instructions:
|
| 206 |
+
- Provide detailed explanations with specific examples
|
| 207 |
+
- Include relevant numbers, dates, and financial figures when available
|
| 208 |
+
- Structure your response with clear headings when appropriate
|
| 209 |
+
- Explain the significance of findings in the context of governance and accountability""",
|
| 210 |
+
DEFAULT_USER_PROMPT_TEMPLATE
|
| 211 |
+
)
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return variants
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Backward compatibility function
|
| 218 |
+
def format_context_with_citations(context_list: List[str]) -> str:
|
| 219 |
+
"""
|
| 220 |
+
Format context list with document citations.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
context_list: List of context passages
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Formatted context string with citations
|
| 227 |
+
"""
|
| 228 |
+
formatted_passages = []
|
| 229 |
+
for i, passage in enumerate(context_list, 1):
|
| 230 |
+
formatted_passages.append(f"Doc {i}: {passage}")
|
| 231 |
+
|
| 232 |
+
return "\n\n".join(formatted_passages)
|
src/loader.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading utilities for chunks and JSON files."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from langchain.docstore.document import Document
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_json(filepath: Path | str) -> List[Dict[str, Any]]:
|
| 10 |
+
"""
|
| 11 |
+
Load JSON data from file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
filepath: Path to JSON file
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
List of dictionaries containing the JSON data
|
| 18 |
+
"""
|
| 19 |
+
filepath = Path(filepath)
|
| 20 |
+
|
| 21 |
+
if not filepath.exists():
|
| 22 |
+
raise FileNotFoundError(f"JSON file not found: {filepath}")
|
| 23 |
+
|
| 24 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 25 |
+
data = json.load(f)
|
| 26 |
+
|
| 27 |
+
return data
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def open_file(filepath: Path | str) -> str:
|
| 31 |
+
"""
|
| 32 |
+
Open and read a text file.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
filepath: Path to text file
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
File contents as string
|
| 39 |
+
"""
|
| 40 |
+
filepath = Path(filepath)
|
| 41 |
+
|
| 42 |
+
if not filepath.exists():
|
| 43 |
+
raise FileNotFoundError(f"File not found: {filepath}")
|
| 44 |
+
|
| 45 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 46 |
+
content = f.read()
|
| 47 |
+
|
| 48 |
+
return content
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_chunks(chunks_file: Path | str = None) -> List[Dict[str, Any]]:
|
| 52 |
+
"""
|
| 53 |
+
Load document chunks from JSON file.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
chunks_file: Path to chunks JSON file. If None, uses default path.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
List of chunk dictionaries
|
| 60 |
+
"""
|
| 61 |
+
if chunks_file is None:
|
| 62 |
+
chunks_file = Path("reports/docling_chunks.json")
|
| 63 |
+
|
| 64 |
+
return load_json(chunks_file)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def chunks_to_documents(chunks: List[Dict[str, Any]]) -> List[Document]:
|
| 68 |
+
"""
|
| 69 |
+
Convert chunk dictionaries to LangChain Document objects.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
chunks: List of chunk dictionaries
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
List of Document objects
|
| 76 |
+
"""
|
| 77 |
+
documents = []
|
| 78 |
+
|
| 79 |
+
for chunk in chunks:
|
| 80 |
+
doc = Document(
|
| 81 |
+
page_content=chunk.get("content", ""),
|
| 82 |
+
metadata=chunk.get("metadata", {})
|
| 83 |
+
)
|
| 84 |
+
documents.append(doc)
|
| 85 |
+
|
| 86 |
+
return documents
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def validate_chunks(chunks: List[Dict[str, Any]]) -> bool:
|
| 90 |
+
"""
|
| 91 |
+
Validate that chunks have required fields.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
chunks: List of chunk dictionaries
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
True if valid, raises ValueError if invalid
|
| 98 |
+
"""
|
| 99 |
+
required_fields = ["content", "metadata"]
|
| 100 |
+
|
| 101 |
+
for i, chunk in enumerate(chunks):
|
| 102 |
+
for field in required_fields:
|
| 103 |
+
if field not in chunk:
|
| 104 |
+
raise ValueError(f"Chunk {i} missing required field: {field}")
|
| 105 |
+
|
| 106 |
+
# Validate metadata has required fields
|
| 107 |
+
metadata = chunk["metadata"]
|
| 108 |
+
if not isinstance(metadata, dict):
|
| 109 |
+
raise ValueError(f"Chunk {i} metadata must be a dictionary")
|
| 110 |
+
|
| 111 |
+
# Check for common metadata fields
|
| 112 |
+
if "filename" not in metadata:
|
| 113 |
+
raise ValueError(f"Chunk {i} metadata missing 'filename' field")
|
| 114 |
+
|
| 115 |
+
return True
|
src/logging.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities (placeholder for legacy compatibility)."""
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from uuid import uuid4
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from threading import Lock
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
|
| 10 |
+
from .config import load_config
|
| 11 |
+
|
| 12 |
+
def save_logs(
|
| 13 |
+
scheduler=None,
|
| 14 |
+
json_dataset_path: Path = None,
|
| 15 |
+
logs_data: Dict[str, Any] = None,
|
| 16 |
+
feedback: str = None
|
| 17 |
+
) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Save logs (placeholder for legacy compatibility).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
scheduler: HuggingFace scheduler (not used in refactored version)
|
| 23 |
+
json_dataset_path: Path to JSON dataset
|
| 24 |
+
logs_data: Log data dictionary
|
| 25 |
+
feedback: User feedback
|
| 26 |
+
|
| 27 |
+
Note:
|
| 28 |
+
This is a placeholder function for backward compatibility.
|
| 29 |
+
In the refactored version, logging would be handled differently.
|
| 30 |
+
"""
|
| 31 |
+
if not is_logging_enabled():
|
| 32 |
+
return
|
| 33 |
+
try:
|
| 34 |
+
current_time = datetime.now().timestamp()
|
| 35 |
+
logs_data["time"] = str(current_time)
|
| 36 |
+
if feedback:
|
| 37 |
+
logs_data["feedback"] = feedback
|
| 38 |
+
logs_data["record_id"] = str(uuid4())
|
| 39 |
+
field_order = [
|
| 40 |
+
"record_id",
|
| 41 |
+
"session_id",
|
| 42 |
+
"time",
|
| 43 |
+
"session_duration_seconds",
|
| 44 |
+
"client_location",
|
| 45 |
+
"platform",
|
| 46 |
+
"system_prompt",
|
| 47 |
+
"sources",
|
| 48 |
+
"reports",
|
| 49 |
+
"subtype",
|
| 50 |
+
"year",
|
| 51 |
+
"question",
|
| 52 |
+
"retriever",
|
| 53 |
+
"endpoint_type",
|
| 54 |
+
"reader",
|
| 55 |
+
"docs",
|
| 56 |
+
"answer",
|
| 57 |
+
"feedback"
|
| 58 |
+
]
|
| 59 |
+
ordered_logs = {k: logs_data.get(k) for k in field_order if k in logs_data}
|
| 60 |
+
lock = getattr(scheduler, "lock", None)
|
| 61 |
+
if lock is None:
|
| 62 |
+
lock = Lock()
|
| 63 |
+
with lock:
|
| 64 |
+
with open(json_dataset_path, 'a') as f:
|
| 65 |
+
json.dump(ordered_logs, f)
|
| 66 |
+
f.write("\n")
|
| 67 |
+
logging.info("logging done")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.error(f"Error saving logs: {e}")
|
| 70 |
+
raise
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def setup_logging(log_level: str = "INFO", log_file: str = None) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Set up logging configuration.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
log_level: Logging level
|
| 79 |
+
log_file: Optional log file path
|
| 80 |
+
"""
|
| 81 |
+
if not is_logging_enabled():
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Configure logging
|
| 85 |
+
logging.basicConfig(
|
| 86 |
+
level=getattr(logging, log_level.upper()),
|
| 87 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 88 |
+
handlers=[
|
| 89 |
+
logging.StreamHandler(),
|
| 90 |
+
logging.FileHandler(log_file) if log_file else logging.NullHandler()
|
| 91 |
+
]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def log_query_response(
|
| 96 |
+
query: str,
|
| 97 |
+
response: str,
|
| 98 |
+
metadata: Dict[str, Any] = None
|
| 99 |
+
) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Log query and response for analysis.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
query: User query
|
| 105 |
+
response: System response
|
| 106 |
+
metadata: Additional metadata
|
| 107 |
+
"""
|
| 108 |
+
if not is_logging_enabled():
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
logger = logging.getLogger(__name__)
|
| 112 |
+
|
| 113 |
+
log_entry = {
|
| 114 |
+
"query": query,
|
| 115 |
+
"response_length": len(response),
|
| 116 |
+
"metadata": metadata or {}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
logger.info(f"Query processed: {log_entry}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def log_error(error: Exception, context: Dict[str, Any] = None) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Log error with context.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
error: Exception that occurred
|
| 128 |
+
context: Additional context information
|
| 129 |
+
"""
|
| 130 |
+
if not is_logging_enabled():
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
logger = logging.getLogger(__name__)
|
| 134 |
+
|
| 135 |
+
error_info = {
|
| 136 |
+
"error_type": type(error).__name__,
|
| 137 |
+
"error_message": str(error),
|
| 138 |
+
"context": context or {}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
logger.error(f"Error occurred: {error_info}")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def log_performance_metrics(
|
| 145 |
+
operation: str,
|
| 146 |
+
duration: float,
|
| 147 |
+
metadata: Dict[str, Any] = None
|
| 148 |
+
) -> None:
|
| 149 |
+
"""
|
| 150 |
+
Log performance metrics.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
operation: Name of the operation
|
| 154 |
+
duration: Duration in seconds
|
| 155 |
+
metadata: Additional metadata
|
| 156 |
+
"""
|
| 157 |
+
if not is_logging_enabled():
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
logger = logging.getLogger(__name__)
|
| 161 |
+
|
| 162 |
+
metrics = {
|
| 163 |
+
"operation": operation,
|
| 164 |
+
"duration_seconds": duration,
|
| 165 |
+
"metadata": metadata or {}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
logger.info(f"Performance metrics: {metrics}")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def is_session_enabled() -> bool:
|
| 172 |
+
"""
|
| 173 |
+
Returns True if session management is enabled, False otherwise.
|
| 174 |
+
Checks environment variable ENABLE_SESSION first, then config.
|
| 175 |
+
"""
|
| 176 |
+
env = os.getenv("ENABLE_SESSION")
|
| 177 |
+
if env is not None:
|
| 178 |
+
return env.lower() in ("1", "true", "yes", "on")
|
| 179 |
+
config = load_config()
|
| 180 |
+
return config.get("features", {}).get("enable_session", True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def is_logging_enabled() -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Returns True if logging is enabled, False otherwise.
|
| 186 |
+
Checks environment variable ENABLE_LOGGING first, then config.
|
| 187 |
+
"""
|
| 188 |
+
env = os.getenv("ENABLE_LOGGING")
|
| 189 |
+
if env is not None:
|
| 190 |
+
return env.lower() in ("1", "true", "yes", "on")
|
| 191 |
+
config = load_config()
|
| 192 |
+
return config.get("features", {}).get("enable_logging", True)
|
| 193 |
+
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main pipeline orchestrator for the Audit QA system."""
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
|
| 7 |
+
from langchain.docstore.document import Document
|
| 8 |
+
|
| 9 |
+
from .logging import log_error
|
| 10 |
+
from .llm.adapters import LLMRegistry
|
| 11 |
+
from .loader import chunks_to_documents
|
| 12 |
+
from .vectorstore import VectorStoreManager
|
| 13 |
+
from .retrieval.context import ContextRetriever
|
| 14 |
+
from .config.loader import get_embedding_model_for_collection
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class PipelineResult:
|
| 20 |
+
"""Result of pipeline execution."""
|
| 21 |
+
answer: str
|
| 22 |
+
sources: List[Document]
|
| 23 |
+
execution_time: float
|
| 24 |
+
metadata: Dict[str, Any]
|
| 25 |
+
query: str = "" # Add default value for query
|
| 26 |
+
|
| 27 |
+
def __post_init__(self):
|
| 28 |
+
"""Post-initialization processing."""
|
| 29 |
+
if not self.query:
|
| 30 |
+
self.query = "Unknown query"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class PipelineManager:
|
| 34 |
+
"""Main pipeline manager for the RAG system."""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: dict = None):
|
| 37 |
+
"""
|
| 38 |
+
Initialize the pipeline manager.
|
| 39 |
+
"""
|
| 40 |
+
self.config = config or {}
|
| 41 |
+
self.vectorstore_manager = None
|
| 42 |
+
self.context_retriever = None # Initialize as None
|
| 43 |
+
self.llm_client = None
|
| 44 |
+
self.report_service = None
|
| 45 |
+
self.chunks = None
|
| 46 |
+
|
| 47 |
+
# Initialize components
|
| 48 |
+
self._initialize_components()
|
| 49 |
+
|
| 50 |
+
def update_config(self, new_config: dict):
|
| 51 |
+
"""
|
| 52 |
+
Update the pipeline configuration.
|
| 53 |
+
This is useful for experiments that need different settings.
|
| 54 |
+
"""
|
| 55 |
+
if not isinstance(new_config, dict):
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
# Deep merge the new config with existing config
|
| 59 |
+
def deep_merge(base_dict, update_dict):
|
| 60 |
+
for key, value in update_dict.items():
|
| 61 |
+
if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
|
| 62 |
+
deep_merge(base_dict[key], value)
|
| 63 |
+
else:
|
| 64 |
+
base_dict[key] = value
|
| 65 |
+
|
| 66 |
+
deep_merge(self.config, new_config)
|
| 67 |
+
|
| 68 |
+
# Auto-infer embedding model from collection name if not "docling"
|
| 69 |
+
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
|
| 70 |
+
if collection_name != 'docling':
|
| 71 |
+
inferred_model = get_embedding_model_for_collection(collection_name)
|
| 72 |
+
if inferred_model:
|
| 73 |
+
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
|
| 74 |
+
if 'retriever' not in self.config:
|
| 75 |
+
self.config['retriever'] = {}
|
| 76 |
+
self.config['retriever']['model'] = inferred_model
|
| 77 |
+
# Set default normalize parameter if not present
|
| 78 |
+
if 'normalize' not in self.config['retriever']:
|
| 79 |
+
self.config['retriever']['normalize'] = True
|
| 80 |
+
|
| 81 |
+
# Also update vectorstore config if it exists
|
| 82 |
+
if 'vectorstore' in self.config:
|
| 83 |
+
self.config['vectorstore']['embedding_model'] = inferred_model
|
| 84 |
+
|
| 85 |
+
print(f"🔧 CONFIG UPDATED: Pipeline config updated with experiment settings")
|
| 86 |
+
|
| 87 |
+
# Re-initialize vectorstore manager with updated config
|
| 88 |
+
self._reinitialize_vectorstore_manager()
|
| 89 |
+
|
| 90 |
+
def _reinitialize_vectorstore_manager(self):
|
| 91 |
+
"""Re-initialize vectorstore manager with current config."""
|
| 92 |
+
try:
|
| 93 |
+
self.vectorstore_manager = VectorStoreManager(self.config)
|
| 94 |
+
print("🔄 VectorStore manager re-initialized with updated config")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"❌ Error re-initializing vectorstore manager: {e}")
|
| 97 |
+
|
| 98 |
+
def _get_reranker_model_name(self) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Get the reranker model name from configuration.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Reranker model name or default
|
| 104 |
+
"""
|
| 105 |
+
return (
|
| 106 |
+
self.config.get('retrieval', {}).get('reranker_model') or
|
| 107 |
+
self.config.get('ranker', {}).get('model') or
|
| 108 |
+
self.config.get('reranker_model') or
|
| 109 |
+
'BAAI/bge-reranker-v2-m3'
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _initialize_components(self):
|
| 113 |
+
"""Initialize pipeline components."""
|
| 114 |
+
try:
|
| 115 |
+
# Load config if not provided
|
| 116 |
+
if not self.config:
|
| 117 |
+
from auditqa.config.loader import load_config
|
| 118 |
+
self.config = load_config()
|
| 119 |
+
|
| 120 |
+
# Auto-infer embedding model from collection name if not "docling"
|
| 121 |
+
collection_name = self.config.get('qdrant', {}).get('collection_name', 'docling')
|
| 122 |
+
if collection_name != 'docling':
|
| 123 |
+
inferred_model = get_embedding_model_for_collection(collection_name)
|
| 124 |
+
if inferred_model:
|
| 125 |
+
print(f"🔍 Auto-inferred embedding model for collection '{collection_name}': {inferred_model}")
|
| 126 |
+
if 'retriever' not in self.config:
|
| 127 |
+
self.config['retriever'] = {}
|
| 128 |
+
self.config['retriever']['model'] = inferred_model
|
| 129 |
+
# Set default normalize parameter if not present
|
| 130 |
+
if 'normalize' not in self.config['retriever']:
|
| 131 |
+
self.config['retriever']['normalize'] = True
|
| 132 |
+
|
| 133 |
+
# Also update vectorstore config if it exists
|
| 134 |
+
if 'vectorstore' in self.config:
|
| 135 |
+
self.config['vectorstore']['embedding_model'] = inferred_model
|
| 136 |
+
|
| 137 |
+
self.vectorstore_manager = VectorStoreManager(self.config)
|
| 138 |
+
|
| 139 |
+
self.llm_manager = LLMRegistry()
|
| 140 |
+
|
| 141 |
+
# Try to get LLM client using the correct method
|
| 142 |
+
self.llm_client = None
|
| 143 |
+
try:
|
| 144 |
+
# Try using get_adapter method (most likely correct)
|
| 145 |
+
self.llm_client = self.llm_manager.get_adapter("openai")
|
| 146 |
+
print("✅ LLM CLIENT: Initialized using get_adapter method")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
try:
|
| 149 |
+
# Try direct instantiation with config
|
| 150 |
+
from auditqa.llm.adapters import get_llm_client
|
| 151 |
+
self.llm_client = get_llm_client("openai", self.config)
|
| 152 |
+
print("✅ LLM CLIENT: Initialized using direct get_llm_client function with config")
|
| 153 |
+
except Exception as e2:
|
| 154 |
+
print(f"❌ LLM CLIENT: Registry methods failed - {e2}")
|
| 155 |
+
# Try to create a simple LLM client directly
|
| 156 |
+
try:
|
| 157 |
+
from langchain_openai import ChatOpenAI
|
| 158 |
+
import os
|
| 159 |
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
|
| 160 |
+
if api_key:
|
| 161 |
+
self.llm_client = ChatOpenAI(
|
| 162 |
+
model="gpt-3.5-turbo",
|
| 163 |
+
api_key=api_key,
|
| 164 |
+
temperature=0.1,
|
| 165 |
+
max_tokens=1000
|
| 166 |
+
)
|
| 167 |
+
print("✅ LLM CLIENT: Initialized using direct ChatOpenAI")
|
| 168 |
+
else:
|
| 169 |
+
print("❌ LLM CLIENT: No API key available")
|
| 170 |
+
except Exception as e3:
|
| 171 |
+
print(f"❌ LLM CLIENT: Direct instantiation also failed - {e3}")
|
| 172 |
+
self.llm_client = None
|
| 173 |
+
|
| 174 |
+
# Load system prompt
|
| 175 |
+
from auditqa.llm.templates import DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 176 |
+
self.system_prompt = DEFAULT_AUDIT_SYSTEM_PROMPT
|
| 177 |
+
|
| 178 |
+
# Initialize report service
|
| 179 |
+
try:
|
| 180 |
+
from auditqa.reporting.service import ReportService
|
| 181 |
+
self.report_service = ReportService()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Warning: Could not initialize report service: {e}")
|
| 184 |
+
self.report_service = None
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Warning: Error initializing components: {e}")
|
| 188 |
+
|
| 189 |
+
def test_retrieval(
|
| 190 |
+
self,
|
| 191 |
+
query: str,
|
| 192 |
+
reports: List[str] = None,
|
| 193 |
+
sources: str = None,
|
| 194 |
+
subtype: List[str] = None,
|
| 195 |
+
k: int = None,
|
| 196 |
+
search_mode: str = None,
|
| 197 |
+
search_alpha: float = None,
|
| 198 |
+
use_reranking: bool = True
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
"""
|
| 201 |
+
Test retrieval only without LLM inference.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
query: User query
|
| 205 |
+
reports: List of specific report filenames
|
| 206 |
+
sources: Source category
|
| 207 |
+
subtype: List of subtypes
|
| 208 |
+
k: Number of documents to retrieve
|
| 209 |
+
search_mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 210 |
+
search_alpha: Weight for vector scores in hybrid mode
|
| 211 |
+
use_reranking: Whether to use reranking
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Dictionary with retrieval results and metadata
|
| 215 |
+
"""
|
| 216 |
+
start_time = time.time()
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
# Set default search parameters if not provided
|
| 220 |
+
if search_mode is None:
|
| 221 |
+
search_mode = self.config.get("hybrid", {}).get("default_mode", "vector_only")
|
| 222 |
+
if search_alpha is None:
|
| 223 |
+
search_alpha = self.config.get("hybrid", {}).get("default_alpha", 0.5)
|
| 224 |
+
|
| 225 |
+
# Get vector store
|
| 226 |
+
vectorstore = self.vectorstore_manager.get_vectorstore()
|
| 227 |
+
if not vectorstore:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
"Vector store not available. Call connect_vectorstore() or create_vectorstore() first."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Retrieve context with scores for test retrieval
|
| 233 |
+
context_docs_with_scores = self.context_retriever.retrieve_with_scores(
|
| 234 |
+
vectorstore=vectorstore,
|
| 235 |
+
query=query,
|
| 236 |
+
reports=reports,
|
| 237 |
+
sources=sources,
|
| 238 |
+
subtype=subtype,
|
| 239 |
+
k=k,
|
| 240 |
+
search_mode=search_mode,
|
| 241 |
+
alpha=search_alpha,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Extract documents and scores
|
| 245 |
+
context_docs = [doc for doc, score in context_docs_with_scores]
|
| 246 |
+
context_scores = [score for doc, score in context_docs_with_scores]
|
| 247 |
+
|
| 248 |
+
execution_time = time.time() - start_time
|
| 249 |
+
|
| 250 |
+
# Format results with actual scores
|
| 251 |
+
results = []
|
| 252 |
+
for i, (doc, score) in enumerate(zip(context_docs, context_scores)):
|
| 253 |
+
results.append({
|
| 254 |
+
"rank": i + 1,
|
| 255 |
+
"content": doc.page_content, # Return full content without truncation
|
| 256 |
+
"metadata": doc.metadata,
|
| 257 |
+
"score": score if score is not None else 0.0
|
| 258 |
+
})
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"results": results,
|
| 262 |
+
"num_results": len(results),
|
| 263 |
+
"execution_time": execution_time,
|
| 264 |
+
"search_mode": search_mode,
|
| 265 |
+
"search_alpha": search_alpha,
|
| 266 |
+
"query": query
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"❌ Error during retrieval test: {e}")
|
| 271 |
+
log_error(e, {"component": "retrieval_test", "query": query})
|
| 272 |
+
return {
|
| 273 |
+
"results": [],
|
| 274 |
+
"num_results": 0,
|
| 275 |
+
"execution_time": time.time() - start_time,
|
| 276 |
+
"error": str(e),
|
| 277 |
+
"search_mode": search_mode or "unknown",
|
| 278 |
+
"search_alpha": search_alpha or 0.5,
|
| 279 |
+
"query": query
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
def connect_vectorstore(self, force_recreate: bool = False) -> bool:
|
| 283 |
+
"""
|
| 284 |
+
Connect to existing vector store.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
force_recreate: If True, recreate the collection if dimension mismatch occurs
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
True if successful, False otherwise
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=force_recreate)
|
| 294 |
+
if vectorstore:
|
| 295 |
+
print("✅ Connected to vector store")
|
| 296 |
+
return True
|
| 297 |
+
else:
|
| 298 |
+
print("❌ Failed to connect to vector store")
|
| 299 |
+
return False
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f"❌ Error connecting to vector store: {e}")
|
| 302 |
+
log_error(e, {"component": "vectorstore_connection"})
|
| 303 |
+
|
| 304 |
+
# If it's a dimension mismatch error, try with force_recreate
|
| 305 |
+
if "dimensions" in str(e).lower() and not force_recreate:
|
| 306 |
+
print("🔄 Dimension mismatch detected, attempting to recreate collection...")
|
| 307 |
+
try:
|
| 308 |
+
vectorstore = self.vectorstore_manager.connect_to_existing(force_recreate=True)
|
| 309 |
+
if vectorstore:
|
| 310 |
+
print("✅ Connected to vector store (recreated)")
|
| 311 |
+
return True
|
| 312 |
+
except Exception as recreate_error:
|
| 313 |
+
print(f"❌ Failed to recreate vector store: {recreate_error}")
|
| 314 |
+
log_error(recreate_error, {"component": "vectorstore_recreation"})
|
| 315 |
+
|
| 316 |
+
return False
|
| 317 |
+
|
| 318 |
+
def create_vectorstore(self) -> bool:
|
| 319 |
+
"""
|
| 320 |
+
Create new vector store from chunks.
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
True if successful, False otherwise
|
| 324 |
+
"""
|
| 325 |
+
try:
|
| 326 |
+
if not self.chunks:
|
| 327 |
+
raise ValueError("No chunks available for vector store creation")
|
| 328 |
+
|
| 329 |
+
documents = chunks_to_documents(self.chunks)
|
| 330 |
+
self.vectorstore_manager.create_from_documents(documents)
|
| 331 |
+
print("✅ Vector store created successfully")
|
| 332 |
+
return True
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"❌ Error creating vector store: {e}")
|
| 335 |
+
log_error(e, {"component": "vectorstore_creation"})
|
| 336 |
+
return False
|
| 337 |
+
|
| 338 |
+
def create_audit_prompt(self, query: str, context_docs: List[Document]) -> str:
|
| 339 |
+
"""Create a prompt for the LLM to generate an answer."""
|
| 340 |
+
try:
|
| 341 |
+
# Ensure query is not None
|
| 342 |
+
if not query or not isinstance(query, str) or query.strip() == "":
|
| 343 |
+
return "Error: No query provided"
|
| 344 |
+
|
| 345 |
+
# Ensure context_docs is not None and is a list
|
| 346 |
+
if context_docs is None:
|
| 347 |
+
context_docs = []
|
| 348 |
+
|
| 349 |
+
# Filter out None documents and ensure they have content
|
| 350 |
+
valid_docs = []
|
| 351 |
+
for doc in context_docs:
|
| 352 |
+
if doc is not None:
|
| 353 |
+
if hasattr(doc, 'page_content') and doc.page_content and isinstance(doc.page_content, str):
|
| 354 |
+
valid_docs.append(doc)
|
| 355 |
+
elif isinstance(doc, str) and doc.strip():
|
| 356 |
+
valid_docs.append(doc)
|
| 357 |
+
|
| 358 |
+
# Create context string
|
| 359 |
+
if valid_docs:
|
| 360 |
+
context_parts = []
|
| 361 |
+
for i, doc in enumerate(valid_docs, 1):
|
| 362 |
+
if hasattr(doc, 'page_content') and doc.page_content:
|
| 363 |
+
context_parts.append(f"Doc {i}: {doc.page_content}")
|
| 364 |
+
elif isinstance(doc, str) and doc.strip():
|
| 365 |
+
context_parts.append(f"Doc {i}: {doc}")
|
| 366 |
+
|
| 367 |
+
context_string = "\n\n".join(context_parts)
|
| 368 |
+
else:
|
| 369 |
+
context_string = "No relevant context found."
|
| 370 |
+
|
| 371 |
+
# Create the prompt
|
| 372 |
+
prompt = f"""
|
| 373 |
+
{self.system_prompt}
|
| 374 |
+
|
| 375 |
+
Context:
|
| 376 |
+
{context_string}
|
| 377 |
+
|
| 378 |
+
Query: {query}
|
| 379 |
+
|
| 380 |
+
Answer:"""
|
| 381 |
+
|
| 382 |
+
return prompt
|
| 383 |
+
|
| 384 |
+
except Exception as e:
|
| 385 |
+
print(f"Error creating audit prompt: {e}")
|
| 386 |
+
return f"Error creating prompt: {e}"
|
| 387 |
+
|
| 388 |
+
def _generate_answer(self, prompt: str) -> str:
|
| 389 |
+
"""Generate answer using the LLM."""
|
| 390 |
+
try:
|
| 391 |
+
if not prompt or not isinstance(prompt, str) or prompt.strip() == "":
|
| 392 |
+
return "Error: No prompt provided"
|
| 393 |
+
|
| 394 |
+
# Ensure LLM client is available
|
| 395 |
+
if not self.llm_client:
|
| 396 |
+
return "Error: LLM client not available"
|
| 397 |
+
|
| 398 |
+
# Generate response using the correct method
|
| 399 |
+
if hasattr(self.llm_client, 'generate'):
|
| 400 |
+
# Use the generate method (for adapters)
|
| 401 |
+
response = self.llm_client.generate([{"role": "user", "content": prompt}])
|
| 402 |
+
|
| 403 |
+
# Extract content from LLMResponse
|
| 404 |
+
if hasattr(response, 'content'):
|
| 405 |
+
answer = response.content
|
| 406 |
+
else:
|
| 407 |
+
answer = str(response)
|
| 408 |
+
|
| 409 |
+
elif hasattr(self.llm_client, 'invoke'):
|
| 410 |
+
# Use the invoke method (for direct LangChain models)
|
| 411 |
+
response = self.llm_client.invoke(prompt)
|
| 412 |
+
|
| 413 |
+
# Extract content safely
|
| 414 |
+
if hasattr(response, 'content') and response.content is not None:
|
| 415 |
+
answer = response.content
|
| 416 |
+
elif isinstance(response, str) and response.strip():
|
| 417 |
+
answer = response
|
| 418 |
+
else:
|
| 419 |
+
answer = str(response) if response is not None else "Error: LLM returned None response"
|
| 420 |
+
else:
|
| 421 |
+
return "Error: LLM client has no generate or invoke method"
|
| 422 |
+
|
| 423 |
+
# Ensure answer is not None and is a string
|
| 424 |
+
if answer is None or not isinstance(answer, str):
|
| 425 |
+
return "Error: LLM returned invalid response"
|
| 426 |
+
|
| 427 |
+
return answer.strip()
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"Error generating answer: {e}")
|
| 431 |
+
return f"Error generating answer: {e}"
|
| 432 |
+
|
| 433 |
+
def run(
|
| 434 |
+
self,
|
| 435 |
+
query: str,
|
| 436 |
+
reports: List[str] = None,
|
| 437 |
+
sources: List[str] = None,
|
| 438 |
+
subtype: List[str] = None,
|
| 439 |
+
llm_provider: str = None,
|
| 440 |
+
use_reranking: bool = True,
|
| 441 |
+
search_mode: str = None,
|
| 442 |
+
search_alpha: float = None,
|
| 443 |
+
auto_infer_filters: bool = True,
|
| 444 |
+
filters: Dict[str, Any] = None,
|
| 445 |
+
) -> PipelineResult:
|
| 446 |
+
"""
|
| 447 |
+
Run the complete RAG pipeline.
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
query: User query
|
| 451 |
+
reports: List of specific report filenames
|
| 452 |
+
sources: Source category filter
|
| 453 |
+
subtype: List of subtypes/filenames
|
| 454 |
+
llm_provider: LLM provider to use
|
| 455 |
+
use_reranking: Whether to use reranking
|
| 456 |
+
search_mode: Search mode (vector, sparse, hybrid)
|
| 457 |
+
search_alpha: Alpha value for hybrid search
|
| 458 |
+
auto_infer_filters: Whether to auto-infer filters from query
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
PipelineResult object
|
| 462 |
+
"""
|
| 463 |
+
try:
|
| 464 |
+
# Validate input
|
| 465 |
+
if not query or not isinstance(query, str) or query.strip() == "":
|
| 466 |
+
return PipelineResult(
|
| 467 |
+
answer="Error: Invalid query provided",
|
| 468 |
+
sources=[],
|
| 469 |
+
execution_time=0.0,
|
| 470 |
+
metadata={'error': 'Invalid query'},
|
| 471 |
+
query=query
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Ensure lists are not None
|
| 475 |
+
if reports is None:
|
| 476 |
+
reports = []
|
| 477 |
+
if subtype is None:
|
| 478 |
+
subtype = []
|
| 479 |
+
|
| 480 |
+
start_time = time.time()
|
| 481 |
+
|
| 482 |
+
# Auto-infer filters if enabled and no explicit filters provided
|
| 483 |
+
inferred_filters = {}
|
| 484 |
+
filters_applied = False
|
| 485 |
+
qdrant_filter = None # Add this
|
| 486 |
+
|
| 487 |
+
if auto_infer_filters and not any([reports, sources, subtype]):
|
| 488 |
+
print(f"🤖 AUTO-INFERRING FILTERS: No explicit filters provided, analyzing query...")
|
| 489 |
+
try:
|
| 490 |
+
# Import get_available_metadata here to avoid circular imports
|
| 491 |
+
from auditqa.retrieval.filter import get_available_metadata, infer_filters_from_query
|
| 492 |
+
|
| 493 |
+
# Get available metadata
|
| 494 |
+
available_metadata = get_available_metadata(self.vectorstore_manager.get_vectorstore())
|
| 495 |
+
|
| 496 |
+
# Infer filters from query - this returns a Qdrant filter
|
| 497 |
+
qdrant_filter, filter_summary = infer_filters_from_query(
|
| 498 |
+
query=query,
|
| 499 |
+
available_metadata=available_metadata,
|
| 500 |
+
llm_client=self.llm_client
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
if qdrant_filter:
|
| 504 |
+
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
|
| 505 |
+
filters_applied = True
|
| 506 |
+
# Don't set sources/reports/subtype - use the Qdrant filter directly
|
| 507 |
+
else:
|
| 508 |
+
print(f"⚠️ NO QDRANT FILTER: Could not build Qdrant filter from query")
|
| 509 |
+
|
| 510 |
+
except Exception as e:
|
| 511 |
+
print(f"❌ AUTO-INFERENCE FAILED: {e}")
|
| 512 |
+
qdrant_filter = None
|
| 513 |
+
else:
|
| 514 |
+
# Check if any explicit filters were provided
|
| 515 |
+
filters_applied = any([reports, sources, subtype])
|
| 516 |
+
if filters_applied:
|
| 517 |
+
print(f"✅ EXPLICIT FILTERS: Using provided filters")
|
| 518 |
+
else:
|
| 519 |
+
print(f"⚠️ NO FILTERS: No explicit filters and auto-inference disabled")
|
| 520 |
+
|
| 521 |
+
# Extract filter parameters from the filters parameter
|
| 522 |
+
reports = filters.get('reports', []) if filters else []
|
| 523 |
+
sources = filters.get('sources', []) if filters else []
|
| 524 |
+
subtype = filters.get('subtype', []) if filters else []
|
| 525 |
+
year = filters.get('year', []) if filters else []
|
| 526 |
+
district = filters.get('district', []) if filters else []
|
| 527 |
+
filenames = filters.get('filenames', []) if filters else [] # Support mutually exclusive filename filtering
|
| 528 |
+
|
| 529 |
+
# Get vectorstore
|
| 530 |
+
vectorstore = self.vectorstore_manager.get_vectorstore()
|
| 531 |
+
if not vectorstore:
|
| 532 |
+
return PipelineResult(
|
| 533 |
+
answer="Error: Vector store not available",
|
| 534 |
+
sources=[],
|
| 535 |
+
execution_time=0.0,
|
| 536 |
+
metadata={'error': 'Vector store not available'},
|
| 537 |
+
query=query
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Initialize context retriever if not already done
|
| 541 |
+
if not hasattr(self, 'context_retriever') or self.context_retriever is None:
|
| 542 |
+
# Get the actual vectorstore object
|
| 543 |
+
vectorstore_obj = self.vectorstore_manager.get_vectorstore()
|
| 544 |
+
if vectorstore_obj is None:
|
| 545 |
+
print("❌ ERROR: Vectorstore is None, cannot initialize ContextRetriever")
|
| 546 |
+
return None
|
| 547 |
+
self.context_retriever = ContextRetriever(vectorstore_obj, self.config)
|
| 548 |
+
print("✅ ContextRetriever initialized successfully")
|
| 549 |
+
|
| 550 |
+
# Debug config access
|
| 551 |
+
print(f" CONFIG DEBUG: Full config keys: {list(self.config.keys()) if isinstance(self.config, dict) else 'Not a dict'}")
|
| 552 |
+
print(f"🔍 CONFIG DEBUG: Retriever config: {self.config.get('retriever', {})}")
|
| 553 |
+
print(f"🔍 CONFIG DEBUG: Retrieval config: {self.config.get('retrieval', {})}")
|
| 554 |
+
print(f"🔍 CONFIG DEBUG: use_reranking from config: {self.config.get('retrieval', {}).get('use_reranking', 'NOT_FOUND')}")
|
| 555 |
+
|
| 556 |
+
# Get the correct top_k value
|
| 557 |
+
# Priority: experiment config > retriever config > default
|
| 558 |
+
top_k = (
|
| 559 |
+
self.config.get('retrieval', {}).get('top_k') or
|
| 560 |
+
self.config.get('retriever', {}).get('top_k') or
|
| 561 |
+
5
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Get reranking setting
|
| 565 |
+
use_reranking = self.config.get('retrieval', {}).get('use_reranking', False)
|
| 566 |
+
|
| 567 |
+
print(f"🔍 CONFIG DEBUG: Final top_k: {top_k}")
|
| 568 |
+
print(f"🔍 CONFIG DEBUG: Final use_reranking: {use_reranking}")
|
| 569 |
+
|
| 570 |
+
# Retrieve context using the context retriever
|
| 571 |
+
context_docs = self.context_retriever.retrieve_context(
|
| 572 |
+
query=query,
|
| 573 |
+
k=top_k,
|
| 574 |
+
reports=reports,
|
| 575 |
+
sources=sources,
|
| 576 |
+
subtype=subtype,
|
| 577 |
+
year=year,
|
| 578 |
+
district=district,
|
| 579 |
+
filenames=filenames,
|
| 580 |
+
use_reranking=use_reranking,
|
| 581 |
+
qdrant_filter=qdrant_filter
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Ensure context_docs is not None
|
| 585 |
+
if context_docs is None:
|
| 586 |
+
context_docs = []
|
| 587 |
+
|
| 588 |
+
# Generate answer
|
| 589 |
+
answer = self._generate_answer(self.create_audit_prompt(query, context_docs))
|
| 590 |
+
|
| 591 |
+
execution_time = time.time() - start_time
|
| 592 |
+
|
| 593 |
+
# Create result with comprehensive metadata
|
| 594 |
+
result = PipelineResult(
|
| 595 |
+
answer=answer,
|
| 596 |
+
sources=context_docs,
|
| 597 |
+
execution_time=execution_time,
|
| 598 |
+
metadata={
|
| 599 |
+
'llm_provider': llm_provider,
|
| 600 |
+
'use_reranking': use_reranking,
|
| 601 |
+
'search_mode': search_mode,
|
| 602 |
+
'search_alpha': search_alpha,
|
| 603 |
+
'auto_infer_filters': auto_infer_filters,
|
| 604 |
+
'filters_applied': filters_applied,
|
| 605 |
+
'with_filtering': filters_applied,
|
| 606 |
+
'filter_conditions': {
|
| 607 |
+
'reports': reports,
|
| 608 |
+
'sources': sources,
|
| 609 |
+
'subtype': subtype
|
| 610 |
+
},
|
| 611 |
+
'inferred_filters': inferred_filters,
|
| 612 |
+
'applied_filters': {
|
| 613 |
+
'reports': reports,
|
| 614 |
+
'sources': sources,
|
| 615 |
+
'subtype': subtype
|
| 616 |
+
},
|
| 617 |
+
# Store filter and reranking metadata
|
| 618 |
+
'filter_details': {
|
| 619 |
+
'explicit_filters': {
|
| 620 |
+
'reports': reports,
|
| 621 |
+
'sources': sources,
|
| 622 |
+
'subtype': subtype,
|
| 623 |
+
'year': year
|
| 624 |
+
},
|
| 625 |
+
'inferred_filters': inferred_filters if auto_infer_filters else {},
|
| 626 |
+
'auto_inference_enabled': auto_infer_filters,
|
| 627 |
+
'qdrant_filter_applied': qdrant_filter is not None,
|
| 628 |
+
'filter_summary': filter_summary if 'filter_summary' in locals() else None
|
| 629 |
+
},
|
| 630 |
+
'reranker_model': self._get_reranker_model_name() if use_reranking else None,
|
| 631 |
+
'reranker_applied': use_reranking,
|
| 632 |
+
'reranking_info': {
|
| 633 |
+
'model': self._get_reranker_model_name(),
|
| 634 |
+
'applied': use_reranking,
|
| 635 |
+
'top_k': len(context_docs) if context_docs else 0,
|
| 636 |
+
# 'original_documents': [
|
| 637 |
+
# {
|
| 638 |
+
# 'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 639 |
+
# 'metadata': doc.metadata,
|
| 640 |
+
# 'score': getattr(doc, 'score', getattr(doc, 'original_score', 0.0))
|
| 641 |
+
# } for doc in context_docs
|
| 642 |
+
# ] if use_reranking else None,
|
| 643 |
+
'reranked_documents': [
|
| 644 |
+
{
|
| 645 |
+
'content': doc.page_content[:200] + '...' if len(doc.page_content) > 200 else doc.page_content,
|
| 646 |
+
'metadata': doc.metadata,
|
| 647 |
+
'score': doc.metadata.get('original_score', getattr(doc, 'score', 0.0)),
|
| 648 |
+
'original_rank': doc.metadata.get('original_rank', None),
|
| 649 |
+
'final_rank': doc.metadata.get('final_rank', None),
|
| 650 |
+
'reranked_score': doc.metadata.get('reranked_score', None)
|
| 651 |
+
} for doc in context_docs
|
| 652 |
+
] if use_reranking else None
|
| 653 |
+
}
|
| 654 |
+
},
|
| 655 |
+
query=query
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return result
|
| 659 |
+
|
| 660 |
+
except Exception as e:
|
| 661 |
+
print(f"Error in pipeline run: {e}")
|
| 662 |
+
return PipelineResult(
|
| 663 |
+
answer=f"Error processing query: {e}",
|
| 664 |
+
sources=[],
|
| 665 |
+
execution_time=0.0,
|
| 666 |
+
metadata={'error': str(e)},
|
| 667 |
+
query=query
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def get_system_status(self) -> Dict[str, Any]:
|
| 673 |
+
"""
|
| 674 |
+
Get system status information.
|
| 675 |
+
|
| 676 |
+
Returns:
|
| 677 |
+
Dictionary with system status
|
| 678 |
+
"""
|
| 679 |
+
status = {
|
| 680 |
+
"config_loaded": bool(self.config),
|
| 681 |
+
"chunks_loaded": bool(self.chunks),
|
| 682 |
+
"vectorstore_connected": bool(
|
| 683 |
+
self.vectorstore_manager and self.vectorstore_manager.get_vectorstore()
|
| 684 |
+
),
|
| 685 |
+
"components_initialized": bool(
|
| 686 |
+
self.context_retriever and self.report_service
|
| 687 |
+
),
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
if self.chunks:
|
| 691 |
+
status["num_chunks"] = len(self.chunks)
|
| 692 |
+
|
| 693 |
+
if self.report_service:
|
| 694 |
+
status["available_sources"] = self.report_service.get_available_sources()
|
| 695 |
+
status["available_reports"] = len(
|
| 696 |
+
self.report_service.get_available_reports()
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
status["overall_status"] = (
|
| 700 |
+
"ready"
|
| 701 |
+
if all(
|
| 702 |
+
[
|
| 703 |
+
status["config_loaded"],
|
| 704 |
+
status["chunks_loaded"],
|
| 705 |
+
status["vectorstore_connected"],
|
| 706 |
+
status["components_initialized"],
|
| 707 |
+
]
|
| 708 |
+
)
|
| 709 |
+
else "not_ready"
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
return status
|
| 713 |
+
|
| 714 |
+
def get_available_llm_providers(self) -> List[str]:
|
| 715 |
+
"""Get list of available LLM providers."""
|
| 716 |
+
providers = []
|
| 717 |
+
reader_config = self.config.get("reader", {})
|
| 718 |
+
|
| 719 |
+
for provider in [
|
| 720 |
+
"MISTRAL",
|
| 721 |
+
"OPENAI",
|
| 722 |
+
"OLLAMA",
|
| 723 |
+
"INF_PROVIDERS",
|
| 724 |
+
"NVIDIA",
|
| 725 |
+
"DEDICATED",
|
| 726 |
+
"OPENROUTER",
|
| 727 |
+
]:
|
| 728 |
+
if provider in reader_config:
|
| 729 |
+
providers.append(provider.lower())
|
| 730 |
+
|
| 731 |
+
return providers
|
src/reporting/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report metadata and utilities."""
|
| 2 |
+
|
| 3 |
+
from .metadata import get_report_metadata, get_available_sources
|
| 4 |
+
from .service import ReportService
|
| 5 |
+
|
| 6 |
+
__all__ = ["get_report_metadata", "get_available_sources", "ReportService"]
|
src/reporting/feedback_schema.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Schema for RAG Chatbot
|
| 3 |
+
|
| 4 |
+
This module defines dataclasses for feedback data structures
|
| 5 |
+
and provides Snowflake schema generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass, asdict, field
|
| 9 |
+
from typing import List, Optional, Dict, Any, Union
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class RetrievedDocument:
|
| 15 |
+
"""Single retrieved document metadata"""
|
| 16 |
+
doc_id: str
|
| 17 |
+
filename: str
|
| 18 |
+
page: int
|
| 19 |
+
score: float
|
| 20 |
+
content: str
|
| 21 |
+
metadata: Dict[str, Any]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class RetrievalEntry:
|
| 26 |
+
"""Single retrieval operation metadata"""
|
| 27 |
+
rag_query: str
|
| 28 |
+
documents_retrieved: List[RetrievedDocument]
|
| 29 |
+
conversation_length: int
|
| 30 |
+
filters_applied: Optional[Dict[str, Any]] = None
|
| 31 |
+
timestamp: Optional[float] = None
|
| 32 |
+
_raw_data: Optional[Dict[str, Any]] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class UserFeedback:
|
| 37 |
+
"""User feedback submission data"""
|
| 38 |
+
feedback_id: str
|
| 39 |
+
open_ended_feedback: Optional[str]
|
| 40 |
+
score: int
|
| 41 |
+
is_feedback_about_last_retrieval: bool
|
| 42 |
+
retrieved_data: List[RetrievalEntry]
|
| 43 |
+
conversation_id: str
|
| 44 |
+
timestamp: float
|
| 45 |
+
message_count: int
|
| 46 |
+
has_retrievals: bool
|
| 47 |
+
retrieval_count: int
|
| 48 |
+
user_query: Optional[str] = None
|
| 49 |
+
bot_response: Optional[str] = None
|
| 50 |
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
| 51 |
+
|
| 52 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 53 |
+
"""Convert to dictionary with nested data structures"""
|
| 54 |
+
result = asdict(self)
|
| 55 |
+
# Handle nested objects
|
| 56 |
+
if self.retrieved_data:
|
| 57 |
+
result['retrieved_data'] = [self._serialize_retrieval_entry(entry) for entry in self.retrieved_data]
|
| 58 |
+
return result
|
| 59 |
+
|
| 60 |
+
def _serialize_retrieval_entry(self, entry: RetrievalEntry) -> Dict[str, Any]:
|
| 61 |
+
"""Serialize retrieval entry to dict"""
|
| 62 |
+
# If raw data exists, use it (it's already properly formatted)
|
| 63 |
+
if hasattr(entry, '_raw_data') and entry._raw_data:
|
| 64 |
+
return entry._raw_data
|
| 65 |
+
|
| 66 |
+
# Otherwise, serialize the dataclass
|
| 67 |
+
result = asdict(entry)
|
| 68 |
+
if entry.documents_retrieved:
|
| 69 |
+
result['documents_retrieved'] = [asdict(doc) for doc in entry.documents_retrieved]
|
| 70 |
+
return result
|
| 71 |
+
|
| 72 |
+
def to_snowflake_schema(self) -> Dict[str, Any]:
|
| 73 |
+
"""Generate Snowflake schema for this dataclass"""
|
| 74 |
+
schema = {
|
| 75 |
+
"feedback_id": "VARCHAR(255)",
|
| 76 |
+
"open_ended_feedback": "VARCHAR(16777216)", # Large text
|
| 77 |
+
"score": "INTEGER",
|
| 78 |
+
"is_feedback_about_last_retrieval": "BOOLEAN",
|
| 79 |
+
"conversation_id": "VARCHAR(255)",
|
| 80 |
+
"timestamp": "NUMBER(20, 0)",
|
| 81 |
+
"message_count": "INTEGER",
|
| 82 |
+
"has_retrievals": "BOOLEAN",
|
| 83 |
+
"retrieval_count": "INTEGER",
|
| 84 |
+
"user_query": "VARCHAR(16777216)",
|
| 85 |
+
"bot_response": "VARCHAR(16777216)",
|
| 86 |
+
"created_at": "TIMESTAMP_NTZ",
|
| 87 |
+
"retrieved_data": "VARIANT", # Array of retrieval entries
|
| 88 |
+
# retrieved_data structure:
|
| 89 |
+
# [
|
| 90 |
+
# {
|
| 91 |
+
# "rag_query": "...",
|
| 92 |
+
# "conversation_length": 5,
|
| 93 |
+
# "timestamp": 1234567890,
|
| 94 |
+
# "docs_retrieved": [
|
| 95 |
+
# {"filename": "...", "page": 14, "score": 0.95, ...},
|
| 96 |
+
# ...
|
| 97 |
+
# ]
|
| 98 |
+
# },
|
| 99 |
+
# ...
|
| 100 |
+
# ]
|
| 101 |
+
}
|
| 102 |
+
return schema
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def get_snowflake_create_table_sql(cls, table_name: str = "user_feedback") -> str:
|
| 106 |
+
"""Generate CREATE TABLE SQL for Snowflake"""
|
| 107 |
+
schema = cls.to_snowflake_schema(None)
|
| 108 |
+
|
| 109 |
+
columns = []
|
| 110 |
+
for col_name, col_type in schema.items():
|
| 111 |
+
nullable = "NULL" if col_name not in ["feedback_id", "score", "timestamp"] else "NOT NULL"
|
| 112 |
+
columns.append(f" {col_name} {col_type} {nullable}")
|
| 113 |
+
|
| 114 |
+
# Build SQL string properly
|
| 115 |
+
columns_str = ",\n".join(columns)
|
| 116 |
+
|
| 117 |
+
sql = f"""CREATE TABLE IF NOT EXISTS {table_name} (
|
| 118 |
+
{columns_str},
|
| 119 |
+
PRIMARY KEY (feedback_id)
|
| 120 |
+
);
|
| 121 |
+
|
| 122 |
+
-- Create index on timestamp for querying by time
|
| 123 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_timestamp ON {table_name} (timestamp);
|
| 124 |
+
|
| 125 |
+
-- Create index on conversation_id for querying by conversation
|
| 126 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_conversation ON {table_name} (conversation_id);
|
| 127 |
+
|
| 128 |
+
-- Create index on score for feedback analysis
|
| 129 |
+
CREATE INDEX IF NOT EXISTS idx_feedback_score ON {table_name} (score);
|
| 130 |
+
"""
|
| 131 |
+
return sql
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# Snowflake variant schema for retrieved_data array
|
| 135 |
+
RETRIEVAL_ENTRY_SCHEMA = {
|
| 136 |
+
"rag_query": "VARCHAR",
|
| 137 |
+
"documents_retrieved": "ARRAY", # Array of document objects
|
| 138 |
+
"conversation_length": "INTEGER",
|
| 139 |
+
"filters_applied": "OBJECT",
|
| 140 |
+
"timestamp": "NUMBER"
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
DOCUMENT_SCHEMA = {
|
| 144 |
+
"doc_id": "VARCHAR",
|
| 145 |
+
"filename": "VARCHAR",
|
| 146 |
+
"page": "INTEGER",
|
| 147 |
+
"score": "DOUBLE",
|
| 148 |
+
"content": "VARCHAR(16777216)",
|
| 149 |
+
"metadata": "OBJECT"
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def generate_snowflake_schema_sql() -> str:
|
| 154 |
+
"""Generate complete Snowflake schema SQL for feedback system"""
|
| 155 |
+
return UserFeedback.get_snowflake_create_table_sql("user_feedback")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def create_feedback_from_dict(data: Dict[str, Any]) -> UserFeedback:
|
| 159 |
+
"""Create UserFeedback instance from dictionary"""
|
| 160 |
+
# Parse retrieved_data if present
|
| 161 |
+
retrieved_data = []
|
| 162 |
+
if "retrieved_data" in data and data["retrieved_data"]:
|
| 163 |
+
for entry_dict in data.get("retrieved_data", []):
|
| 164 |
+
# Map the actual structure from rag_retrieval_history
|
| 165 |
+
# Entry has: conversation_up_to, rag_query_expansion, docs_retrieved
|
| 166 |
+
try:
|
| 167 |
+
# Try to map to expected structure
|
| 168 |
+
entry = RetrievalEntry(
|
| 169 |
+
rag_query=entry_dict.get("rag_query_expansion", ""),
|
| 170 |
+
documents_retrieved=[], # Empty for now, will store as raw data
|
| 171 |
+
conversation_length=len(entry_dict.get("conversation_up_to", [])),
|
| 172 |
+
filters_applied=None,
|
| 173 |
+
timestamp=entry_dict.get("timestamp", None)
|
| 174 |
+
)
|
| 175 |
+
# Store raw data in the entry
|
| 176 |
+
entry._raw_data = entry_dict # Store original for preservation
|
| 177 |
+
retrieved_data.append(entry)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
# If mapping fails, store as-is without strict typing
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
return UserFeedback(
|
| 183 |
+
feedback_id=data.get("feedback_id", f"feedback_{data.get('timestamp', 'unknown')}"),
|
| 184 |
+
open_ended_feedback=data.get("open_ended_feedback"),
|
| 185 |
+
score=data["score"],
|
| 186 |
+
is_feedback_about_last_retrieval=data["is_feedback_about_last_retrieval"],
|
| 187 |
+
retrieved_data=retrieved_data,
|
| 188 |
+
conversation_id=data["conversation_id"],
|
| 189 |
+
timestamp=data["timestamp"],
|
| 190 |
+
message_count=data["message_count"],
|
| 191 |
+
has_retrievals=data["has_retrievals"],
|
| 192 |
+
retrieval_count=data["retrieval_count"],
|
| 193 |
+
user_query=data.get("user_query"),
|
| 194 |
+
bot_response=data.get("bot_response")
|
| 195 |
+
)
|
| 196 |
+
|
src/reporting/metadata.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report metadata management."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Any, Set
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_report_metadata(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 8 |
+
"""
|
| 9 |
+
Extract metadata from chunks.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
chunks: List of chunk dictionaries
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Dictionary with report metadata
|
| 16 |
+
"""
|
| 17 |
+
if not chunks:
|
| 18 |
+
return {}
|
| 19 |
+
|
| 20 |
+
sources = set()
|
| 21 |
+
filenames = set()
|
| 22 |
+
years = set()
|
| 23 |
+
|
| 24 |
+
for chunk in chunks:
|
| 25 |
+
metadata = chunk.get("metadata", {})
|
| 26 |
+
|
| 27 |
+
if "source" in metadata:
|
| 28 |
+
sources.add(metadata["source"])
|
| 29 |
+
|
| 30 |
+
if "filename" in metadata:
|
| 31 |
+
filenames.add(metadata["filename"])
|
| 32 |
+
|
| 33 |
+
if "year" in metadata:
|
| 34 |
+
years.add(metadata["year"])
|
| 35 |
+
|
| 36 |
+
return {
|
| 37 |
+
"sources": sorted(list(sources)),
|
| 38 |
+
"filenames": sorted(list(filenames)),
|
| 39 |
+
"years": sorted(list(years)),
|
| 40 |
+
"total_chunks": len(chunks)
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_available_sources() -> List[str]:
|
| 45 |
+
"""
|
| 46 |
+
Get list of available report sources (legacy compatibility).
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of source categories
|
| 50 |
+
"""
|
| 51 |
+
# This would typically come from the original auditqa_old.reports module
|
| 52 |
+
# For now, return common categories
|
| 53 |
+
return [
|
| 54 |
+
"Consolidated",
|
| 55 |
+
"Ministry, Department, Agency and Projects",
|
| 56 |
+
"Local Government",
|
| 57 |
+
"Value for Money",
|
| 58 |
+
"Thematic",
|
| 59 |
+
"Hospital",
|
| 60 |
+
"Project"
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_source_subtypes() -> Dict[str, List[str]]:
|
| 65 |
+
"""
|
| 66 |
+
Get mapping of sources to their subtypes (placeholder).
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dictionary mapping sources to subtypes
|
| 70 |
+
"""
|
| 71 |
+
# This was originally imported from auditqa_old.reports.new_files
|
| 72 |
+
# For now, return a placeholder structure
|
| 73 |
+
return {
|
| 74 |
+
"Consolidated": ["Annual Consolidated OAG 2024", "Annual Consolidated OAG 2023"],
|
| 75 |
+
"Local Government": ["District Reports", "Municipal Reports"],
|
| 76 |
+
"Ministry, Department, Agency and Projects": ["Ministry Reports", "Agency Reports"],
|
| 77 |
+
"Value for Money": ["VFM Reports 2024", "VFM Reports 2023"],
|
| 78 |
+
"Thematic": ["Thematic Reports 2024", "Thematic Reports 2023"],
|
| 79 |
+
"Hospital": ["Hospital Reports 2024", "Hospital Reports 2023"],
|
| 80 |
+
"Project": ["Project Reports 2024", "Project Reports 2023"]
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def validate_report_filters(
|
| 85 |
+
reports: List[str] = None,
|
| 86 |
+
sources: str = None,
|
| 87 |
+
subtype: List[str] = None,
|
| 88 |
+
available_metadata: Dict[str, Any] = None
|
| 89 |
+
) -> Dict[str, Any]:
|
| 90 |
+
"""
|
| 91 |
+
Validate report filter parameters.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
reports: List of specific report filenames
|
| 95 |
+
sources: Source category
|
| 96 |
+
subtype: List of subtypes
|
| 97 |
+
available_metadata: Available metadata for validation
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary with validation results
|
| 101 |
+
"""
|
| 102 |
+
validation_result = {
|
| 103 |
+
"valid": True,
|
| 104 |
+
"warnings": [],
|
| 105 |
+
"errors": []
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if not available_metadata:
|
| 109 |
+
validation_result["warnings"].append("No metadata available for validation")
|
| 110 |
+
return validation_result
|
| 111 |
+
|
| 112 |
+
available_sources = available_metadata.get("sources", [])
|
| 113 |
+
available_filenames = available_metadata.get("filenames", [])
|
| 114 |
+
|
| 115 |
+
# Validate sources
|
| 116 |
+
if sources and sources not in available_sources:
|
| 117 |
+
validation_result["errors"].append(f"Source '{sources}' not found in available sources")
|
| 118 |
+
validation_result["valid"] = False
|
| 119 |
+
|
| 120 |
+
# Validate reports
|
| 121 |
+
if reports:
|
| 122 |
+
for report in reports:
|
| 123 |
+
if report not in available_filenames:
|
| 124 |
+
validation_result["warnings"].append(f"Report '{report}' not found in available reports")
|
| 125 |
+
|
| 126 |
+
# Validate subtypes
|
| 127 |
+
if subtype:
|
| 128 |
+
for sub in subtype:
|
| 129 |
+
if sub not in available_filenames:
|
| 130 |
+
validation_result["warnings"].append(f"Subtype '{sub}' not found in available reports")
|
| 131 |
+
|
| 132 |
+
return validation_result
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_report_statistics(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 136 |
+
"""
|
| 137 |
+
Get statistics about reports in chunks.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
chunks: List of chunk dictionaries
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Dictionary with report statistics
|
| 144 |
+
"""
|
| 145 |
+
if not chunks:
|
| 146 |
+
return {}
|
| 147 |
+
|
| 148 |
+
stats = {
|
| 149 |
+
"total_chunks": len(chunks),
|
| 150 |
+
"sources": {},
|
| 151 |
+
"years": {},
|
| 152 |
+
"avg_chunk_length": 0,
|
| 153 |
+
"total_content_length": 0
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
total_length = 0
|
| 157 |
+
|
| 158 |
+
for chunk in chunks:
|
| 159 |
+
content = chunk.get("content", "")
|
| 160 |
+
total_length += len(content)
|
| 161 |
+
|
| 162 |
+
metadata = chunk.get("metadata", {})
|
| 163 |
+
|
| 164 |
+
# Count by source
|
| 165 |
+
source = metadata.get("source", "Unknown")
|
| 166 |
+
stats["sources"][source] = stats["sources"].get(source, 0) + 1
|
| 167 |
+
|
| 168 |
+
# Count by year
|
| 169 |
+
year = metadata.get("year", "Unknown")
|
| 170 |
+
stats["years"][year] = stats["years"].get(year, 0) + 1
|
| 171 |
+
|
| 172 |
+
stats["total_content_length"] = total_length
|
| 173 |
+
stats["avg_chunk_length"] = total_length / len(chunks) if chunks else 0
|
| 174 |
+
|
| 175 |
+
return stats
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def filter_chunks_by_metadata(
|
| 179 |
+
chunks: List[Dict[str, Any]],
|
| 180 |
+
source_filter: str = None,
|
| 181 |
+
filename_filter: List[str] = None,
|
| 182 |
+
year_filter: List[str] = None
|
| 183 |
+
) -> List[Dict[str, Any]]:
|
| 184 |
+
"""
|
| 185 |
+
Filter chunks by metadata criteria.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
chunks: List of chunk dictionaries
|
| 189 |
+
source_filter: Source to filter by
|
| 190 |
+
filename_filter: List of filenames to filter by
|
| 191 |
+
year_filter: List of years to filter by
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Filtered list of chunks
|
| 195 |
+
"""
|
| 196 |
+
filtered_chunks = chunks
|
| 197 |
+
|
| 198 |
+
if source_filter:
|
| 199 |
+
filtered_chunks = [
|
| 200 |
+
chunk for chunk in filtered_chunks
|
| 201 |
+
if chunk.get("metadata", {}).get("source") == source_filter
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
if filename_filter:
|
| 205 |
+
filtered_chunks = [
|
| 206 |
+
chunk for chunk in filtered_chunks
|
| 207 |
+
if chunk.get("metadata", {}).get("filename") in filename_filter
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
if year_filter:
|
| 211 |
+
filtered_chunks = [
|
| 212 |
+
chunk for chunk in filtered_chunks
|
| 213 |
+
if chunk.get("metadata", {}).get("year") in year_filter
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
return filtered_chunks
|
src/reporting/service.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Report service for managing report operations."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List, Any, Optional
|
| 4 |
+
from .metadata import get_report_metadata, get_available_sources, get_source_subtypes
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ReportService:
|
| 8 |
+
"""Service class for report operations."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, chunks: List[Dict[str, Any]] = None):
|
| 11 |
+
"""
|
| 12 |
+
Initialize report service.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
chunks: List of chunk dictionaries
|
| 16 |
+
"""
|
| 17 |
+
self.chunks = chunks or []
|
| 18 |
+
self.metadata = get_report_metadata(self.chunks) if self.chunks else {}
|
| 19 |
+
|
| 20 |
+
def get_available_sources(self) -> List[str]:
|
| 21 |
+
"""Get available report sources."""
|
| 22 |
+
if self.metadata:
|
| 23 |
+
return self.metadata.get("sources", [])
|
| 24 |
+
return get_available_sources()
|
| 25 |
+
|
| 26 |
+
def get_available_reports(self) -> List[str]:
|
| 27 |
+
"""Get available report filenames."""
|
| 28 |
+
return self.metadata.get("filenames", [])
|
| 29 |
+
|
| 30 |
+
def get_source_subtypes(self) -> Dict[str, List[str]]:
|
| 31 |
+
"""Get source to subtype mapping."""
|
| 32 |
+
# For now, use the placeholder function
|
| 33 |
+
# In a full implementation, this would be derived from actual data
|
| 34 |
+
return get_source_subtypes()
|
| 35 |
+
|
| 36 |
+
def get_reports_by_source(self, source: str) -> List[str]:
|
| 37 |
+
"""
|
| 38 |
+
Get reports filtered by source.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
source: Source category
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
List of report filenames
|
| 45 |
+
"""
|
| 46 |
+
if not self.chunks:
|
| 47 |
+
return []
|
| 48 |
+
|
| 49 |
+
reports = set()
|
| 50 |
+
for chunk in self.chunks:
|
| 51 |
+
metadata = chunk.get("metadata", {})
|
| 52 |
+
if metadata.get("source") == source:
|
| 53 |
+
filename = metadata.get("filename")
|
| 54 |
+
if filename:
|
| 55 |
+
reports.add(filename)
|
| 56 |
+
|
| 57 |
+
return sorted(list(reports))
|
| 58 |
+
|
| 59 |
+
def get_years_by_source(self, source: str) -> List[str]:
|
| 60 |
+
"""
|
| 61 |
+
Get years available for a specific source.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
source: Source category
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of years
|
| 68 |
+
"""
|
| 69 |
+
if not self.chunks:
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
years = set()
|
| 73 |
+
for chunk in self.chunks:
|
| 74 |
+
metadata = chunk.get("metadata", {})
|
| 75 |
+
if metadata.get("source") == source:
|
| 76 |
+
year = metadata.get("year")
|
| 77 |
+
if year:
|
| 78 |
+
years.add(year)
|
| 79 |
+
|
| 80 |
+
return sorted(list(years))
|
| 81 |
+
|
| 82 |
+
def search_reports(self, query: str) -> List[str]:
|
| 83 |
+
"""
|
| 84 |
+
Search for reports by name.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
query: Search query
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of matching report filenames
|
| 91 |
+
"""
|
| 92 |
+
if not self.chunks:
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
query_lower = query.lower()
|
| 96 |
+
matching_reports = set()
|
| 97 |
+
|
| 98 |
+
for chunk in self.chunks:
|
| 99 |
+
metadata = chunk.get("metadata", {})
|
| 100 |
+
filename = metadata.get("filename", "")
|
| 101 |
+
|
| 102 |
+
if query_lower in filename.lower():
|
| 103 |
+
matching_reports.add(filename)
|
| 104 |
+
|
| 105 |
+
return sorted(list(matching_reports))
|
| 106 |
+
|
| 107 |
+
def get_report_info(self, filename: str) -> Dict[str, Any]:
|
| 108 |
+
"""
|
| 109 |
+
Get information about a specific report.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
filename: Report filename
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dictionary with report information
|
| 116 |
+
"""
|
| 117 |
+
if not self.chunks:
|
| 118 |
+
return {}
|
| 119 |
+
|
| 120 |
+
report_info = {
|
| 121 |
+
"filename": filename,
|
| 122 |
+
"chunk_count": 0,
|
| 123 |
+
"sources": set(),
|
| 124 |
+
"years": set(),
|
| 125 |
+
"total_content_length": 0
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for chunk in self.chunks:
|
| 129 |
+
metadata = chunk.get("metadata", {})
|
| 130 |
+
if metadata.get("filename") == filename:
|
| 131 |
+
report_info["chunk_count"] += 1
|
| 132 |
+
report_info["total_content_length"] += len(chunk.get("content", ""))
|
| 133 |
+
|
| 134 |
+
if "source" in metadata:
|
| 135 |
+
report_info["sources"].add(metadata["source"])
|
| 136 |
+
|
| 137 |
+
if "year" in metadata:
|
| 138 |
+
report_info["years"].add(metadata["year"])
|
| 139 |
+
|
| 140 |
+
# Convert sets to lists
|
| 141 |
+
report_info["sources"] = list(report_info["sources"])
|
| 142 |
+
report_info["years"] = list(report_info["years"])
|
| 143 |
+
|
| 144 |
+
return report_info
|
src/reporting/snowflake_connector.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Snowflake Connector for Feedback System
|
| 3 |
+
|
| 4 |
+
This module handles inserting user feedback into Snowflake.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from src.reporting.feedback_schema import UserFeedback
|
| 12 |
+
|
| 13 |
+
# Try to import snowflake connector
|
| 14 |
+
try:
|
| 15 |
+
import snowflake.connector
|
| 16 |
+
SNOWFLAKE_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SNOWFLAKE_AVAILABLE = False
|
| 19 |
+
logging.warning("⚠️ snowflake-connector-python not installed. Install with: pip install snowflake-connector-python")
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SnowflakeFeedbackConnector:
|
| 27 |
+
"""Connector for inserting feedback into Snowflake"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
user: str,
|
| 32 |
+
password: str,
|
| 33 |
+
account: str,
|
| 34 |
+
warehouse: str,
|
| 35 |
+
database: str = "SNOWFLAKE_LEARNING",
|
| 36 |
+
schema: str = "PUBLIC"
|
| 37 |
+
):
|
| 38 |
+
self.user = user
|
| 39 |
+
self.password = password
|
| 40 |
+
self.account = account
|
| 41 |
+
self.warehouse = warehouse
|
| 42 |
+
self.database = database
|
| 43 |
+
self.schema = schema
|
| 44 |
+
self._connection = None
|
| 45 |
+
|
| 46 |
+
def connect(self):
|
| 47 |
+
"""Establish Snowflake connection"""
|
| 48 |
+
if not SNOWFLAKE_AVAILABLE:
|
| 49 |
+
raise ImportError("snowflake-connector-python is not installed. Install with: pip install snowflake-connector-python")
|
| 50 |
+
|
| 51 |
+
logger.info("=" * 80)
|
| 52 |
+
logger.info("🔌 SNOWFLAKE CONNECTION: Attempting to connect...")
|
| 53 |
+
logger.info(f" - Account: {self.account}")
|
| 54 |
+
logger.info(f" - Warehouse: {self.warehouse}")
|
| 55 |
+
logger.info(f" - Database: {self.database}")
|
| 56 |
+
logger.info(f" - Schema: {self.schema}")
|
| 57 |
+
logger.info(f" - User: {self.user}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
self._connection = snowflake.connector.connect(
|
| 61 |
+
user=self.user,
|
| 62 |
+
password=self.password,
|
| 63 |
+
account=self.account,
|
| 64 |
+
warehouse=self.warehouse
|
| 65 |
+
# Don't set database/schema in connection - we'll do it per query
|
| 66 |
+
)
|
| 67 |
+
logger.info("✅ SNOWFLAKE CONNECTION: Successfully connected")
|
| 68 |
+
logger.info("=" * 80)
|
| 69 |
+
print(f"✅ Connected to Snowflake: {self.database}.{self.schema}")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"❌ SNOWFLAKE CONNECTION FAILED: {e}")
|
| 72 |
+
logger.error("=" * 80)
|
| 73 |
+
print(f"❌ Failed to connect to Snowflake: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
def disconnect(self):
|
| 77 |
+
"""Close Snowflake connection"""
|
| 78 |
+
if self._connection:
|
| 79 |
+
self._connection.close()
|
| 80 |
+
print("✅ Disconnected from Snowflake")
|
| 81 |
+
|
| 82 |
+
def insert_feedback(self, feedback: UserFeedback) -> bool:
|
| 83 |
+
"""Insert a single feedback record into Snowflake"""
|
| 84 |
+
logger.info("=" * 80)
|
| 85 |
+
logger.info("🔄 SNOWFLAKE INSERT: Starting feedback insertion process")
|
| 86 |
+
logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
|
| 87 |
+
|
| 88 |
+
if not self._connection:
|
| 89 |
+
logger.error("❌ Not connected to Snowflake. Call connect() first.")
|
| 90 |
+
raise RuntimeError("Not connected to Snowflake. Call connect() first.")
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
logger.info("📊 VALIDATION: Validating feedback data structure...")
|
| 94 |
+
|
| 95 |
+
# Validate feedback object
|
| 96 |
+
validation_errors = []
|
| 97 |
+
if not feedback.feedback_id:
|
| 98 |
+
validation_errors.append("Missing feedback_id")
|
| 99 |
+
if feedback.score is None:
|
| 100 |
+
validation_errors.append("Missing score")
|
| 101 |
+
if feedback.timestamp is None:
|
| 102 |
+
validation_errors.append("Missing timestamp")
|
| 103 |
+
|
| 104 |
+
if validation_errors:
|
| 105 |
+
logger.error(f"❌ VALIDATION FAILED: {validation_errors}")
|
| 106 |
+
return False
|
| 107 |
+
else:
|
| 108 |
+
logger.info("✅ VALIDATION PASSED: All required fields present")
|
| 109 |
+
|
| 110 |
+
logger.info("📋 Data Summary:")
|
| 111 |
+
logger.info(f" - Feedback ID: {feedback.feedback_id}")
|
| 112 |
+
logger.info(f" - Score: {feedback.score}")
|
| 113 |
+
logger.info(f" - Conversation ID: {feedback.conversation_id}")
|
| 114 |
+
logger.info(f" - Has Retrievals: {feedback.has_retrievals}")
|
| 115 |
+
logger.info(f" - Retrieval Count: {feedback.retrieval_count}")
|
| 116 |
+
logger.info(f" - Message Count: {feedback.message_count}")
|
| 117 |
+
logger.info(f" - Timestamp: {feedback.timestamp}")
|
| 118 |
+
|
| 119 |
+
cursor = self._connection.cursor()
|
| 120 |
+
logger.info("✅ SNOWFLAKE CONNECTION: Cursor created")
|
| 121 |
+
|
| 122 |
+
# Set database and schema context
|
| 123 |
+
logger.info(f"🔧 SETTING CONTEXT: Database={self.database}, Schema={self.schema}")
|
| 124 |
+
try:
|
| 125 |
+
cursor.execute(f'USE DATABASE "{self.database}"')
|
| 126 |
+
cursor.execute(f'USE SCHEMA "{self.schema}"')
|
| 127 |
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
| 128 |
+
current_db, current_schema = cursor.fetchone()
|
| 129 |
+
logger.info(f"✅ Current context verified: Database={current_db}, Schema={current_schema}")
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"❌ Could not set context: {e}")
|
| 132 |
+
raise
|
| 133 |
+
|
| 134 |
+
# Prepare data
|
| 135 |
+
logger.info("🔧 DATA PREPARATION: Preparing retrieved_data...")
|
| 136 |
+
retrieved_data_raw = feedback.to_dict()['retrieved_data']
|
| 137 |
+
|
| 138 |
+
logger.info(f" - Retrieved data type (raw): {type(retrieved_data_raw).__name__}")
|
| 139 |
+
logger.info(f" - Retrieved data: {repr(retrieved_data_raw)[:200]}")
|
| 140 |
+
|
| 141 |
+
# If retrieved_data is already a string (from UI), parse it
|
| 142 |
+
if isinstance(retrieved_data_raw, str):
|
| 143 |
+
logger.info(" - Parsing string to Python object")
|
| 144 |
+
retrieved_data = json.loads(retrieved_data_raw)
|
| 145 |
+
elif retrieved_data_raw is None:
|
| 146 |
+
retrieved_data = None
|
| 147 |
+
else:
|
| 148 |
+
# It's already a Python object (list/dict)
|
| 149 |
+
logger.info(" - Data is already a Python object")
|
| 150 |
+
retrieved_data = retrieved_data_raw
|
| 151 |
+
|
| 152 |
+
logger.info(f" - Retrieved data size: {len(str(retrieved_data)) if retrieved_data else 0} characters")
|
| 153 |
+
logger.info(f" - Retrieved data type: {type(retrieved_data).__name__}")
|
| 154 |
+
|
| 155 |
+
# Convert to JSON string for TEXT column
|
| 156 |
+
if retrieved_data:
|
| 157 |
+
retrieved_data_for_db = json.dumps(retrieved_data)
|
| 158 |
+
logger.info(f" - Converting to JSON string for TEXT column")
|
| 159 |
+
logger.info(f" - JSON string length: {len(retrieved_data_for_db)}")
|
| 160 |
+
else:
|
| 161 |
+
logger.info(f" - Retrieved data is None, using NULL")
|
| 162 |
+
retrieved_data_for_db = None
|
| 163 |
+
|
| 164 |
+
# Build SQL with retrieved_data as a TEXT column parameter
|
| 165 |
+
sql = f"""INSERT INTO user_feedback (
|
| 166 |
+
feedback_id,
|
| 167 |
+
open_ended_feedback,
|
| 168 |
+
score,
|
| 169 |
+
is_feedback_about_last_retrieval,
|
| 170 |
+
conversation_id,
|
| 171 |
+
timestamp,
|
| 172 |
+
message_count,
|
| 173 |
+
has_retrievals,
|
| 174 |
+
retrieval_count,
|
| 175 |
+
user_query,
|
| 176 |
+
bot_response,
|
| 177 |
+
created_at,
|
| 178 |
+
retrieved_data
|
| 179 |
+
) VALUES (
|
| 180 |
+
%(feedback_id)s, %(open_ended_feedback)s, %(score)s, %(is_feedback_about_last_retrieval)s,
|
| 181 |
+
%(conversation_id)s, %(timestamp)s, %(message_count)s, %(has_retrievals)s,
|
| 182 |
+
%(retrieval_count)s, %(user_query)s, %(bot_response)s, %(created_at)s,
|
| 183 |
+
%(retrieved_data)s
|
| 184 |
+
)"""
|
| 185 |
+
|
| 186 |
+
logger.info("📝 SQL PREPARATION: Building INSERT statement...")
|
| 187 |
+
logger.info(f" - Target table: user_feedback")
|
| 188 |
+
logger.info(f" - Database: {self.database}")
|
| 189 |
+
logger.info(f" - Schema: {self.schema}")
|
| 190 |
+
|
| 191 |
+
# Prepare parameters
|
| 192 |
+
params = {
|
| 193 |
+
'feedback_id': feedback.feedback_id,
|
| 194 |
+
'open_ended_feedback': feedback.open_ended_feedback,
|
| 195 |
+
'score': feedback.score,
|
| 196 |
+
'is_feedback_about_last_retrieval': feedback.is_feedback_about_last_retrieval,
|
| 197 |
+
'conversation_id': feedback.conversation_id,
|
| 198 |
+
'timestamp': int(feedback.timestamp),
|
| 199 |
+
'message_count': feedback.message_count,
|
| 200 |
+
'has_retrievals': feedback.has_retrievals,
|
| 201 |
+
'retrieval_count': feedback.retrieval_count,
|
| 202 |
+
'user_query': feedback.user_query,
|
| 203 |
+
'bot_response': feedback.bot_response,
|
| 204 |
+
'created_at': feedback.created_at,
|
| 205 |
+
'retrieved_data': retrieved_data_for_db
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# Execute insert
|
| 209 |
+
logger.info("🚀 SQL EXECUTION: Executing INSERT query...")
|
| 210 |
+
cursor.execute(sql, params)
|
| 211 |
+
|
| 212 |
+
logger.info("✅ SQL EXECUTION: Query executed successfully")
|
| 213 |
+
logger.info(f" - Rows affected: 1")
|
| 214 |
+
logger.info(f" - Status: SUCCESS")
|
| 215 |
+
|
| 216 |
+
cursor.close()
|
| 217 |
+
logger.info("✅ SNOWFLAKE INSERT: Feedback inserted successfully")
|
| 218 |
+
logger.info(f"📝 Inserted feedback: {feedback.feedback_id}")
|
| 219 |
+
logger.info("=" * 80)
|
| 220 |
+
return True
|
| 221 |
+
|
| 222 |
+
except Exception as e:
|
| 223 |
+
# Check if it's a Snowflake error
|
| 224 |
+
if SNOWFLAKE_AVAILABLE and "ProgrammingError" in str(type(e)):
|
| 225 |
+
logger.error(f"❌ SQL EXECUTION ERROR: {e}")
|
| 226 |
+
logger.error(f" - Error code: {getattr(e, 'errno', 'Unknown')}")
|
| 227 |
+
logger.error(f" - SQL state: {getattr(e, 'sqlstate', 'Unknown')}")
|
| 228 |
+
else:
|
| 229 |
+
logger.error(f"❌ SNOWFLAKE INSERT FAILED: {type(e).__name__}")
|
| 230 |
+
logger.error(f" - Error: {e}")
|
| 231 |
+
logger.error("=" * 80)
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def __enter__(self):
|
| 235 |
+
"""Context manager entry"""
|
| 236 |
+
self.connect()
|
| 237 |
+
return self
|
| 238 |
+
|
| 239 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 240 |
+
"""Context manager exit"""
|
| 241 |
+
self.disconnect()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_snowflake_connector_from_env() -> Optional[SnowflakeFeedbackConnector]:
|
| 245 |
+
"""Create Snowflake connector from environment variables"""
|
| 246 |
+
user = os.getenv("SNOWFLAKE_USER")
|
| 247 |
+
password = os.getenv("SNOWFLAKE_PASSWORD")
|
| 248 |
+
account = os.getenv("SNOWFLAKE_ACCOUNT")
|
| 249 |
+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE")
|
| 250 |
+
database = os.getenv("SNOWFLAKE_DATABASE", "SNOWFLAKE_LEARN")
|
| 251 |
+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
|
| 252 |
+
|
| 253 |
+
if not all([user, password, account, warehouse]):
|
| 254 |
+
print("⚠️ Snowflake credentials not found in environment variables")
|
| 255 |
+
print("Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
return SnowflakeFeedbackConnector(
|
| 259 |
+
user=user,
|
| 260 |
+
password=password,
|
| 261 |
+
account=account,
|
| 262 |
+
warehouse=warehouse,
|
| 263 |
+
database=database,
|
| 264 |
+
schema=schema
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def save_to_snowflake(feedback: UserFeedback) -> bool:
|
| 269 |
+
"""Helper function to save feedback to Snowflake"""
|
| 270 |
+
logger.info("=" * 80)
|
| 271 |
+
logger.info("🔵 SNOWFLAKE SAVE: Starting save process")
|
| 272 |
+
logger.info(f"📝 Feedback ID: {feedback.feedback_id}")
|
| 273 |
+
|
| 274 |
+
connector = get_snowflake_connector_from_env()
|
| 275 |
+
|
| 276 |
+
if not connector:
|
| 277 |
+
logger.warning("⚠️ SNOWFLAKE SAVE: Skipping insertion (credentials not configured)")
|
| 278 |
+
logger.warning(" Required variables: SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, SNOWFLAKE_ACCOUNT, SNOWFLAKE_WAREHOUSE")
|
| 279 |
+
logger.info("=" * 80)
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
logger.info("📡 SNOWFLAKE SAVE: Establishing connection...")
|
| 284 |
+
connector.connect()
|
| 285 |
+
logger.info("✅ SNOWFLAKE SAVE: Connection established")
|
| 286 |
+
|
| 287 |
+
logger.info("📥 SNOWFLAKE SAVE: Attempting to insert feedback...")
|
| 288 |
+
success = connector.insert_feedback(feedback)
|
| 289 |
+
|
| 290 |
+
logger.info("🔌 SNOWFLAKE SAVE: Disconnecting...")
|
| 291 |
+
connector.disconnect()
|
| 292 |
+
|
| 293 |
+
if success:
|
| 294 |
+
logger.info("✅ SNOWFLAKE SAVE: Successfully saved feedback")
|
| 295 |
+
else:
|
| 296 |
+
logger.error("❌ SNOWFLAKE SAVE: Failed to save feedback")
|
| 297 |
+
|
| 298 |
+
logger.info("=" * 80)
|
| 299 |
+
return success
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"❌ SNOWFLAKE SAVE ERROR: {type(e).__name__}")
|
| 302 |
+
logger.error(f" - Error: {e}")
|
| 303 |
+
logger.info("=" * 80)
|
| 304 |
+
return False
|
| 305 |
+
|
src/retrieval/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document retrieval and filtering utilities."""
|
| 2 |
+
|
| 3 |
+
from .filter import create_filter, FilterBuilder
|
| 4 |
+
from .context import ContextRetriever, get_context
|
| 5 |
+
from .hybrid import HybridRetriever, get_available_search_modes, get_search_mode_description
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"create_filter",
|
| 9 |
+
"FilterBuilder",
|
| 10 |
+
"ContextRetriever",
|
| 11 |
+
"get_context",
|
| 12 |
+
"HybridRetriever",
|
| 13 |
+
"get_available_search_modes",
|
| 14 |
+
"get_search_mode_description"
|
| 15 |
+
]
|
src/retrieval/colbert_cache.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ColBERT embeddings cache for test set documents.
|
| 3 |
+
Provides O(1) lookup for ColBERT embeddings during late interaction.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Optional, Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ColBERTCache:
|
| 13 |
+
"""Cache for ColBERT embeddings of test set documents."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, cache_file: str = "test_set_colbert_cache.json"):
|
| 16 |
+
self.cache_file = Path("outputs/caches") / cache_file
|
| 17 |
+
self.embeddings_cache: Dict[str, np.ndarray] = {}
|
| 18 |
+
self._load_cache()
|
| 19 |
+
|
| 20 |
+
def _load_cache(self):
|
| 21 |
+
"""Load embeddings from cache file."""
|
| 22 |
+
if not self.cache_file.exists():
|
| 23 |
+
print(f"⚠️ ColBERT cache not found: {self.cache_file}")
|
| 24 |
+
print("💡 Run 'python precalculate_test_set_colbert.py' to create cache")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
print(f"📂 Loading ColBERT cache from {self.cache_file}...")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
with open(self.cache_file, 'r') as f:
|
| 31 |
+
cache_data = json.load(f)
|
| 32 |
+
|
| 33 |
+
# Reconstruct embeddings from compressed format
|
| 34 |
+
for doc_id, data in cache_data.items():
|
| 35 |
+
embedding_min = data['min']
|
| 36 |
+
embedding_max = data['max']
|
| 37 |
+
quantized_embedding = np.array(data['embedding'], dtype=np.uint8)
|
| 38 |
+
|
| 39 |
+
# Reconstruct original embedding
|
| 40 |
+
reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min
|
| 41 |
+
self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape'])
|
| 42 |
+
|
| 43 |
+
print(f"✅ Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache")
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"❌ Error loading ColBERT cache: {e}")
|
| 47 |
+
self.embeddings_cache = {}
|
| 48 |
+
|
| 49 |
+
def get_embedding(self, document_text: str) -> Optional[np.ndarray]:
|
| 50 |
+
"""Get ColBERT embedding for a document (O(1) lookup)."""
|
| 51 |
+
return self.embeddings_cache.get(document_text)
|
| 52 |
+
|
| 53 |
+
def has_embedding(self, document_text: str) -> bool:
|
| 54 |
+
"""Check if embedding exists for document."""
|
| 55 |
+
return document_text in self.embeddings_cache
|
| 56 |
+
|
| 57 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 58 |
+
"""Get cache statistics."""
|
| 59 |
+
return {
|
| 60 |
+
'total_embeddings': len(self.embeddings_cache),
|
| 61 |
+
'cache_file': str(self.cache_file),
|
| 62 |
+
'cache_exists': self.cache_file.exists()
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Global cache instance
|
| 67 |
+
_colbert_cache = None
|
| 68 |
+
|
| 69 |
+
def get_colbert_cache() -> ColBERTCache:
|
| 70 |
+
"""Get global ColBERT cache instance."""
|
| 71 |
+
global _colbert_cache
|
| 72 |
+
if _colbert_cache is None:
|
| 73 |
+
_colbert_cache = ColBERTCache()
|
| 74 |
+
return _colbert_cache
|
src/retrieval/context.py
ADDED
|
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Context retrieval with reranking capabilities."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Optional, Tuple, Dict, Any
|
| 5 |
+
from langchain.schema import Document
|
| 6 |
+
from langchain_community.vectorstores import Qdrant
|
| 7 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 8 |
+
from sentence_transformers import CrossEncoder
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from qdrant_client.http import models as rest
|
| 12 |
+
import traceback
|
| 13 |
+
|
| 14 |
+
from .filter import create_filter
|
| 15 |
+
|
| 16 |
+
class ContextRetriever:
|
| 17 |
+
"""
|
| 18 |
+
Context retriever for hybrid search with optional filtering and reranking.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, vectorstore: Qdrant, config: dict = None):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the context retriever.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
vectorstore: Qdrant vector store instance
|
| 27 |
+
config: Configuration dictionary
|
| 28 |
+
"""
|
| 29 |
+
self.vectorstore = vectorstore
|
| 30 |
+
self.config = config or {}
|
| 31 |
+
self.reranker = None
|
| 32 |
+
|
| 33 |
+
# BM25 attributes
|
| 34 |
+
self.bm25_vectorizer = None
|
| 35 |
+
self.bm25_matrix = None
|
| 36 |
+
self.bm25_documents = None
|
| 37 |
+
|
| 38 |
+
# Initialize reranker if available
|
| 39 |
+
# Try to get reranker model from different config paths
|
| 40 |
+
self.reranker_model_name = (
|
| 41 |
+
config.get('retrieval', {}).get('reranker_model') or
|
| 42 |
+
config.get('ranker', {}).get('model') or
|
| 43 |
+
config.get('reranker_model') or
|
| 44 |
+
'BAAI/bge-reranker-v2-m3'
|
| 45 |
+
)
|
| 46 |
+
self.reranker_type = self._detect_reranker_type(self.reranker_model_name)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
if self.reranker_type == 'colbert':
|
| 50 |
+
from colbert.infra import Run, ColBERTConfig
|
| 51 |
+
from colbert.modeling.checkpoint import Checkpoint
|
| 52 |
+
# ColBERT uses late interaction - different implementation needed
|
| 53 |
+
print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})")
|
| 54 |
+
print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)")
|
| 55 |
+
|
| 56 |
+
# Create ColBERT config for CPU mode
|
| 57 |
+
colbert_config = ColBERTConfig(
|
| 58 |
+
doc_maxlen=300,
|
| 59 |
+
query_maxlen=32,
|
| 60 |
+
nbits=2,
|
| 61 |
+
kmeans_niters=4,
|
| 62 |
+
root="./colbert_data"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Load checkpoint (e.g. "colbert-ir/colbertv2.0")
|
| 66 |
+
self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
|
| 67 |
+
self.colbert_model = self.colbert_checkpoint.model
|
| 68 |
+
self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
|
| 69 |
+
self.reranker = self._colbert_rerank # attach wrapper function
|
| 70 |
+
print(f"✅ COLBERT: Model and tokenizer loaded successfully")
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
# Standard CrossEncoder for BGE and other models
|
| 74 |
+
from sentence_transformers import CrossEncoder
|
| 75 |
+
self.reranker = CrossEncoder(self.reranker_model_name)
|
| 76 |
+
print(f"✅ RERANKER: Initialized {self.reranker_model_name}")
|
| 77 |
+
print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"⚠️ Reranker initialization failed: {e}")
|
| 80 |
+
self.reranker = None
|
| 81 |
+
|
| 82 |
+
def _detect_reranker_type(self, model_name: str) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Detect the type of reranker based on model name.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model_name: Name of the reranker model
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
'colbert' for ColBERT models, 'crossencoder' for others
|
| 91 |
+
"""
|
| 92 |
+
model_name_lower = model_name.lower()
|
| 93 |
+
|
| 94 |
+
# ColBERT model patterns
|
| 95 |
+
colbert_patterns = [
|
| 96 |
+
'colbert',
|
| 97 |
+
'colbert-ir',
|
| 98 |
+
'colbertv2',
|
| 99 |
+
'colbert-v2'
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
for pattern in colbert_patterns:
|
| 103 |
+
if pattern in model_name_lower:
|
| 104 |
+
return 'colbert'
|
| 105 |
+
|
| 106 |
+
# Default to cross-encoder for BGE and other models
|
| 107 |
+
return 'crossencoder'
|
| 108 |
+
|
| 109 |
+
def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]:
|
| 110 |
+
"""
|
| 111 |
+
Perform similarity search and fetch ColBERT embeddings for documents.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
query: Search query
|
| 115 |
+
k: Number of documents to retrieve
|
| 116 |
+
**kwargs: Additional search parameters (filter, etc.)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
List of (Document, score) tuples with ColBERT embeddings in metadata
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings")
|
| 123 |
+
|
| 124 |
+
# Use the vectorstore's similarity_search_with_score method instead of direct client
|
| 125 |
+
# This ensures proper filter handling
|
| 126 |
+
if 'filter' in kwargs and kwargs['filter']:
|
| 127 |
+
# Use the vectorstore method with filter
|
| 128 |
+
result = self.vectorstore.similarity_search_with_score(
|
| 129 |
+
query,
|
| 130 |
+
k=k,
|
| 131 |
+
filter=kwargs['filter']
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
# Use the vectorstore method without filter
|
| 135 |
+
result = self.vectorstore.similarity_search_with_score(query, k=k)
|
| 136 |
+
|
| 137 |
+
# Convert to the format we need
|
| 138 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 139 |
+
documents, scores = result
|
| 140 |
+
elif isinstance(result, list):
|
| 141 |
+
documents = []
|
| 142 |
+
scores = []
|
| 143 |
+
for item in result:
|
| 144 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 145 |
+
doc, score = item
|
| 146 |
+
documents.append(doc)
|
| 147 |
+
scores.append(score)
|
| 148 |
+
else:
|
| 149 |
+
documents.append(item)
|
| 150 |
+
scores.append(0.0)
|
| 151 |
+
else:
|
| 152 |
+
documents = []
|
| 153 |
+
scores = []
|
| 154 |
+
|
| 155 |
+
# Now we need to fetch the ColBERT embeddings for these documents
|
| 156 |
+
# We'll use the Qdrant client directly for this part since we need specific payload fields
|
| 157 |
+
from qdrant_client.http import models as rest
|
| 158 |
+
|
| 159 |
+
collection_name = self.vectorstore.collection_name
|
| 160 |
+
|
| 161 |
+
# Get document IDs from the retrieved documents
|
| 162 |
+
doc_ids = []
|
| 163 |
+
for doc in documents:
|
| 164 |
+
# Extract ID from document metadata or use page_content hash as fallback
|
| 165 |
+
doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
|
| 166 |
+
if not doc_id:
|
| 167 |
+
# Use a hash of the content as ID
|
| 168 |
+
import hashlib
|
| 169 |
+
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
|
| 170 |
+
doc_ids.append(doc_id)
|
| 171 |
+
|
| 172 |
+
# Fetch documents with ColBERT embeddings from Qdrant
|
| 173 |
+
search_result = self.vectorstore.client.retrieve(
|
| 174 |
+
collection_name=collection_name,
|
| 175 |
+
ids=doc_ids,
|
| 176 |
+
with_payload=True,
|
| 177 |
+
with_vectors=False
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Convert results to Document objects with ColBERT embeddings
|
| 181 |
+
enhanced_documents = []
|
| 182 |
+
enhanced_scores = []
|
| 183 |
+
|
| 184 |
+
# Create a mapping from doc_id to original score
|
| 185 |
+
doc_id_to_score = {}
|
| 186 |
+
for i, doc in enumerate(documents):
|
| 187 |
+
doc_id = doc.metadata.get('id') or doc.metadata.get('_id')
|
| 188 |
+
if not doc_id:
|
| 189 |
+
import hashlib
|
| 190 |
+
doc_id = hashlib.md5(doc.page_content.encode()).hexdigest()
|
| 191 |
+
doc_id_to_score[doc_id] = scores[i]
|
| 192 |
+
|
| 193 |
+
for point in search_result:
|
| 194 |
+
# Extract payload
|
| 195 |
+
payload = point.payload
|
| 196 |
+
|
| 197 |
+
# Get the original score for this document
|
| 198 |
+
doc_id = str(point.id)
|
| 199 |
+
original_score = doc_id_to_score.get(doc_id, 0.0)
|
| 200 |
+
|
| 201 |
+
# Create Document object with ColBERT embeddings
|
| 202 |
+
doc = Document(
|
| 203 |
+
page_content=payload.get('page_content', ''),
|
| 204 |
+
metadata={
|
| 205 |
+
**payload.get('metadata', {}),
|
| 206 |
+
'colbert_embedding': payload.get('colbert_embedding'),
|
| 207 |
+
'colbert_model': payload.get('colbert_model'),
|
| 208 |
+
'colbert_calculated_at': payload.get('colbert_calculated_at')
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
enhanced_documents.append(doc)
|
| 213 |
+
enhanced_scores.append(original_score)
|
| 214 |
+
|
| 215 |
+
print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings")
|
| 216 |
+
|
| 217 |
+
return list(zip(enhanced_documents, enhanced_scores))
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"❌ COLBERT RETRIEVAL ERROR: {e}")
|
| 221 |
+
print(f"❌ Falling back to regular similarity search")
|
| 222 |
+
|
| 223 |
+
# Fallback to regular search - handle filter parameter correctly
|
| 224 |
+
if 'filter' in kwargs and kwargs['filter']:
|
| 225 |
+
return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter'])
|
| 226 |
+
else:
|
| 227 |
+
return self.vectorstore.similarity_search_with_score(query, k=k)
|
| 228 |
+
|
| 229 |
+
def retrieve_context(
|
| 230 |
+
self,
|
| 231 |
+
query: str,
|
| 232 |
+
k: int = 5,
|
| 233 |
+
reports: Optional[List[str]] = None,
|
| 234 |
+
sources: Optional[List[str]] = None,
|
| 235 |
+
subtype: Optional[str] = None,
|
| 236 |
+
year: Optional[str] = None,
|
| 237 |
+
district: Optional[List[str]] = None,
|
| 238 |
+
filenames: Optional[List[str]] = None,
|
| 239 |
+
use_reranking: bool = False,
|
| 240 |
+
qdrant_filter: Optional[rest.Filter] = None
|
| 241 |
+
) -> List[Document]:
|
| 242 |
+
"""
|
| 243 |
+
Retrieve context documents using hybrid search with optional filtering and reranking.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
query: User query
|
| 247 |
+
top_k: Number of documents to retrieve
|
| 248 |
+
reports: List of report names to filter by
|
| 249 |
+
sources: List of sources to filter by
|
| 250 |
+
subtype: Document subtype to filter by
|
| 251 |
+
year: Year to filter by
|
| 252 |
+
use_reranking: Whether to apply reranking
|
| 253 |
+
qdrant_filter: Pre-built Qdrant filter to use
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
List of retrieved documents
|
| 257 |
+
"""
|
| 258 |
+
try:
|
| 259 |
+
# Determine how many documents to retrieve
|
| 260 |
+
retrieve_k = k #* 3 if use_reranking else k # Retrieve more for reranking
|
| 261 |
+
|
| 262 |
+
# Build search kwargs
|
| 263 |
+
search_kwargs = {}
|
| 264 |
+
|
| 265 |
+
# Use qdrant_filter if provided (this takes precedence)
|
| 266 |
+
if qdrant_filter:
|
| 267 |
+
search_kwargs = {"filter": qdrant_filter}
|
| 268 |
+
print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter")
|
| 269 |
+
else:
|
| 270 |
+
# Build filter from individual parameters
|
| 271 |
+
filter_obj = create_filter(
|
| 272 |
+
reports=reports,
|
| 273 |
+
sources=sources,
|
| 274 |
+
subtype=subtype,
|
| 275 |
+
year=year,
|
| 276 |
+
district=district,
|
| 277 |
+
filenames=filenames
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if filter_obj:
|
| 281 |
+
search_kwargs = {"filter": filter_obj}
|
| 282 |
+
print(f"✅ FILTERS APPLIED: Using built filter")
|
| 283 |
+
else:
|
| 284 |
+
search_kwargs = {}
|
| 285 |
+
print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 286 |
+
|
| 287 |
+
# Perform vector search
|
| 288 |
+
try:
|
| 289 |
+
# Check if we need ColBERT embeddings for reranking
|
| 290 |
+
if use_reranking and self.reranker_type == 'colbert':
|
| 291 |
+
result = self._similarity_search_with_colbert_embeddings(
|
| 292 |
+
query,
|
| 293 |
+
k=retrieve_k,
|
| 294 |
+
**search_kwargs
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
result = self.vectorstore.similarity_search_with_score(
|
| 298 |
+
query,
|
| 299 |
+
k=retrieve_k,
|
| 300 |
+
**search_kwargs
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Handle different return formats
|
| 304 |
+
if isinstance(result, tuple) and len(result) == 2:
|
| 305 |
+
documents, scores = result
|
| 306 |
+
elif isinstance(result, list) and len(result) > 0:
|
| 307 |
+
# Handle case where result is a list of (Document, score) tuples
|
| 308 |
+
documents = []
|
| 309 |
+
scores = []
|
| 310 |
+
for item in result:
|
| 311 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 312 |
+
doc, score = item
|
| 313 |
+
documents.append(doc)
|
| 314 |
+
scores.append(score)
|
| 315 |
+
else:
|
| 316 |
+
# Handle case where item is just a Document
|
| 317 |
+
documents.append(item)
|
| 318 |
+
scores.append(0.0) # Default score
|
| 319 |
+
else:
|
| 320 |
+
documents = []
|
| 321 |
+
scores = []
|
| 322 |
+
|
| 323 |
+
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})")
|
| 324 |
+
|
| 325 |
+
# If we got fewer documents than requested, try without filters
|
| 326 |
+
if len(documents) < retrieve_k and search_kwargs.get('filter'):
|
| 327 |
+
print(f"⚠️ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...")
|
| 328 |
+
try:
|
| 329 |
+
result_no_filter = self.vectorstore.similarity_search_with_score(
|
| 330 |
+
query,
|
| 331 |
+
k=retrieve_k
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2:
|
| 335 |
+
documents_no_filter, scores_no_filter = result_no_filter
|
| 336 |
+
elif isinstance(result_no_filter, list):
|
| 337 |
+
documents_no_filter = []
|
| 338 |
+
scores_no_filter = []
|
| 339 |
+
for item in result_no_filter:
|
| 340 |
+
if isinstance(item, tuple) and len(item) == 2:
|
| 341 |
+
doc, score = item
|
| 342 |
+
documents_no_filter.append(doc)
|
| 343 |
+
scores_no_filter.append(score)
|
| 344 |
+
else:
|
| 345 |
+
documents_no_filter.append(item)
|
| 346 |
+
scores_no_filter.append(0.0)
|
| 347 |
+
else:
|
| 348 |
+
documents_no_filter = []
|
| 349 |
+
scores_no_filter = []
|
| 350 |
+
|
| 351 |
+
if len(documents_no_filter) > len(documents):
|
| 352 |
+
print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters")
|
| 353 |
+
documents = documents_no_filter
|
| 354 |
+
scores = scores_no_filter
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(f"⚠️ RETRIEVAL: Fallback search failed: {e}")
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"❌ RETRIEVAL ERROR: {str(e)}")
|
| 360 |
+
return []
|
| 361 |
+
|
| 362 |
+
# Apply reranking if enabled
|
| 363 |
+
reranking_applied = False
|
| 364 |
+
if use_reranking and len(documents) > 1:
|
| 365 |
+
print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...")
|
| 366 |
+
try:
|
| 367 |
+
original_docs = documents.copy()
|
| 368 |
+
original_scores = scores.copy()
|
| 369 |
+
|
| 370 |
+
# Apply reranking
|
| 371 |
+
# print(f"🔍 ORIGINAL DOCS: {documents[0]}")
|
| 372 |
+
reranked_docs = self._apply_reranking(query, documents, scores)
|
| 373 |
+
# print(f"🔍 RERANKED DOCS: {reranked_docs[0]}")
|
| 374 |
+
reranking_applied = len(reranked_docs) > 0
|
| 375 |
+
|
| 376 |
+
if reranking_applied:
|
| 377 |
+
print(f"✅ RERANKING APPLIED: {self.reranker_model_name}")
|
| 378 |
+
documents = reranked_docs
|
| 379 |
+
# Update scores to reflect reranking
|
| 380 |
+
# scores = [0.0] * len(documents) # Reranked scores are not directly comparable
|
| 381 |
+
else:
|
| 382 |
+
print(f"⚠️ RERANKING FAILED: Using original order")
|
| 383 |
+
documents = original_docs
|
| 384 |
+
scores = original_scores
|
| 385 |
+
return documents
|
| 386 |
+
|
| 387 |
+
except Exception as e:
|
| 388 |
+
print(f"❌ RERANKING ERROR: {str(e)}")
|
| 389 |
+
print(f"⚠️ RERANKING FAILED: Using original order")
|
| 390 |
+
reranking_applied = False
|
| 391 |
+
elif use_reranking and len(documents) <= 1:
|
| 392 |
+
print(f"ℹ️ RERANKING: Skipped (only {len(documents)} document(s) retrieved)")
|
| 393 |
+
if use_reranking:
|
| 394 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
|
| 395 |
+
# Store original scores in metadata
|
| 396 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 397 |
+
doc.metadata['original_score'] = float(score)
|
| 398 |
+
doc.metadata['reranking_applied'] = False
|
| 399 |
+
return documents
|
| 400 |
+
else:
|
| 401 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or insufficient documents)")
|
| 402 |
+
|
| 403 |
+
# Limit to requested number of documents
|
| 404 |
+
documents = documents[:k]
|
| 405 |
+
scores = scores[:k] if scores else [0.0] * len(documents)
|
| 406 |
+
|
| 407 |
+
# Add metadata to documents
|
| 408 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 409 |
+
if hasattr(doc, 'metadata'):
|
| 410 |
+
doc.metadata.update({
|
| 411 |
+
'reranking_applied': reranking_applied,
|
| 412 |
+
'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None,
|
| 413 |
+
'original_rank': i + 1,
|
| 414 |
+
'final_rank': i + 1,
|
| 415 |
+
'original_score': float(score) if score is not None else 0.0
|
| 416 |
+
})
|
| 417 |
+
|
| 418 |
+
return documents
|
| 419 |
+
|
| 420 |
+
except Exception as e:
|
| 421 |
+
print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}")
|
| 422 |
+
return []
|
| 423 |
+
|
| 424 |
+
def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 425 |
+
"""
|
| 426 |
+
Apply reranking to documents using the appropriate reranker.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
query: User query
|
| 430 |
+
documents: List of documents to rerank
|
| 431 |
+
scores: Original scores
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
Reranked list of documents
|
| 435 |
+
"""
|
| 436 |
+
if not self.reranker or len(documents) == 0:
|
| 437 |
+
return documents
|
| 438 |
+
|
| 439 |
+
try:
|
| 440 |
+
print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents")
|
| 441 |
+
print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}")
|
| 442 |
+
|
| 443 |
+
if self.reranker_type == 'colbert':
|
| 444 |
+
return self._apply_colbert_reranking(query, documents, scores)
|
| 445 |
+
else:
|
| 446 |
+
return self._apply_crossencoder_reranking(query, documents, scores)
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"❌ RERANKING ERROR: {str(e)}")
|
| 450 |
+
return documents
|
| 451 |
+
|
| 452 |
+
def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 453 |
+
"""
|
| 454 |
+
Apply reranking using CrossEncoder (BGE and other models).
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
query: User query
|
| 458 |
+
documents: List of documents to rerank
|
| 459 |
+
scores: Original scores
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Reranked list of documents
|
| 463 |
+
"""
|
| 464 |
+
# Prepare pairs for reranking
|
| 465 |
+
pairs = []
|
| 466 |
+
for doc in documents:
|
| 467 |
+
pairs.append([query, doc.page_content])
|
| 468 |
+
|
| 469 |
+
print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking")
|
| 470 |
+
|
| 471 |
+
# Get reranking scores using the correct CrossEncoder API
|
| 472 |
+
rerank_scores = self.reranker.predict(pairs)
|
| 473 |
+
|
| 474 |
+
# Handle single score case
|
| 475 |
+
if not isinstance(rerank_scores, (list, np.ndarray)):
|
| 476 |
+
rerank_scores = [rerank_scores]
|
| 477 |
+
|
| 478 |
+
# Ensure we have the right number of scores
|
| 479 |
+
if len(rerank_scores) != len(documents):
|
| 480 |
+
print(f"⚠️ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}")
|
| 481 |
+
return documents
|
| 482 |
+
|
| 483 |
+
print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores")
|
| 484 |
+
print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") # Show first 5 scores
|
| 485 |
+
|
| 486 |
+
# Combine documents with their rerank scores
|
| 487 |
+
doc_scores = list(zip(documents, rerank_scores))
|
| 488 |
+
|
| 489 |
+
# Sort by rerank score (descending)
|
| 490 |
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
| 491 |
+
|
| 492 |
+
# Extract reranked documents and store scores in metadata
|
| 493 |
+
reranked_docs = []
|
| 494 |
+
for i, (doc, rerank_score) in enumerate(doc_scores):
|
| 495 |
+
# Find original index for original score
|
| 496 |
+
original_idx = documents.index(doc)
|
| 497 |
+
original_score = scores[original_idx] if original_idx < len(scores) else 0.0
|
| 498 |
+
|
| 499 |
+
# Create new document with reranking metadata
|
| 500 |
+
new_doc = Document(
|
| 501 |
+
page_content=doc.page_content,
|
| 502 |
+
metadata={
|
| 503 |
+
**doc.metadata,
|
| 504 |
+
'reranking_applied': True,
|
| 505 |
+
'reranker_model': self.reranker_model_name,
|
| 506 |
+
'reranker_type': self.reranker_type,
|
| 507 |
+
'original_rank': original_idx + 1,
|
| 508 |
+
'final_rank': i + 1,
|
| 509 |
+
'original_score': float(original_score),
|
| 510 |
+
'reranked_score': float(rerank_score)
|
| 511 |
+
}
|
| 512 |
+
)
|
| 513 |
+
reranked_docs.append(new_doc)
|
| 514 |
+
|
| 515 |
+
print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents")
|
| 516 |
+
|
| 517 |
+
return reranked_docs
|
| 518 |
+
|
| 519 |
+
def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 520 |
+
"""
|
| 521 |
+
Apply reranking using ColBERT late interaction.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
query: User query
|
| 525 |
+
documents: List of documents to rerank
|
| 526 |
+
scores: Original scores
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
Reranked list of documents
|
| 530 |
+
"""
|
| 531 |
+
# Use the actual ColBERT reranking implementation
|
| 532 |
+
return self._colbert_rerank(query, documents, scores)
|
| 533 |
+
|
| 534 |
+
def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]:
|
| 535 |
+
"""
|
| 536 |
+
ColBERT reranking using late interaction with pre-calculated embeddings support.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
query: User query
|
| 540 |
+
documents: List of documents to rerank
|
| 541 |
+
scores: Original scores
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
Reranked list of documents
|
| 545 |
+
"""
|
| 546 |
+
try:
|
| 547 |
+
print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents")
|
| 548 |
+
|
| 549 |
+
# Check if documents have pre-calculated ColBERT embeddings
|
| 550 |
+
pre_calculated_embeddings = []
|
| 551 |
+
documents_without_embeddings = []
|
| 552 |
+
documents_without_indices = []
|
| 553 |
+
|
| 554 |
+
for i, doc in enumerate(documents):
|
| 555 |
+
if (hasattr(doc, 'metadata') and
|
| 556 |
+
'colbert_embedding' in doc.metadata and
|
| 557 |
+
doc.metadata['colbert_embedding'] is not None):
|
| 558 |
+
# Use pre-calculated embedding
|
| 559 |
+
colbert_embedding = doc.metadata['colbert_embedding']
|
| 560 |
+
if isinstance(colbert_embedding, list):
|
| 561 |
+
colbert_embedding = torch.tensor(colbert_embedding)
|
| 562 |
+
pre_calculated_embeddings.append(colbert_embedding)
|
| 563 |
+
else:
|
| 564 |
+
# Need to calculate embedding
|
| 565 |
+
documents_without_embeddings.append(doc)
|
| 566 |
+
documents_without_indices.append(i)
|
| 567 |
+
|
| 568 |
+
# Calculate query embedding
|
| 569 |
+
query_embeddings = self.colbert_checkpoint.queryFromText([query])
|
| 570 |
+
|
| 571 |
+
# Calculate embeddings for documents without pre-calculated ones
|
| 572 |
+
if documents_without_embeddings:
|
| 573 |
+
print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings")
|
| 574 |
+
doc_texts = [doc.page_content for doc in documents_without_embeddings]
|
| 575 |
+
doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts)
|
| 576 |
+
|
| 577 |
+
# Insert calculated embeddings into the right positions
|
| 578 |
+
for i, embedding in enumerate(doc_embeddings):
|
| 579 |
+
idx = documents_without_indices[i]
|
| 580 |
+
pre_calculated_embeddings.insert(idx, embedding)
|
| 581 |
+
else:
|
| 582 |
+
print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents")
|
| 583 |
+
|
| 584 |
+
# Calculate late interaction scores
|
| 585 |
+
# ColBERT uses MaxSim: for each query token, find max similarity with document tokens
|
| 586 |
+
colbert_scores = []
|
| 587 |
+
for i, doc_embedding in enumerate(pre_calculated_embeddings):
|
| 588 |
+
# Calculate similarity matrix between query and document i
|
| 589 |
+
sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2))
|
| 590 |
+
|
| 591 |
+
# MaxSim: for each query token, take max similarity with document
|
| 592 |
+
max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0]
|
| 593 |
+
|
| 594 |
+
# Sum over query tokens to get final score
|
| 595 |
+
final_score = torch.sum(max_sim_per_query_token).item()
|
| 596 |
+
colbert_scores.append(final_score)
|
| 597 |
+
|
| 598 |
+
# Sort documents by ColBERT scores
|
| 599 |
+
doc_scores = list(zip(documents, colbert_scores))
|
| 600 |
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
| 601 |
+
|
| 602 |
+
# Create reranked documents with metadata
|
| 603 |
+
reranked_docs = []
|
| 604 |
+
for i, (doc, colbert_score) in enumerate(doc_scores):
|
| 605 |
+
original_idx = documents.index(doc)
|
| 606 |
+
original_score = scores[original_idx] if original_idx < len(scores) else 0.0
|
| 607 |
+
|
| 608 |
+
new_doc = Document(
|
| 609 |
+
page_content=doc.page_content,
|
| 610 |
+
metadata={
|
| 611 |
+
**doc.metadata,
|
| 612 |
+
'reranking_applied': True,
|
| 613 |
+
'reranker_model': self.reranker_model_name,
|
| 614 |
+
'reranker_type': self.reranker_type,
|
| 615 |
+
'original_rank': original_idx + 1,
|
| 616 |
+
'final_rank': i + 1,
|
| 617 |
+
'original_score': float(original_score),
|
| 618 |
+
'reranked_score': float(colbert_score),
|
| 619 |
+
'colbert_score': float(colbert_score),
|
| 620 |
+
'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata
|
| 621 |
+
}
|
| 622 |
+
)
|
| 623 |
+
reranked_docs.append(new_doc)
|
| 624 |
+
|
| 625 |
+
print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction")
|
| 626 |
+
print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...")
|
| 627 |
+
|
| 628 |
+
return reranked_docs
|
| 629 |
+
|
| 630 |
+
except Exception as e:
|
| 631 |
+
print(f"❌ COLBERT RERANKING ERROR: {str(e)}")
|
| 632 |
+
print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}")
|
| 633 |
+
# Fallback to original order - return documents as-is
|
| 634 |
+
return documents
|
| 635 |
+
|
| 636 |
+
def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None,
|
| 637 |
+
sources: List[str] = None, subtype: List[str] = None,
|
| 638 |
+
year: List[str] = None, use_reranking: bool = False,
|
| 639 |
+
qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]:
|
| 640 |
+
"""
|
| 641 |
+
Retrieve context documents with scores using hybrid search with optional reranking.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
query: User query
|
| 645 |
+
vectorstore: Optional vectorstore instance (for compatibility)
|
| 646 |
+
k: Number of documents to retrieve
|
| 647 |
+
reports: List of report names to filter by
|
| 648 |
+
sources: List of sources to filter by
|
| 649 |
+
subtype: Document subtype to filter by
|
| 650 |
+
year: List of years to filter by
|
| 651 |
+
use_reranking: Whether to apply reranking
|
| 652 |
+
qdrant_filter: Pre-built Qdrant filter
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
Tuple of (documents, scores)
|
| 656 |
+
"""
|
| 657 |
+
try:
|
| 658 |
+
# Use the provided vectorstore if available, otherwise use the instance one
|
| 659 |
+
if vectorstore:
|
| 660 |
+
self.vectorstore = vectorstore
|
| 661 |
+
|
| 662 |
+
# Determine search strategy
|
| 663 |
+
search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only')
|
| 664 |
+
|
| 665 |
+
if search_strategy == 'vector_only':
|
| 666 |
+
# Vector search only
|
| 667 |
+
print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...")
|
| 668 |
+
|
| 669 |
+
if qdrant_filter:
|
| 670 |
+
print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter")
|
| 671 |
+
# Pass filter as positional argument, not keyword argument
|
| 672 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 673 |
+
query,
|
| 674 |
+
k=k,
|
| 675 |
+
filter=qdrant_filter
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
# Build filter from individual parameters
|
| 679 |
+
filter_conditions = self._build_filter_conditions(reports, sources, subtype, year)
|
| 680 |
+
if filter_conditions:
|
| 681 |
+
print(f"✅ FILTER APPLIED: {filter_conditions}")
|
| 682 |
+
results = self.vectorstore.similarity_search_with_score(
|
| 683 |
+
query,
|
| 684 |
+
k=k,
|
| 685 |
+
filter=filter_conditions
|
| 686 |
+
)
|
| 687 |
+
else:
|
| 688 |
+
print(f"ℹ️ NO FILTERS APPLIED: All documents will be searched")
|
| 689 |
+
results = self.vectorstore.similarity_search_with_score(query, k=k)
|
| 690 |
+
|
| 691 |
+
print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}")
|
| 692 |
+
print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}")
|
| 693 |
+
|
| 694 |
+
# Handle different result formats
|
| 695 |
+
if results and isinstance(results[0], tuple):
|
| 696 |
+
documents = [doc for doc, score in results]
|
| 697 |
+
scores = [score for doc, score in results]
|
| 698 |
+
print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}")
|
| 699 |
+
else:
|
| 700 |
+
documents = results
|
| 701 |
+
scores = [0.0] * len(documents)
|
| 702 |
+
print(f"🔍 SEARCH DEBUG: No scores available, using default")
|
| 703 |
+
|
| 704 |
+
print(f"🔧 CONVERTING: Converting {len(documents)} documents")
|
| 705 |
+
|
| 706 |
+
# Convert to Document objects and store original scores
|
| 707 |
+
final_documents = []
|
| 708 |
+
for i, (doc, score) in enumerate(zip(documents, scores)):
|
| 709 |
+
if hasattr(doc, 'page_content'):
|
| 710 |
+
new_doc = Document(
|
| 711 |
+
page_content=doc.page_content,
|
| 712 |
+
metadata=doc.metadata.copy()
|
| 713 |
+
)
|
| 714 |
+
# Store original score in metadata
|
| 715 |
+
new_doc.metadata['original_score'] = float(score) if score is not None else 0.0
|
| 716 |
+
final_documents.append(new_doc)
|
| 717 |
+
else:
|
| 718 |
+
print(f"⚠️ WARNING: Document {i} has no page_content")
|
| 719 |
+
|
| 720 |
+
print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents")
|
| 721 |
+
|
| 722 |
+
# Apply reranking if enabled
|
| 723 |
+
if use_reranking and len(final_documents) > 1:
|
| 724 |
+
print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...")
|
| 725 |
+
final_documents = self._apply_reranking(query, final_documents, scores)
|
| 726 |
+
print(f"✅ RERANKING APPLIED: {self.reranker_model}")
|
| 727 |
+
else:
|
| 728 |
+
print(f"ℹ️ RERANKING: Skipped (disabled or no documents)")
|
| 729 |
+
|
| 730 |
+
return final_documents, scores
|
| 731 |
+
|
| 732 |
+
else:
|
| 733 |
+
print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}")
|
| 734 |
+
return [], []
|
| 735 |
+
|
| 736 |
+
except Exception as e:
|
| 737 |
+
print(f"❌ RETRIEVAL ERROR: {e}")
|
| 738 |
+
print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}")
|
| 739 |
+
return [], []
|
| 740 |
+
|
| 741 |
+
def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None,
|
| 742 |
+
subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]:
|
| 743 |
+
"""
|
| 744 |
+
Build Qdrant filter conditions from individual parameters.
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
reports: List of report names
|
| 748 |
+
sources: List of sources
|
| 749 |
+
subtype: Document subtype
|
| 750 |
+
year: List of years
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
Qdrant filter or None
|
| 754 |
+
"""
|
| 755 |
+
conditions = []
|
| 756 |
+
|
| 757 |
+
if reports:
|
| 758 |
+
conditions.append(rest.FieldCondition(
|
| 759 |
+
key="metadata.filename",
|
| 760 |
+
match=rest.MatchAny(any=reports)
|
| 761 |
+
))
|
| 762 |
+
|
| 763 |
+
if sources:
|
| 764 |
+
conditions.append(rest.FieldCondition(
|
| 765 |
+
key="metadata.source",
|
| 766 |
+
match=rest.MatchAny(any=sources)
|
| 767 |
+
))
|
| 768 |
+
|
| 769 |
+
if subtype:
|
| 770 |
+
conditions.append(rest.FieldCondition(
|
| 771 |
+
key="metadata.subtype",
|
| 772 |
+
match=rest.MatchAny(any=subtype)
|
| 773 |
+
))
|
| 774 |
+
|
| 775 |
+
if year:
|
| 776 |
+
conditions.append(rest.FieldCondition(
|
| 777 |
+
key="metadata.year",
|
| 778 |
+
match=rest.MatchAny(any=year)
|
| 779 |
+
))
|
| 780 |
+
|
| 781 |
+
if conditions:
|
| 782 |
+
return rest.Filter(must=conditions)
|
| 783 |
+
|
| 784 |
+
return None
|
| 785 |
+
|
| 786 |
+
def get_context(
|
| 787 |
+
query: str,
|
| 788 |
+
vectorstore: Qdrant,
|
| 789 |
+
k: int = 5,
|
| 790 |
+
reports: Optional[List[str]] = None,
|
| 791 |
+
sources: Optional[List[str]] = None,
|
| 792 |
+
subtype: Optional[str] = None,
|
| 793 |
+
year: Optional[str] = None,
|
| 794 |
+
use_reranking: bool = False,
|
| 795 |
+
qdrant_filter: Optional[rest.Filter] = None
|
| 796 |
+
) -> List[Document]:
|
| 797 |
+
"""
|
| 798 |
+
Convenience function to get context documents.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
query: User query
|
| 802 |
+
vectorstore: Qdrant vector store instance
|
| 803 |
+
k: Number of documents to retrieve
|
| 804 |
+
reports: Optional list of report names to filter by
|
| 805 |
+
sources: Optional list of source categories to filter by
|
| 806 |
+
subtype: Optional subtype to filter by
|
| 807 |
+
year: Optional year to filter by
|
| 808 |
+
use_reranking: Whether to apply reranking
|
| 809 |
+
qdrant_filter: Optional pre-built Qdrant filter
|
| 810 |
+
|
| 811 |
+
Returns:
|
| 812 |
+
List of retrieved documents
|
| 813 |
+
"""
|
| 814 |
+
retriever = ContextRetriever(vectorstore)
|
| 815 |
+
return retriever.retrieve_context(
|
| 816 |
+
query=query,
|
| 817 |
+
k=k,
|
| 818 |
+
reports=reports,
|
| 819 |
+
sources=sources,
|
| 820 |
+
subtype=subtype,
|
| 821 |
+
year=year,
|
| 822 |
+
use_reranking=use_reranking,
|
| 823 |
+
qdrant_filter=qdrant_filter
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
def format_context_for_llm(documents: List[Document]) -> str:
|
| 828 |
+
"""
|
| 829 |
+
Format retrieved documents for LLM input.
|
| 830 |
+
|
| 831 |
+
Args:
|
| 832 |
+
documents: List of Document objects
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
Formatted string for LLM
|
| 836 |
+
"""
|
| 837 |
+
if not documents:
|
| 838 |
+
return ""
|
| 839 |
+
|
| 840 |
+
formatted_parts = []
|
| 841 |
+
for i, doc in enumerate(documents, 1):
|
| 842 |
+
content = doc.page_content.strip()
|
| 843 |
+
source = doc.metadata.get('filename', 'Unknown')
|
| 844 |
+
|
| 845 |
+
formatted_parts.append(f"Document {i} (Source: {source}):\n{content}")
|
| 846 |
+
|
| 847 |
+
return "\n\n".join(formatted_parts)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def get_context_metadata(documents: List[Document]) -> Dict[str, Any]:
|
| 851 |
+
"""
|
| 852 |
+
Extract metadata summary from retrieved documents.
|
| 853 |
+
|
| 854 |
+
Args:
|
| 855 |
+
documents: List of Document objects
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
Dictionary with metadata summary
|
| 859 |
+
"""
|
| 860 |
+
if not documents:
|
| 861 |
+
return {}
|
| 862 |
+
|
| 863 |
+
sources = set()
|
| 864 |
+
years = set()
|
| 865 |
+
doc_types = set()
|
| 866 |
+
|
| 867 |
+
for doc in documents:
|
| 868 |
+
metadata = doc.metadata
|
| 869 |
+
if 'filename' in metadata:
|
| 870 |
+
sources.add(metadata['filename'])
|
| 871 |
+
if 'year' in metadata:
|
| 872 |
+
years.add(metadata['year'])
|
| 873 |
+
if 'source' in metadata:
|
| 874 |
+
doc_types.add(metadata['source'])
|
| 875 |
+
|
| 876 |
+
return {
|
| 877 |
+
"num_documents": len(documents),
|
| 878 |
+
"sources": list(sources),
|
| 879 |
+
"years": list(years),
|
| 880 |
+
"document_types": list(doc_types)
|
| 881 |
+
}
|
src/retrieval/filter.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document filtering utilities for Qdrant vector store."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Union, Dict, Tuple, Any
|
| 4 |
+
from qdrant_client.http import models as rest
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FilterBuilder:
|
| 9 |
+
"""Builder class for creating Qdrant filters."""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.conditions = []
|
| 13 |
+
|
| 14 |
+
def add_source_filter(self, source: Union[str, List[str]]) -> 'FilterBuilder':
|
| 15 |
+
"""Add source filter condition."""
|
| 16 |
+
if source:
|
| 17 |
+
if isinstance(source, list):
|
| 18 |
+
condition = rest.FieldCondition(
|
| 19 |
+
key="metadata.source",
|
| 20 |
+
match=rest.MatchAny(any=source)
|
| 21 |
+
)
|
| 22 |
+
print(f"🔧 FilterBuilder: Added source filter for {source}")
|
| 23 |
+
else:
|
| 24 |
+
condition = rest.FieldCondition(
|
| 25 |
+
key="metadata.source",
|
| 26 |
+
match=rest.MatchValue(value=source)
|
| 27 |
+
)
|
| 28 |
+
print(f"🔧 FilterBuilder: Added source filter for '{source}'")
|
| 29 |
+
self.conditions.append(condition)
|
| 30 |
+
return self
|
| 31 |
+
|
| 32 |
+
def add_filename_filter(self, filenames: List[str]) -> 'FilterBuilder':
|
| 33 |
+
"""Add filename filter condition."""
|
| 34 |
+
if filenames:
|
| 35 |
+
condition = rest.FieldCondition(
|
| 36 |
+
key="metadata.filename",
|
| 37 |
+
match=rest.MatchAny(any=filenames)
|
| 38 |
+
)
|
| 39 |
+
self.conditions.append(condition)
|
| 40 |
+
print(f"🔧 FilterBuilder: Added filename filter for {filenames}")
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
def add_year_filter(self, years: List[str]) -> 'FilterBuilder':
|
| 44 |
+
"""Add year filter condition."""
|
| 45 |
+
if years:
|
| 46 |
+
condition = rest.FieldCondition(
|
| 47 |
+
key="metadata.year",
|
| 48 |
+
match=rest.MatchAny(any=years)
|
| 49 |
+
)
|
| 50 |
+
self.conditions.append(condition)
|
| 51 |
+
print(f"🔧 FilterBuilder: Added year filter for {years}")
|
| 52 |
+
return self
|
| 53 |
+
|
| 54 |
+
def add_district_filter(self, districts: List[str]) -> 'FilterBuilder':
|
| 55 |
+
"""Add district filter condition."""
|
| 56 |
+
if districts:
|
| 57 |
+
condition = rest.FieldCondition(
|
| 58 |
+
key="metadata.district",
|
| 59 |
+
match=rest.MatchAny(any=districts)
|
| 60 |
+
)
|
| 61 |
+
self.conditions.append(condition)
|
| 62 |
+
print(f"🔧 FilterBuilder: Added district filter for {districts}")
|
| 63 |
+
return self
|
| 64 |
+
|
| 65 |
+
def add_custom_filter(self, key: str, value: Union[str, List[str]]) -> 'FilterBuilder':
|
| 66 |
+
"""Add custom filter condition."""
|
| 67 |
+
if isinstance(value, list):
|
| 68 |
+
condition = rest.FieldCondition(
|
| 69 |
+
key=key,
|
| 70 |
+
match=rest.MatchAny(any=value)
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
condition = rest.FieldCondition(
|
| 74 |
+
key=key,
|
| 75 |
+
match=rest.MatchValue(value=value)
|
| 76 |
+
)
|
| 77 |
+
self.conditions.append(condition)
|
| 78 |
+
return self
|
| 79 |
+
|
| 80 |
+
def build(self) -> rest.Filter:
|
| 81 |
+
"""Build the final filter."""
|
| 82 |
+
if not self.conditions:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
return rest.Filter(must=self.conditions)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def create_filter(
|
| 89 |
+
reports: List[str] = None,
|
| 90 |
+
sources: Union[str, List[str]] = None,
|
| 91 |
+
subtype: List[str] = None,
|
| 92 |
+
year: List[str] = None,
|
| 93 |
+
district: List[str] = None,
|
| 94 |
+
filenames: List[str] = None
|
| 95 |
+
) -> rest.Filter:
|
| 96 |
+
"""
|
| 97 |
+
Create a search filter for Qdrant (legacy function for compatibility).
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
reports: List of specific report filenames
|
| 101 |
+
sources: Source category
|
| 102 |
+
subtype: List of subtypes/filenames
|
| 103 |
+
year: List of years
|
| 104 |
+
district: List of districts
|
| 105 |
+
filenames: List of specific filenames (mutually exclusive with other filters)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Qdrant Filter object
|
| 109 |
+
|
| 110 |
+
Note:
|
| 111 |
+
If filenames are provided, ONLY filename filtering is applied (mutually exclusive)
|
| 112 |
+
"""
|
| 113 |
+
builder = FilterBuilder()
|
| 114 |
+
|
| 115 |
+
# Check if filename filtering is requested (mutually exclusive)
|
| 116 |
+
# Both filenames and reports serve the same purpose (backward compatibility)
|
| 117 |
+
# Prefer filenames, fallback to reports for legacy support
|
| 118 |
+
target_filenames = filenames if filenames else reports
|
| 119 |
+
|
| 120 |
+
if target_filenames and len(target_filenames) > 0:
|
| 121 |
+
# ONLY apply filename filter, ignore all other filters
|
| 122 |
+
print(f"🔍 FILTER APPLIED: Filenames = {target_filenames} (mutually exclusive mode)")
|
| 123 |
+
builder.add_filename_filter(target_filenames)
|
| 124 |
+
else:
|
| 125 |
+
# Otherwise, filter by source and subtype
|
| 126 |
+
print(f"🔍 FILTER APPLIED: Sources = {sources}, Subtype = {subtype}, Year = {year}, District = {district}")
|
| 127 |
+
if sources:
|
| 128 |
+
print(f"✅ Adding source filter: metadata.source = '{sources}'")
|
| 129 |
+
builder.add_source_filter(sources)
|
| 130 |
+
if subtype:
|
| 131 |
+
print(f"✅ Adding subtype filter: metadata.filename IN {subtype}")
|
| 132 |
+
builder.add_filename_filter(subtype)
|
| 133 |
+
if year:
|
| 134 |
+
print(f"✅ Adding year filter: metadata.year IN {year}")
|
| 135 |
+
builder.add_year_filter(year)
|
| 136 |
+
|
| 137 |
+
if district:
|
| 138 |
+
print(f"✅ Adding district filter: metadata.district IN {district}")
|
| 139 |
+
builder.add_district_filter(district)
|
| 140 |
+
|
| 141 |
+
filter_obj = builder.build()
|
| 142 |
+
|
| 143 |
+
if filter_obj:
|
| 144 |
+
print(f"�� FINAL FILTER: {len(filter_obj.must)} condition(s) applied")
|
| 145 |
+
for i, condition in enumerate(filter_obj.must, 1):
|
| 146 |
+
print(f" Condition {i}: {condition.key} = {condition.match}")
|
| 147 |
+
else:
|
| 148 |
+
print("⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 149 |
+
|
| 150 |
+
return filter_obj
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def create_advanced_filter(
|
| 154 |
+
must_conditions: List[dict] = None,
|
| 155 |
+
should_conditions: List[dict] = None,
|
| 156 |
+
must_not_conditions: List[dict] = None
|
| 157 |
+
) -> rest.Filter:
|
| 158 |
+
"""
|
| 159 |
+
Create advanced filter with multiple condition types.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
must_conditions: Conditions that must match
|
| 163 |
+
should_conditions: Conditions that should match (OR logic)
|
| 164 |
+
must_not_conditions: Conditions that must not match
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Qdrant Filter object
|
| 168 |
+
"""
|
| 169 |
+
filter_dict = {}
|
| 170 |
+
|
| 171 |
+
if must_conditions:
|
| 172 |
+
filter_dict["must"] = [
|
| 173 |
+
_dict_to_field_condition(cond) for cond in must_conditions
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
if should_conditions:
|
| 177 |
+
filter_dict["should"] = [
|
| 178 |
+
_dict_to_field_condition(cond) for cond in should_conditions
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
if must_not_conditions:
|
| 182 |
+
filter_dict["must_not"] = [
|
| 183 |
+
_dict_to_field_condition(cond) for cond in must_not_conditions
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
if not filter_dict:
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
return rest.Filter(**filter_dict)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _dict_to_field_condition(condition_dict: dict) -> rest.FieldCondition:
|
| 193 |
+
"""Convert dictionary to FieldCondition."""
|
| 194 |
+
key = condition_dict["key"]
|
| 195 |
+
value = condition_dict["value"]
|
| 196 |
+
|
| 197 |
+
if isinstance(value, list):
|
| 198 |
+
match = rest.MatchAny(any=value)
|
| 199 |
+
else:
|
| 200 |
+
match = rest.MatchValue(value=value)
|
| 201 |
+
|
| 202 |
+
return rest.FieldCondition(key=key, match=match)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def validate_filter(filter_obj: rest.Filter) -> bool:
|
| 206 |
+
"""
|
| 207 |
+
Validate that a filter object is properly constructed.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
filter_obj: Qdrant Filter object
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
True if valid, raises ValueError if invalid
|
| 214 |
+
"""
|
| 215 |
+
if filter_obj is None:
|
| 216 |
+
return True
|
| 217 |
+
|
| 218 |
+
if not isinstance(filter_obj, rest.Filter):
|
| 219 |
+
raise ValueError("Filter must be a rest.Filter object")
|
| 220 |
+
|
| 221 |
+
# Check that at least one condition type is present
|
| 222 |
+
has_conditions = any([
|
| 223 |
+
hasattr(filter_obj, 'must') and filter_obj.must,
|
| 224 |
+
hasattr(filter_obj, 'should') and filter_obj.should,
|
| 225 |
+
hasattr(filter_obj, 'must_not') and filter_obj.must_not
|
| 226 |
+
])
|
| 227 |
+
|
| 228 |
+
if not has_conditions:
|
| 229 |
+
raise ValueError("Filter must have at least one condition")
|
| 230 |
+
|
| 231 |
+
return True
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def infer_filters_from_query(
|
| 235 |
+
query: str,
|
| 236 |
+
available_metadata: dict,
|
| 237 |
+
llm_client=None
|
| 238 |
+
) -> Tuple[rest.Filter, Union[dict, None]]:
|
| 239 |
+
"""
|
| 240 |
+
Automatically infer filters from a query using LLM analysis.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
query: User query to analyze
|
| 244 |
+
available_metadata: Available metadata values in the vectorstore
|
| 245 |
+
llm_client: LLM client for analysis (optional)
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Qdrant Filter object with inferred conditions
|
| 249 |
+
"""
|
| 250 |
+
print(f"�� AUTO-INFERRING FILTERS from query: '{query[:50]}...'")
|
| 251 |
+
|
| 252 |
+
# Check if LLM client is available
|
| 253 |
+
if not llm_client:
|
| 254 |
+
print(f"❌ LLM CLIENT MISSING: Cannot use LLM analysis, falling back to rule-based")
|
| 255 |
+
return _infer_filters_rule_based(query, available_metadata), None
|
| 256 |
+
|
| 257 |
+
# Extract available options
|
| 258 |
+
available_sources = available_metadata.get('sources', [])
|
| 259 |
+
available_years = available_metadata.get('years', [])
|
| 260 |
+
available_filenames = available_metadata.get('filenames', [])
|
| 261 |
+
|
| 262 |
+
print(f"📊 Available metadata: sources={len(available_sources)}, years={len(available_years)}, filenames={len(available_filenames)}")
|
| 263 |
+
|
| 264 |
+
# Try LLM analysis first
|
| 265 |
+
print(f" LLM ANALYSIS: Attempting LLM-based filter inference...")
|
| 266 |
+
llm_result = _analyze_query_with_llm(
|
| 267 |
+
query=query,
|
| 268 |
+
available_metadata=available_metadata,
|
| 269 |
+
llm_client=llm_client
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
if llm_result:
|
| 273 |
+
print(f"✅ LLM SUCCESS: LLM successfully inferred filters")
|
| 274 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 275 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(llm_result)
|
| 276 |
+
if qdrant_filter:
|
| 277 |
+
print(f"✅ QDRANT FILTER: Successfully built Qdrant filter")
|
| 278 |
+
# print(f"✅ INFERRED FILTERS: {qdrant_filter}")
|
| 279 |
+
return qdrant_filter, filter_summary
|
| 280 |
+
else:
|
| 281 |
+
print(f"❌ QDRANT FILTER: Failed to build Qdrant filter, trying rule-based fallback")
|
| 282 |
+
rule_based_result = _infer_filters_rule_based(query, available_metadata)
|
| 283 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 284 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
|
| 285 |
+
if qdrant_filter:
|
| 286 |
+
print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
|
| 287 |
+
return qdrant_filter, filter_summary
|
| 288 |
+
else:
|
| 289 |
+
print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
|
| 290 |
+
return None, None
|
| 291 |
+
else:
|
| 292 |
+
print(f"⚠️ LLM FAILED: LLM could not infer filters, trying rule-based fallback")
|
| 293 |
+
rule_based_result = _infer_filters_rule_based(query, available_metadata)
|
| 294 |
+
# Use the _build_qdrant_filter function to properly build the Qdrant filter
|
| 295 |
+
qdrant_filter, filter_summary = _build_qdrant_filter(rule_based_result)
|
| 296 |
+
if qdrant_filter:
|
| 297 |
+
print(f"✅ RULE-BASED QDRANT FILTER: Successfully built Qdrant filter")
|
| 298 |
+
return qdrant_filter, filter_summary
|
| 299 |
+
else:
|
| 300 |
+
print(f"❌ RULE-BASED QDRANT FILTER: Failed to build Qdrant filter")
|
| 301 |
+
return None, None
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _analyze_query_with_llm(
|
| 305 |
+
query: str,
|
| 306 |
+
available_metadata: Dict[str, List[str]],
|
| 307 |
+
llm_client=None
|
| 308 |
+
) -> dict:
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
- Filenames: {available_metadata.get('filenames', [])}
|
| 313 |
+
|
| 314 |
+
📁 FILENAME FILTERING (Use Sparingly):
|
| 315 |
+
- Only if specific filename explicitly mentioned
|
| 316 |
+
- Prefer source/subtype over filename
|
| 317 |
+
- Be very conservative
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
"filenames": ["filename1", "filename2"] or [],
|
| 321 |
+
- For filenames: Only use if you have high confidence and can identify specific files
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
"""
|
| 326 |
+
Use LLM to analyze query and infer appropriate filters.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
query: User query to analyze
|
| 330 |
+
available_metadata: Available metadata values in the vectorstore
|
| 331 |
+
llm_client: LLM client for analysis
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Dictionary with inferred filters or empty dict if failed
|
| 335 |
+
"""
|
| 336 |
+
if not llm_client:
|
| 337 |
+
print("❌ LLM CLIENT MISSING: Cannot analyze query without LLM client")
|
| 338 |
+
return {}
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
print(f" LLM ANALYSIS: Analyzing query with LLM...")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
"""
|
| 345 |
+
For example: "What is the expected ... in 2024" - this refference to a future statement, so retrieving documents for 2023, 2022 and 2021 can be relevant too
|
| 346 |
+
Another example: "What is the GDP increase now compared to 2022" - this is a relative statement, refferring to past data, so both Year 2022, and now - 2025 needs to be detected/marked
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
# Create prompt for LLM analysis
|
| 350 |
+
prompt = f"""
|
| 351 |
+
You are a filter inference system. Analyze this query and return ONLY a JSON object.
|
| 352 |
+
|
| 353 |
+
Query: "{query}"
|
| 354 |
+
|
| 355 |
+
Available metadata:
|
| 356 |
+
- Sources: {available_metadata.get('sources', [])}
|
| 357 |
+
- Years: {available_metadata.get('years', [])}
|
| 358 |
+
|
| 359 |
+
FILTER INFERENCE GUIDELINES:
|
| 360 |
+
|
| 361 |
+
YEAR FILTERING (Be VERY Conservative):
|
| 362 |
+
✅ INFER YEARS ONLY IF:
|
| 363 |
+
- Explicit 4-digit years: "2022", "2023", "2021"
|
| 364 |
+
- Clear relative terms: "last year", "this year", "recent", "current year" (for the context, now is 2025)
|
| 365 |
+
- Temporal context: "annual report 2022", "audit for 2023"
|
| 366 |
+
- Give multiple years for complex queries.
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
❌ DO NOT INFER YEARS FOR:
|
| 370 |
+
- Vague terms: "implementation", "activities", "costs", "challenges", "issues"
|
| 371 |
+
- General concepts: "PDM", "administrative", "budget", "staff"
|
| 372 |
+
- Process descriptions: "how were", "what challenges", "management of"
|
| 373 |
+
|
| 374 |
+
🏛️ SOURCE FILTERING (Context-Based):
|
| 375 |
+
- "Ministry, Department and Agency" → Central government, ministries, departments, PS/ST
|
| 376 |
+
- "Local Government" → Districts, municipalities, local authorities, DLG
|
| 377 |
+
- "Consolidated" → Annual consolidated reports, OAG reports
|
| 378 |
+
- "Thematic" → Special studies, thematic reports
|
| 379 |
+
|
| 380 |
+
�� SUBTYPE FILTERING (Document Type):
|
| 381 |
+
- "audit" → Audit reports, reviews, examinations
|
| 382 |
+
- "report" → General reports, annual reports
|
| 383 |
+
- "guidance" → Guidelines, directives, circulars
|
| 384 |
+
|
| 385 |
+
CONFIDENCE SCORING:
|
| 386 |
+
- 0.9-1.0: Crystal clear indicators (explicit years, specific sources)
|
| 387 |
+
- 0.7-0.8: Good indicators (relative years, clear context)
|
| 388 |
+
- 0.5-0.6: Moderate indicators (some context clues)
|
| 389 |
+
- 0.0-0.4: Low confidence (vague or unclear)
|
| 390 |
+
|
| 391 |
+
EXAMPLES:
|
| 392 |
+
✅ "What challenges arose in 2022?" → years: ["2022"], confidence: 1
|
| 393 |
+
✅ "How were administrative costs managed in our government?" → sources: ["Local Government"], confidence: 0.75
|
| 394 |
+
✅ "PDM implementation guidelines from last year" → years: ["2024"], confidence: 0.9
|
| 395 |
+
❌ "What issues arose with budget execution?" → NO FILTERS, confidence: 0.2
|
| 396 |
+
❌ "How were tools related to administrative costs?" → NO FILTERS, confidence: 0.1
|
| 397 |
+
|
| 398 |
+
RESPONSE FORMAT (JSON only):
|
| 399 |
+
{{
|
| 400 |
+
"years": ["2022", "2023"] or [],
|
| 401 |
+
"sources": ["Ministry, Department and Agency", "Local Government"] or [],
|
| 402 |
+
"subtype": ["audit", "report"] or [],
|
| 403 |
+
"confidence": 0.8,
|
| 404 |
+
"reasoning": "Very brief explanation of filter choices"
|
| 405 |
+
}}
|
| 406 |
+
|
| 407 |
+
Rules:
|
| 408 |
+
- Use OR logic (SHOULD) for multiple values
|
| 409 |
+
- Prefer sources over filenames
|
| 410 |
+
- Only include years if clearly mentioned
|
| 411 |
+
- Return null for unclear fields
|
| 412 |
+
- For sources/subtypes: Include at least 3 candidates unless confidence is high and you can identify exactly one source (MUST)
|
| 413 |
+
- For years: If you want to include, then include at least 2 candidates unless confidence is high and you can identify exactly one year (MUST)
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
print(f"🔄 LLM CALL: Sending prompt to LLM...")
|
| 417 |
+
try:
|
| 418 |
+
# Try different methods to call the LLM
|
| 419 |
+
if hasattr(llm_client, 'invoke'):
|
| 420 |
+
response = llm_client.invoke(prompt)
|
| 421 |
+
elif hasattr(llm_client, 'generate'):
|
| 422 |
+
response = llm_client.generate([{"role": "user", "content": prompt}])
|
| 423 |
+
elif hasattr(llm_client, 'call'):
|
| 424 |
+
response = llm_client.call(prompt)
|
| 425 |
+
elif hasattr(llm_client, 'predict'):
|
| 426 |
+
response = llm_client.predict(prompt)
|
| 427 |
+
else:
|
| 428 |
+
# Try to call it directly
|
| 429 |
+
response = llm_client(prompt)
|
| 430 |
+
|
| 431 |
+
print(f"✅ LLM CALL SUCCESS: Received response from LLM")
|
| 432 |
+
|
| 433 |
+
# Extract content from response
|
| 434 |
+
if hasattr(response, 'content'):
|
| 435 |
+
response_content = response.content
|
| 436 |
+
elif hasattr(response, 'text'):
|
| 437 |
+
response_content = response.text
|
| 438 |
+
elif isinstance(response, str):
|
| 439 |
+
response_content = response
|
| 440 |
+
else:
|
| 441 |
+
response_content = str(response)
|
| 442 |
+
|
| 443 |
+
print(f"🔄 LLM RESPONSE: {response_content[:200]}...")
|
| 444 |
+
|
| 445 |
+
except Exception as e:
|
| 446 |
+
print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
|
| 447 |
+
return {}
|
| 448 |
+
|
| 449 |
+
# Parse JSON response
|
| 450 |
+
import json
|
| 451 |
+
import re
|
| 452 |
+
try:
|
| 453 |
+
print(f"🔄 JSON PARSING: Attempting to parse LLM response...")
|
| 454 |
+
|
| 455 |
+
# Clean the response to extract JSON from markdown
|
| 456 |
+
response_text = response_content.strip()
|
| 457 |
+
|
| 458 |
+
# Remove markdown formatting if present
|
| 459 |
+
if "```json" in response_text:
|
| 460 |
+
# Extract JSON from markdown code block
|
| 461 |
+
start_marker = "```json"
|
| 462 |
+
end_marker = "```"
|
| 463 |
+
start_idx = response_text.find(start_marker)
|
| 464 |
+
if start_idx != -1:
|
| 465 |
+
start_idx += len(start_marker)
|
| 466 |
+
end_idx = response_text.find(end_marker, start_idx)
|
| 467 |
+
if end_idx != -1:
|
| 468 |
+
response_text = response_text[start_idx:end_idx].strip()
|
| 469 |
+
|
| 470 |
+
# Try to find JSON object in the response
|
| 471 |
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
| 472 |
+
if json_match:
|
| 473 |
+
response_text = json_match.group(0)
|
| 474 |
+
|
| 475 |
+
print(f"🔄 JSON PARSING: Cleaned response: {response_text[:200]}...")
|
| 476 |
+
|
| 477 |
+
# Parse JSON
|
| 478 |
+
filters = json.loads(response_text)
|
| 479 |
+
print(f"✅ JSON PARSING SUCCESS: Parsed filters: {filters}")
|
| 480 |
+
|
| 481 |
+
# Validate filters
|
| 482 |
+
if not isinstance(filters, dict):
|
| 483 |
+
print(f"❌ JSON VALIDATION FAILED: Response is not a dictionary")
|
| 484 |
+
return {}
|
| 485 |
+
|
| 486 |
+
# Check if any filters were inferred
|
| 487 |
+
has_filters = any(filters.get(key) for key in ['sources', 'years', 'filenames'])
|
| 488 |
+
if not has_filters:
|
| 489 |
+
print(f"⚠️ QUERY DIFFICULT: LLM could not determine appropriate filters from query")
|
| 490 |
+
return {}
|
| 491 |
+
|
| 492 |
+
# print(f"✅ FILTER INFERENCE SUCCESS: Inferred filters: {filters}")
|
| 493 |
+
return filters
|
| 494 |
+
|
| 495 |
+
except json.JSONDecodeError as e:
|
| 496 |
+
print(f"❌ JSON PARSING FAILED: Invalid JSON format - {e}")
|
| 497 |
+
print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
|
| 498 |
+
return {}
|
| 499 |
+
except Exception as e:
|
| 500 |
+
print(f"❌ JSON PARSING FAILED: Unexpected error - {e}")
|
| 501 |
+
print(f"❌ JSON PARSING FAILED: Raw response: {response_text[:500]}...")
|
| 502 |
+
return {}
|
| 503 |
+
|
| 504 |
+
except Exception as e:
|
| 505 |
+
print(f"❌ LLM CALL FAILED: Error calling LLM - {e}")
|
| 506 |
+
return {}
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _infer_filters_rule_based(
|
| 510 |
+
query: str,
|
| 511 |
+
available_metadata: dict
|
| 512 |
+
) -> dict:
|
| 513 |
+
"""
|
| 514 |
+
Rule-based fallback for filter inference with improved logic.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
query: User query
|
| 518 |
+
available_metadata: Available metadata values in the vectorstore
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
Dictionary of inferred filters
|
| 522 |
+
"""
|
| 523 |
+
print(f" RULE-BASED ANALYSIS: Starting rule-based inference for query: '{query[:50]}...'")
|
| 524 |
+
|
| 525 |
+
inferred = {}
|
| 526 |
+
query_lower = query.lower()
|
| 527 |
+
|
| 528 |
+
# SEMANTIC SOURCE INFERENCE - Use semantic understanding
|
| 529 |
+
source_matches = []
|
| 530 |
+
|
| 531 |
+
# Define semantic mappings for better source inference
|
| 532 |
+
source_keywords = {
|
| 533 |
+
'consolidated': ['consolidated', 'annual', 'oag', 'auditor general', 'government', 'financial statements', 'budget', 'expenditure', 'revenue'],
|
| 534 |
+
'military': ['military', 'defence', 'defense', 'army', 'navy', 'air force', 'security', 'defense ministry'],
|
| 535 |
+
'departmental': ['department', 'ministry', 'agency', 'authority', 'commission', 'board', 'directorate'],
|
| 536 |
+
'thematic': ['thematic', 'sector', 'program', 'project', 'initiative', 'development', 'infrastructure']
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
for source in available_metadata.get('sources', []):
|
| 540 |
+
source_lower = source.lower()
|
| 541 |
+
|
| 542 |
+
# Direct keyword match
|
| 543 |
+
if source_lower in query_lower:
|
| 544 |
+
source_matches.append(source)
|
| 545 |
+
print(f"✅ DIRECT MATCH: Found direct keyword match for '{source}'")
|
| 546 |
+
else:
|
| 547 |
+
# Semantic keyword matching
|
| 548 |
+
if source_lower in source_keywords:
|
| 549 |
+
keywords = source_keywords[source_lower]
|
| 550 |
+
matches = sum(1 for keyword in keywords if keyword in query_lower)
|
| 551 |
+
if matches >= 2: # Require at least 2 keyword matches for semantic inference
|
| 552 |
+
source_matches.append(source)
|
| 553 |
+
print(f"✅ SEMANTIC MATCH: Found {matches} semantic keywords for '{source}': {[k for k in keywords if k in query_lower]}")
|
| 554 |
+
|
| 555 |
+
if source_matches:
|
| 556 |
+
# Use SHOULD (OR logic) for multiple sources
|
| 557 |
+
inferred['sources_should'] = source_matches
|
| 558 |
+
print(f"✅ SOURCE INFERENCE: Found {len(source_matches)} sources with OR logic: {source_matches}")
|
| 559 |
+
else:
|
| 560 |
+
print("❌ SOURCE INFERENCE: No source keywords found in query")
|
| 561 |
+
|
| 562 |
+
# Infer year filters - use SHOULD (OR logic) for multiple years
|
| 563 |
+
import re
|
| 564 |
+
year_matches = []
|
| 565 |
+
for year in available_metadata.get('years', []):
|
| 566 |
+
if year in query or f"'{year}" in query:
|
| 567 |
+
year_matches.append(year)
|
| 568 |
+
|
| 569 |
+
if year_matches:
|
| 570 |
+
# Use SHOULD (OR logic) for multiple years
|
| 571 |
+
inferred['years_should'] = year_matches
|
| 572 |
+
print(f"✅ YEAR INFERENCE: Found {len(year_matches)} years with OR logic: {year_matches}")
|
| 573 |
+
else:
|
| 574 |
+
print("❌ YEAR INFERENCE: No year references found in query")
|
| 575 |
+
|
| 576 |
+
# Only infer filename filters if no year filter was found (to avoid conflicts)
|
| 577 |
+
if not year_matches:
|
| 578 |
+
filename_matches = []
|
| 579 |
+
for filename in available_metadata.get('filenames', []):
|
| 580 |
+
# Only match if multiple words from filename appear in query
|
| 581 |
+
filename_words = filename.lower().split()
|
| 582 |
+
matches = sum(1 for word in filename_words if word in query_lower)
|
| 583 |
+
if matches >= 2: # High confidence threshold
|
| 584 |
+
filename_matches.append(filename)
|
| 585 |
+
|
| 586 |
+
if filename_matches:
|
| 587 |
+
# Use SHOULD (OR logic) for multiple filenames
|
| 588 |
+
inferred['filenames_should'] = filename_matches
|
| 589 |
+
print(f"✅ FILENAME INFERENCE: Found {len(filename_matches)} filenames with OR logic: {filename_matches}")
|
| 590 |
+
else:
|
| 591 |
+
print("❌ FILENAME INFERENCE: No high-confidence filename matches found")
|
| 592 |
+
else:
|
| 593 |
+
print("ℹ️ FILENAME INFERENCE: Skipped (year filter already applied to avoid conflicts)")
|
| 594 |
+
|
| 595 |
+
print(f" RULE-BASED RESULT: {inferred}")
|
| 596 |
+
return inferred
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def _validate_inferred_filters(inferred_filters: dict) -> dict:
|
| 600 |
+
"""
|
| 601 |
+
Validate and normalize inferred filters to ensure they're in the expected format.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
inferred_filters: Raw inferred filters dictionary
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
Validated and normalized filters dictionary
|
| 608 |
+
"""
|
| 609 |
+
if not isinstance(inferred_filters, dict):
|
| 610 |
+
print(f"⚠️ FILTER VALIDATION: Inferred filters is not a dict: {type(inferred_filters)}")
|
| 611 |
+
return {}
|
| 612 |
+
|
| 613 |
+
validated = {}
|
| 614 |
+
|
| 615 |
+
# Normalize field names and validate values
|
| 616 |
+
for field_name in ['sources', 'sources_should', 'years', 'years_should', 'filenames', 'filenames_should']:
|
| 617 |
+
if field_name in inferred_filters and inferred_filters[field_name]:
|
| 618 |
+
value = inferred_filters[field_name]
|
| 619 |
+
if isinstance(value, list) and len(value) > 0:
|
| 620 |
+
# Remove any None or empty string values
|
| 621 |
+
clean_value = [v for v in value if v is not None and str(v).strip()]
|
| 622 |
+
if clean_value:
|
| 623 |
+
validated[field_name] = clean_value
|
| 624 |
+
print(f"✅ FILTER VALIDATION: {field_name} = {clean_value}")
|
| 625 |
+
elif isinstance(value, str) and value.strip():
|
| 626 |
+
validated[field_name] = [value.strip()]
|
| 627 |
+
print(f"✅ FILTER VALIDATION: {field_name} = [{value.strip()}]")
|
| 628 |
+
|
| 629 |
+
return validated
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def _build_qdrant_filter(inferred_filters: dict) -> rest.Filter:
|
| 633 |
+
"""
|
| 634 |
+
Build Qdrant filter from inferred filters.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
inferred_filters: Dictionary with inferred filter values
|
| 638 |
+
|
| 639 |
+
Returns:
|
| 640 |
+
Qdrant Filter object
|
| 641 |
+
"""
|
| 642 |
+
try:
|
| 643 |
+
from qdrant_client.http import models as rest
|
| 644 |
+
|
| 645 |
+
# Validate and normalize the inferred filters first
|
| 646 |
+
validated_filters = _validate_inferred_filters(inferred_filters)
|
| 647 |
+
if not validated_filters:
|
| 648 |
+
print(f"⚠️ NO VALID FILTERS: All filters were invalid or empty")
|
| 649 |
+
return None, {}
|
| 650 |
+
|
| 651 |
+
conditions = []
|
| 652 |
+
filter_summary = {}
|
| 653 |
+
|
| 654 |
+
# Handle sources (use OR logic for multiple values)
|
| 655 |
+
# Support both 'sources' and 'sources_should' field names
|
| 656 |
+
source_values = None
|
| 657 |
+
if 'sources' in validated_filters and validated_filters['sources']:
|
| 658 |
+
source_values = validated_filters['sources']
|
| 659 |
+
elif 'sources_should' in validated_filters and validated_filters['sources_should']:
|
| 660 |
+
source_values = validated_filters['sources_should']
|
| 661 |
+
|
| 662 |
+
if source_values and isinstance(source_values, list) and len(source_values) > 0:
|
| 663 |
+
if len(source_values) == 1:
|
| 664 |
+
conditions.append(rest.FieldCondition(
|
| 665 |
+
key="metadata.source",
|
| 666 |
+
match=rest.MatchValue(value=source_values[0])
|
| 667 |
+
))
|
| 668 |
+
else:
|
| 669 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 670 |
+
conditions.append(rest.FieldCondition(
|
| 671 |
+
key="metadata.source",
|
| 672 |
+
match=rest.MatchAny(any=source_values)
|
| 673 |
+
))
|
| 674 |
+
filter_summary['sources'] = f"SHOULD: {source_values}"
|
| 675 |
+
|
| 676 |
+
# Handle years (use OR logic for multiple values)
|
| 677 |
+
# Support both 'years' and 'years_should' field names
|
| 678 |
+
year_values = None
|
| 679 |
+
if 'years' in validated_filters and validated_filters['years']:
|
| 680 |
+
year_values = validated_filters['years']
|
| 681 |
+
elif 'years_should' in validated_filters and validated_filters['years_should']:
|
| 682 |
+
year_values = validated_filters['years_should']
|
| 683 |
+
|
| 684 |
+
if year_values and isinstance(year_values, list) and len(year_values) > 0:
|
| 685 |
+
if len(year_values) == 1:
|
| 686 |
+
conditions.append(rest.FieldCondition(
|
| 687 |
+
key="metadata.year",
|
| 688 |
+
match=rest.MatchValue(value=year_values[0])
|
| 689 |
+
))
|
| 690 |
+
else:
|
| 691 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 692 |
+
conditions.append(rest.FieldCondition(
|
| 693 |
+
key="metadata.year",
|
| 694 |
+
match=rest.MatchAny(any=year_values)
|
| 695 |
+
))
|
| 696 |
+
filter_summary['years'] = f"SHOULD: {year_values}"
|
| 697 |
+
|
| 698 |
+
# Handle filenames (use OR logic for multiple values)
|
| 699 |
+
# Support both 'filenames' and 'filenames_should' field names
|
| 700 |
+
filename_values = None
|
| 701 |
+
if 'filenames' in validated_filters and validated_filters['filenames']:
|
| 702 |
+
filename_values = validated_filters['filenames']
|
| 703 |
+
elif 'filenames_should' in validated_filters and validated_filters['filenames_should']:
|
| 704 |
+
filename_values = validated_filters['filenames_should']
|
| 705 |
+
|
| 706 |
+
if filename_values and isinstance(filename_values, list) and len(filename_values) > 0:
|
| 707 |
+
if len(filename_values) == 1:
|
| 708 |
+
conditions.append(rest.FieldCondition(
|
| 709 |
+
key="metadata.filename",
|
| 710 |
+
match=rest.MatchValue(value=filename_values[0])
|
| 711 |
+
))
|
| 712 |
+
else:
|
| 713 |
+
# Use MatchAny instead of Filter(should=...) to avoid QueryPoints error
|
| 714 |
+
conditions.append(rest.FieldCondition(
|
| 715 |
+
key="metadata.filename",
|
| 716 |
+
match=rest.MatchAny(any=filename_values)
|
| 717 |
+
))
|
| 718 |
+
filter_summary['filenames'] = f"SHOULD: {filename_values}"
|
| 719 |
+
|
| 720 |
+
# Build final filter
|
| 721 |
+
if conditions:
|
| 722 |
+
# Always wrap conditions in a Filter object, even for single conditions
|
| 723 |
+
result_filter = rest.Filter(must=conditions)
|
| 724 |
+
|
| 725 |
+
# Print clean filter summary
|
| 726 |
+
print(f"✅ APPLIED FILTERS: {filter_summary}")
|
| 727 |
+
return result_filter, filter_summary
|
| 728 |
+
else:
|
| 729 |
+
print(f"⚠️ NO FILTERS APPLIED: All documents will be searched")
|
| 730 |
+
return None, {}
|
| 731 |
+
|
| 732 |
+
except Exception as e:
|
| 733 |
+
print(f"❌ FILTER BUILD ERROR: {str(e)}")
|
| 734 |
+
print(f"🔍 DEBUG: Original inferred filters keys: {list(inferred_filters.keys()) if isinstance(inferred_filters, dict) else 'Not a dict'}")
|
| 735 |
+
print(f"🔍 DEBUG: Original inferred filters content: {inferred_filters}")
|
| 736 |
+
print(f"🔍 DEBUG: Validated filters keys: {list(validated_filters.keys()) if isinstance(validated_filters, dict) else 'Not a dict'}")
|
| 737 |
+
print(f"🔍 DEBUG: Validated filters content: {validated_filters}")
|
| 738 |
+
# Return a safe fallback - no filter (search all documents)
|
| 739 |
+
return None, {}
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
class MetadataCache:
|
| 743 |
+
"""Cache for vectorstore metadata to avoid repeated queries."""
|
| 744 |
+
|
| 745 |
+
def __init__(self):
|
| 746 |
+
self._cache = None
|
| 747 |
+
self._last_updated = None
|
| 748 |
+
self._cache_ttl = 3600 # 1 hour TTL
|
| 749 |
+
|
| 750 |
+
def get_metadata(self, vectorstore) -> dict:
|
| 751 |
+
"""
|
| 752 |
+
Get metadata from cache or load it if not available/expired.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
vectorstore: QdrantVectorStore instance
|
| 756 |
+
|
| 757 |
+
Returns:
|
| 758 |
+
Dictionary of available metadata values
|
| 759 |
+
"""
|
| 760 |
+
import time
|
| 761 |
+
|
| 762 |
+
# Check if cache is valid
|
| 763 |
+
if (self._cache is not None and
|
| 764 |
+
self._last_updated is not None and
|
| 765 |
+
time.time() - self._last_updated < self._cache_ttl):
|
| 766 |
+
print(f"✅ METADATA CACHE: Using cached metadata")
|
| 767 |
+
return self._cache
|
| 768 |
+
|
| 769 |
+
try:
|
| 770 |
+
print(f"🔄 METADATA CACHE: Loading metadata from vectorstore...")
|
| 771 |
+
|
| 772 |
+
# Get collection info
|
| 773 |
+
try:
|
| 774 |
+
collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
|
| 775 |
+
print(f"✅ Collection info retrieved: {getattr(collection_info, 'name', 'unknown')}")
|
| 776 |
+
except Exception as e:
|
| 777 |
+
print(f"⚠️ Could not get collection info: {e}")
|
| 778 |
+
|
| 779 |
+
# Get ALL documents to extract complete metadata
|
| 780 |
+
print(f"📄 Scanning entire corpus for complete metadata extraction...")
|
| 781 |
+
|
| 782 |
+
# Get collection info to determine total size
|
| 783 |
+
try:
|
| 784 |
+
collection_info = vectorstore._client.get_collection(vectorstore.collection_name)
|
| 785 |
+
total_points = getattr(collection_info, 'points_count', 0)
|
| 786 |
+
print(f"📊 Total documents in corpus: {total_points}")
|
| 787 |
+
except Exception as e:
|
| 788 |
+
print(f"⚠️ Could not get collection size: {e}")
|
| 789 |
+
total_points = 0
|
| 790 |
+
|
| 791 |
+
# Extract unique metadata values from ALL documents
|
| 792 |
+
sources = set()
|
| 793 |
+
years = set()
|
| 794 |
+
filenames = set()
|
| 795 |
+
|
| 796 |
+
# Try to use scroll to get all documents in batches
|
| 797 |
+
batch_size = 1000 # Process in batches to avoid memory issues
|
| 798 |
+
offset = None
|
| 799 |
+
processed_count = 0
|
| 800 |
+
scroll_success = False
|
| 801 |
+
|
| 802 |
+
try:
|
| 803 |
+
while True:
|
| 804 |
+
# Scroll through all documents
|
| 805 |
+
scroll_result = vectorstore._client.scroll(
|
| 806 |
+
collection_name=vectorstore.collection_name,
|
| 807 |
+
limit=batch_size,
|
| 808 |
+
offset=offset,
|
| 809 |
+
with_payload=True,
|
| 810 |
+
with_vectors=False # We only need metadata
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
points = scroll_result[0] # Get the points
|
| 814 |
+
if not points:
|
| 815 |
+
break # No more documents
|
| 816 |
+
|
| 817 |
+
# Process each document
|
| 818 |
+
for i, point in enumerate(points):
|
| 819 |
+
if hasattr(point, 'payload') and point.payload:
|
| 820 |
+
payload = point.payload
|
| 821 |
+
|
| 822 |
+
# Debug: Log structure of first few documents
|
| 823 |
+
if processed_count + i < 2: # Only log first 2 documents
|
| 824 |
+
print(f"🔍 DEBUG Document {processed_count + i + 1} payload structure:")
|
| 825 |
+
print(f" Payload keys: {list(payload.keys()) if isinstance(payload, dict) else 'Not a dict'}")
|
| 826 |
+
if isinstance(payload, dict) and 'metadata' in payload:
|
| 827 |
+
print(f" Metadata keys: {list(payload['metadata'].keys()) if isinstance(payload['metadata'], dict) else 'Not a dict'}")
|
| 828 |
+
elif isinstance(payload, dict):
|
| 829 |
+
print(f" Top-level keys: {list(payload.keys())}")
|
| 830 |
+
print(f" Payload type: {type(payload)}")
|
| 831 |
+
print(f" Payload sample: {str(payload)[:200]}...")
|
| 832 |
+
print()
|
| 833 |
+
|
| 834 |
+
# Try different metadata structures
|
| 835 |
+
found_metadata = False
|
| 836 |
+
|
| 837 |
+
# Structure 1: payload['metadata']['source']
|
| 838 |
+
if isinstance(payload, dict) and 'metadata' in payload:
|
| 839 |
+
metadata = payload['metadata']
|
| 840 |
+
if isinstance(metadata, dict):
|
| 841 |
+
if 'source' in metadata:
|
| 842 |
+
sources.add(metadata['source'])
|
| 843 |
+
found_metadata = True
|
| 844 |
+
if 'year' in metadata:
|
| 845 |
+
years.add(metadata['year'])
|
| 846 |
+
found_metadata = True
|
| 847 |
+
if 'filename' in metadata:
|
| 848 |
+
filenames.add(metadata['filename'])
|
| 849 |
+
found_metadata = True
|
| 850 |
+
|
| 851 |
+
# Structure 2: payload['source'] (direct)
|
| 852 |
+
if isinstance(payload, dict):
|
| 853 |
+
if 'source' in payload:
|
| 854 |
+
sources.add(payload['source'])
|
| 855 |
+
found_metadata = True
|
| 856 |
+
if 'year' in payload:
|
| 857 |
+
years.add(payload['year'])
|
| 858 |
+
found_metadata = True
|
| 859 |
+
if 'filename' in payload:
|
| 860 |
+
filenames.add(payload['filename'])
|
| 861 |
+
found_metadata = True
|
| 862 |
+
|
| 863 |
+
# Structure 3: Check for nested structures
|
| 864 |
+
if not found_metadata and isinstance(payload, dict):
|
| 865 |
+
# Look for any nested dict that might contain metadata
|
| 866 |
+
for key, value in payload.items():
|
| 867 |
+
if isinstance(value, dict):
|
| 868 |
+
if 'source' in value:
|
| 869 |
+
sources.add(value['source'])
|
| 870 |
+
found_metadata = True
|
| 871 |
+
if 'year' in value:
|
| 872 |
+
years.add(value['year'])
|
| 873 |
+
found_metadata = True
|
| 874 |
+
if 'filename' in value:
|
| 875 |
+
filenames.add(value['filename'])
|
| 876 |
+
found_metadata = True
|
| 877 |
+
|
| 878 |
+
processed_count += len(points)
|
| 879 |
+
progress_pct = (processed_count / total_points * 100) if total_points > 0 else 0
|
| 880 |
+
print(f"📄 Processed {processed_count}/{total_points} documents ({progress_pct:.1f}%)... (sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
|
| 881 |
+
|
| 882 |
+
# Update offset for next batch
|
| 883 |
+
offset = scroll_result[1] # Next offset
|
| 884 |
+
if offset is None:
|
| 885 |
+
break # No more documents
|
| 886 |
+
|
| 887 |
+
scroll_success = True
|
| 888 |
+
print(f"✅ Scroll method successful - processed {processed_count} documents")
|
| 889 |
+
|
| 890 |
+
except Exception as e:
|
| 891 |
+
print(f"❌ Scroll method failed: {e}")
|
| 892 |
+
print(f"🔄 Falling back to similarity search method...")
|
| 893 |
+
|
| 894 |
+
# Fallback: Use similarity search with multiple queries to get more coverage
|
| 895 |
+
fallback_queries = [
|
| 896 |
+
"", # Empty query
|
| 897 |
+
"audit", "report", "government", "ministry", "department",
|
| 898 |
+
"local", "consolidated", "annual", "financial", "budget",
|
| 899 |
+
"2020", "2021", "2022", "2023", "2024" # Year queries
|
| 900 |
+
]
|
| 901 |
+
|
| 902 |
+
processed_count = 0
|
| 903 |
+
for query in fallback_queries:
|
| 904 |
+
try:
|
| 905 |
+
# Get documents for this query
|
| 906 |
+
docs = vectorstore.similarity_search(query, k=1000) # Get more per query
|
| 907 |
+
|
| 908 |
+
for j, doc in enumerate(docs):
|
| 909 |
+
if hasattr(doc, 'metadata') and doc.metadata:
|
| 910 |
+
# Debug: Log structure of first few documents in fallback
|
| 911 |
+
if processed_count + j < 3: # Only log first 3 documents per query
|
| 912 |
+
print(f"🔍 DEBUG Fallback Document {processed_count + j + 1} (query: '{query}') metadata structure:")
|
| 913 |
+
print(f" Metadata keys: {list(doc.metadata.keys()) if isinstance(doc.metadata, dict) else 'Not a dict'}")
|
| 914 |
+
print(f" Metadata type: {type(doc.metadata)}")
|
| 915 |
+
print(f" Metadata sample: {str(doc.metadata)[:200]}...")
|
| 916 |
+
print()
|
| 917 |
+
|
| 918 |
+
if 'source' in doc.metadata:
|
| 919 |
+
sources.add(doc.metadata['source'])
|
| 920 |
+
if 'year' in doc.metadata:
|
| 921 |
+
years.add(doc.metadata['year'])
|
| 922 |
+
if 'filename' in doc.metadata:
|
| 923 |
+
filenames.add(doc.metadata['filename'])
|
| 924 |
+
|
| 925 |
+
processed_count += len(docs)
|
| 926 |
+
print(f"📄 Fallback query '{query}': {len(docs)} docs (total: {processed_count}, sources: {len(sources)}, years: {len(years)}, filenames: {len(filenames)})")
|
| 927 |
+
|
| 928 |
+
except Exception as query_error:
|
| 929 |
+
print(f"⚠️ Fallback query '{query}' failed: {query_error}")
|
| 930 |
+
continue
|
| 931 |
+
|
| 932 |
+
print(f"✅ Fallback method completed - processed {processed_count} documents")
|
| 933 |
+
|
| 934 |
+
print(f"✅ Completed scanning {processed_count} documents from entire corpus")
|
| 935 |
+
|
| 936 |
+
# Convert to sorted lists
|
| 937 |
+
metadata = {
|
| 938 |
+
'sources': sorted(list(sources)),
|
| 939 |
+
'years': sorted(list(years)),
|
| 940 |
+
'filenames': sorted(list(filenames))
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
# Cache the results
|
| 944 |
+
self._cache = metadata
|
| 945 |
+
self._last_updated = time.time()
|
| 946 |
+
|
| 947 |
+
print(f"✅ Complete metadata extracted from entire corpus: {len(sources)} sources, {len(years)} years, {len(filenames)} files")
|
| 948 |
+
|
| 949 |
+
# Debug: Show what was actually found
|
| 950 |
+
if sources:
|
| 951 |
+
print(f"📁 Sources found: {sorted(list(sources))}")
|
| 952 |
+
else:
|
| 953 |
+
print(f"❌ No sources found - check metadata structure")
|
| 954 |
+
|
| 955 |
+
if years:
|
| 956 |
+
print(f"📅 Years found: {sorted(list(years))}")
|
| 957 |
+
else:
|
| 958 |
+
print(f"❌ No years found - check metadata structure")
|
| 959 |
+
|
| 960 |
+
if filenames:
|
| 961 |
+
print(f"📄 Filenames found: {sorted(list(filenames))[:10]}{'...' if len(filenames) > 10 else ''}")
|
| 962 |
+
else:
|
| 963 |
+
print(f"❌ No filenames found - check metadata structure")
|
| 964 |
+
return metadata
|
| 965 |
+
|
| 966 |
+
except Exception as e:
|
| 967 |
+
print(f"❌ Error extracting metadata: {e}")
|
| 968 |
+
return {'sources': [], 'years': [], 'filenames': []}
|
| 969 |
+
|
| 970 |
+
# Global metadata cache
|
| 971 |
+
_metadata_cache = MetadataCache()
|
| 972 |
+
|
| 973 |
+
def get_available_metadata(vectorstore) -> dict:
|
| 974 |
+
"""Get available metadata values from the vectorstore efficiently."""
|
| 975 |
+
return _metadata_cache.get_metadata(vectorstore)
|
src/retrieval/hybrid.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hybrid search implementation combining vector and sparse retrieval."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Dict, Any, Tuple
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from langchain.docstore.document import Document
|
| 8 |
+
from langchain_qdrant import QdrantVectorStore
|
| 9 |
+
from langchain_community.retrievers import BM25Retriever
|
| 10 |
+
from .filter import create_filter
|
| 11 |
+
import pickle
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HybridRetriever:
|
| 16 |
+
"""
|
| 17 |
+
Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search.
|
| 18 |
+
Supports configurable search modes: vector_only, sparse_only, or hybrid.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: Dict[str, Any]):
|
| 22 |
+
"""
|
| 23 |
+
Initialize hybrid retriever.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
config: Configuration dictionary with hybrid search settings
|
| 27 |
+
"""
|
| 28 |
+
self.config = config
|
| 29 |
+
self.bm25_retriever = None
|
| 30 |
+
self.documents = []
|
| 31 |
+
self._bm25_cache_file = None
|
| 32 |
+
|
| 33 |
+
def _get_bm25_cache_path(self) -> str:
|
| 34 |
+
"""Get path for BM25 cache file."""
|
| 35 |
+
cache_dir = Path("cache/bm25")
|
| 36 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
return str(cache_dir / "bm25_retriever.pkl")
|
| 38 |
+
|
| 39 |
+
def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Initialize BM25 retriever with documents.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
documents: List of Document objects to index
|
| 45 |
+
force_rebuild: Whether to force rebuilding the BM25 index
|
| 46 |
+
"""
|
| 47 |
+
self.documents = documents
|
| 48 |
+
self._bm25_cache_file = self._get_bm25_cache_path()
|
| 49 |
+
|
| 50 |
+
# Try to load cached BM25 retriever
|
| 51 |
+
if not force_rebuild and os.path.exists(self._bm25_cache_file):
|
| 52 |
+
try:
|
| 53 |
+
print("Loading cached BM25 retriever...")
|
| 54 |
+
with open(self._bm25_cache_file, 'rb') as f:
|
| 55 |
+
self.bm25_retriever = pickle.load(f)
|
| 56 |
+
print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents")
|
| 57 |
+
return
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"⚠️ Failed to load cached BM25 retriever: {e}")
|
| 60 |
+
print("Building new BM25 index...")
|
| 61 |
+
|
| 62 |
+
# Build new BM25 retriever
|
| 63 |
+
print("Building BM25 index...")
|
| 64 |
+
try:
|
| 65 |
+
# Use langchain's BM25Retriever
|
| 66 |
+
self.bm25_retriever = BM25Retriever.from_documents(documents)
|
| 67 |
+
|
| 68 |
+
# Configure BM25 parameters
|
| 69 |
+
bm25_config = self.config.get("bm25", {})
|
| 70 |
+
k = bm25_config.get("top_k", 20)
|
| 71 |
+
self.bm25_retriever.k = k
|
| 72 |
+
|
| 73 |
+
# Cache the BM25 retriever
|
| 74 |
+
with open(self._bm25_cache_file, 'wb') as f:
|
| 75 |
+
pickle.dump(self.bm25_retriever, f)
|
| 76 |
+
print(f"✅ Built and cached BM25 retriever with {len(documents)} documents")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"❌ Failed to build BM25 retriever: {e}")
|
| 80 |
+
print("BM25 search will be disabled")
|
| 81 |
+
self.bm25_retriever = None
|
| 82 |
+
|
| 83 |
+
def _filter_documents_by_metadata(
|
| 84 |
+
self,
|
| 85 |
+
documents: List[Document],
|
| 86 |
+
reports: List[str] = None,
|
| 87 |
+
sources: str = None,
|
| 88 |
+
subtype: List[str] = None,
|
| 89 |
+
year: List[str] = None
|
| 90 |
+
) -> List[Document]:
|
| 91 |
+
"""
|
| 92 |
+
Filter documents by metadata criteria.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
documents: List of documents to filter
|
| 96 |
+
reports: List of specific report filenames
|
| 97 |
+
sources: Source category
|
| 98 |
+
subtype: List of subtypes
|
| 99 |
+
year: List of years
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Filtered list of documents
|
| 103 |
+
"""
|
| 104 |
+
if not any([reports, sources, subtype, year]):
|
| 105 |
+
return documents
|
| 106 |
+
|
| 107 |
+
filtered_docs = []
|
| 108 |
+
for doc in documents:
|
| 109 |
+
metadata = doc.metadata
|
| 110 |
+
|
| 111 |
+
# Filter by reports
|
| 112 |
+
if reports:
|
| 113 |
+
filename = metadata.get('filename', '')
|
| 114 |
+
if not any(report in filename for report in reports):
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Filter by sources
|
| 118 |
+
if sources:
|
| 119 |
+
doc_source = metadata.get('source', '')
|
| 120 |
+
if sources != doc_source:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
# Filter by subtype
|
| 124 |
+
if subtype:
|
| 125 |
+
doc_subtype = metadata.get('subtype', '')
|
| 126 |
+
if doc_subtype not in subtype:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
# Filter by year
|
| 130 |
+
if year:
|
| 131 |
+
doc_year = str(metadata.get('year', ''))
|
| 132 |
+
if doc_year not in year:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
filtered_docs.append(doc)
|
| 136 |
+
|
| 137 |
+
return filtered_docs
|
| 138 |
+
|
| 139 |
+
def _bm25_search(
|
| 140 |
+
self,
|
| 141 |
+
query: str,
|
| 142 |
+
k: int = 20,
|
| 143 |
+
reports: List[str] = None,
|
| 144 |
+
sources: str = None,
|
| 145 |
+
subtype: List[str] = None,
|
| 146 |
+
year: List[str] = None
|
| 147 |
+
) -> List[Tuple[Document, float]]:
|
| 148 |
+
"""
|
| 149 |
+
Perform BM25 sparse search.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
query: Search query
|
| 153 |
+
k: Number of documents to retrieve
|
| 154 |
+
reports: List of specific report filenames
|
| 155 |
+
sources: Source category
|
| 156 |
+
subtype: List of subtypes
|
| 157 |
+
year: List of years
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
List of (Document, score) tuples
|
| 161 |
+
"""
|
| 162 |
+
if not self.bm25_retriever:
|
| 163 |
+
print("⚠️ BM25 retriever not available")
|
| 164 |
+
return []
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Get BM25 results
|
| 168 |
+
self.bm25_retriever.k = k
|
| 169 |
+
bm25_docs = self.bm25_retriever.invoke(query)
|
| 170 |
+
|
| 171 |
+
# Apply metadata filtering
|
| 172 |
+
if any([reports, sources, subtype, year]):
|
| 173 |
+
bm25_docs = self._filter_documents_by_metadata(
|
| 174 |
+
bm25_docs, reports, sources, subtype, year
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# BM25Retriever doesn't return scores directly, so we'll use placeholder scores
|
| 178 |
+
# In a production system, you'd want to access the actual BM25 scores
|
| 179 |
+
results = []
|
| 180 |
+
for i, doc in enumerate(bm25_docs):
|
| 181 |
+
# Assign decreasing scores based on rank (higher rank = higher score)
|
| 182 |
+
# Normalize to [0, 1] range for consistency with vector search
|
| 183 |
+
score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1)))
|
| 184 |
+
results.append((doc, score))
|
| 185 |
+
|
| 186 |
+
return results
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"❌ BM25 search failed: {e}")
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
def _vector_search(
|
| 193 |
+
self,
|
| 194 |
+
vectorstore: QdrantVectorStore,
|
| 195 |
+
query: str,
|
| 196 |
+
k: int = 20,
|
| 197 |
+
reports: List[str] = None,
|
| 198 |
+
sources: str = None,
|
| 199 |
+
subtype: List[str] = None,
|
| 200 |
+
year: List[str] = None
|
| 201 |
+
) -> List[Tuple[Document, float]]:
|
| 202 |
+
"""
|
| 203 |
+
Perform vector similarity search.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
vectorstore: QdrantVectorStore instance
|
| 207 |
+
query: Search query
|
| 208 |
+
k: Number of documents to retrieve
|
| 209 |
+
reports: List of specific report filenames
|
| 210 |
+
sources: Source category
|
| 211 |
+
subtype: List of subtypes
|
| 212 |
+
year: List of years
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
List of (Document, score) tuples
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
# Create filter
|
| 219 |
+
filter_obj = create_filter(
|
| 220 |
+
reports=reports,
|
| 221 |
+
sources=sources,
|
| 222 |
+
subtype=subtype,
|
| 223 |
+
year=year
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Perform vector search
|
| 227 |
+
if filter_obj:
|
| 228 |
+
results = vectorstore.similarity_search_with_score(
|
| 229 |
+
query, k=k, filter=filter_obj
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
results = vectorstore.similarity_search_with_score(query, k=k)
|
| 233 |
+
|
| 234 |
+
return results
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"❌ Vector search failed: {e}")
|
| 238 |
+
return []
|
| 239 |
+
|
| 240 |
+
def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]:
|
| 241 |
+
"""
|
| 242 |
+
Normalize scores to [0, 1] range.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
results: List of (Document, score) tuples
|
| 246 |
+
method: Normalization method ('min_max' or 'z_score')
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
List of (Document, normalized_score) tuples
|
| 250 |
+
"""
|
| 251 |
+
if not results:
|
| 252 |
+
return results
|
| 253 |
+
|
| 254 |
+
scores = [score for _, score in results]
|
| 255 |
+
|
| 256 |
+
if method == "min_max":
|
| 257 |
+
min_score = min(scores)
|
| 258 |
+
max_score = max(scores)
|
| 259 |
+
if max_score == min_score:
|
| 260 |
+
normalized_results = [(doc, 1.0) for doc, _ in results]
|
| 261 |
+
else:
|
| 262 |
+
normalized_results = [
|
| 263 |
+
(doc, (score - min_score) / (max_score - min_score))
|
| 264 |
+
for doc, score in results
|
| 265 |
+
]
|
| 266 |
+
elif method == "z_score":
|
| 267 |
+
mean_score = np.mean(scores)
|
| 268 |
+
std_score = np.std(scores)
|
| 269 |
+
if std_score == 0:
|
| 270 |
+
normalized_results = [(doc, 1.0) for doc, _ in results]
|
| 271 |
+
else:
|
| 272 |
+
normalized_results = [
|
| 273 |
+
(doc, max(0, (score - mean_score) / std_score))
|
| 274 |
+
for doc, score in results
|
| 275 |
+
]
|
| 276 |
+
else:
|
| 277 |
+
normalized_results = results
|
| 278 |
+
|
| 279 |
+
return normalized_results
|
| 280 |
+
|
| 281 |
+
def _combine_results(
|
| 282 |
+
self,
|
| 283 |
+
vector_results: List[Tuple[Document, float]],
|
| 284 |
+
bm25_results: List[Tuple[Document, float]],
|
| 285 |
+
alpha: float = 0.5
|
| 286 |
+
) -> List[Tuple[Document, float]]:
|
| 287 |
+
"""
|
| 288 |
+
Combine vector and BM25 results with weighted scoring.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
vector_results: Vector search results
|
| 292 |
+
bm25_results: BM25 search results
|
| 293 |
+
alpha: Weight for vector scores (1-alpha for BM25 scores)
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Combined and ranked results
|
| 297 |
+
"""
|
| 298 |
+
# Normalize scores
|
| 299 |
+
vector_results = self._normalize_scores(vector_results)
|
| 300 |
+
bm25_results = self._normalize_scores(bm25_results)
|
| 301 |
+
|
| 302 |
+
# Create document ID mapping for both result sets
|
| 303 |
+
vector_docs = {id(doc): (doc, score) for doc, score in vector_results}
|
| 304 |
+
bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results}
|
| 305 |
+
|
| 306 |
+
# Combine scores
|
| 307 |
+
combined_scores = {}
|
| 308 |
+
all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys())
|
| 309 |
+
|
| 310 |
+
for doc_id in all_doc_ids:
|
| 311 |
+
vector_score = vector_docs.get(doc_id, (None, 0.0))[1]
|
| 312 |
+
bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1]
|
| 313 |
+
|
| 314 |
+
# Weighted combination
|
| 315 |
+
combined_score = alpha * vector_score + (1 - alpha) * bm25_score
|
| 316 |
+
|
| 317 |
+
# Get document object
|
| 318 |
+
doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0]
|
| 319 |
+
combined_scores[doc_id] = (doc, combined_score)
|
| 320 |
+
|
| 321 |
+
# Sort by combined score (descending)
|
| 322 |
+
sorted_results = sorted(
|
| 323 |
+
combined_scores.values(),
|
| 324 |
+
key=lambda x: x[1],
|
| 325 |
+
reverse=True
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return sorted_results
|
| 329 |
+
|
| 330 |
+
def retrieve(
|
| 331 |
+
self,
|
| 332 |
+
vectorstore: QdrantVectorStore,
|
| 333 |
+
query: str,
|
| 334 |
+
mode: str = "hybrid",
|
| 335 |
+
reports: List[str] = None,
|
| 336 |
+
sources: str = None,
|
| 337 |
+
subtype: List[str] = None,
|
| 338 |
+
year: List[str] = None,
|
| 339 |
+
alpha: float = 0.5,
|
| 340 |
+
k: int = None
|
| 341 |
+
) -> List[Document]:
|
| 342 |
+
"""
|
| 343 |
+
Retrieve documents using the specified search mode.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
vectorstore: QdrantVectorStore instance
|
| 347 |
+
query: Search query
|
| 348 |
+
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 349 |
+
reports: List of specific report filenames
|
| 350 |
+
sources: Source category
|
| 351 |
+
subtype: List of subtypes
|
| 352 |
+
year: List of years
|
| 353 |
+
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
|
| 354 |
+
k: Number of documents to retrieve
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
List of relevant Document objects
|
| 358 |
+
"""
|
| 359 |
+
if k is None:
|
| 360 |
+
k = self.config.get("retriever", {}).get("top_k", 20)
|
| 361 |
+
|
| 362 |
+
results = []
|
| 363 |
+
|
| 364 |
+
if mode == "vector_only":
|
| 365 |
+
# Vector search only
|
| 366 |
+
vector_results = self._vector_search(
|
| 367 |
+
vectorstore, query, k, reports, sources, subtype, year
|
| 368 |
+
)
|
| 369 |
+
results = [(doc, score) for doc, score in vector_results]
|
| 370 |
+
|
| 371 |
+
elif mode == "sparse_only":
|
| 372 |
+
# BM25 search only
|
| 373 |
+
bm25_results = self._bm25_search(
|
| 374 |
+
query, k, reports, sources, subtype, year
|
| 375 |
+
)
|
| 376 |
+
results = [(doc, score) for doc, score in bm25_results]
|
| 377 |
+
|
| 378 |
+
elif mode == "hybrid":
|
| 379 |
+
# Hybrid search - combine both
|
| 380 |
+
# Get more results from each method to have better fusion
|
| 381 |
+
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
|
| 382 |
+
|
| 383 |
+
vector_results = self._vector_search(
|
| 384 |
+
vectorstore, query, retrieval_k, reports, sources, subtype, year
|
| 385 |
+
)
|
| 386 |
+
bm25_results = self._bm25_search(
|
| 387 |
+
query, retrieval_k, reports, sources, subtype, year
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
results = self._combine_results(vector_results, bm25_results, alpha)
|
| 391 |
+
|
| 392 |
+
else:
|
| 393 |
+
raise ValueError(f"Unknown search mode: {mode}")
|
| 394 |
+
|
| 395 |
+
# Limit to top k results
|
| 396 |
+
results = results[:k]
|
| 397 |
+
|
| 398 |
+
# Return just the documents
|
| 399 |
+
return [doc for doc, score in results]
|
| 400 |
+
|
| 401 |
+
def retrieve_with_scores(
|
| 402 |
+
self,
|
| 403 |
+
vectorstore: QdrantVectorStore,
|
| 404 |
+
query: str,
|
| 405 |
+
mode: str = "hybrid",
|
| 406 |
+
reports: List[str] = None,
|
| 407 |
+
sources: str = None,
|
| 408 |
+
subtype: List[str] = None,
|
| 409 |
+
year: List[str] = None,
|
| 410 |
+
alpha: float = 0.5,
|
| 411 |
+
k: int = None
|
| 412 |
+
) -> List[Tuple[Document, float]]:
|
| 413 |
+
"""
|
| 414 |
+
Retrieve documents with scores using the specified search mode.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
vectorstore: QdrantVectorStore instance
|
| 418 |
+
query: Search query
|
| 419 |
+
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
|
| 420 |
+
reports: List of specific report filenames
|
| 421 |
+
sources: Source category
|
| 422 |
+
subtype: List of subtypes
|
| 423 |
+
year: List of years
|
| 424 |
+
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
|
| 425 |
+
k: Number of documents to retrieve
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
List of (Document, score) tuples
|
| 429 |
+
"""
|
| 430 |
+
if k is None:
|
| 431 |
+
k = self.config.get("retriever", {}).get("top_k", 20)
|
| 432 |
+
|
| 433 |
+
results = []
|
| 434 |
+
|
| 435 |
+
if mode == "vector_only":
|
| 436 |
+
# Vector search only
|
| 437 |
+
results = self._vector_search(
|
| 438 |
+
vectorstore, query, k, reports, sources, subtype, year
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
elif mode == "sparse_only":
|
| 442 |
+
# BM25 search only
|
| 443 |
+
results = self._bm25_search(
|
| 444 |
+
query, k, reports, sources, subtype, year
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
elif mode == "hybrid":
|
| 448 |
+
# Hybrid search - combine both
|
| 449 |
+
# Get more results from each method to have better fusion
|
| 450 |
+
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
|
| 451 |
+
|
| 452 |
+
vector_results = self._vector_search(
|
| 453 |
+
vectorstore, query, retrieval_k, reports, sources, subtype, year
|
| 454 |
+
)
|
| 455 |
+
bm25_results = self._bm25_search(
|
| 456 |
+
query, retrieval_k, reports, sources, subtype, year
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
results = self._combine_results(vector_results, bm25_results, alpha)
|
| 460 |
+
|
| 461 |
+
else:
|
| 462 |
+
raise ValueError(f"Unknown search mode: {mode}")
|
| 463 |
+
|
| 464 |
+
# Limit to top k results
|
| 465 |
+
return results[:k]
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def get_available_search_modes() -> List[str]:
|
| 469 |
+
"""Get list of available search modes."""
|
| 470 |
+
return ["vector_only", "sparse_only", "hybrid"]
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def get_search_mode_description() -> Dict[str, str]:
|
| 474 |
+
"""Get descriptions for each search mode."""
|
| 475 |
+
return {
|
| 476 |
+
"vector_only": "Semantic search using dense embeddings - good for conceptual matching",
|
| 477 |
+
"sparse_only": "Keyword search using BM25 - good for exact term matching",
|
| 478 |
+
"hybrid": "Combined semantic and keyword search - balanced approach"
|
| 479 |
+
}
|
src/vectorstore.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store management and operations."""
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Any, List, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from langchain_qdrant import QdrantVectorStore
|
| 8 |
+
from langchain.docstore.document import Document
|
| 9 |
+
from langchain_core.embeddings import Embeddings
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MatryoshkaEmbeddings(Embeddings):
|
| 15 |
+
"""Custom embeddings class that supports Matryoshka dimension truncation."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str, truncate_dim: int = None, **kwargs):
|
| 18 |
+
"""
|
| 19 |
+
Initialize Matryoshka embeddings.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_name: Name of the model
|
| 23 |
+
truncate_dim: Dimension to truncate to (for Matryoshka models)
|
| 24 |
+
**kwargs: Additional arguments (ignored for Matryoshka models)
|
| 25 |
+
"""
|
| 26 |
+
self.model_name = model_name
|
| 27 |
+
self.truncate_dim = truncate_dim
|
| 28 |
+
|
| 29 |
+
if truncate_dim and "matryoshka" in model_name.lower():
|
| 30 |
+
# Use SentenceTransformer directly for Matryoshka models
|
| 31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
|
| 33 |
+
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
|
| 34 |
+
else:
|
| 35 |
+
# Use standard HuggingFaceEmbeddings
|
| 36 |
+
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
| 37 |
+
|
| 38 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 39 |
+
"""Embed documents."""
|
| 40 |
+
if self.truncate_dim and "matryoshka" in self.model_name.lower():
|
| 41 |
+
embeddings = self.model.encode(texts, normalize_embeddings=True)
|
| 42 |
+
return embeddings.tolist()
|
| 43 |
+
else:
|
| 44 |
+
return self.model.embed_documents(texts)
|
| 45 |
+
|
| 46 |
+
def embed_query(self, text: str) -> List[float]:
|
| 47 |
+
"""Embed query."""
|
| 48 |
+
if self.truncate_dim and "matryoshka" in self.model_name.lower():
|
| 49 |
+
embedding = self.model.encode([text], normalize_embeddings=True)
|
| 50 |
+
return embedding[0].tolist()
|
| 51 |
+
else:
|
| 52 |
+
return self.model.embed_query(text)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class VectorStoreManager:
|
| 56 |
+
"""Manages vector store operations and connections."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, config: Dict[str, Any]):
|
| 59 |
+
"""
|
| 60 |
+
Initialize vector store manager.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
config: Configuration dictionary
|
| 64 |
+
"""
|
| 65 |
+
self.config = config
|
| 66 |
+
self.embeddings = self._create_embeddings()
|
| 67 |
+
self.vectorstore = None
|
| 68 |
+
|
| 69 |
+
# Define metadata fields that need payload indexes for filtering
|
| 70 |
+
self.metadata_fields = [
|
| 71 |
+
("metadata.year", "keyword"),
|
| 72 |
+
("metadata.source", "keyword"),
|
| 73 |
+
("metadata.filename", "keyword"),
|
| 74 |
+
# Add more metadata fields as needed
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def _create_embeddings(self) -> HuggingFaceEmbeddings:
|
| 78 |
+
"""Create embeddings model from configuration."""
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
|
| 81 |
+
model_name = self.config["retriever"]["model"]
|
| 82 |
+
normalize = self.config["retriever"]["normalize"]
|
| 83 |
+
|
| 84 |
+
model_kwargs = {"device": device}
|
| 85 |
+
encode_kwargs = {
|
| 86 |
+
"normalize_embeddings": normalize,
|
| 87 |
+
"batch_size": 100,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# For Matryoshka models, check if we need to truncate dimensions
|
| 91 |
+
if "matryoshka" in model_name.lower():
|
| 92 |
+
# Check if we have a specific dimension requirement
|
| 93 |
+
collection_name = self.config.get("qdrant", {}).get("collection_name", "")
|
| 94 |
+
|
| 95 |
+
if "modernbert-embed-base-akryl-matryoshka" in collection_name:
|
| 96 |
+
# This collection expects 768 dimensions
|
| 97 |
+
truncate_dim = 768
|
| 98 |
+
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
|
| 99 |
+
|
| 100 |
+
# Use custom MatryoshkaEmbeddings
|
| 101 |
+
embeddings = MatryoshkaEmbeddings(
|
| 102 |
+
model_name=model_name,
|
| 103 |
+
truncate_dim=truncate_dim,
|
| 104 |
+
model_kwargs=model_kwargs,
|
| 105 |
+
encode_kwargs=encode_kwargs,
|
| 106 |
+
show_progress=True,
|
| 107 |
+
)
|
| 108 |
+
return embeddings
|
| 109 |
+
|
| 110 |
+
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
|
| 111 |
+
embeddings = HuggingFaceEmbeddings(
|
| 112 |
+
model_name=model_name,
|
| 113 |
+
model_kwargs=model_kwargs,
|
| 114 |
+
encode_kwargs=encode_kwargs,
|
| 115 |
+
show_progress=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return embeddings
|
| 119 |
+
|
| 120 |
+
def ensure_metadata_indexes(self) -> None:
|
| 121 |
+
"""
|
| 122 |
+
Create payload indexes for all required metadata fields.
|
| 123 |
+
This ensures filtering works properly, especially in Qdrant Cloud.
|
| 124 |
+
"""
|
| 125 |
+
if not self.vectorstore:
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
qdrant_config = self.config["qdrant"]
|
| 129 |
+
collection_name = qdrant_config["collection_name"]
|
| 130 |
+
|
| 131 |
+
for field_name, field_type in self.metadata_fields:
|
| 132 |
+
try:
|
| 133 |
+
self.vectorstore.client.create_payload_index(
|
| 134 |
+
collection_name=collection_name,
|
| 135 |
+
field_name=field_name,
|
| 136 |
+
field_type=field_type
|
| 137 |
+
)
|
| 138 |
+
print(f"Created payload index for {field_name} ({field_type})")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
# Index might already exist or other error - log but continue
|
| 141 |
+
print(f"Index creation for {field_name} ({field_type}): {str(e)}")
|
| 142 |
+
|
| 143 |
+
def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore:
|
| 144 |
+
"""
|
| 145 |
+
Connect to existing Qdrant collection.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
force_recreate: If True, recreate the collection if dimension mismatch occurs
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
QdrantVectorStore instance
|
| 152 |
+
"""
|
| 153 |
+
qdrant_config = self.config["qdrant"]
|
| 154 |
+
|
| 155 |
+
kwargs_qdrant = {
|
| 156 |
+
"url": qdrant_config["url"],
|
| 157 |
+
"collection_name": qdrant_config["collection_name"],
|
| 158 |
+
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
|
| 159 |
+
"api_key": qdrant_config.get("api_key", None),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if force_recreate:
|
| 163 |
+
kwargs_qdrant["force_recreate"] = True
|
| 164 |
+
|
| 165 |
+
self.vectorstore = QdrantVectorStore.from_existing_collection(
|
| 166 |
+
embedding=self.embeddings,
|
| 167 |
+
**kwargs_qdrant
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Ensure payload indexes exist for metadata filtering
|
| 171 |
+
self.ensure_metadata_indexes()
|
| 172 |
+
|
| 173 |
+
return self.vectorstore
|
| 174 |
+
|
| 175 |
+
def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore:
|
| 176 |
+
"""
|
| 177 |
+
Create new Qdrant collection from documents.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
documents: List of Document objects
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
QdrantVectorStore instance
|
| 184 |
+
"""
|
| 185 |
+
qdrant_config = self.config["qdrant"]
|
| 186 |
+
|
| 187 |
+
kwargs_qdrant = {
|
| 188 |
+
"url": qdrant_config["url"],
|
| 189 |
+
"collection_name": qdrant_config["collection_name"],
|
| 190 |
+
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
|
| 191 |
+
"api_key": qdrant_config.get("api_key", None),
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
self.vectorstore = QdrantVectorStore.from_documents(
|
| 195 |
+
documents=documents,
|
| 196 |
+
embedding=self.embeddings,
|
| 197 |
+
**kwargs_qdrant
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Ensure payload indexes exist for metadata filtering
|
| 201 |
+
self.ensure_metadata_indexes()
|
| 202 |
+
|
| 203 |
+
return self.vectorstore
|
| 204 |
+
|
| 205 |
+
def delete_collection(self) -> None:
|
| 206 |
+
"""
|
| 207 |
+
Delete the current Qdrant collection.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
QdrantVectorStore instance
|
| 211 |
+
"""
|
| 212 |
+
qdrant_config = self.config["qdrant"]
|
| 213 |
+
collection_name = qdrant_config.get("collection_name")
|
| 214 |
+
|
| 215 |
+
self.vectorstore.client.delete_collection(
|
| 216 |
+
collection_name=collection_name
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return self.vectorstore
|
| 220 |
+
|
| 221 |
+
def get_vectorstore(self) -> Optional[QdrantVectorStore]:
|
| 222 |
+
"""Get current vectorstore instance."""
|
| 223 |
+
return self.vectorstore
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore:
|
| 227 |
+
"""
|
| 228 |
+
Get local Qdrant vector store (legacy function for compatibility).
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
config: Configuration dictionary
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
QdrantVectorStore instance
|
| 235 |
+
"""
|
| 236 |
+
manager = VectorStoreManager(config)
|
| 237 |
+
return manager.connect_to_existing()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore:
|
| 241 |
+
"""
|
| 242 |
+
Create new vector store from documents.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
config: Configuration dictionary
|
| 246 |
+
documents: List of Document objects
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
QdrantVectorStore instance
|
| 250 |
+
"""
|
| 251 |
+
manager = VectorStoreManager(config)
|
| 252 |
+
return manager.create_from_documents(documents)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings:
|
| 256 |
+
"""
|
| 257 |
+
Create embeddings model from configuration (legacy function).
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
config: Configuration dictionary
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
HuggingFaceEmbeddings instance
|
| 264 |
+
"""
|
| 265 |
+
manager = VectorStoreManager(config)
|
| 266 |
+
return manager.embeddings
|
utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import dataclasses
|
| 3 |
+
from uuid import UUID
|
| 4 |
+
from typing import Any
|
| 5 |
+
from datetime import datetime, date
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import configparser
|
| 9 |
+
from torch import cuda
|
| 10 |
+
from qdrant_client.http import models as rest
|
| 11 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 12 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_config(fp):
|
| 16 |
+
config = configparser.ConfigParser()
|
| 17 |
+
config.read_file(open(fp))
|
| 18 |
+
return config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_embeddings_model(config):
|
| 22 |
+
device = "cuda" if cuda.is_available() else "cpu"
|
| 23 |
+
|
| 24 |
+
# Define embedding model
|
| 25 |
+
model_name = config.get("retriever", "MODEL")
|
| 26 |
+
model_kwargs = {"device": device}
|
| 27 |
+
normalize_embeddings = bool(int(config.get("retriever", "NORMALIZE")))
|
| 28 |
+
encode_kwargs = {
|
| 29 |
+
"normalize_embeddings": normalize_embeddings,
|
| 30 |
+
"batch_size": 100,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
embeddings = HuggingFaceEmbeddings(
|
| 34 |
+
show_progress=True,
|
| 35 |
+
model_name=model_name,
|
| 36 |
+
model_kwargs=model_kwargs,
|
| 37 |
+
encode_kwargs=encode_kwargs,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
return embeddings
|
| 41 |
+
|
| 42 |
+
# Create a search filter for Qdrant
|
| 43 |
+
def create_filter(
|
| 44 |
+
reports: list = [], sources: str = None, subtype: str = None, year: str = None
|
| 45 |
+
):
|
| 46 |
+
if len(reports) == 0:
|
| 47 |
+
print(f"defining filter for sources:{sources}, subtype:{subtype}")
|
| 48 |
+
filter = rest.Filter(
|
| 49 |
+
must=[
|
| 50 |
+
rest.FieldCondition(
|
| 51 |
+
key="metadata.source", match=rest.MatchValue(value=sources)
|
| 52 |
+
),
|
| 53 |
+
rest.FieldCondition(
|
| 54 |
+
key="metadata.filename", match=rest.MatchAny(any=subtype)
|
| 55 |
+
),
|
| 56 |
+
# rest.FieldCondition(
|
| 57 |
+
# key="metadata.year",
|
| 58 |
+
# match=rest.MatchAny(any=year)
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
print(f"defining filter for allreports:{reports}")
|
| 63 |
+
filter = rest.Filter(
|
| 64 |
+
must=[
|
| 65 |
+
rest.FieldCondition(
|
| 66 |
+
key="metadata.filename", match=rest.MatchAny(any=reports)
|
| 67 |
+
)
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return filter
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_json(fp):
|
| 75 |
+
with open(fp, "r") as f:
|
| 76 |
+
docs = json.load(f)
|
| 77 |
+
return docs
|
| 78 |
+
|
| 79 |
+
def get_timestamp():
|
| 80 |
+
now = datetime.datetime.now()
|
| 81 |
+
timestamp = now.strftime("%Y%m%d%H%M%S")
|
| 82 |
+
return timestamp
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# A custom class to help with recursive serialization.
|
| 87 |
+
# This approach avoids modifying the original object.
|
| 88 |
+
class _RecursiveSerializer(json.JSONEncoder):
|
| 89 |
+
"""A custom JSONEncoder that handles complex types by converting them to dicts or strings."""
|
| 90 |
+
def default(self, obj):
|
| 91 |
+
# Prefer the pydantic method if it exists for the most robust serialization.
|
| 92 |
+
if hasattr(obj, 'model_dump'):
|
| 93 |
+
return obj.model_dump()
|
| 94 |
+
|
| 95 |
+
# Handle dataclasses
|
| 96 |
+
if dataclasses.is_dataclass(obj):
|
| 97 |
+
return dataclasses.asdict(obj)
|
| 98 |
+
|
| 99 |
+
# Handle other non-serializable but common types.
|
| 100 |
+
if isinstance(obj, (datetime, date, UUID)):
|
| 101 |
+
return str(obj)
|
| 102 |
+
|
| 103 |
+
# Fallback for general objects with a __dict__
|
| 104 |
+
if hasattr(obj, '__dict__'):
|
| 105 |
+
return obj.__dict__
|
| 106 |
+
|
| 107 |
+
# Default fallback to JSONEncoder's behavior
|
| 108 |
+
return super().default(obj)
|
| 109 |
+
|
| 110 |
+
def to_json_string(obj: Any, **kwargs) -> str:
|
| 111 |
+
"""
|
| 112 |
+
Serializes a Python object into a JSON-formatted string.
|
| 113 |
+
|
| 114 |
+
This function is a comprehensive utility that can handle:
|
| 115 |
+
- Standard Python types (lists, dicts, strings, numbers, bools, None).
|
| 116 |
+
- Pydantic models (using `model_dump()`).
|
| 117 |
+
- Dataclasses (using `dataclasses.asdict()`).
|
| 118 |
+
- Standard library types not natively JSON-serializable (e.g., datetime, UUID).
|
| 119 |
+
- Custom classes with a `__dict__`.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
obj (Any): The Python object to serialize.
|
| 123 |
+
**kwargs: Additional keyword arguments to pass to `json.dumps`.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
str: A JSON-formatted string.
|
| 127 |
+
|
| 128 |
+
Example:
|
| 129 |
+
>>> from datetime import datetime
|
| 130 |
+
>>> from pydantic import BaseModel
|
| 131 |
+
>>> from dataclasses import dataclass
|
| 132 |
+
|
| 133 |
+
>>> class Address(BaseModel):
|
| 134 |
+
... street: str
|
| 135 |
+
... city: str
|
| 136 |
+
|
| 137 |
+
>>> @dataclass
|
| 138 |
+
... class Product:
|
| 139 |
+
... id: int
|
| 140 |
+
... name: str
|
| 141 |
+
|
| 142 |
+
>>> class Order(BaseModel):
|
| 143 |
+
... user_address: Address
|
| 144 |
+
... item: Product
|
| 145 |
+
|
| 146 |
+
>>> order_obj = Order(
|
| 147 |
+
... user_address=Address(street="123 Main St", city="Example City"),
|
| 148 |
+
... item=Product(id=1, name="Laptop")
|
| 149 |
+
... )
|
| 150 |
+
|
| 151 |
+
>>> print(to_json_string(order_obj, indent=2))
|
| 152 |
+
{
|
| 153 |
+
"user_address": {
|
| 154 |
+
"street": "123 Main St",
|
| 155 |
+
"city": "Example City"
|
| 156 |
+
},
|
| 157 |
+
"item": {
|
| 158 |
+
"id": 1,
|
| 159 |
+
"name": "Laptop"
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
"""
|
| 163 |
+
return json.dumps(obj, cls=_RecursiveSerializer, **kwargs)
|