import os import time import uuid from flask import Flask, request, render_template, session, jsonify, Response from werkzeug.utils import secure_filename from rag_processor import create_rag_chain from typing import Sequence, Any, List import fitz import re import io from gtts import gTTS from langchain_core.documents import Document from langchain_community.document_loaders import ( TextLoader, Docx2txtLoader, ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_experimental.text_splitter import SemanticChunker from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.retrievers import EnsembleRetriever from langchain_community.retrievers import BM25Retriever from langchain_community.chat_message_histories import ChatMessageHistory from langchain.storage import InMemoryStore app = Flask(__name__) app.config['SECRET_KEY'] = os.urandom(24) is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU")) if is_hf_spaces: app.config['UPLOAD_FOLDER'] = '/tmp/uploads' else: app.config['UPLOAD_FOLDER'] = 'uploads' try: os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") except Exception as e: print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}") app.config['UPLOAD_FOLDER'] = '/tmp/uploads' os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}") rag_chains = {} message_histories = {} doc_stores = {} # To hold the InMemoryStore for each session print("Loading embedding model...") try: hf_token = os.getenv("HF_TOKEN") EMBEDDING_MODEL = HuggingFaceEmbeddings( model_name="google/embeddinggemma-300m", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}, ) print("Embedding model loaded successfully.") except Exception as e: print(f"FATAL: Could not load embedding model. Error: {e}") raise def load_pdf_with_fallback(filepath): try: docs = [] with fitz.open(filepath) as pdf_doc: for page_num, page in enumerate(pdf_doc): text = page.get_text() if text.strip(): docs.append(Document( page_content=text, metadata={ "source": os.path.basename(filepath), "page": page_num + 1, } )) if docs: print(f"Successfully loaded PDF with PyMuPDF: {filepath}") return docs else: raise ValueError("No text content found in PDF.") except Exception as e: print(f"PyMuPDF failed for {filepath}: {e}") raise LOADER_MAPPING = { ".txt": TextLoader, ".pdf": load_pdf_with_fallback, ".docx": Docx2txtLoader, } def get_session_history(session_id: str) -> ChatMessageHistory: if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory() return message_histories[session_id] @app.route('/health', methods=['GET']) def health_check(): return jsonify({'status': 'healthy'}), 200 @app.route('/', methods=['GET']) def index(): return render_template('index.html') @app.route('/upload', methods=['POST']) def upload_files(): files = request.files.getlist('file') if not files or all(f.filename == '' for f in files): return jsonify({'status': 'error', 'message': 'No selected files.'}), 400 all_docs = [] processed_files, failed_files = [], [] for file in files: if file and file.filename: filename = secure_filename(file.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) try: file.save(filepath) file_ext = os.path.splitext(filename)[1].lower() if file_ext not in LOADER_MAPPING: raise ValueError("Unsupported file format.") loader_func = LOADER_MAPPING[file_ext] docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load() if not docs: raise ValueError("No content extracted.") all_docs.extend(docs) processed_files.append(filename) print(f"✓ Successfully processed: {filename}") except Exception as e: error_msg = str(e) print(f"✗ Error processing {filename}: {error_msg}") failed_files.append(f"{filename} ({error_msg})") if not all_docs: error_summary = "Failed to process all files." if failed_files: error_summary += " Reasons: " + ", ".join(failed_files) return jsonify({'status': 'error', 'message': error_summary}), 400 try: print("Starting RAG pipeline setup...") parent_splitter =RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=300, separators=["\n\n", "\n", ". ", " ", ""], # Prioritize natural breaks length_function=len) child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100) parent_docs = parent_splitter.split_documents(all_docs) doc_ids = [str(uuid.uuid4()) for _ in parent_docs] child_docs = [] for i, doc in enumerate(parent_docs): _id = doc_ids[i] sub_docs = child_splitter.split_documents([doc]) for child in sub_docs: child.metadata["doc_id"] = _id child_docs.extend(sub_docs) store = InMemoryStore() store.mset(list(zip(doc_ids, parent_docs))) vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.") bm25_retriever = BM25Retriever.from_documents(child_docs) bm25_retriever.k = 3 faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.4, 0.6] ) print("Created Hybrid Retriever for child documents.") session_id = str(uuid.uuid4()) doc_stores[session_id] = store rag_chain_components = create_rag_chain(ensemble_retriever, get_session_history, EMBEDDING_MODEL, store) rag_chains[session_id] = rag_chain_components session['session_id'] = session_id success_msg = f"Successfully processed: {', '.join(processed_files)}" if failed_files: success_msg += f"\nFailed to process: {', '.join(failed_files)}" return jsonify({ 'status': 'success', 'filename': success_msg, 'session_id': session_id }) except Exception as e: import traceback traceback.print_exc() return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500 @app.route('/chat', methods=['POST']) def chat(): data = request.get_json() question = data.get('question') session_id = session.get('session_id') or data.get('session_id') if not question or not session_id or session_id not in rag_chains: return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400 try: chain_components = rag_chains[session_id] config = {"configurable": {"session_id": session_id}} print("\n" + "="*50) print("--- STARTING DIAGNOSTIC RUN ---") print(f"Original Question: {question}") print("="*50 + "\n") rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages}) #print(f"--- 1. Rewritten Query ---\n{rewritten_query}\n") hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query}) #print(f"--- 2. HyDE Document ---\n{hyde_doc}\n") final_retrieved_docs = chain_components["base_retriever"].get_relevant_documents(hyde_doc) #print(f"--- 3. Retrieved Top {len(final_retrieved_docs)} Child Docs ---") #for i, doc in enumerate(final_retrieved_docs): #print(f" Doc {i+1}: {doc.page_content[:150]}... (Source: {doc.metadata.get('source')})") #print("\n") final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs) #print(f"--- 4. Final {len(final_context_docs)} Parent Docs for LLM ---") #for i, doc in enumerate(final_context_docs): #print(f" Final Doc {i+1} (Source: {doc.metadata.get('source')}, Page: {doc.metadata.get('page')}):\n '{doc.page_content[:300]}...'\n---") #print("="*50) #print("--- INVOKING FINAL CHAIN ---") #print("="*50 + "\n") answer_string = chain_components["final_chain"].invoke({"question": question}, config=config) return jsonify({'answer': answer_string}) except Exception as e: import traceback traceback.print_exc() return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500 def clean_markdown_for_tts(text: str) -> str: text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text) text = re.sub(r'\_(.*?)\_', r'\1', text) text = re.sub(r'`(.*?)`', r'\1', text) text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE) text = re.sub(r'\n+', ' ', text) return text.strip() @app.route('/tts', methods=['POST']) def text_to_speech(): data = request.get_json() text = data.get('text') if not text: return jsonify({'status': 'error', 'message': 'No text provided.'}), 400 try: clean_text = clean_markdown_for_tts(text) tts = gTTS(clean_text, lang='en') mp3_fp = io.BytesIO() tts.write_to_fp(mp3_fp) mp3_fp.seek(0) return Response(mp3_fp, mimetype='audio/mpeg') except Exception as e: print(f"Error in TTS generation: {e}") return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500 if __name__ == '__main__': port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)