|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
from flask import Flask, request, render_template, session, jsonify, Response |
|
|
from werkzeug.utils import secure_filename |
|
|
from rag_processor import create_rag_chain |
|
|
from typing import Sequence, Any, List |
|
|
import fitz |
|
|
import re |
|
|
import io |
|
|
from gtts import gTTS |
|
|
from langchain_core.documents import Document |
|
|
|
|
|
from langchain_community.document_loaders import ( |
|
|
TextLoader, |
|
|
Docx2txtLoader, |
|
|
) |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_experimental.text_splitter import SemanticChunker |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain.retrievers import EnsembleRetriever |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
|
from langchain.storage import InMemoryStore |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
app.config['SECRET_KEY'] = os.urandom(24) |
|
|
|
|
|
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU")) |
|
|
if is_hf_spaces: |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' |
|
|
else: |
|
|
app.config['UPLOAD_FOLDER'] = 'uploads' |
|
|
|
|
|
try: |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}") |
|
|
except Exception as e: |
|
|
print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}") |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}") |
|
|
|
|
|
rag_chains = {} |
|
|
message_histories = {} |
|
|
doc_stores = {} |
|
|
|
|
|
print("Loading embedding model...") |
|
|
try: |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
EMBEDDING_MODEL = HuggingFaceEmbeddings( |
|
|
model_name="google/embeddinggemma-300m", |
|
|
model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True}, |
|
|
) |
|
|
print("Embedding model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"FATAL: Could not load embedding model. Error: {e}") |
|
|
raise |
|
|
|
|
|
def load_pdf_with_fallback(filepath): |
|
|
try: |
|
|
docs = [] |
|
|
with fitz.open(filepath) as pdf_doc: |
|
|
for page_num, page in enumerate(pdf_doc): |
|
|
text = page.get_text() |
|
|
if text.strip(): |
|
|
docs.append(Document( |
|
|
page_content=text, |
|
|
metadata={ |
|
|
"source": os.path.basename(filepath), |
|
|
"page": page_num + 1, |
|
|
} |
|
|
)) |
|
|
if docs: |
|
|
print(f"Successfully loaded PDF with PyMuPDF: {filepath}") |
|
|
return docs |
|
|
else: |
|
|
raise ValueError("No text content found in PDF.") |
|
|
except Exception as e: |
|
|
print(f"PyMuPDF failed for {filepath}: {e}") |
|
|
raise |
|
|
|
|
|
LOADER_MAPPING = { |
|
|
".txt": TextLoader, |
|
|
".pdf": load_pdf_with_fallback, |
|
|
".docx": Docx2txtLoader, |
|
|
} |
|
|
|
|
|
def get_session_history(session_id: str) -> ChatMessageHistory: |
|
|
if session_id not in message_histories: |
|
|
message_histories[session_id] = ChatMessageHistory() |
|
|
return message_histories[session_id] |
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
|
def health_check(): |
|
|
return jsonify({'status': 'healthy'}), 200 |
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
|
def index(): |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/upload', methods=['POST']) |
|
|
def upload_files(): |
|
|
files = request.files.getlist('file') |
|
|
if not files or all(f.filename == '' for f in files): |
|
|
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400 |
|
|
|
|
|
all_docs = [] |
|
|
processed_files, failed_files = [], [] |
|
|
|
|
|
for file in files: |
|
|
if file and file.filename: |
|
|
filename = secure_filename(file.filename) |
|
|
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
|
|
try: |
|
|
file.save(filepath) |
|
|
file_ext = os.path.splitext(filename)[1].lower() |
|
|
if file_ext not in LOADER_MAPPING: |
|
|
raise ValueError("Unsupported file format.") |
|
|
|
|
|
loader_func = LOADER_MAPPING[file_ext] |
|
|
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load() |
|
|
|
|
|
if not docs: |
|
|
raise ValueError("No content extracted.") |
|
|
|
|
|
all_docs.extend(docs) |
|
|
processed_files.append(filename) |
|
|
print(f"✓ Successfully processed: {filename}") |
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
print(f"✗ Error processing {filename}: {error_msg}") |
|
|
failed_files.append(f"{filename} ({error_msg})") |
|
|
|
|
|
if not all_docs: |
|
|
error_summary = "Failed to process all files." |
|
|
if failed_files: |
|
|
error_summary += " Reasons: " + ", ".join(failed_files) |
|
|
return jsonify({'status': 'error', 'message': error_summary}), 400 |
|
|
|
|
|
try: |
|
|
print("Starting RAG pipeline setup...") |
|
|
|
|
|
parent_splitter =RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=300, |
|
|
separators=["\n\n", "\n", ". ", " ", ""], |
|
|
length_function=len) |
|
|
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100) |
|
|
|
|
|
parent_docs = parent_splitter.split_documents(all_docs) |
|
|
doc_ids = [str(uuid.uuid4()) for _ in parent_docs] |
|
|
|
|
|
child_docs = [] |
|
|
for i, doc in enumerate(parent_docs): |
|
|
_id = doc_ids[i] |
|
|
sub_docs = child_splitter.split_documents([doc]) |
|
|
for child in sub_docs: |
|
|
child.metadata["doc_id"] = _id |
|
|
child_docs.extend(sub_docs) |
|
|
|
|
|
store = InMemoryStore() |
|
|
store.mset(list(zip(doc_ids, parent_docs))) |
|
|
|
|
|
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) |
|
|
|
|
|
print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.") |
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(child_docs) |
|
|
bm25_retriever.k = 3 |
|
|
|
|
|
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
|
retrievers=[bm25_retriever, faiss_retriever], |
|
|
weights=[0.4, 0.6] |
|
|
) |
|
|
print("Created Hybrid Retriever for child documents.") |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
doc_stores[session_id] = store |
|
|
|
|
|
rag_chain_components = create_rag_chain(ensemble_retriever, get_session_history, EMBEDDING_MODEL, store) |
|
|
|
|
|
rag_chains[session_id] = rag_chain_components |
|
|
session['session_id'] = session_id |
|
|
|
|
|
success_msg = f"Successfully processed: {', '.join(processed_files)}" |
|
|
if failed_files: |
|
|
success_msg += f"\nFailed to process: {', '.join(failed_files)}" |
|
|
|
|
|
return jsonify({ |
|
|
'status': 'success', |
|
|
'filename': success_msg, |
|
|
'session_id': session_id |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500 |
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
data = request.get_json() |
|
|
question = data.get('question') |
|
|
session_id = session.get('session_id') or data.get('session_id') |
|
|
|
|
|
if not question or not session_id or session_id not in rag_chains: |
|
|
return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400 |
|
|
|
|
|
try: |
|
|
chain_components = rag_chains[session_id] |
|
|
config = {"configurable": {"session_id": session_id}} |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("--- STARTING DIAGNOSTIC RUN ---") |
|
|
print(f"Original Question: {question}") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages}) |
|
|
|
|
|
|
|
|
hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query}) |
|
|
|
|
|
|
|
|
final_retrieved_docs = chain_components["base_retriever"].get_relevant_documents(hyde_doc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answer_string = chain_components["final_chain"].invoke({"question": question}, config=config) |
|
|
|
|
|
return jsonify({'answer': answer_string}) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500 |
|
|
|
|
|
def clean_markdown_for_tts(text: str) -> str: |
|
|
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text) |
|
|
text = re.sub(r'\_(.*?)\_', r'\1', text) |
|
|
text = re.sub(r'`(.*?)`', r'\1', text) |
|
|
text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'\n+', ' ', text) |
|
|
return text.strip() |
|
|
|
|
|
@app.route('/tts', methods=['POST']) |
|
|
def text_to_speech(): |
|
|
data = request.get_json() |
|
|
text = data.get('text') |
|
|
|
|
|
if not text: |
|
|
return jsonify({'status': 'error', 'message': 'No text provided.'}), 400 |
|
|
|
|
|
try: |
|
|
clean_text = clean_markdown_for_tts(text) |
|
|
tts = gTTS(clean_text, lang='en') |
|
|
mp3_fp = io.BytesIO() |
|
|
tts.write_to_fp(mp3_fp) |
|
|
mp3_fp.seek(0) |
|
|
return Response(mp3_fp, mimetype='audio/mpeg') |
|
|
except Exception as e: |
|
|
print(f"Error in TTS generation: {e}") |
|
|
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
app.run(host="0.0.0.0", port=port, debug=False) |