riteshraut
fix
f7d42c1
raw
history blame
14.5 kB
import os
import uuid
from flask import Flask, request, render_template, session, jsonify, Response
from werkzeug.utils import secure_filename
from rag_processor import create_rag_chain
from typing import Sequence, Any, List
import fitz
import re
import io
from gtts import gTTS
from langchain_core.documents import Document
from langchain_community.document_loaders import TextLoader, Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_community.retrievers import BM25Retriever
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.storage import InMemoryStore
from sentence_transformers.cross_encoder import CrossEncoder
app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)
# --- FIX: Use STRING keys for the dictionary ---
# Maps temperature strings (from the form) to the mode labels
TEMPERATURE_LABELS = {
"0.2": "Precise",
"0.4": "Confident",
"0.6": "Balanced",
"0.8": "Flexible",
"1.0": "Creative"
}
class LocalReranker(BaseDocumentCompressor):
model: Any
top_n: int = 5
class Config:
arbitrary_types_allowed = True
def compress_documents(
self, documents: Sequence[Document], query: str, callbacks=None
) -> Sequence[Document]:
if not documents:
return []
pairs = [[query, doc.page_content] for doc in documents]
scores = self.model.predict(pairs, show_progress_bar=False)
doc_scores = list(zip(documents, scores))
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
top_docs = []
for doc, score in sorted_doc_scores[: self.top_n]:
doc.metadata["rerank_score"] = float(score)
top_docs.append(doc)
return top_docs
def create_optimized_parent_child_chunks(all_docs):
if not all_docs:
print("❌ CHUNKING: No input documents provided!")
return [], [], []
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=900, chunk_overlap=200, separators=["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
)
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=350, chunk_overlap=80, separators=["\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
)
parent_docs = parent_splitter.split_documents(all_docs)
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
child_docs = []
for i, parent_doc in enumerate(parent_docs):
parent_id = doc_ids[i]
children = child_splitter.split_documents([parent_doc])
for j, child in enumerate(children):
child.metadata.update({
"doc_id": parent_id, "chunk_index": j, "total_chunks": len(children),
"is_first_chunk": j == 0, "is_last_chunk": j == len(children) - 1,
})
if len(children) > 1:
if j == 0: child.page_content = "[Beginning] " + child.page_content
elif j == len(children) - 1: child.page_content = "[Continues...] " + child.page_content
child_docs.append(child)
print(f"βœ… CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks.")
return parent_docs, child_docs, doc_ids
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
if not docs: return []
parent_scores, child_content_by_parent = {}, {}
for doc in docs:
parent_id = doc.metadata.get("doc_id")
if parent_id:
parent_scores[parent_id] = parent_scores.get(parent_id, 0) + 1
if parent_id not in child_content_by_parent: child_content_by_parent[parent_id] = []
child_content_by_parent[parent_id].append(doc.page_content)
parent_ids = list(parent_scores.keys())
parents = store.mget(parent_ids)
enhanced_parents = []
for i, parent in enumerate(parents):
if parent is not None:
parent_id = parent_ids[i]
if parent_id in child_content_by_parent:
child_excerpts = "\n".join(child_content_by_parent[parent_id][:3])
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
enhanced_parent = Document(
page_content=enhanced_content,
metadata={**parent.metadata, "child_relevance_score": parent_scores[parent_id], "matching_children": len(child_content_by_parent[parent_id])}
)
enhanced_parents.append(enhanced_parent)
else:
print(f"❌ PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
enhanced_parents.sort(key=lambda p: p.metadata.get("child_relevance_score", 0), reverse=True)
return enhanced_parents
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
try:
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
print(f"πŸ“ Upload folder ready: {app.config['UPLOAD_FOLDER']}")
except Exception as e:
print(f"❌ Failed to create upload folder, falling back to /tmp: {e}")
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
session_data = {}
message_histories = {}
print("πŸ”„ Loading embedding model...")
try:
EMBEDDING_MODEL = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
print("βœ… Embedding model loaded.")
except Exception as e: print(f"❌ FATAL: Could not load embedding model. Error: {e}"); raise e
print("πŸ”„ Loading reranker model...")
try:
RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
print("βœ… Reranker model loaded.")
except Exception as e: print(f"❌ FATAL: Could not load reranker model. Error: {e}"); raise e
def load_pdf_with_fallback(filepath):
try:
docs = []
with fitz.open(filepath) as pdf_doc:
for page_num, page in enumerate(pdf_doc):
text = page.get_text()
if text.strip(): docs.append(Document(page_content=text, metadata={"source": os.path.basename(filepath), "page": page_num + 1}))
if docs:
print(f"βœ… Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages")
return docs
else: raise ValueError("No text content found in PDF.")
except Exception as e: print(f"❌ PyMuPDF failed for {filepath}: {e}"); raise
LOADER_MAPPING = {".txt": TextLoader, ".pdf": load_pdf_with_fallback, ".docx": Docx2txtLoader}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory()
return message_histories[session_id]
@app.route('/health', methods=['GET'])
def health_check(): return jsonify({'status': 'healthy'}), 200
@app.route('/', methods=['GET'])
def index(): return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_files():
files = request.files.getlist('file')
# Get temperature as a string for the dictionary key
temperature_str = request.form.get('temperature', '0.2')
temperature = float(temperature_str) # Convert to float for the LLM
model_name = request.form.get('model_name', 'moonshotai/kimi-k2-instruct')
print(f"βš™οΈ UPLOAD: Model: {model_name}, Temp: {temperature}")
if not files or all(f.filename == '' for f in files):
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
all_docs, processed_files, failed_files = [], [], []
print(f"πŸ“ Processing {len(files)} file(s)...")
for file in files:
if file and file.filename:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
try:
file.save(filepath)
file_ext = os.path.splitext(filename)[1].lower()
if file_ext not in LOADER_MAPPING: raise ValueError("Unsupported file format.")
loader_func = LOADER_MAPPING[file_ext]
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
if not docs: raise ValueError("No content extracted.")
all_docs.extend(docs)
processed_files.append(filename)
except Exception as e:
print(f"βœ— Error processing {filename}: {e}")
failed_files.append(f"{filename} ({e})")
if not all_docs:
return jsonify({'status': 'error', 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"}), 400
print(f"βœ… UPLOAD: Processed {len(processed_files)} files.")
try:
print("πŸ”„ Starting RAG pipeline setup...")
parent_docs, child_docs, doc_ids = create_optimized_parent_child_chunks(all_docs)
if not child_docs: raise ValueError("No child documents created during chunking.")
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
store = InMemoryStore(); store.mset(list(zip(doc_ids, parent_docs)))
print(f"βœ… Indexed {len(child_docs)} document chunks.")
bm25_retriever = BM25Retriever.from_documents(child_docs); bm25_retriever.k = 12
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 12})
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4])
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
def get_parents(docs: List[Document]) -> List[Document]: return get_context_aware_parents(docs, store)
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
final_retriever = compression_retriever | get_parents
session_id = str(uuid.uuid4())
rag_chain, api_key_manager = create_rag_chain(
retriever=final_retriever, get_session_history_func=get_session_history,
model_name=model_name, temperature=temperature
)
# Store the float temperature
session_data[session_id] = {'chain': rag_chain, 'model_name': model_name, 'temperature': temperature, 'api_key_manager': api_key_manager}
success_msg = f"Processed: {', '.join(processed_files)}"
if failed_files: success_msg += f". Failed: {', '.join(failed_files)}"
# Get the mode label using the STRING key
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
print(f"βœ… UPLOAD COMPLETE: Session {session_id} is ready.")
# Return all info needed by the frontend
return jsonify({
'status': 'success',
'filename': success_msg,
'session_id': session_id,
'model_name': model_name,
'mode': mode_label
})
except Exception as e:
import traceback; traceback.print_exc()
return jsonify({'status': 'error', 'message': f'RAG setup failed: {e}'}), 500
@app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
question, session_id = data.get('question'), data.get('session_id') or session.get('session_id')
if not question: return jsonify({'status': 'error', 'message': 'No question provided.'}), 400
if not session_id or session_id not in session_data:
print(f"❌ CHAT: Invalid session {session_id}.")
return jsonify({'status': 'error', 'message': 'Invalid session. Please upload documents first.'}), 400
try:
session_info = session_data[session_id]
rag_chain = session_info['chain']
# --- START: BUGFIX & FEATURE UPDATE ---
# 1. Get model name from session
model_name = session_info['model_name']
# 2. Get temperature (float) and convert to string for lookup
temperature_float = session_info['temperature']
temperature_str = str(temperature_float)
# 3. Get the correct mode label
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
# --- END: BUGFIX & FEATURE UPDATE ---
print(f"πŸ’¬ CHAT: Invoking chain for session {session_id}...")
answer = rag_chain.invoke({"question": question}, config={"configurable": {"session_id": session_id}})
print(f"βœ… CHAT: Answer generated.")
# Return all info needed by the frontend
return jsonify({
'answer': answer,
'model_name': model_name,
'mode': mode_label
})
except Exception as e:
import traceback; traceback.print_exc()
return jsonify({'status': 'error', 'message': f'Error during chat: {e}'}), 500
def clean_markdown_for_tts(text: str) -> str:
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text); text = re.sub(r'\_(.*?)\_', r'\1', text)
text = re.sub(r'`(.*?)`', r'\1', text); text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
text = re.sub(r'\n+', ' ', text)
return text.strip()
@app.route('/tts', methods=['POST'])
def text_to_speech():
data = request.get_json(); text = data.get('text')
if not text: return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
try:
clean_text = clean_markdown_for_tts(text); tts = gTTS(clean_text, lang='en')
mp3_fp = io.BytesIO(); tts.write_to_fp(mp3_fp); mp3_fp.seek(0)
return Response(mp3_fp, mimetype='audio/mpeg')
except Exception as e:
print(f"❌ TTS Error: {e}")
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
print(f"πŸš€ Starting Flask app on port {port}")
app.run(host="0.0.0.0", port=port, debug=False, threaded=False)