riteshraut
upgraded embedding
57bb94b
raw
history blame
10.9 kB
import os
import time
import uuid
from flask import Flask, request, render_template, session, jsonify, Response
from werkzeug.utils import secure_filename
from rag_processor import create_rag_chain
from typing import Sequence, Any, List
import fitz
import re
import io
from gtts import gTTS
from langchain_core.documents import Document
from langchain_community.document_loaders import (
TextLoader,
Docx2txtLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.storage import InMemoryStore
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
if is_hf_spaces:
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
else:
app.config['UPLOAD_FOLDER'] = 'uploads'
try:
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
except Exception as e:
print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
rag_chains = {}
message_histories = {}
doc_stores = {} # To hold the InMemoryStore for each session
print("Loading embedding model...")
try:
hf_token = os.getenv("HF_TOKEN")
EMBEDDING_MODEL = HuggingFaceEmbeddings(
model_name="google/embeddinggemma-300m",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True},
)
print("Embedding model loaded successfully.")
except Exception as e:
print(f"FATAL: Could not load embedding model. Error: {e}")
raise
def load_pdf_with_fallback(filepath):
try:
docs = []
with fitz.open(filepath) as pdf_doc:
for page_num, page in enumerate(pdf_doc):
text = page.get_text()
if text.strip():
docs.append(Document(
page_content=text,
metadata={
"source": os.path.basename(filepath),
"page": page_num + 1,
}
))
if docs:
print(f"Successfully loaded PDF with PyMuPDF: {filepath}")
return docs
else:
raise ValueError("No text content found in PDF.")
except Exception as e:
print(f"PyMuPDF failed for {filepath}: {e}")
raise
LOADER_MAPPING = {
".txt": TextLoader,
".pdf": load_pdf_with_fallback,
".docx": Docx2txtLoader,
}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in message_histories:
message_histories[session_id] = ChatMessageHistory()
return message_histories[session_id]
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy'}), 200
@app.route('/', methods=['GET'])
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_files():
files = request.files.getlist('file')
if not files or all(f.filename == '' for f in files):
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
all_docs = []
processed_files, failed_files = [], []
for file in files:
if file and file.filename:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
try:
file.save(filepath)
file_ext = os.path.splitext(filename)[1].lower()
if file_ext not in LOADER_MAPPING:
raise ValueError("Unsupported file format.")
loader_func = LOADER_MAPPING[file_ext]
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
if not docs:
raise ValueError("No content extracted.")
all_docs.extend(docs)
processed_files.append(filename)
print(f"✓ Successfully processed: {filename}")
except Exception as e:
error_msg = str(e)
print(f"✗ Error processing {filename}: {error_msg}")
failed_files.append(f"{filename} ({error_msg})")
if not all_docs:
error_summary = "Failed to process all files."
if failed_files:
error_summary += " Reasons: " + ", ".join(failed_files)
return jsonify({'status': 'error', 'message': error_summary}), 400
try:
print("Starting RAG pipeline setup...")
parent_splitter =RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=300,
separators=["\n\n", "\n", ". ", " ", ""], # Prioritize natural breaks
length_function=len)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100)
parent_docs = parent_splitter.split_documents(all_docs)
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
child_docs = []
for i, doc in enumerate(parent_docs):
_id = doc_ids[i]
sub_docs = child_splitter.split_documents([doc])
for child in sub_docs:
child.metadata["doc_id"] = _id
child_docs.extend(sub_docs)
store = InMemoryStore()
store.mset(list(zip(doc_ids, parent_docs)))
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
bm25_retriever = BM25Retriever.from_documents(child_docs)
bm25_retriever.k = 3
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.4, 0.6]
)
print("Created Hybrid Retriever for child documents.")
session_id = str(uuid.uuid4())
doc_stores[session_id] = store
rag_chain_components = create_rag_chain(ensemble_retriever, get_session_history, EMBEDDING_MODEL, store)
rag_chains[session_id] = rag_chain_components
session['session_id'] = session_id
success_msg = f"Successfully processed: {', '.join(processed_files)}"
if failed_files:
success_msg += f"\nFailed to process: {', '.join(failed_files)}"
return jsonify({
'status': 'success',
'filename': success_msg,
'session_id': session_id
})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
@app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
question = data.get('question')
session_id = session.get('session_id') or data.get('session_id')
if not question or not session_id or session_id not in rag_chains:
return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
try:
chain_components = rag_chains[session_id]
config = {"configurable": {"session_id": session_id}}
print("\n" + "="*50)
print("--- STARTING DIAGNOSTIC RUN ---")
print(f"Original Question: {question}")
print("="*50 + "\n")
rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages})
#print(f"--- 1. Rewritten Query ---\n{rewritten_query}\n")
hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query})
#print(f"--- 2. HyDE Document ---\n{hyde_doc}\n")
final_retrieved_docs = chain_components["base_retriever"].get_relevant_documents(hyde_doc)
#print(f"--- 3. Retrieved Top {len(final_retrieved_docs)} Child Docs ---")
#for i, doc in enumerate(final_retrieved_docs):
#print(f" Doc {i+1}: {doc.page_content[:150]}... (Source: {doc.metadata.get('source')})")
#print("\n")
final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs)
#print(f"--- 4. Final {len(final_context_docs)} Parent Docs for LLM ---")
#for i, doc in enumerate(final_context_docs):
#print(f" Final Doc {i+1} (Source: {doc.metadata.get('source')}, Page: {doc.metadata.get('page')}):\n '{doc.page_content[:300]}...'\n---")
#print("="*50)
#print("--- INVOKING FINAL CHAIN ---")
#print("="*50 + "\n")
answer_string = chain_components["final_chain"].invoke({"question": question}, config=config)
return jsonify({'answer': answer_string})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
def clean_markdown_for_tts(text: str) -> str:
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
text = re.sub(r'\_(.*?)\_', r'\1', text)
text = re.sub(r'`(.*?)`', r'\1', text)
text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
text = re.sub(r'\n+', ' ', text)
return text.strip()
@app.route('/tts', methods=['POST'])
def text_to_speech():
data = request.get_json()
text = data.get('text')
if not text:
return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
try:
clean_text = clean_markdown_for_tts(text)
tts = gTTS(clean_text, lang='en')
mp3_fp = io.BytesIO()
tts.write_to_fp(mp3_fp)
mp3_fp.seek(0)
return Response(mp3_fp, mimetype='audio/mpeg')
except Exception as e:
print(f"Error in TTS generation: {e}")
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)