audit_assistant / app.py
Ara Yeroyan
create UI
ce77124
raw
history blame
31.9 kB
"""
Intelligent Audit Report Chatbot UI
"""
import os
import sys
import time
import json
import uuid
import logging
from pathlib import Path
import argparse
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from multi_agent_chatbot import get_multi_agent_chatbot
from smart_chatbot import get_chatbot as get_smart_chatbot
from src.reporting.feedback_schema import create_feedback_from_dict
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Page config
st.set_page_config(
layout="wide",
page_icon="πŸ€–",
initial_sidebar_state="expanded",
page_title="Intelligent Audit Report Chatbot"
)
# Custom CSS
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
font-weight: bold;
color: #1f77b4;
text-align: center;
margin-bottom: 1rem;
}
.subtitle {
font-size: 1.2rem;
color: #666;
text-align: center;
margin-bottom: 2rem;
}
.session-info {
background-color: #f0f2f6;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
font-size: 0.9rem;
}
.user-message {
background-color: #007bff;
color: white;
padding: 12px 16px;
border-radius: 18px 18px 4px 18px;
margin: 8px 0;
margin-left: 20%;
word-wrap: break-word;
}
.bot-message {
background-color: #f1f3f4;
color: #333;
padding: 12px 16px;
border-radius: 18px 18px 18px 4px;
margin: 8px 0;
margin-right: 20%;
word-wrap: break-word;
border: 1px solid #e0e0e0;
}
.filter-section {
margin-bottom: 20px;
padding: 15px;
background-color: #f8f9fa;
border-radius: 8px;
border: 1px solid #e9ecef;
}
.filter-title {
font-weight: bold;
margin-bottom: 10px;
color: #495057;
}
.feedback-section {
background-color: #f8f9fa;
padding: 20px;
border-radius: 10px;
margin-top: 30px;
border: 2px solid #dee2e6;
}
.retrieval-history {
background-color: #ffffff;
padding: 15px;
border-radius: 5px;
margin: 10px 0;
border-left: 4px solid #007bff;
}
</style>
""", unsafe_allow_html=True)
def get_system_type():
"""Get the current system type"""
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return "Smart Chatbot System"
else:
return "Multi-Agent System"
def get_chatbot():
"""Initialize and return the chatbot based on system type"""
# Check environment variable for system type
system = os.environ.get('CHATBOT_SYSTEM', 'multi-agent')
if system == 'smart':
return get_smart_chatbot()
else:
return get_multi_agent_chatbot()
def serialize_messages(messages):
"""Serialize LangChain messages to dictionaries"""
serialized = []
for msg in messages:
if hasattr(msg, 'content'):
serialized.append({
"type": type(msg).__name__,
"content": str(msg.content)
})
return serialized
def serialize_documents(sources):
"""Serialize document objects to dictionaries with deduplication"""
serialized = []
seen_content = set()
for doc in sources:
content = getattr(doc, 'page_content', getattr(doc, 'content', ''))
# Skip if we've seen this exact content before
if content in seen_content:
continue
seen_content.add(content)
doc_dict = {
"content": content,
"metadata": getattr(doc, 'metadata', {}),
"score": getattr(doc, 'metadata', {}).get('reranked_score', getattr(doc, 'metadata', {}).get('original_score', 0.0)),
"id": getattr(doc, 'metadata', {}).get('_id', 'unknown'),
"source": getattr(doc, 'metadata', {}).get('source', 'unknown'),
"year": getattr(doc, 'metadata', {}).get('year', 'unknown'),
"district": getattr(doc, 'metadata', {}).get('district', 'unknown'),
"page": getattr(doc, 'metadata', {}).get('page', 'unknown'),
"chunk_id": getattr(doc, 'metadata', {}).get('chunk_id', 'unknown'),
"page_label": getattr(doc, 'metadata', {}).get('page_label', 'unknown'),
"original_score": getattr(doc, 'metadata', {}).get('original_score', 0.0),
"reranked_score": getattr(doc, 'metadata', {}).get('reranked_score', None)
}
serialized.append(doc_dict)
return serialized
@st.cache_data
def load_filter_options():
try:
with open("filter_options.json", "r") as f:
return json.load(f)
except FileNotFoundError:
st.info([x for x in os.listdir() if x.endswith('.json')])
st.error("filter_options.json not found. Please run the metadata analysis script.")
return {"sources": [], "years": [], "districts": [], 'filenames': []}
def main():
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'conversation_id' not in st.session_state:
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
if 'session_start_time' not in st.session_state:
st.session_state.session_start_time = time.time()
if 'active_filters' not in st.session_state:
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
# Track RAG retrieval history for feedback
if 'rag_retrieval_history' not in st.session_state:
st.session_state.rag_retrieval_history = []
# Initialize chatbot only once per app session (cached)
if 'chatbot' not in st.session_state:
with st.spinner("πŸ”„ Loading AI models and connecting to database..."):
st.session_state.chatbot = get_chatbot()
st.success("βœ… AI system ready!")
# Reset conversation history if needed (but keep chatbot cached)
if 'reset_conversation' in st.session_state and st.session_state.reset_conversation:
st.session_state.messages = []
st.session_state.conversation_id = f"session_{uuid.uuid4().hex[:8]}"
st.session_state.session_start_time = time.time()
st.session_state.rag_retrieval_history = []
st.session_state.feedback_submitted = False
st.session_state.reset_conversation = False
st.rerun()
# Header with system indicator
col1, col2 = st.columns([3, 1])
with col1:
st.markdown('<h1 class="main-header">πŸ€– Intelligent Audit Report Chatbot</h1>', unsafe_allow_html=True)
with col2:
system_type = get_system_type()
if "Multi-Agent" in system_type:
st.success(f"πŸ”§ {system_type}")
else:
st.info(f"πŸ”§ {system_type}")
st.markdown('<p class="subtitle">Ask questions about audit reports. Use the sidebar filters to narrow down your search!</p>', unsafe_allow_html=True)
# Session info
duration = int(time.time() - st.session_state.session_start_time)
duration_str = f"{duration // 60}m {duration % 60}s"
st.markdown(f'''
<div class="session-info">
<strong>Session Info:</strong> Messages: {len(st.session_state.messages)} | Duration: {duration_str} | Status: Active | ID: {st.session_state.conversation_id}
</div>
''', unsafe_allow_html=True)
# Load filter options
filter_options = load_filter_options()
# Sidebar for filters
with st.sidebar:
st.markdown("### πŸ” Search Filters")
st.markdown("Select filters to narrow down your search. Leave empty to search all data.")
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
st.markdown('<div class="filter-title">πŸ“„ Specific Reports (Filename Filter)</div>', unsafe_allow_html=True)
st.markdown('<p style="font-size: 0.85em; color: #666;">⚠️ Selecting specific reports will ignore all other filters</p>', unsafe_allow_html=True)
selected_filenames = st.multiselect(
"Select specific reports:",
options=filter_options.get('filenames', []),
default=st.session_state.active_filters.get('filenames', []),
key="filenames_filter",
help="Choose specific reports to search. When enabled, all other filters are ignored."
)
st.markdown('</div>', unsafe_allow_html=True)
# Determine if filename filter is active
filename_mode = len(selected_filenames) > 0
# Sources filter
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
st.markdown('<div class="filter-title">πŸ“Š Sources</div>', unsafe_allow_html=True)
selected_sources = st.multiselect(
"Select sources:",
options=filter_options['sources'],
default=st.session_state.active_filters['sources'],
disabled = filename_mode,
key="sources_filter",
help="Choose which types of reports to search"
)
st.markdown('</div>', unsafe_allow_html=True)
# Years filter
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
st.markdown('<div class="filter-title">πŸ“… Years</div>', unsafe_allow_html=True)
selected_years = st.multiselect(
"Select years:",
options=filter_options['years'],
default=st.session_state.active_filters['years'],
disabled = filename_mode,
key="years_filter",
help="Choose which years to search"
)
st.markdown('</div>', unsafe_allow_html=True)
# Districts filter
st.markdown('<div class="filter-section">', unsafe_allow_html=True)
st.markdown('<div class="filter-title">🏘️ Districts</div>', unsafe_allow_html=True)
selected_districts = st.multiselect(
"Select districts:",
options=filter_options['districts'],
default=st.session_state.active_filters['districts'],
disabled = filename_mode,
key="districts_filter",
help="Choose which districts to search"
)
st.markdown('</div>', unsafe_allow_html=True)
# Update active filters
st.session_state.active_filters = {
'sources': selected_sources if not filename_mode else [],
'years': selected_years if not filename_mode else [],
'districts': selected_districts if not filename_mode else [],
'filenames': selected_filenames
}
# Clear filters button
if st.button("πŸ—‘οΈ Clear All Filters", key="clear_filters_button"):
st.session_state.active_filters = {'sources': [], 'years': [], 'districts': [], 'filenames': []}
st.rerun()
# Main content area with tabs
tab1, tab2 = st.tabs(["πŸ’¬ Chat", "πŸ“„ Retrieved Documents"])
with tab1:
# Chat container
chat_container = st.container()
with chat_container:
# Display conversation history
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
st.markdown(f'<div class="user-message">{message.content}</div>', unsafe_allow_html=True)
elif isinstance(message, AIMessage):
st.markdown(f'<div class="bot-message">{message.content}</div>', unsafe_allow_html=True)
# Input area
st.markdown("<br>", unsafe_allow_html=True)
# Create two columns for input and button
col1, col2 = st.columns([4, 1])
with col1:
# Use a counter to force input clearing
if 'input_counter' not in st.session_state:
st.session_state.input_counter = 0
user_input = st.text_input(
"Type your message here...",
placeholder="Ask about budget allocations, expenditures, or audit findings...",
key=f"user_input_{st.session_state.input_counter}",
label_visibility="collapsed"
)
with col2:
send_button = st.button("Send", key="send_button", use_container_width=True)
# Clear chat button
if st.button("πŸ—‘οΈ Clear Chat", key="clear_chat_button"):
st.session_state.reset_conversation = True
# Clear all conversation files
import os
conversations_dir = "conversations"
if os.path.exists(conversations_dir):
for file in os.listdir(conversations_dir):
if file.endswith('.json'):
os.remove(os.path.join(conversations_dir, file))
st.rerun()
# Handle user input
if send_button and user_input:
# Construct filter context string
filter_context_str = ""
if selected_filenames:
filter_context_str += "FILTER CONTEXT:\n"
filter_context_str += f"Filenames: {', '.join(selected_filenames)}\n"
filter_context_str += "USER QUERY:\n"
elif selected_sources or selected_years or selected_districts:
filter_context_str += "FILTER CONTEXT:\n"
if selected_sources:
filter_context_str += f"Sources: {', '.join(selected_sources)}\n"
if selected_years:
filter_context_str += f"Years: {', '.join(selected_years)}\n"
if selected_districts:
filter_context_str += f"Districts: {', '.join(selected_districts)}\n"
filter_context_str += "USER QUERY:\n"
full_query = filter_context_str + user_input
# Add user message to history
st.session_state.messages.append(HumanMessage(content=user_input))
# Get chatbot response
with st.spinner("πŸ€” Thinking..."):
try:
# Pass the full query with filter context
chat_result = st.session_state.chatbot.chat(full_query, st.session_state.conversation_id)
# Handle both old format (string) and new format (dict)
if isinstance(chat_result, dict):
response = chat_result['response']
rag_result = chat_result.get('rag_result')
st.session_state.last_rag_result = rag_result
# Track RAG retrieval for feedback
if rag_result:
sources = rag_result.get('sources', []) if isinstance(rag_result, dict) else (rag_result.sources if hasattr(rag_result, 'sources') else [])
# Get the actual RAG query
actual_rag_query = chat_result.get('actual_rag_query', '')
if actual_rag_query:
# Format it like the log message
timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
formatted_query = f"{timestamp} - INFO - πŸ” ACTUAL RAG QUERY: '{actual_rag_query}'"
else:
formatted_query = "No RAG query available"
retrieval_entry = {
"conversation_up_to": serialize_messages(st.session_state.messages),
"rag_query_expansion": formatted_query,
"docs_retrieved": serialize_documents(sources)
}
st.session_state.rag_retrieval_history.append(retrieval_entry)
else:
response = chat_result
st.session_state.last_rag_result = None
# Add bot response to history
st.session_state.messages.append(AIMessage(content=response))
except Exception as e:
error_msg = f"Sorry, I encountered an error: {str(e)}"
st.session_state.messages.append(AIMessage(content=error_msg))
# Clear input and rerun
st.session_state.input_counter += 1 # This will clear the input
st.rerun()
with tab2:
# Document retrieval panel
if hasattr(st.session_state, 'last_rag_result') and st.session_state.last_rag_result:
rag_result = st.session_state.last_rag_result
# Handle both PipelineResult object and dictionary formats
sources = None
if hasattr(rag_result, 'sources'):
# PipelineResult object format
sources = rag_result.sources
elif isinstance(rag_result, dict) and 'sources' in rag_result:
# Dictionary format from multi-agent system
sources = rag_result['sources']
if sources and len(sources) > 0:
# Count unique filenames
unique_filenames = set()
for doc in sources:
filename = getattr(doc, 'metadata', {}).get('filename', 'Unknown')
unique_filenames.add(filename)
st.markdown(f"**Found {len(sources)} document chunks from {len(unique_filenames)} unique documents (showing top 10):**")
if len(unique_filenames) < len(sources):
st.info(f"πŸ’‘ **Note**: Each document is split into multiple chunks. You're seeing {len(sources)} chunks from {len(unique_filenames)} documents.")
for i, doc in enumerate(sources[:10]): # Show top 10
# Get relevance score and ID if available
metadata = getattr(doc, 'metadata', {})
score = metadata.get('reranked_score', metadata.get('original_score', None))
chunk_id = metadata.get('_id', 'Unknown')
score_text = f" (Score: {score:.3f}, ID: {chunk_id[:8]}...)" if score is not None else f" (ID: {chunk_id[:8]}...)"
with st.expander(f"πŸ“„ Document {i+1}: {getattr(doc, 'metadata', {}).get('filename', 'Unknown')[:50]}...{score_text}"):
# Display document metadata with emojis
metadata = getattr(doc, 'metadata', {})
col1, col2, col3, col4 = st.columns([2, 1.5, 1, 1])
with col1:
st.write(f"πŸ“„ **File:** {metadata.get('filename', 'Unknown')}")
with col2:
st.write(f"πŸ›οΈ **Source:** {metadata.get('source', 'Unknown')}")
with col3:
st.write(f"πŸ“… **Year:** {metadata.get('year', 'Unknown')}")
with col4:
# Display page number and chunk ID
page = metadata.get('page_label', metadata.get('page', 'Unknown'))
chunk_id = metadata.get('_id', 'Unknown')
st.write(f"πŸ“– **Page:** {page}")
st.write(f"πŸ†” **ID:** {chunk_id}")
# Display full content (no truncation)
content = getattr(doc, 'page_content', 'No content available')
st.write(f"**Full Content:**")
st.text_area("Full Content", value=content, height=300, disabled=True, label_visibility="collapsed", key=f"preview_{i}")
else:
st.info("No documents were retrieved for the last query.")
else:
st.info("No documents have been retrieved yet. Start a conversation to see retrieved documents here.")
# Feedback Dashboard Section
st.markdown("---")
st.markdown("### πŸ’¬ Feedback Dashboard")
# Check if there's any conversation to provide feedback on
has_conversation = len(st.session_state.messages) > 0
has_retrievals = len(st.session_state.rag_retrieval_history) > 0
if not has_conversation:
st.info("πŸ’‘ Start a conversation to provide feedback!")
st.markdown("The feedback dashboard will be enabled once you begin chatting.")
else:
st.markdown("Help us improve by providing feedback on this conversation.")
# Initialize feedback state if not exists
if 'feedback_submitted' not in st.session_state:
st.session_state.feedback_submitted = False
# Feedback form
with st.form("feedback_form", clear_on_submit=False):
col1, col2 = st.columns([1, 1])
with col1:
feedback_score = st.slider(
"Rate this conversation (1-5)",
min_value=1,
max_value=5,
help="How satisfied are you with the conversation?"
)
with col2:
is_feedback_about_last_retrieval = st.checkbox(
"Feedback about last retrieval only",
value=True,
help="If checked, feedback applies to the most recent document retrieval"
)
open_ended_feedback = st.text_area(
"Your feedback (optional)",
placeholder="Tell us what went well or what could be improved...",
height=100
)
# Disable submit if no score selected
submit_disabled = feedback_score is None
submitted = st.form_submit_button(
"πŸ“€ Submit Feedback",
use_container_width=True,
disabled=submit_disabled
)
if submitted and not st.session_state.feedback_submitted:
# Log the feedback data being submitted
print("=" * 80)
print("πŸ”„ FEEDBACK SUBMISSION: Starting...")
print("=" * 80)
st.write("πŸ” **Debug: Feedback Data Being Submitted:**")
# Create feedback data dictionary
feedback_dict = {
"open_ended_feedback": open_ended_feedback,
"score": feedback_score,
"is_feedback_about_last_retrieval": is_feedback_about_last_retrieval,
"retrieved_data": st.session_state.rag_retrieval_history.copy() if st.session_state.rag_retrieval_history else [],
"conversation_id": st.session_state.conversation_id,
"timestamp": time.time(),
"message_count": len(st.session_state.messages),
"has_retrievals": has_retrievals,
"retrieval_count": len(st.session_state.rag_retrieval_history)
}
print(f"πŸ“ FEEDBACK SUBMISSION: Score={feedback_score}, Retrievals={len(st.session_state.rag_retrieval_history) if st.session_state.rag_retrieval_history else 0}")
# Create UserFeedback dataclass instance
feedback_obj = None # Initialize outside try block
try:
feedback_obj = create_feedback_from_dict(feedback_dict)
print(f"βœ… FEEDBACK SUBMISSION: Feedback object created - ID={feedback_obj.feedback_id}")
st.write(f"βœ… **Feedback Object Created**")
st.write(f"- Feedback ID: {feedback_obj.feedback_id}")
st.write(f"- Score: {feedback_obj.score}/5")
st.write(f"- Has Retrievals: {feedback_obj.has_retrievals}")
# Convert back to dict for JSON serialization
feedback_data = feedback_obj.to_dict()
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Failed to create feedback object: {e}")
st.error(f"Failed to create feedback object: {e}")
feedback_data = feedback_dict
# Display the data being submitted
st.json(feedback_data)
# Save feedback to file
feedback_dir = Path("feedback")
feedback_dir.mkdir(exist_ok=True)
feedback_file = feedback_dir / f"feedback_{st.session_state.conversation_id}_{int(time.time())}.json"
try:
# Save to local file
print(f"πŸ’Ύ FEEDBACK SAVE: Saving to local file: {feedback_file}")
with open(feedback_file, 'w') as f:
json.dump(feedback_data, f, indent=2, default=str)
print(f"βœ… FEEDBACK SAVE: Local file saved successfully")
st.success("βœ… Thank you for your feedback! It has been saved locally.")
st.balloons()
# Save to Snowflake if enabled and credentials available
logger.info("πŸ”„ FEEDBACK SAVE: Starting Snowflake save process...")
logger.info(f"πŸ“Š FEEDBACK SAVE: feedback_obj={'exists' if feedback_obj else 'None'}")
try:
import os
snowflake_enabled = os.getenv("SNOWFLAKE_ENABLED", "false").lower() == "true"
logger.info(f"πŸ” SNOWFLAKE CHECK: enabled={snowflake_enabled}")
if snowflake_enabled:
if feedback_obj:
try:
from auditqa.reporting.snowflake_connector import save_to_snowflake
logger.info("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...")
print("πŸ“€ SNOWFLAKE UI: Attempting to save feedback to Snowflake...") # Also print to terminal
if save_to_snowflake(feedback_obj):
logger.info("βœ… SNOWFLAKE UI: Successfully saved to Snowflake")
print("βœ… SNOWFLAKE UI: Successfully saved to Snowflake") # Also print to terminal
st.success("βœ… Feedback also saved to Snowflake!")
else:
logger.warning("⚠️ SNOWFLAKE UI: Save failed")
print("⚠️ SNOWFLAKE UI: Save failed") # Also print to terminal
st.warning("⚠️ Snowflake save failed, but local save succeeded")
except Exception as e:
logger.error(f"❌ SNOWFLAKE UI ERROR: {e}")
print(f"❌ SNOWFLAKE UI ERROR: {e}") # Also print to terminal
import traceback
traceback.print_exc() # Print full traceback to terminal
st.warning(f"⚠️ Could not save to Snowflake: {e}")
else:
logger.warning("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)")
print("⚠️ SNOWFLAKE UI: Skipping (feedback object not created)") # Also print to terminal
st.warning("⚠️ Skipping Snowflake save (feedback object not created)")
else:
logger.info("πŸ’‘ SNOWFLAKE UI: Integration disabled")
print("πŸ’‘ SNOWFLAKE UI: Integration disabled") # Also print to terminal
st.info("πŸ’‘ Snowflake integration disabled (set SNOWFLAKE_ENABLED=true to enable)")
except NameError as e:
import traceback
traceback.print_exc()
logger.error(f"❌ NameError in Snowflake save: {e}")
print(f"❌ NameError in Snowflake save: {e}") # Also print to terminal
st.warning(f"⚠️ Snowflake save error: {e}")
except Exception as e:
logger.error(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}")
print(f"❌ Exception in Snowflake save: {type(e).__name__}: {e}") # Also print to terminal
st.warning(f"⚠️ Snowflake save error: {e}")
# Mark feedback as submitted to prevent resubmission
st.session_state.feedback_submitted = True
print("=" * 80)
print(f"βœ… FEEDBACK SUBMISSION: Completed successfully")
print("=" * 80)
# Log file location
st.info(f"πŸ“ Feedback saved to: {feedback_file}")
except Exception as e:
print(f"❌ FEEDBACK SUBMISSION: Error saving feedback: {e}")
print(f"❌ FEEDBACK SUBMISSION: Error type: {type(e).__name__}")
import traceback
traceback.print_exc()
st.error(f"❌ Error saving feedback: {e}")
st.write(f"Debug error: {str(e)}")
elif st.session_state.feedback_submitted:
st.success("βœ… Feedback already submitted for this conversation!")
if st.button("πŸ”„ Submit New Feedback", key="new_feedback_button"):
st.session_state.feedback_submitted = False
st.rerun()
# Display retrieval history stats
if st.session_state.rag_retrieval_history:
st.markdown("---")
st.markdown("#### πŸ“Š Retrieval History")
with st.expander(f"View {len(st.session_state.rag_retrieval_history)} retrieval entries", expanded=False):
for idx, entry in enumerate(st.session_state.rag_retrieval_history, 1):
st.markdown(f"**Retrieval #{idx}**")
# Display the actual RAG query
rag_query_expansion = entry.get("rag_query_expansion", "No query available")
st.code(rag_query_expansion, language="text")
# Display summary stats
st.json({
"conversation_length": len(entry.get("conversation_up_to", [])),
"documents_retrieved": len(entry.get("docs_retrieved", []))
})
st.markdown("---")
# Auto-scroll to bottom
st.markdown("""
<script>
window.scrollTo(0, document.body.scrollHeight);
</script>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()