"""
Intelligent Audit Report Chatbot UI
"""
import os
import time
import json
import uuid
import logging
import traceback
from pathlib import Path
from collections import Counter
from typing import List, Dict, Any, Optional
import pandas as pd
import streamlit as st
import plotly.express as px
from langchain_core.messages import HumanMessage, AIMessage
from multi_agent_chatbot import get_multi_agent_chatbot
from smart_chatbot import get_chatbot as get_smart_chatbot
from src.reporting.snowflake_connector import save_to_snowflake
from src.reporting.feedback_schema import create_feedback_from_dict
from src.config.paths import (
IS_DEPLOYED,
PROJECT_DIR,
HF_CACHE_DIR,
FEEDBACK_DIR,
CONVERSATIONS_DIR,
)
# ===== CRITICAL: Fix OMP_NUM_THREADS FIRST, before ANY other imports =====
# Some libraries load at import time and will fail if OMP_NUM_THREADS is invalid
omp_threads = os.environ.get("OMP_NUM_THREADS", "")
try:
if omp_threads:
# Handle invalid formats like "3500m" by extracting just the number
# Remove any non-numeric suffix and convert to int
cleaned = ''.join(filter(str.isdigit, omp_threads))
if cleaned:
threads = int(cleaned)
if threads <= 0:
os.environ["OMP_NUM_THREADS"] = "1"
else:
# Set the cleaned integer value back
os.environ["OMP_NUM_THREADS"] = str(threads)
else:
os.environ["OMP_NUM_THREADS"] = "1"
else:
os.environ["OMP_NUM_THREADS"] = "1"
except (ValueError, TypeError):
os.environ["OMP_NUM_THREADS"] = "1"
# ===== Setup HuggingFace cache directories BEFORE any model imports =====
# CRITICAL: Set these before any imports that might use HuggingFace (like sentence-transformers)
# Only override cache directories in deployed environment (local uses defaults)
if IS_DEPLOYED and HF_CACHE_DIR:
cache_dir = str(HF_CACHE_DIR)
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = cache_dir
os.environ["HF_HUB_CACHE"] = cache_dir
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir
# Ensure cache directory exists (created in Dockerfile, but ensure it's there)
try:
os.makedirs(cache_dir, mode=0o755, exist_ok=True)
except (PermissionError, OSError):
# If we can't create it, log but continue (might already exist from Dockerfile)
pass
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Log environment setup for debugging
logger.info(f"🌍 Environment: {'DEPLOYED' if IS_DEPLOYED else 'LOCAL'}")
logger.info(f"📁 PROJECT_DIR: {PROJECT_DIR}")
logger.info(f"📁 HuggingFace cache: {os.environ.get('HF_HOME', 'DEFAULT (not overridden)')}")
logger.info(f"🔧 OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'NOT SET')}")
# Page config
st.set_page_config(
layout="wide",
page_icon="🤖",
initial_sidebar_state="expanded",
page_title="Intelligent Audit Report Chatbot"
)
# Custom CSS
st.markdown("""
""", unsafe_allow_html=True)
def get_system_type():
"""Get the current system type"""
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return "Smart Chatbot System"
else:
return "Multi-Agent System"
def get_chatbot():
"""Initialize and return the chatbot based on system type"""
# Check environment variable for system type
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return get_smart_chatbot()
else:
return get_multi_agent_chatbot()
def serialize_messages(messages):
"""Serialize LangChain messages to dictionaries"""
serialized = []
for msg in messages:
if hasattr(msg, 'content'):
serialized.append({
"type": type(msg).__name__,
"content": str(msg.content)
})
return serialized
def serialize_documents(sources):
"""Serialize document objects to dictionaries with deduplication"""
serialized = []
seen_content = set()
for doc in sources:
content = getattr(doc, 'page_content', getattr(doc, 'content', ''))
# Skip if we've seen this exact content before
if content in seen_content:
continue
seen_content.add(content)
doc_dict = {
"content": content,
"metadata": getattr(doc, 'metadata', {}),
"score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)),
"id": getattr(doc, 'metadata', {}).get('_id', 'unknown'),
"source": getattr(doc, 'metadata', {}).get('source', 'unknown'),
"year": getattr(doc, 'metadata', {}).get('year', 'unknown'),
"district": getattr(doc, 'metadata', {}).get('district', 'unknown'),
"page": getattr(doc, 'metadata', {}).get('page', 'unknown'),
"chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'),
"page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'),
"original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0),
"reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None)
}
serialized.append(doc_dict)
return serialized
def extract_transcript(messages: List[Any]) -> List[Dict[str, str]]:
"""Extract transcript from messages - only user and bot messages, no extra metadata"""
transcript = []
for msg in messages:
if isinstance(msg, HumanMessage):
transcript.append({
"role": "user",
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
})
elif isinstance(msg, AIMessage):
transcript.append({
"role": "assistant",
"content": str(msg.content) if hasattr(msg, 'content') else str(msg)
})
return transcript
def build_retrievals_structure(rag_retrieval_history: List[Dict[str, Any]], messages: List[Any]) -> List[Dict[str, Any]]:
"""Build retrievals structure from retrieval history"""
retrievals = []
for entry in rag_retrieval_history:
# Get the user message that triggered this retrieval
# The entry has conversation_up_to which includes messages up to that point
conversation_up_to = entry.get("conversation_up_to", [])
# Find the last user message in conversation_up_to (this is the trigger)
user_message_trigger = ""
for msg_dict in reversed(conversation_up_to):
if msg_dict.get("type") == "HumanMessage":
user_message_trigger = msg_dict.get("content", "")
break
# Fallback: if not found in conversation_up_to, get from actual messages
# This handles edge cases where conversation_up_to might be incomplete
if not user_message_trigger:
# Find which retrieval this is (0-indexed)
retrieval_idx = rag_retrieval_history.index(entry)
# The user message that triggered this retrieval is at position (retrieval_idx * 2)
# because each retrieval is preceded by: user message, bot response, user message, ...
# But we need to account for the fact that the first retrieval happens after the first user message
user_msgs = [msg for msg in messages if isinstance(msg, HumanMessage)]
if retrieval_idx < len(user_msgs):
user_message_trigger = str(user_msgs[retrieval_idx].content)
elif user_msgs:
# Fallback to last user message
user_message_trigger = str(user_msgs[-1].content)
# Get retrieved documents and truncate content to 100 chars
docs_retrieved = entry.get("docs_retrieved", [])
retrieved_docs = []
for doc in docs_retrieved:
doc_copy = doc.copy()
# Truncate content to 100 characters (keep all other fields)
if "content" in doc_copy:
doc_copy["content"] = doc_copy["content"][:100]
retrieved_docs.append(doc_copy)
retrievals.append({
"retrieved_docs": retrieved_docs,
"user_message_trigger": user_message_trigger
})
return retrievals
def build_feedback_score_related_retrieval_docs(
is_feedback_about_last_retrieval: bool,
messages: List[Any],
rag_retrieval_history: List[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""Build feedback_score_related_retrieval_docs structure"""
if not rag_retrieval_history:
return None
# Get the relevant retrieval entry
if is_feedback_about_last_retrieval:
relevant_entry = rag_retrieval_history[-1]
else:
# If feedback is about all retrievals, use the last one as default
relevant_entry = rag_retrieval_history[-1]
# Get conversation up to that point
conversation_up_to = relevant_entry.get("conversation_up_to", [])
# Convert to transcript format (role/content)
conversation_up_to_point = []
for msg_dict in conversation_up_to:
if msg_dict.get("type") == "HumanMessage":
conversation_up_to_point.append({
"role": "user",
"content": msg_dict.get("content", "")
})
elif msg_dict.get("type") == "AIMessage":
conversation_up_to_point.append({
"role": "assistant",
"content": msg_dict.get("content", "")
})
# Get retrieved docs with full content (not truncated)
retrieved_docs = relevant_entry.get("docs_retrieved", [])
return {
"conversation_up_to_point": conversation_up_to_point,
"retrieved_docs": retrieved_docs
}
def extract_chunk_statistics(sources: List[Any]) -> Dict[str, Any]:
"""Extract statistics from retrieved chunks."""
if not sources:
return {}
sources_list = []
years = []
filenames = []
districts = []
for doc in sources:
metadata = getattr(doc, 'metadata', {})
# Extract source
source = metadata.get('source', 'Unknown')
sources_list.append(source)
# Extract year
year = metadata.get('year', 'Unknown')
if year and year != 'Unknown':
try:
# Convert to int first, then back to string to ensure it's a proper year
year_int = int(float(year)) # Handle both int and float strings
if 1900 <= year_int <= 2030: # Reasonable year range
years.append(str(year_int))
else:
years.append('Unknown')
except (ValueError, TypeError):
years.append('Unknown')
else:
years.append('Unknown')
# Extract filename
filename = metadata.get('filename', 'Unknown')
filenames.append(filename)
# Extract district
district = metadata.get('district', 'Unknown')
if district and district != 'Unknown':
districts.append(district)
else:
districts.append('Unknown')
# Count occurrences
source_counts = Counter(sources_list)
year_counts = Counter(years)
filename_counts = Counter(filenames)
district_counts = Counter(districts)
return {
'total_chunks': len(sources),
'unique_sources': len(source_counts),
'unique_years': len([y for y in year_counts.keys() if y != 'Unknown']),
'unique_filenames': len(filename_counts),
'unique_districts': len([d for d in district_counts.keys() if d != 'Unknown']),
'source_distribution': dict(source_counts),
'year_distribution': dict(year_counts),
'filename_distribution': dict(filename_counts),
'district_distribution': dict(district_counts),
'sources': sources_list,
'years': years,
'filenames': filenames,
'districts': districts
}
def display_chunk_statistics_charts(stats: Dict[str, Any], title: str = "Retrieval Statistics"):
"""Display statistics as interactive charts for 10+ results."""
if not stats or stats.get('total_chunks', 0) == 0:
return
# Wrap everything in one styled container - open it
st.markdown(f"""
📊 {title}
Total Chunks
{stats['total_chunks']}
Unique Sources
{stats['unique_sources']}
Unique Years
{stats['unique_years']}
Unique Files
{stats['unique_filenames']}
""", unsafe_allow_html=True)
# Charts - three columns to include Districts
col1, col2, col3 = st.columns(3)
with col1:
# Source distribution chart
if stats['source_distribution']:
source_df = pd.DataFrame(
list(stats['source_distribution'].items()),
columns=['Source', 'Count']
)
fig_source = px.bar(
source_df,
x='Count',
y='Source',
orientation='h',
title='Distribution by Source',
color='Count',
color_continuous_scale='viridis'
)
fig_source.update_layout(height=400, showlegend=False)
st.plotly_chart(fig_source, use_container_width=True)
with col2:
# Year distribution chart
if stats['year_distribution']:
# Filter out 'Unknown' years for the chart
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
if year_dist_filtered:
year_df = pd.DataFrame(
list(year_dist_filtered.items()),
columns=['Year', 'Count']
)
# Sort by year as integer but keep as string for categorical display
year_df['Year_Int'] = year_df['Year'].astype(int)
year_df = year_df.sort_values('Year_Int').drop('Year_Int', axis=1)
fig_year = px.bar(
year_df,
x='Year',
y='Count',
title='Distribution by Year',
color='Count',
color_continuous_scale='plasma'
)
# Ensure years are treated as categorical (discrete) not continuous
fig_year.update_xaxes(type='category')
fig_year.update_layout(height=400, showlegend=False)
st.plotly_chart(fig_year, use_container_width=True)
else:
st.info("No valid years found in the results")
with col3:
# District distribution chart
if stats.get('district_distribution'):
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
if district_dist_filtered:
district_df = pd.DataFrame(
list(district_dist_filtered.items()),
columns=['District', 'Count']
)
district_df = district_df.sort_values('Count', ascending=False)
fig_district = px.bar(
district_df,
x='Count',
y='District',
orientation='h',
title='Distribution by District',
color='Count',
color_continuous_scale='blues'
)
fig_district.update_layout(height=400, showlegend=False)
st.plotly_chart(fig_district, use_container_width=True)
else:
st.info("No valid districts found in the results")
# Close the container
st.markdown('
', unsafe_allow_html=True)
def display_chunk_statistics_table(stats: Dict[str, Any], title: str = "Retrieval Distribution"):
"""Display statistics as tables for smaller results with fixed alignment."""
if not stats or stats.get('total_chunks', 0) == 0:
return
# Wrap in styled container
st.markdown('', unsafe_allow_html=True)
st.subheader(f"📊 {title}")
# Create a container with fixed height for alignment
stats_container = st.container()
with stats_container:
# Create 4 equal columns for consistent alignment
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown("**🏘️ Districts**")
if stats.get('district_distribution'):
district_dist_filtered = {k: v for k, v in stats['district_distribution'].items() if k != 'Unknown'}
if district_dist_filtered:
district_data = {
"District": list(district_dist_filtered.keys()),
"Count": list(district_dist_filtered.values())
}
district_df = pd.DataFrame(district_data).sort_values('Count', ascending=False)
st.dataframe(district_df, hide_index=True, use_container_width=True)
else:
st.write("No district data")
else:
st.write("No district data")
with col2:
st.markdown("**📂 Sources**")
if stats['source_distribution']:
source_data = {
"Source": list(stats['source_distribution'].keys()),
"Count": list(stats['source_distribution'].values())
}
source_df = pd.DataFrame(source_data).sort_values('Count', ascending=False)
st.dataframe(source_df, hide_index=True, use_container_width=True)
else:
st.write("No source data")
with col3:
st.markdown("**📅 Years**")
if stats['year_distribution']:
year_dist_filtered = {k: v for k, v in stats['year_distribution'].items() if k != 'Unknown'}
if year_dist_filtered:
year_data = {
"Year": list(year_dist_filtered.keys()),
"Count": list(year_dist_filtered.values())
}
year_df = pd.DataFrame(year_data)
# Sort by year as integer but display as string
year_df['Year_Int'] = year_df['Year'].astype(int)
year_df = year_df.sort_values('Year_Int')[['Year', 'Count']]
st.dataframe(year_df, hide_index=True, use_container_width=True)
else:
st.write("No year data")
else:
st.write("No year data")
with col4:
st.markdown("**📄 Files**")
if stats['filename_distribution']:
filename_items = list(stats['filename_distribution'].items())
filename_items.sort(key=lambda x: x[1], reverse=True)
# Show top files with truncated names
file_data = {
"File": [f[:30] + "..." if len(f) > 30 else f for f, c in filename_items[:5]],
"Count": [c for f, c in filename_items[:5]]
}
file_df = pd.DataFrame(file_data)
st.dataframe(file_df, hide_index=True, use_container_width=True)
else:
st.write("No file data")
# Close container
st.markdown('
', unsafe_allow_html=True)
@st.cache_data
def load_filter_options():
try:
filter_options_path = PROJECT_DIR / "src" / "config" / "filter_options.json"
with open(filter_options_path, "r") as f:
return json.load(f)
except FileNotFoundError:
st.info(f"Looking for filter_options.json in: {PROJECT_DIR / 'src' / 'config'}")
st.error("filter_options.json not found. Please run the metadata analysis script.")
return {"sources": [], "years": [], "districts": [], 'filenames': []}
def main():
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'conversation_id' not in st.session_state:
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
if 'session_start_time' not in st.session_state:
st.session_state.session_start_time = time.time()
if 'active_filters' not in st.session_state:
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
# Track RAG retrieval history for feedback
if 'rag_retrieval_history' not in st.session_state:
st.session_state.rag_retrieval_history = []
# Initialize chatbot only once per app session (cached)
if 'chatbot' not in st.session_state:
with st.spinner("🔄 Loading AI models and connecting to database..."):
st.session_state.chatbot = get_chatbot()
st.success("✅ AI system ready!")
# Reset conversation history if needed (but keep chatbot cached)
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
st.session_state.messages = []
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
st.session_state.session_start_time = time.time()
st.session_state.rag_retrieval_history = []
st.session_state.feedback_submitted = False
st.session_state.reset_conversation = False
st.rerun()
# Header - fully center aligned
st.markdown('🤖 Intelligent Audit Report Chatbot
', unsafe_allow_html=True)
st.markdown('Ask questions about audit reports. Use the sidebar filters to narrow down your search!
', unsafe_allow_html=True)
# Session info
duration = int(time.time() - st.session_state.session_start_time)
duration_str = f"{duration // 60}m {duration % 60}s"
st.markdown(f'''
Session Info: Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
''', unsafe_allow_html=True)
# Load filter options
filter_options = load_filter_options()
# Sidebar for filters
with st.sidebar:
# Instructions section (collapsible)
with st.expander("📖 How to Use", expanded=False):
st.markdown("""
#### 🎯 Using Filters
1. **Select filters** from the sidebar to narrow your search:
2. **Leave filters empty** to search across all data
3. **Type your question** in the chat input at the bottom
4. **Click "Send"** to submit your question
#### 💡 Tips
- Use specific questions for better results
- Combine multiple filters for precise searches
- Check the "Retrieved Documents" tab to see source material
#### ⚠️ Important
**When finished, please close the browser window** to free up computational resources.
---
For more detailed help, see the example questions at the bottom of the page.
""")
st.markdown("### 🔍 Search Filters")
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
st.markdown('', unsafe_allow_html=True)
st.markdown('
📄 Specific Reports (Filename Filter)
', unsafe_allow_html=True)
st.markdown('
⚠️ Selecting specific reports will ignore all other filters
', unsafe_allow_html=True)
selected_filenames = st.multiselect(
"Select specific reports:",
options=filter_options.get('filenames', []),
default=st.session_state.active_filters.get('filenames', []),
key="filenames_filter",
help="Choose specific reports to search. When enabled, all other filters are ignored."
)
st.markdown('
', unsafe_allow_html=True)
# Determine if filename filter is active
filename_mode = len(selected_filenames) > 0
# Sources filter
st.markdown('', unsafe_allow_html=True)
st.markdown('
📊 Sources
', unsafe_allow_html=True)
selected_sources = st.multiselect(
"Select sources:",
options=filter_options['sources'],
default=st.session_state.active_filters['sources'],
disabled = filename_mode,
key="sources_filter",
help="Choose which types of reports to search"
)
st.markdown('
', unsafe_allow_html=True)
# Years filter
# st.markdown('', unsafe_allow_html=True)
st.markdown('
📅 Years
', unsafe_allow_html=True)
selected_years = st.multiselect(
"Select years:",
options=filter_options['years'],
default=st.session_state.active_filters['years'],
disabled = filename_mode,
key="years_filter",
help="Choose which years to search"
)
st.markdown('
', unsafe_allow_html=True)
# Districts filter
# st.markdown('', unsafe_allow_html=True)
st.markdown('
🏘️ Districts
', unsafe_allow_html=True)
selected_districts = st.multiselect(
"Select districts:",
options=filter_options['districts'],
default=st.session_state.active_filters['districts'],
disabled = filename_mode,
key="districts_filter",
help="Choose which districts to search"
)
st.markdown('
', unsafe_allow_html=True)
# Update active filters
st.session_state.active_filters = {
'sources': selected_sources if not filename_mode else [],
'years': selected_years if not filename_mode else [],
'districts': selected_districts if not filename_mode else [],
'filenames': selected_filenames
}
# Clear filters button
if st.button("🗑️ Clear All Filters", key="clear_filters_button"):
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
st.rerun()
# Main content area with tabs
tab1, tab2 = st.tabs(["💬 Chat", "📄 Retrieved Documents"])
with tab1:
# Chat container
chat_container = st.container()
with chat_container:
# Display conversation history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
st.markdown(f'{message.content}
', unsafe_allow_html=True)
elif isinstance(message, AIMessage):
st.markdown(f'{message.content}
', unsafe_allow_html=True)
# Input area
st.markdown("
", unsafe_allow_html=True)
# Create two columns for input and button
col1, col2 = st.columns([4, 1])
with col1:
# Use a counter to force input clearing
if 'input_counter' not in st.session_state:
st.session_state.input_counter = 0
# Handle pending question from example questions section
if 'pending_question' in st.session_state and st.session_state.pending_question:
default_value = st.session_state.pending_question
# Increment counter to force new input widget
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
del st.session_state.pending_question
key_suffix = st.session_state.input_counter
else:
default_value = ""
key_suffix = st.session_state.input_counter
user_input = st.text_input(
"Type your message here...",
placeholder="Ask about budget allocations, expenditures, or audit findings...",
key=f"user_input_{key_suffix}",
label_visibility="collapsed",
value=default_value if default_value else None
)
with col2:
send_button = st.button("Send", key="send_button", use_container_width=True)
# Clear chat button
if st.button("🗑️ Clear Chat", key="clear_chat_button"):
st.session_state.reset_conversation = True
# Clear all conversation files
conversations_path = CONVERSATIONS_DIR
if conversations_path.exists():
for file in conversations_path.iterdir():
if file.suffix == '.json':
file.unlink()
st.rerun()
# Handle user input
if send_button and user_input:
# Construct filter context string
filter_context_str = ""
if selected_filenames:
filter_context_str += "FILTER CONTEXT:\n"
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
filter_context_str += "USER QUERY:\n"
elif selected_sources or selected_years or selected_districts:
filter_context_str += "FILTER CONTEXT:\n"
if selected_sources:
filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
if selected_years:
filter_context_str += f"Years: {', '.join(selected_years)}\n"
if selected_districts:
filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
filter_context_str += "USER QUERY:\n"
full_query = filter_context_str + user_input
# Add user message to history
st.session_state.messages.append(HumanMessage(content=user_input))
# Get chatbot response
with st.spinner("🤔 Thinking..."):
try:
# Pass the full query with filter context
chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id)
# Handle both old format (string) and new format (dict)
if isinstance(chat_result, dict):
response = chat_result['response']
rag_result = chat_result.get('rag_result')
st.session_state.last_rag_result = rag_result
# Track RAG retrieval for feedback
if rag_result:
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
# Get the actual RAG query
actual_rag_query = chat_result.get('actual_rag_query', '')
if actual_rag_query:
# Format it like the log message
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
formatted_query = f"{timestamp} - INFO - 🔍 ACTUAL RAG QUERY: '{actual_rag_query}'"
else:
formatted_query = "No RAG query available"
retrieval_entry = {
"conversation_up_to": serialize_messages(st.session_state.messages),
"rag_query_expansion": formatted_query,
"docs_retrieved": serialize_documents(sources)
}
st.session_state.rag_retrieval_history.append(retrieval_entry)
else:
response = chat_result
st.session_state.last_rag_result = None
# Add bot response to history
st.session_state.messages.append(AIMessage(content=response))
except Exception as e:
error_msg = f"Sorry, I encountered an error: {str(e)}"
st.session_state.messages.append(AIMessage(content=error_msg))
# Clear input and rerun
st.session_state.input_counter += 1 # This will clear the input
st.rerun()
with tab2:
# Document retrieval panel
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
rag_result = st.session_state.last_rag_result
# Handle both PipelineResult object and dictionary formats
sources = None
if hasattr(rag_result, 'sources'):
# PipelineResult object format
sources = rag_result.sources
elif isinstance(rag_result, dict) and 'sources' in rag_result:
# Dictionary format from multi-agent system
sources = rag_result['sources']
if sources and len(sources) > 0:
# Count unique filenames
unique_filenames = set()
for doc in sources:
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
unique_filenames.add(filename)
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 20):**")
if len(unique_filenames) < len(sources):
st.info(f"💡 **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
# Extract and display statistics
stats = extract_chunk_statistics(sources)
# Show charts for 10+ results, tables for fewer
if len(sources) >= 10:
display_chunk_statistics_charts(stats, "Retrieval Statistics")
# Also show tables below charts for detailed view
st.markdown("---")
display_chunk_statistics_table(stats, "Retrieval Distribution")
else:
display_chunk_statistics_table(stats, "Retrieval Distribution")
st.markdown("---")
st.markdown("### 📄 Document Details")
for i, doc in enumerate(sources): # Show all documents
# Get relevance score and ID if available
metadata = getattr(doc, 'metadata', {})
score = metadata.get('reranked_score', metadata.get('original_score', None))
chunk_id = metadata.get('_id', 'Unknown')
score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
with st.expander(f"📄 Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
# Display document metadata with emojis
metadata = getattr(doc, 'metadata', {})
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
with col1:
st.write(f"📄 **File:** {metadata.get('filename', 'Unknown')}")
with col2:
st.write(f"🏛️ **Source:** {metadata.get('source', 'Unknown')}")
with col3:
st.write(f"📅 **Year:** {metadata.get('year', 'Unknown')}")
with col4:
# Display page number and chunk ID
page = metadata.get('page_label', metadata.get('page', 'Unknown'))
chunk_id = metadata.get('_id', 'Unknown')
st.write(f"📖 **Page:** {page}")
st.write(f"🆔 **ID:** {chunk_id}")
# Display full content (no truncation)
content = getattr(doc, 'page_content', 'No content available')
st.write(f"**Full Content:**")
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
else:
st.info("No documents were retrieved for the last query.")
else:
st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.")
# Feedback Dashboard Section
st.markdown("---")
st.markdown("### 💬 Feedback Dashboard")
# Check if there's any conversation to provide feedback on
has_conversation = len(st.session_state.messages) > 0
has_retrievals = len(st.session_state.rag_retrieval_history) > 0
if not has_conversation:
st.info("💡 Start a conversation to provide feedback!")
st.markdown("The feedback dashboard will be enabled once you begin chatting.")
else:
st.markdown("Help us improve by providing feedback on this conversation.")
# Initialize feedback state if not exists
if 'feedback_submitted' not in st.session_state:
st.session_state.feedback_submitted = False
# Feedback form - only show if feedback not already submitted
if not st.session_state.feedback_submitted:
with st.form("feedback_form", clear_on_submit=False):
col1, col2 = st.columns([1, 1])
with col1:
feedback_score = st.slider(
"Rate this conversation (1-5)",
min_value=1,
max_value=5,
help="How satisfied are you with the conversation?"
)
with col2:
is_feedback_about_last_retrieval = st.checkbox(
"Feedback about last retrieval only",
value=True,
help="If checked, feedback applies to the most recent document retrieval"
)
open_ended_feedback = st.text_area(
"Your feedback (optional)",
placeholder="Tell us what went well or what could be improved...",
height=100
)
# Disable submit if no score selected
submit_disabled = feedback_score is None
submitted = st.form_submit_button(
"📤 Submit Feedback",
use_container_width=True,
disabled=submit_disabled
)
if submitted:
# Log the feedback data being submitted
print("=" * 80)
print("🔄 FEEDBACK SUBMISSION: Starting...")
print("=" * 80)
st.write("🔍 **Debug: Feedback Data Being Submitted:**")
# Extract transcript from messages
transcript = extract_transcript(st.session_state.messages)
# Build retrievals structure
retrievals = build_retrievals_structure(
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
st.session_state.messages
)
# Build feedback_score_related_retrieval_docs
feedback_score_related_retrieval_docs = build_feedback_score_related_retrieval_docs(
is_feedback_about_last_retrieval,
st.session_state.messages,
st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
)
# Preserve old retrieved_data format for backward compatibility
retrieved_data_old_format = st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else []
# Create feedback data dictionary
feedback_dict = {
"open_ended_feedback": open_ended_feedback,
"score": feedback_score,
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
"conversation_id": st.session_state.conversation_id,
"timestamp": time.time(),
"message_count": len(st.session_state.messages),
"has_retrievals": has_retrievals,
"retrieval_count": len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0,
"transcript": transcript,
"retrievals": retrievals,
"feedback_score_related_retrieval_docs": feedback_score_related_retrieval_docs,
"retrieved_data": retrieved_data_old_format # Preserved old column
}
print(f"📝 FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
# Create UserFeedback dataclass instance
feedback_obj = None # Initialize outside try block
try:
feedback_obj = create_feedback_from_dict(feedback_dict)
print(f"✅ FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
st.write(f"✅ **Feedback Object Created**")
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
st.write(f"- Score: {feedback_obj.score}/5")
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
# Convert back to dict for JSON serialization
feedback_data = feedback_obj.to_dict()
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
st.error(f"Failed to create feedback object: {e}")
feedback_data = feedback_dict
# Display the data being submitted
st.json(feedback_data)
# Save feedback to file - use PROJECT_DIR to ensure writability
feedback_dir = FEEDBACK_DIR
try:
# Ensure directory exists with write permissions (777 for compatibility)
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
except (PermissionError, OSError) as e:
logger.warning(f"Could not create feedback directory at {feedback_dir}: {e}")
# Fallback to relative path
feedback_dir = Path("feedback")
feedback_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
try:
# Ensure parent directory exists before writing
feedback_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
# Save to local file first
print(f"💾 FEEDBACK SAVE: Saving to local file: {feedback_file}")
with open(feedback_file, 'w') as f:
json.dump(feedback_data, f, indent=2, default=str)
print(f"✅ FEEDBACK SAVE: Local file saved successfully")
# Save to Snowflake if enabled and credentials available
logger.info("🔄 FEEDBACK SAVE: Starting Snowflake save process...")
logger.info(f"📊 FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
snowflake_success = False
try:
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
logger.info(f"🔍 SNOWFLAKE CHECK: enabled={snowflake_enabled}")
if snowflake_enabled:
if feedback_obj:
try:
logger.info("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
print("📤 SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
snowflake_success = save_to_snowflake(feedback_obj)
if snowflake_success:
logger.info("✅ SNOWFLAKE UI: Successfully saved to Snowflake")
print("✅ SNOWFLAKE UI: Successfully saved to Snowflake")
else:
logger.warning("⚠️ SNOWFLAKE UI: Save failed")
print("⚠️ SNOWFLAKE UI: Save failed")
except Exception as e:
logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
print(f"❌ SNOWFLAKE UI ERROR: {e}")
traceback.print_exc()
snowflake_success = False
else:
logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
snowflake_success = False
else:
logger.info("💡 SNOWFLAKE UI: Integration disabled")
print("💡 SNOWFLAKE UI: Integration disabled")
# If Snowflake is disabled, consider it successful (local save only)
snowflake_success = True
except Exception as e:
logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
snowflake_success = False
# Only show success if Snowflake save succeeded (or if Snowflake is disabled)
if snowflake_success:
st.success("✅ Thank you for your feedback! It has been saved successfully.")
st.balloons()
else:
st.warning("⚠️ Feedback saved locally, but Snowflake save failed. Please check logs.")
# Mark feedback as submitted to prevent resubmission
st.session_state.feedback_submitted = True
print("=" * 80)
print(f"✅ FEEDBACK SUBMISSION: Completed successfully")
print("=" * 80)
# Log file location
st.info(f"📁 Feedback saved to: {feedback_file}")
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
traceback.print_exc()
st.error(f"❌ Error saving feedback: {e}")
st.write(f"Debug error: {str(e)}")
else:
# Feedback already submitted - show success message and reset option
st.success("✅ Feedback already submitted for this conversation!")
col1, col2 = st.columns([1, 1])
with col1:
if st.button("🔄 Submit New Feedback", key="new_feedback_button", use_container_width=True):
try:
st.session_state.feedback_submitted = False
st.rerun()
except Exception as e:
# Handle any Streamlit API exceptions gracefully
logger.error(f"Error resetting feedback state: {e}")
st.error(f"Error resetting feedback. Please refresh the page.")
with col2:
if st.button("📋 View Conversation", key="view_conversation_button", use_container_width=True):
# Scroll to conversation - this is handled by the auto-scroll at bottom
pass
# Display retrieval history stats
if st.session_state.rag_retrieval_history:
st.markdown("---")
st.markdown("#### 📊 Retrieval History")
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
st.markdown(f"**Retrieval #{idx}**")
# Display the actual RAG query
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
st.code(rag_query_expansion, language="text")
# Display summary stats
st.json({
"conversation_length": len(entry.get("conversation_up_to", [])),
"documents_retrieved": len(entry.get("docs_retrieved", []))
})
st.markdown("---")
# Example Questions Section
st.markdown("---")
st.markdown("### 💡 Example Questions")
st.markdown("Click on any question below to use it, or modify the editable examples:")
# Initialize example question state
if 'custom_question_1' not in st.session_state:
st.session_state.custom_question_1 = "How were administrative costs managed in the PDM implementation, and what issues arose with budget execution regarding staff salaries?"
if 'custom_question_2' not in st.session_state:
st.session_state.custom_question_2 = "What did the National Coordinator say about the release of funds for PDM administrative costs in the letter dated 29th September 2022 and how did the funding received affect the activities of the PDCs and PDM SACCOs in the FY 2022/23?"
# Question 1: Filename insights (fixed, clickable)
st.markdown("#### 📄 Question 1: List insights from a specific file")
col1, col2 = st.columns([3, 1])
with col1:
example_q1 = "List couple of insights from the filename."
st.markdown(f"**Example:** `{example_q1}`")
st.info("💡 **Filter to apply:** Select a Filename from the sidebar panel before asking this question.")
with col2:
if st.button("📋 Use This Question", key="use_example_1", use_container_width=True):
st.session_state.pending_question = example_q1
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
st.rerun()
st.markdown("---")
# Questions 2 & 3: Editable examples
st.markdown("#### ✏️ Customizable Questions (Edit and use)")
# Question 2
# st.markdown("**Question 2:**")
custom_q1 = st.text_area(
"Edit question 2:",
value=st.session_state.custom_question_1,
height=80,
key="edit_question_2",
help="Modify this question to fit your needs, then click 'Use This Question'"
)
col1, col2 = st.columns([1, 4])
with col1:
if st.button("📋 Use Question 2", key="use_custom_1", use_container_width=True):
if custom_q1.strip():
st.session_state.pending_question = custom_q1.strip()
st.session_state.custom_question_1 = custom_q1.strip()
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
st.rerun()
else:
st.warning("Please enter a question first!")
with col2:
st.caption("💡 Tip: Add specific details like dates, names, or amounts to get more precise answers")
st.info("💡 **Filter to apply:** Select District(s) and Year(s) sidebar panel before asking this question.")
st.markdown("---")
# Question 3
# st.markdown("**Question 3:**")
custom_q2 = st.text_area(
"Edit question 3:",
value=st.session_state.custom_question_2,
height=80,
key="edit_question_3",
help="Modify this question to fit your needs, then click 'Use This Question'"
)
col1, col2 = st.columns([1, 4])
with col1:
if st.button("📋 Use Question 3", key="use_custom_2", use_container_width=True):
if custom_q2.strip():
st.session_state.pending_question = custom_q2.strip()
st.session_state.custom_question_2 = custom_q2.strip()
st.session_state.input_counter = (st.session_state.get('input_counter', 0) + 1) % 1000
st.rerun()
else:
st.warning("Please enter a question first!")
with col2:
st.caption("💡 Tip: Use specific terms from the documents (e.g., 'PDM', 'SACCOs', 'FY 2022/23')")
# Store selected question for next render (handled in input section above)
# This ensures the question populates the input field correctly
# Auto-scroll to bottom
st.markdown("""
""", unsafe_allow_html=True)
if __name__ == "__main__":
# Check if running in Streamlit context
try:
from streamlit.runtime.scriptrunner import get_script_run_ctx
if get_script_run_ctx() is None:
# Not in Streamlit runtime - show helpful message
print("=" * 80)
print("⚠️ WARNING: This is a Streamlit app!")
print("=" * 80)
print("\nPlease run this app using:")
print(" streamlit run app.py")
print("\nNot: python app.py")
print("\nThe app will not function correctly when run with 'python app.py'")
print("=" * 80)
import sys
sys.exit(1)
except ImportError:
# Streamlit not installed or not in Streamlit context
print("=" * 80)
print("⚠️ WARNING: This is a Streamlit app!")
print("=" * 80)
print("\nPlease run this app using:")
print(" streamlit run app.py")
print("\nNot: python app.py")
print("=" * 80)
import sys
sys.exit(1)
main()