RAG_Chatbot / src /streamlit_app.py
shahzad4894's picture
Update src/streamlit_app.py
5172692 verified
import streamlit as st
import os
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.chains import RetrievalQA
from supabase import create_client
from langchain.prompts import PromptTemplate
from langchain.agents import Tool, create_react_agent
from langchain.tools.retriever import create_retriever_tool
from langchain.memory import ConversationSummaryBufferMemory
from langchain.agents import AgentExecutor
from langchain.schema import HumanMessage, AIMessage
from langchain.cache import InMemoryCache
from langchain.globals import set_llm_cache
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
import uuid
from datetime import datetime
import json
import time
from collections import defaultdict
from tenacity import retry, stop_after_attempt, wait_exponential
# Page configuration
st.set_page_config(
page_title="AI Document Assistant",
page_icon="πŸ€–",
layout="wide",
initial_sidebar_state="expanded"
)
# Enable LLM caching for faster responses
set_llm_cache(InMemoryCache())
# Custom CSS for professional design
st.markdown("""
<style>
/* Import clean font */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap');
/* Global styles */
* {
font-family: 'Inter', sans-serif;
}
/* Remove default padding/margins */
.main > div {
padding-top: 2rem;
}
/* Header styling */
.header-container {
background: #ffffff;
border-bottom: 1px solid #e5e7eb;
padding: 1.5rem 0;
margin-bottom: 0;
position: sticky;
top: 0;
z-index: 100;
}
.header-title {
font-size: 1.5rem;
font-weight: 600;
color: #111827;
margin: 0;
}
.header-subtitle {
color: #6b7280;
font-size: 0.875rem;
margin: 0.25rem 0 0 0;
}
/* Sidebar styling */
.css-1d391kg {
background-color: #f9fafb;
}
.sidebar-title {
font-weight: 600;
color: #374151;
margin-bottom: 1rem;
}
/* Session buttons */
.session-btn {
background: white;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 12px;
margin: 6px 0;
width: 100%;
text-align: left;
cursor: pointer;
transition: all 0.2s;
color: #374151;
}
.session-btn:hover {
border-color: #3b82f6;
background: #f8fafc;
}
.session-btn.active {
background: #eff6ff;
border-color: #3b82f6;
color: #1d4ed8;
}
/* Chat container */
.chat-container {
background: #ffffff;
border: 1px solid #e5e7eb;
border-radius: 12px;
height: 500px;
overflow-y: auto;
padding: 1rem;
margin-bottom: 1rem;
}
/* Message styling */
.message {
margin-bottom: 1rem;
display: flex;
}
.message.user {
justify-content: flex-end;
}
.message-content {
max-width: 70%;
padding: 12px 16px;
border-radius: 12px;
line-height: 1.5;
}
.message.user .message-content {
background: #3b82f6;
color: white;
border-bottom-right-radius: 4px;
}
.message.bot .message-content {
background: #f3f4f6;
color: #374151;
border: 1px solid #e5e7eb;
border-bottom-left-radius: 4px;
}
.message-label {
font-size: 0.75rem;
font-weight: 500;
margin-bottom: 4px;
opacity: 0.7;
}
.message-tools {
font-size: 0.75rem;
opacity: 0.6;
margin-top: 4px;
}
/* Input area */
.input-container {
background: white;
border: 1px solid #e5e7eb;
border-radius: 12px;
padding: 1rem;
}
/* Buttons */
.stButton > button {
background: #3b82f6;
color: white;
border: none;
border-radius: 8px;
font-weight: 500;
padding: 0.5rem 1rem;
transition: background 0.2s;
}
.stButton > button:hover {
background: #2563eb;
}
/* Status indicators */
.status {
font-size: 0.875rem;
padding: 4px 8px;
border-radius: 6px;
font-weight: 500;
}
.status.connected {
background: #dcfce7;
color: #166534;
}
.status.error {
background: #fee2e2;
color: #dc2626;
}
/* Thinking indicator */
.thinking {
background: #f3f4f6;
padding: 8px 12px;
border-radius: 8px;
color: #6b7280;
font-size: 0.875rem;
margin-bottom: 1rem;
display: inline-block;
}
/* Hide streamlit elements */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Custom scrollbar */
.chat-container::-webkit-scrollbar {
width: 6px;
}
.chat-container::-webkit-scrollbar-track {
background: #f1f5f9;
border-radius: 3px;
}
.chat-container::-webkit-scrollbar-thumb {
background: #cbd5e1;
border-radius: 3px;
}
.chat-container::-webkit-scrollbar-thumb:hover {
background: #94a3b8;
}
</style>
""", unsafe_allow_html=True)
# Rate Limiter Class
class RateLimiter:
def __init__(self, max_requests=10, time_window=60):
self.requests = defaultdict(list)
self.max_requests = max_requests
self.time_window = time_window
def check_limit(self, session_id):
now = time.time()
# Clean old requests
self.requests[session_id] = [
t for t in self.requests[session_id]
if now - t < self.time_window
]
if len(self.requests[session_id]) >= self.max_requests:
return False, f"Rate limit exceeded. Please wait before sending more messages."
self.requests[session_id].append(now)
return True, ""
# Initialize session state
if 'initialized' not in st.session_state:
st.session_state.initialized = False
st.session_state.agent_executor = None
st.session_state.chat_sessions = {}
st.session_state.current_session_id = None
st.session_state.connection_status = "Not Connected"
st.session_state.sidebar_collapsed = False
st.session_state.rate_limiter = RateLimiter(max_requests=20, time_window=60)
st.session_state.supabase = None
# Keys configuration
OPENAI_API_KEY = "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"
SUPABASE_URL = "https://oszztarojsrnckvlqhhx.supabase.co"
SUPABASE_KEY = "sb_publishable_olO1Hq6urwZoIawiPvv1QQ_CjoHcDqu"
def validate_input(user_input: str) -> tuple:
"""Validate user input"""
if not user_input or len(user_input.strip()) < 3:
return False, "Query too short. Please provide more details (at least 3 characters)."
if len(user_input) > 2000:
return False, "Query too long. Please keep it under 2000 characters."
# Check for potential dangerous patterns
dangerous_patterns = ['__import__', 'exec(', 'eval(', 'os.system', 'subprocess']
if any(pattern in user_input.lower() for pattern in dangerous_patterns):
return False, "Invalid input detected. Please rephrase your question."
return True, ""
def save_session_to_db(session_id, session_data):
"""Save session to Supabase"""
try:
if st.session_state.supabase is None:
return
# Prepare messages for JSON serialization
messages_json = []
for msg in session_data['messages']:
msg_copy = msg.copy()
if 'timestamp' in msg_copy:
msg_copy['timestamp'] = msg_copy['timestamp'].isoformat()
messages_json.append(msg_copy)
st.session_state.supabase.table('chat_sessions').upsert({
'id': session_id,
'name': session_data['name'],
'created_at': session_data['created_at'].isoformat(),
'messages': json.dumps(messages_json),
'updated_at': datetime.now().isoformat()
}).execute()
except Exception as e:
st.warning(f"Could not save session to database: {str(e)}")
def load_sessions_from_db():
"""Load all sessions from database"""
try:
if st.session_state.supabase is None:
return {}
response = st.session_state.supabase.table('chat_sessions').select('*').order('created_at', desc=True).execute()
sessions = {}
for session in response.data:
session_id = session['id']
messages = json.loads(session['messages']) if session['messages'] else []
# Convert timestamp strings back to datetime
for msg in messages:
if 'timestamp' in msg and isinstance(msg['timestamp'], str):
msg['timestamp'] = datetime.fromisoformat(msg['timestamp'])
sessions[session_id] = {
'id': session_id,
'name': session['name'],
'created_at': datetime.fromisoformat(session['created_at']),
'messages': messages,
'session_memory': [],
'history': []
}
# Rebuild session memory from messages
for msg in messages:
if msg['type'] == 'user':
sessions[session_id]['session_memory'].append(HumanMessage(content=msg['content']))
else:
sessions[session_id]['session_memory'].append(AIMessage(content=msg['content']))
return sessions
except Exception as e:
st.warning(f"Could not load sessions from database: {str(e)}")
return {}
@st.cache_resource
def initialize_agent():
"""Initialize the LangChain agent with caching"""
try:
# Connect to Supabase
supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
# Reconnect to existing vector store
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents"
)
# LLM setup with streaming
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
openai_api_key=OPENAI_API_KEY,
streaming=False
)
# Create base retriever with better search parameters
base_retriever = vector_store.as_retriever(
search_type="similarity",
search_kwargs={
"k": 3,
}
)
# Add contextual compression for better retrieval
compressor = LLMChainExtractor.from_llm(llm)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# QA Chain for better answers
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=base_retriever,
return_source_documents=True
)
def qa_with_sources(query):
"""Question answering with source tracking"""
try:
result = qa_chain.invoke({"query": query})
return result["result"]
except Exception as e:
return f"Error retrieving information: {str(e)}"
# Retriever tool (core RAG function)
retriever_tool = create_retriever_tool(
retriever=base_retriever,
name="retriever",
description=(
"Use this tool to answer ANY question that might be related to or found in the uploaded or provided documents. "
"Always call this tool FIRST whenever the question could possibly require information from those documents. "
"If the question asks about facts, data, summaries, policies, reports, or anything that may come from the user's documents, "
"use this tool to retrieve the relevant content before answering."
),
)
Retriver_tool = Tool(
name="retriever",
func=retriever_tool,
description=(
"Retrieves relevant context from the user's uploaded or stored documents. "
"Use this tool for any question that might involve the content of the documents, "
"such as document summaries, factual answers, or topic-specific details."
),
)
# QA tool
qa_tool = Tool(
name="Question Answering",
func=llm.invoke,
description=(
"A general-purpose question answering tool. "
"Use this ONLY for casual or open-ended questions that are NOT related to the provided documents. "
"Examples: greetings, opinions, or general world knowledge questions (e.g., 'How are you?', 'What is AI?'). "
"Do NOT use this if the question might depend on the document contents."
),
)
# Summary tool
summary_tool = Tool(
name="Summary",
func=llm.invoke,
description="Summarizes long text passages into concise summaries using a structured summarization prompt.",
prompt=PromptTemplate(
input_variables=["input"],
template="""
You are a summarization assistant. Follow these steps to summarize the text:
1. Read the text carefully.
2. Identify the main points and key details.
3. Write a concise summary that captures the essence of the text.
Text: {input}
Summary:
""",
),
)
# Explanation tool
explanation_tool = Tool(
name="Explanation",
func=llm.invoke,
description="Explains complex concepts in simple, clear terms using examples or analogies when appropriate.",
prompt=PromptTemplate(
input_variables=["input"],
template="""
You are an explanation assistant. Follow these steps to explain the concept:
1. Understand the concept thoroughly.
2. Break down the concept into simpler parts.
3. Provide a clear and detailed explanation with examples.
Concept: {input}
Explanation:
""",
),
)
# Tool list (retriever first for prioritization)
tools = [Retriver_tool, summary_tool, explanation_tool, qa_tool]
tool_names = ", ".join([tool.name for tool in tools])
example = """
Example:
Thought: I should use the retriever tool to find relevant info.
Action: retriever
Action Input: current head of the American Red Cross
Observation: The documents do not mention the head of the American Red Cross.
Thought: The information is not in the documents.
Final Answer: I'm sorry, but I couldn’t find information about that in the provided documents.
"""
# Custom ReAct prompt
react_prompt = PromptTemplate.from_template(
example + """
You are a retrieval-augmented assistant that answers questions ONLY using the information
found in the user's provided documents.
You have access to the following tools:
{tools}
Follow this reasoning format:
Thought: Think about what the question is asking and whether you can find the answer in the user's documents.
Action: The action to take, must be one of [{tool_names}]
Action Input: The input to the action (be specific)
Observation: The result of the action
... (You may repeat this Thought/Action/Observation cycle as needed)
Final Answer: Your final grounded answer to the user's question.
### Important Grounding Rules:
- You MUST first use the 'retriever' tool to search for relevant information in the user's documents.
- Only use the information retrieved from the documents to answer the question.
- If the retrieved information does not contain a clear or relevant answer, respond with:
"I'm sorry, but I couldn’t find information about that in the provided documents."
- Do NOT use your own general knowledge or external world knowledge.
- Use the 'Question Answering' tool only for generic greetings (like 'hi', 'how are you') or clarification.
- You may use multiple tools in sequence before providing the final answer.
Previous conversation:
{chat_history}
Question: {input}
{agent_scratchpad}
"""
).partial(
tools="\n".join([f"{tool.name}: {tool.description}" for tool in tools]),
tool_names=tool_names
)
# Create agent
custom_agent = create_react_agent(llm=llm, tools=tools, prompt=react_prompt)
return custom_agent, tools, supabase, "Connected Successfully"
except Exception as e:
return None, None, None, f"Connection Error: {str(e)}"
def create_new_session():
"""Create a new chat session"""
session_id = str(uuid.uuid4())
session_name = f"Chat {len(st.session_state.chat_sessions) + 1}"
# Initialize session data
st.session_state.chat_sessions[session_id] = {
"id": session_id,
"name": session_name,
"created_at": datetime.now(),
"messages": [],
"session_memory": [],
"history": []
}
st.session_state.current_session_id = session_id
# Save to database
save_session_to_db(session_id, st.session_state.chat_sessions[session_id])
return session_id
def get_recent_context(session_data, max_messages=10):
"""Get only recent messages to avoid context overflow"""
recent_messages = session_data["session_memory"][-max_messages*2:] if len(session_data["session_memory"]) > max_messages*2 else session_data["session_memory"]
return recent_messages
def get_agent_executor_for_session(session_id):
"""Get agent executor with session-specific memory"""
if not st.session_state.initialized:
return None
session_data = st.session_state.chat_sessions[session_id]
# Get recent context to avoid overwhelming the model
recent_memory = get_recent_context(session_data, max_messages=8)
# Create summary buffer memory for this session
memory = ConversationSummaryBufferMemory(
llm=ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY),
memory_key="chat_history",
return_messages=True,
output_key="output",
max_token_limit=1000
)
# Restore recent session memory
memory.chat_memory.messages = recent_memory
# Create agent executor
agent_executor = AgentExecutor(
agent=st.session_state.agent,
tools=st.session_state.tools,
memory=memory,
verbose=True,
handle_parsing_errors="Check your output and make sure it follows the correct format.",
return_intermediate_steps=True,
max_iterations=5,
max_execution_time=45,
)
return agent_executor
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def get_agent_response(agent_executor, user_input):
"""Get response with retry logic"""
return agent_executor.invoke({"input": user_input})
def get_response_with_fallback(agent_executor, user_input):
"""Try multiple strategies if initial response fails"""
try:
# Primary attempt
return get_agent_response(agent_executor, user_input)
except Exception as e1:
st.warning(f"Primary attempt failed, trying simplified approach...")
try:
# Fallback 1: Try with simpler prompt
simplified_input = f"Please answer briefly: {user_input}"
return agent_executor.invoke({"input": simplified_input})
except Exception as e2:
st.warning(f"Simplified approach failed, using direct LLM...")
try:
# Fallback 2: Direct LLM call without tools
llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY)
response_content = llm.invoke(user_input).content
return {"output": response_content, "intermediate_steps": []}
except Exception as e3:
raise Exception(f"All attempts failed: {str(e3)}")
def track_metrics(session_data):
"""Track conversation metrics"""
total_messages = len(session_data["messages"])
user_messages = sum(1 for m in session_data["messages"] if m["type"] == "user")
bot_messages = total_messages - user_messages
# Calculate session duration
if session_data["messages"]:
first_msg = session_data["messages"][0]["timestamp"]
last_msg = session_data["messages"][-1]["timestamp"]
duration = (last_msg - first_msg).seconds
else:
duration = 0
return {
"total_messages": total_messages,
"user_messages": user_messages,
"bot_messages": bot_messages,
"session_duration": duration
}
def main():
# Header
st.markdown("""
<div class="header-container">
<h1 class="header-title">πŸ€– AI Document Assistant</h1>
<p class="header-subtitle">Intelligent document analysis powered by LangChain</p>
</div>
""", unsafe_allow_html=True)
# Initialize agent if not done
if not st.session_state.initialized:
with st.spinner("Initializing AI Agent..."):
agent, tools, supabase, status = initialize_agent()
if agent is not None:
st.session_state.agent = agent
st.session_state.tools = tools
st.session_state.supabase = supabase
st.session_state.connection_status = status
st.session_state.initialized = True
# Load existing sessions from database
loaded_sessions = load_sessions_from_db()
if loaded_sessions:
st.session_state.chat_sessions = loaded_sessions
st.session_state.current_session_id = list(loaded_sessions.keys())[0]
else:
st.session_state.connection_status = status
# Sidebar for session management
with st.sidebar:
st.markdown('<p class="sidebar-title">πŸ’¬ Chat Sessions</p>', unsafe_allow_html=True)
# Connection status
status_class = "connected" if st.session_state.connection_status == "Connected Successfully" else "error"
status_text = "🟒 Connected" if status_class == "connected" else f"πŸ”΄ {st.session_state.connection_status}"
st.markdown(f'<div class="status {status_class}">{status_text}</div>', unsafe_allow_html=True)
st.markdown("---")
# New session button
if st.button("+ New Chat", use_container_width=True):
create_new_session()
st.rerun()
# Display sessions
if st.session_state.chat_sessions:
for session_id, session_data in st.session_state.chat_sessions.items():
is_active = session_id == st.session_state.current_session_id
if st.button(
f"{session_data['name']}\n{len(session_data['messages'])} messages",
key=f"session_{session_id}",
use_container_width=True
):
st.session_state.current_session_id = session_id
st.rerun()
# Session actions
if st.session_state.current_session_id:
st.markdown("---")
# Rename session
new_name = st.text_input(
"Session Name:",
value=st.session_state.chat_sessions[st.session_state.current_session_id]["name"]
)
if st.button("Save Name", key="save_name"):
st.session_state.chat_sessions[st.session_state.current_session_id]["name"] = new_name
save_session_to_db(st.session_state.current_session_id, st.session_state.chat_sessions[st.session_state.current_session_id])
st.success("Name updated!")
st.rerun()
# Show session metrics
if st.session_state.chat_sessions[st.session_state.current_session_id]["messages"]:
metrics = track_metrics(st.session_state.chat_sessions[st.session_state.current_session_id])
st.markdown("---")
st.markdown("**Session Stats:**")
st.text(f"πŸ“Š Messages: {metrics['total_messages']}")
st.text(f"⏱️ Duration: {metrics['session_duration']}s")
# Delete session
if len(st.session_state.chat_sessions) > 1:
st.markdown("---")
if st.button("πŸ—‘οΈ Delete Chat", key="delete_session"):
# Delete from database
if st.session_state.supabase:
try:
st.session_state.supabase.table('chat_sessions').delete().eq('id', st.session_state.current_session_id).execute()
except:
pass
del st.session_state.chat_sessions[st.session_state.current_session_id]
st.session_state.current_session_id = list(st.session_state.chat_sessions.keys())[0]
st.rerun()
# Main content
if not st.session_state.initialized:
st.error("⚠️ Agent initialization failed. Please check your configuration.")
return
# Create default session if none exists
if not st.session_state.chat_sessions:
create_new_session()
# Ensure current session exists
if st.session_state.current_session_id not in st.session_state.chat_sessions:
st.session_state.current_session_id = list(st.session_state.chat_sessions.keys())[0]
current_session = st.session_state.chat_sessions[st.session_state.current_session_id]
# Chat messages display
if current_session["messages"]:
for message in current_session["messages"]:
if message["type"] == "user":
st.markdown(f'''
<div class="message user">
<div class="message-content">
<div class="message-label">You</div>
{message["content"]}
</div>
</div>
''', unsafe_allow_html=True)
else:
tools_info = ""
if message.get('tools_used'):
tools_info = f'<div class="message-tools">πŸ”§ Tools: {", ".join(message["tools_used"])}</div>'
sources_info = ""
if message.get('sources'):
sources_info = f'<div class="message-tools">πŸ“š Sources: {len(message["sources"])} documents</div>'
st.markdown(f'''
<div class="message bot">
<div class="message-content">
<div class="message-label">Assistant</div>
{message["content"]}
{tools_info}
{sources_info}
</div>
</div>
''', unsafe_allow_html=True)
else:
st.markdown("""
<div style="text-align: center; color: #6b7280; padding: 2rem;">
πŸ‘‹ Start a conversation by asking a question about your documents
</div>
""", unsafe_allow_html=True)
# Input area
with st.form("chat_form", clear_on_submit=True):
col1, col2 = st.columns([5, 1])
with col1:
user_input = st.text_area(
"Message",
placeholder="Ask a question about your documents...",
height=80,
label_visibility="collapsed"
)
with col2:
st.markdown("<div style='height: 20px;'></div>", unsafe_allow_html=True)
submit_button = st.form_submit_button("Send", use_container_width=True)
# Process user input
if submit_button and user_input.strip():
# Validate input
is_valid, error_msg = validate_input(user_input)
if not is_valid:
st.error(error_msg)
return
# Check rate limit
can_proceed, rate_limit_msg = st.session_state.rate_limiter.check_limit(st.session_state.current_session_id)
if not can_proceed:
st.error(rate_limit_msg)
return
# Add user message to session
user_message = {
"type": "user",
"content": user_input,
"timestamp": datetime.now()
}
current_session["messages"].append(user_message)
current_session["session_memory"].append(HumanMessage(content=user_input))
# Show thinking indicator
thinking_placeholder = st.empty()
thinking_placeholder.markdown('<div class="thinking">πŸ€” Thinking...</div>', unsafe_allow_html=True)
try:
# Get agent executor for current session
agent_executor = get_agent_executor_for_session(st.session_state.current_session_id)
# Get response from agent with fallback
response = get_response_with_fallback(agent_executor, user_input)
answer = response["output"]
# Extract tools used
tools_used = []
if "intermediate_steps" in response:
for step in response["intermediate_steps"]:
if len(step) > 0 and hasattr(step[0], 'tool'):
tools_used.append(step[0].tool)
tools_used = list(set(tools_used)) if tools_used else []
# Add bot message to session
bot_message = {
"type": "bot",
"content": answer,
"timestamp": datetime.now(),
"tools_used": tools_used
}
current_session["messages"].append(bot_message)
current_session["session_memory"].append(AIMessage(content=answer))
# Save session to database
save_session_to_db(st.session_state.current_session_id, current_session)
except Exception as e:
error_message = {
"type": "bot",
"content": f"❌ I encountered an error processing your request. Please try rephrasing your question or try again later.\n\nError: {str(e)}",
"timestamp": datetime.now()
}
current_session["messages"].append(error_message)
current_session["session_memory"].append(AIMessage(content=error_message["content"]))
finally:
thinking_placeholder.empty()
st.rerun()
if __name__ == "__main__":
main()