import streamlit as st import os from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_community.vectorstores import SupabaseVectorStore from langchain.chains import RetrievalQA from supabase import create_client from langchain.prompts import PromptTemplate from langchain.agents import Tool, create_react_agent from langchain.tools.retriever import create_retriever_tool from langchain.memory import ConversationSummaryBufferMemory from langchain.agents import AgentExecutor from langchain.schema import HumanMessage, AIMessage from langchain.cache import InMemoryCache from langchain.globals import set_llm_cache from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import LLMChainExtractor import uuid from datetime import datetime import json import time from collections import defaultdict from tenacity import retry, stop_after_attempt, wait_exponential # Page configuration st.set_page_config( page_title="AI Document Assistant", page_icon="🤖", layout="wide", initial_sidebar_state="expanded" ) # Enable LLM caching for faster responses set_llm_cache(InMemoryCache()) # Custom CSS for professional design st.markdown(""" """, unsafe_allow_html=True) # Rate Limiter Class class RateLimiter: def __init__(self, max_requests=10, time_window=60): self.requests = defaultdict(list) self.max_requests = max_requests self.time_window = time_window def check_limit(self, session_id): now = time.time() # Clean old requests self.requests[session_id] = [ t for t in self.requests[session_id] if now - t < self.time_window ] if len(self.requests[session_id]) >= self.max_requests: return False, f"Rate limit exceeded. Please wait before sending more messages." self.requests[session_id].append(now) return True, "" # Initialize session state if 'initialized' not in st.session_state: st.session_state.initialized = False st.session_state.agent_executor = None st.session_state.chat_sessions = {} st.session_state.current_session_id = None st.session_state.connection_status = "Not Connected" st.session_state.sidebar_collapsed = False st.session_state.rate_limiter = RateLimiter(max_requests=20, time_window=60) st.session_state.supabase = None # Keys configuration OPENAI_API_KEY = "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" SUPABASE_URL = "https://oszztarojsrnckvlqhhx.supabase.co" SUPABASE_KEY = "sb_publishable_olO1Hq6urwZoIawiPvv1QQ_CjoHcDqu" def validate_input(user_input: str) -> tuple: """Validate user input""" if not user_input or len(user_input.strip()) < 3: return False, "Query too short. Please provide more details (at least 3 characters)." if len(user_input) > 2000: return False, "Query too long. Please keep it under 2000 characters." # Check for potential dangerous patterns dangerous_patterns = ['__import__', 'exec(', 'eval(', 'os.system', 'subprocess'] if any(pattern in user_input.lower() for pattern in dangerous_patterns): return False, "Invalid input detected. Please rephrase your question." return True, "" def save_session_to_db(session_id, session_data): """Save session to Supabase""" try: if st.session_state.supabase is None: return # Prepare messages for JSON serialization messages_json = [] for msg in session_data['messages']: msg_copy = msg.copy() if 'timestamp' in msg_copy: msg_copy['timestamp'] = msg_copy['timestamp'].isoformat() messages_json.append(msg_copy) st.session_state.supabase.table('chat_sessions').upsert({ 'id': session_id, 'name': session_data['name'], 'created_at': session_data['created_at'].isoformat(), 'messages': json.dumps(messages_json), 'updated_at': datetime.now().isoformat() }).execute() except Exception as e: st.warning(f"Could not save session to database: {str(e)}") def load_sessions_from_db(): """Load all sessions from database""" try: if st.session_state.supabase is None: return {} response = st.session_state.supabase.table('chat_sessions').select('*').order('created_at', desc=True).execute() sessions = {} for session in response.data: session_id = session['id'] messages = json.loads(session['messages']) if session['messages'] else [] # Convert timestamp strings back to datetime for msg in messages: if 'timestamp' in msg and isinstance(msg['timestamp'], str): msg['timestamp'] = datetime.fromisoformat(msg['timestamp']) sessions[session_id] = { 'id': session_id, 'name': session['name'], 'created_at': datetime.fromisoformat(session['created_at']), 'messages': messages, 'session_memory': [], 'history': [] } # Rebuild session memory from messages for msg in messages: if msg['type'] == 'user': sessions[session_id]['session_memory'].append(HumanMessage(content=msg['content'])) else: sessions[session_id]['session_memory'].append(AIMessage(content=msg['content'])) return sessions except Exception as e: st.warning(f"Could not load sessions from database: {str(e)}") return {} @st.cache_resource def initialize_agent(): """Initialize the LangChain agent with caching""" try: # Connect to Supabase supabase = create_client(SUPABASE_URL, SUPABASE_KEY) embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY) # Reconnect to existing vector store vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, table_name="documents" ) # LLM setup with streaming llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, openai_api_key=OPENAI_API_KEY, streaming=False ) # Create base retriever with better search parameters base_retriever = vector_store.as_retriever( search_type="similarity", search_kwargs={ "k": 3, } ) # Add contextual compression for better retrieval compressor = LLMChainExtractor.from_llm(llm) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=base_retriever ) # QA Chain for better answers qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=base_retriever, return_source_documents=True ) def qa_with_sources(query): """Question answering with source tracking""" try: result = qa_chain.invoke({"query": query}) return result["result"] except Exception as e: return f"Error retrieving information: {str(e)}" # Retriever tool (core RAG function) retriever_tool = create_retriever_tool( retriever=base_retriever, name="retriever", description=( "Use this tool to answer ANY question that might be related to or found in the uploaded or provided documents. " "Always call this tool FIRST whenever the question could possibly require information from those documents. " "If the question asks about facts, data, summaries, policies, reports, or anything that may come from the user's documents, " "use this tool to retrieve the relevant content before answering." ), ) Retriver_tool = Tool( name="retriever", func=retriever_tool, description=( "Retrieves relevant context from the user's uploaded or stored documents. " "Use this tool for any question that might involve the content of the documents, " "such as document summaries, factual answers, or topic-specific details." ), ) # QA tool qa_tool = Tool( name="Question Answering", func=llm.invoke, description=( "A general-purpose question answering tool. " "Use this ONLY for casual or open-ended questions that are NOT related to the provided documents. " "Examples: greetings, opinions, or general world knowledge questions (e.g., 'How are you?', 'What is AI?'). " "Do NOT use this if the question might depend on the document contents." ), ) # Summary tool summary_tool = Tool( name="Summary", func=llm.invoke, description="Summarizes long text passages into concise summaries using a structured summarization prompt.", prompt=PromptTemplate( input_variables=["input"], template=""" You are a summarization assistant. Follow these steps to summarize the text: 1. Read the text carefully. 2. Identify the main points and key details. 3. Write a concise summary that captures the essence of the text. Text: {input} Summary: """, ), ) # Explanation tool explanation_tool = Tool( name="Explanation", func=llm.invoke, description="Explains complex concepts in simple, clear terms using examples or analogies when appropriate.", prompt=PromptTemplate( input_variables=["input"], template=""" You are an explanation assistant. Follow these steps to explain the concept: 1. Understand the concept thoroughly. 2. Break down the concept into simpler parts. 3. Provide a clear and detailed explanation with examples. Concept: {input} Explanation: """, ), ) # Tool list (retriever first for prioritization) tools = [Retriver_tool, summary_tool, explanation_tool, qa_tool] tool_names = ", ".join([tool.name for tool in tools]) example = """ Example: Thought: I should use the retriever tool to find relevant info. Action: retriever Action Input: current head of the American Red Cross Observation: The documents do not mention the head of the American Red Cross. Thought: The information is not in the documents. Final Answer: I'm sorry, but I couldn’t find information about that in the provided documents. """ # Custom ReAct prompt react_prompt = PromptTemplate.from_template( example + """ You are a retrieval-augmented assistant that answers questions ONLY using the information found in the user's provided documents. You have access to the following tools: {tools} Follow this reasoning format: Thought: Think about what the question is asking and whether you can find the answer in the user's documents. Action: The action to take, must be one of [{tool_names}] Action Input: The input to the action (be specific) Observation: The result of the action ... (You may repeat this Thought/Action/Observation cycle as needed) Final Answer: Your final grounded answer to the user's question. ### Important Grounding Rules: - You MUST first use the 'retriever' tool to search for relevant information in the user's documents. - Only use the information retrieved from the documents to answer the question. - If the retrieved information does not contain a clear or relevant answer, respond with: "I'm sorry, but I couldn’t find information about that in the provided documents." - Do NOT use your own general knowledge or external world knowledge. - Use the 'Question Answering' tool only for generic greetings (like 'hi', 'how are you') or clarification. - You may use multiple tools in sequence before providing the final answer. Previous conversation: {chat_history} Question: {input} {agent_scratchpad} """ ).partial( tools="\n".join([f"{tool.name}: {tool.description}" for tool in tools]), tool_names=tool_names ) # Create agent custom_agent = create_react_agent(llm=llm, tools=tools, prompt=react_prompt) return custom_agent, tools, supabase, "Connected Successfully" except Exception as e: return None, None, None, f"Connection Error: {str(e)}" def create_new_session(): """Create a new chat session""" session_id = str(uuid.uuid4()) session_name = f"Chat {len(st.session_state.chat_sessions) + 1}" # Initialize session data st.session_state.chat_sessions[session_id] = { "id": session_id, "name": session_name, "created_at": datetime.now(), "messages": [], "session_memory": [], "history": [] } st.session_state.current_session_id = session_id # Save to database save_session_to_db(session_id, st.session_state.chat_sessions[session_id]) return session_id def get_recent_context(session_data, max_messages=10): """Get only recent messages to avoid context overflow""" recent_messages = session_data["session_memory"][-max_messages*2:] if len(session_data["session_memory"]) > max_messages*2 else session_data["session_memory"] return recent_messages def get_agent_executor_for_session(session_id): """Get agent executor with session-specific memory""" if not st.session_state.initialized: return None session_data = st.session_state.chat_sessions[session_id] # Get recent context to avoid overwhelming the model recent_memory = get_recent_context(session_data, max_messages=8) # Create summary buffer memory for this session memory = ConversationSummaryBufferMemory( llm=ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY), memory_key="chat_history", return_messages=True, output_key="output", max_token_limit=1000 ) # Restore recent session memory memory.chat_memory.messages = recent_memory # Create agent executor agent_executor = AgentExecutor( agent=st.session_state.agent, tools=st.session_state.tools, memory=memory, verbose=True, handle_parsing_errors="Check your output and make sure it follows the correct format.", return_intermediate_steps=True, max_iterations=5, max_execution_time=45, ) return agent_executor @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10)) def get_agent_response(agent_executor, user_input): """Get response with retry logic""" return agent_executor.invoke({"input": user_input}) def get_response_with_fallback(agent_executor, user_input): """Try multiple strategies if initial response fails""" try: # Primary attempt return get_agent_response(agent_executor, user_input) except Exception as e1: st.warning(f"Primary attempt failed, trying simplified approach...") try: # Fallback 1: Try with simpler prompt simplified_input = f"Please answer briefly: {user_input}" return agent_executor.invoke({"input": simplified_input}) except Exception as e2: st.warning(f"Simplified approach failed, using direct LLM...") try: # Fallback 2: Direct LLM call without tools llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY) response_content = llm.invoke(user_input).content return {"output": response_content, "intermediate_steps": []} except Exception as e3: raise Exception(f"All attempts failed: {str(e3)}") def track_metrics(session_data): """Track conversation metrics""" total_messages = len(session_data["messages"]) user_messages = sum(1 for m in session_data["messages"] if m["type"] == "user") bot_messages = total_messages - user_messages # Calculate session duration if session_data["messages"]: first_msg = session_data["messages"][0]["timestamp"] last_msg = session_data["messages"][-1]["timestamp"] duration = (last_msg - first_msg).seconds else: duration = 0 return { "total_messages": total_messages, "user_messages": user_messages, "bot_messages": bot_messages, "session_duration": duration } def main(): # Header st.markdown("""
Intelligent document analysis powered by LangChain