rdune71 commited on
Commit
bb60cf1
·
1 Parent(s): 89431ea

Add RAG capability with document upload and management

Browse files
app.py CHANGED
@@ -15,8 +15,13 @@ from modules.citation import generate_citations, format_citations
15
  from modules.server_cache import get_cached_result, cache_result
16
  from modules.status_logger import log_request
17
  from modules.server_monitor import ServerMonitor
 
 
 
18
 
19
  server_monitor = ServerMonitor()
 
 
20
 
21
  # Cat-themed greeting function
22
  def get_cat_greeting():
@@ -123,8 +128,8 @@ def run_startup_check():
123
  return wrapper
124
 
125
  # Enhanced streaming with markdown support
126
- async def research_assistant(query, history):
127
- log_request("Research started", query=query)
128
 
129
  # Add typing indicator
130
  history.append((query, "🔄 Searching for information..."))
@@ -188,9 +193,19 @@ async def research_assistant(query, history):
188
  if any(keyword in lower_query for keyword in space_keywords):
189
  context_section += f"\nSpace Weather Context: {space_weather_data}"
190
 
191
- # Build the enriched input with context only when needed
192
  enriched_input = f"{validated_query}\n\n{answer_content}Search Results:\n{search_content}{context_section}"
193
 
 
 
 
 
 
 
 
 
 
 
194
  server_status = server_monitor.check_server_status()
195
  if not server_status["available"]:
196
  wait_time = server_status["estimated_wait"]
@@ -298,11 +313,39 @@ class AsyncGeneratorWrapper:
298
  raise StopIteration
299
  return item
300
 
301
- def research_assistant_wrapper(query, history):
302
- async_gen = research_assistant(query, history)
303
  wrapper = AsyncGeneratorWrapper(async_gen)
304
  return wrapper
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  # Performance dashboard data
307
  def get_performance_stats():
308
  """Get performance statistics from Redis"""
@@ -344,14 +387,16 @@ with gr.Blocks(
344
  gr.Markdown("## How to Use")
345
  gr.Markdown("""
346
  1. Enter a research question in the input box
347
- 2. Click Submit or press Enter
348
- 3. Watch as the response streams in real-time
349
- 4. Review sources at the end of each response
 
350
 
351
  ## Features
352
  - 🔍 Web search integration
353
  - 🌤️ Context-aware weather data (only when relevant)
354
  - 🌌 Context-aware space weather data (only when relevant)
 
355
  - 📚 Real-time citations
356
  - ⚡ Streaming output
357
  """)
@@ -368,6 +413,11 @@ with gr.Blocks(
368
  placeholder="Ask a complex research question...",
369
  lines=3
370
  )
 
 
 
 
 
371
  with gr.Row():
372
  submit_btn = gr.Button("Submit Research Query", variant="primary")
373
  clear_btn = gr.Button("Clear Conversation")
@@ -384,6 +434,25 @@ with gr.Blocks(
384
  label="Example Questions"
385
  )
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  with gr.TabItem("📊 Performance"):
388
  perf_refresh_btn = gr.Button("🔄 Refresh Stats")
389
  perf_display = gr.JSON(label="System Statistics")
@@ -432,9 +501,9 @@ While you wait, why not prepare some treats? I'll be ready to hunt for knowledge
432
  startup_check_result = run_startup_check()
433
  return update_status()
434
 
435
- def respond(message, history):
436
  # Get streaming response
437
- for updated_history in research_assistant_wrapper(message, history):
438
  yield updated_history, update_status()
439
 
440
  def clear_conversation():
@@ -452,17 +521,21 @@ While you wait, why not prepare some treats? I'll be ready to hunt for knowledge
452
  check_btn.click(refresh_status, outputs=status_display)
453
  submit_btn.click(
454
  respond,
455
- [msg, chat_history],
456
  [chatbot, status_display]
457
  )
458
  msg.submit(
459
  respond,
460
- [msg, chat_history],
461
  [chatbot, status_display]
462
  )
463
 
464
  clear_btn.click(clear_conversation, outputs=[chat_history, chatbot])
465
 
 
 
 
 
466
  # Performance dashboard
467
  perf_refresh_btn.click(update_performance_stats, outputs=perf_display)
468
 
 
15
  from modules.server_cache import get_cached_result, cache_result
16
  from modules.status_logger import log_request
17
  from modules.server_monitor import ServerMonitor
18
+ from modules.rag.rag_chain import RAGChain
19
+ from modules.rag.vector_store import VectorStore
20
+ from langchain.docstore.document import Document
21
 
22
  server_monitor = ServerMonitor()
23
+ rag_chain = RAGChain()
24
+ vector_store = VectorStore()
25
 
26
  # Cat-themed greeting function
27
  def get_cat_greeting():
 
128
  return wrapper
129
 
130
  # Enhanced streaming with markdown support
131
+ async def research_assistant(query, history, use_rag=False):
132
+ log_request("Research started", query=query, use_rag=use_rag)
133
 
134
  # Add typing indicator
135
  history.append((query, "🔄 Searching for information..."))
 
193
  if any(keyword in lower_query for keyword in space_keywords):
194
  context_section += f"\nSpace Weather Context: {space_weather_data}"
195
 
196
+ # Build the enriched input
197
  enriched_input = f"{validated_query}\n\n{answer_content}Search Results:\n{search_content}{context_section}"
198
 
199
+ # If RAG is enabled, use it
200
+ if use_rag:
201
+ history[-1] = (query, "📚 Searching document database...")
202
+ yield history
203
+
204
+ rag_result = rag_chain.query(validated_query)
205
+ if rag_result["status"] == "success":
206
+ enriched_input = rag_result["prompt"]
207
+ context_section += f"\n\nDocument Context:\n" + "\n\n".join([doc.page_content for doc in rag_result["context_docs"][:2]])
208
+
209
  server_status = server_monitor.check_server_status()
210
  if not server_status["available"]:
211
  wait_time = server_status["estimated_wait"]
 
313
  raise StopIteration
314
  return item
315
 
316
+ def research_assistant_wrapper(query, history, use_rag):
317
+ async_gen = research_assistant(query, history, use_rag)
318
  wrapper = AsyncGeneratorWrapper(async_gen)
319
  return wrapper
320
 
321
+ # Document upload function
322
+ def upload_documents(files):
323
+ """Upload and process documents for RAG"""
324
+ try:
325
+ documents = []
326
+ for file in files:
327
+ # For PDF files
328
+ if file.name.endswith('.pdf'):
329
+ from PyPDF2 import PdfReader
330
+ reader = PdfReader(file.name)
331
+ text = ""
332
+ for page in reader.pages:
333
+ text += page.extract_text()
334
+ documents.append(Document(page_content=text, metadata={"source": file.name}))
335
+ # For text files
336
+ else:
337
+ with open(file.name, 'r') as f:
338
+ text = f.read()
339
+ documents.append(Document(page_content=text, metadata={"source": file.name}))
340
+
341
+ result = vector_store.add_documents(documents)
342
+ if result["status"] == "success":
343
+ return f"✅ Successfully added {result['count']} document chunks to the knowledge base!"
344
+ else:
345
+ return f"❌ Error adding documents: {result['message']}"
346
+ except Exception as e:
347
+ return f"❌ Error processing documents: {str(e)}"
348
+
349
  # Performance dashboard data
350
  def get_performance_stats():
351
  """Get performance statistics from Redis"""
 
387
  gr.Markdown("## How to Use")
388
  gr.Markdown("""
389
  1. Enter a research question in the input box
390
+ 2. Toggle 'Use Document Knowledge' to enable RAG
391
+ 3. Click Submit or press Enter
392
+ 4. Watch as the response streams in real-time
393
+ 5. Review sources at the end of each response
394
 
395
  ## Features
396
  - 🔍 Web search integration
397
  - 🌤️ Context-aware weather data (only when relevant)
398
  - 🌌 Context-aware space weather data (only when relevant)
399
+ - 📚 RAG (Retrieval-Augmented Generation) with document database
400
  - 📚 Real-time citations
401
  - ⚡ Streaming output
402
  """)
 
413
  placeholder="Ask a complex research question...",
414
  lines=3
415
  )
416
+ use_rag = gr.Checkbox(
417
+ label="📚 Use Document Knowledge (RAG)",
418
+ value=False,
419
+ info="Enable to search uploaded documents for context"
420
+ )
421
  with gr.Row():
422
  submit_btn = gr.Button("Submit Research Query", variant="primary")
423
  clear_btn = gr.Button("Clear Conversation")
 
434
  label="Example Questions"
435
  )
436
 
437
+ with gr.TabItem("📚 Document Management"):
438
+ gr.Markdown("## Upload Documents for RAG")
439
+ gr.Markdown("Upload PDF or text files to add them to the knowledge base for document-based queries.")
440
+ file_upload = gr.File(
441
+ file_types=[".pdf", ".txt"],
442
+ file_count="multiple",
443
+ label="Upload Documents"
444
+ )
445
+ upload_btn = gr.Button("📤 Upload Documents")
446
+ upload_output = gr.Textbox(label="Upload Status", interactive=False)
447
+ clear_docs_btn = gr.Button("🗑️ Clear All Documents")
448
+
449
+ gr.Markdown("## Current Documents")
450
+ doc_list = gr.Textbox(
451
+ label="Document List",
452
+ value="No documents uploaded yet",
453
+ interactive=False
454
+ )
455
+
456
  with gr.TabItem("📊 Performance"):
457
  perf_refresh_btn = gr.Button("🔄 Refresh Stats")
458
  perf_display = gr.JSON(label="System Statistics")
 
501
  startup_check_result = run_startup_check()
502
  return update_status()
503
 
504
+ def respond(message, history, use_rag_flag):
505
  # Get streaming response
506
+ for updated_history in research_assistant_wrapper(message, history, use_rag_flag):
507
  yield updated_history, update_status()
508
 
509
  def clear_conversation():
 
521
  check_btn.click(refresh_status, outputs=status_display)
522
  submit_btn.click(
523
  respond,
524
+ [msg, chat_history, use_rag],
525
  [chatbot, status_display]
526
  )
527
  msg.submit(
528
  respond,
529
+ [msg, chat_history, use_rag],
530
  [chatbot, status_display]
531
  )
532
 
533
  clear_btn.click(clear_conversation, outputs=[chat_history, chatbot])
534
 
535
+ # Document management
536
+ upload_btn.click(upload_documents, file_upload, upload_output)
537
+ clear_docs_btn.click(lambda: vector_store.delete_collection(), None, upload_output)
538
+
539
  # Performance dashboard
540
  perf_refresh_btn.click(update_performance_stats, outputs=perf_display)
541
 
modules/rag/rag_chain.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.llms import OpenAI
3
+ from langchain.prompts import PromptTemplate
4
+ from modules.rag.vector_store import VectorStore
5
+ from modules.analyzer import client
6
+ import os
7
+
8
+ class RAGChain:
9
+ def __init__(self):
10
+ self.vector_store = VectorStore()
11
+ self.retriever = self.vector_store.vector_store.as_retriever(
12
+ search_type="similarity",
13
+ search_kwargs={"k": 5}
14
+ )
15
+
16
+ # Custom prompt template
17
+ self.prompt_template = """
18
+ You are an AI research assistant with access to a document database.
19
+ Use the following pieces of context to answer the question at the end.
20
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
21
+
22
+ Context: {context}
23
+
24
+ Question: {question}
25
+
26
+ Answer:
27
+ """
28
+
29
+ self.prompt = PromptTemplate(
30
+ template=self.prompt_template,
31
+ input_variables=["context", "question"]
32
+ )
33
+
34
+ def query(self, question):
35
+ """Query the RAG system"""
36
+ try:
37
+ # Search for relevant documents
38
+ search_result = self.vector_store.search(question)
39
+ if search_result["status"] != "success":
40
+ return {"status": "error", "message": search_result["message"]}
41
+
42
+ # Format context
43
+ context = "\n\n".join([doc.page_content for doc in search_result["documents"]])
44
+
45
+ # Create enhanced prompt
46
+ enhanced_prompt = self.prompt.format(context=context, question=question)
47
+
48
+ # For streaming, we'll return the prompt for the analyzer to handle
49
+ return {
50
+ "status": "success",
51
+ "prompt": enhanced_prompt,
52
+ "context_docs": search_result["documents"]
53
+ }
54
+ except Exception as e:
55
+ return {"status": "error", "message": str(e)}
modules/rag/vector_store.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chromadb
3
+ from chromadb.utils import embedding_functions
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.embeddings import SentenceTransformerEmbeddings
7
+ from langchain.docstore.document import Document
8
+ import uuid
9
+
10
+ class VectorStore:
11
+ def __init__(self):
12
+ # Initialize embedding function
13
+ self.embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
14
+
15
+ # Initialize ChromaDB client
16
+ self.client = chromadb.PersistentClient(path="./chroma_db")
17
+
18
+ # Create or get collection
19
+ self.collection = self.client.get_or_create_collection(
20
+ name="research_documents",
21
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
22
+ model_name="all-MiniLM-L6-v2"
23
+ )
24
+ )
25
+
26
+ # Initialize LangChain vector store
27
+ self.vector_store = Chroma(
28
+ collection_name="research_documents",
29
+ embedding_function=self.embedding_function,
30
+ persist_directory="./chroma_db"
31
+ )
32
+
33
+ # Initialize text splitter
34
+ self.text_splitter = RecursiveCharacterTextSplitter(
35
+ chunk_size=1000,
36
+ chunk_overlap=200,
37
+ length_function=len,
38
+ )
39
+
40
+ def add_documents(self, documents):
41
+ """Add documents to the vector store"""
42
+ try:
43
+ # Split documents into chunks
44
+ split_docs = []
45
+ for doc in documents:
46
+ splits = self.text_splitter.split_text(doc.page_content)
47
+ for i, split in enumerate(splits):
48
+ split_docs.append(Document(
49
+ page_content=split,
50
+ metadata={**doc.metadata, "chunk": i}
51
+ ))
52
+
53
+ # Add to vector store
54
+ ids = [str(uuid.uuid4()) for _ in split_docs]
55
+ self.vector_store.add_documents(split_docs, ids=ids)
56
+
57
+ return {"status": "success", "count": len(split_docs)}
58
+ except Exception as e:
59
+ return {"status": "error", "message": str(e)}
60
+
61
+ def search(self, query, k=5):
62
+ """Search for relevant documents"""
63
+ try:
64
+ # Perform similarity search
65
+ docs = self.vector_store.similarity_search(query, k=k)
66
+ return {"status": "success", "documents": docs}
67
+ except Exception as e:
68
+ return {"status": "error", "message": str(e)}
69
+
70
+ def delete_collection(self):
71
+ """Delete the entire collection"""
72
+ try:
73
+ self.client.delete_collection("research_documents")
74
+ self.collection = self.client.get_or_create_collection(
75
+ name="research_documents",
76
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
77
+ model_name="all-MiniLM-L6-v2"
78
+ )
79
+ )
80
+ return {"status": "success"}
81
+ except Exception as e:
82
+ return {"status": "error", "message": str(e)}
requirements.txt CHANGED
@@ -5,3 +5,10 @@ redis
5
  aiohttp
6
  requests
7
  python-dotenv
 
 
 
 
 
 
 
 
5
  aiohttp
6
  requests
7
  python-dotenv
8
+ langchain
9
+ langchain-community
10
+ langchain-openai
11
+ chromadb
12
+ sentence-transformers
13
+ pypdf
14
+ python-multipart