|
|
import os |
|
|
import uuid |
|
|
from flask import Flask, request, render_template, session, jsonify, Response |
|
|
from werkzeug.utils import secure_filename |
|
|
from rag_processor import create_rag_chain |
|
|
from typing import Sequence, Any, List |
|
|
import fitz |
|
|
import re |
|
|
import io |
|
|
from gtts import gTTS |
|
|
from langchain_core.documents import Document |
|
|
from langchain_community.document_loaders import TextLoader, Docx2txtLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever |
|
|
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
|
from langchain.storage import InMemoryStore |
|
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
|
|
|
|
app = Flask(__name__) |
|
|
app.config['SECRET_KEY'] = os.urandom(24) |
|
|
|
|
|
|
|
|
|
|
|
TEMPERATURE_LABELS = { |
|
|
"0.2": "Precise", |
|
|
"0.4": "Confident", |
|
|
"0.6": "Balanced", |
|
|
"0.8": "Flexible", |
|
|
"1.0": "Creative" |
|
|
} |
|
|
|
|
|
class LocalReranker(BaseDocumentCompressor): |
|
|
model: Any |
|
|
top_n: int = 5 |
|
|
|
|
|
class Config: |
|
|
arbitrary_types_allowed = True |
|
|
|
|
|
def compress_documents( |
|
|
self, documents: Sequence[Document], query: str, callbacks=None |
|
|
) -> Sequence[Document]: |
|
|
if not documents: |
|
|
return [] |
|
|
pairs = [[query, doc.page_content] for doc in documents] |
|
|
scores = self.model.predict(pairs, show_progress_bar=False) |
|
|
doc_scores = list(zip(documents, scores)) |
|
|
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True) |
|
|
top_docs = [] |
|
|
for doc, score in sorted_doc_scores[: self.top_n]: |
|
|
doc.metadata["rerank_score"] = float(score) |
|
|
top_docs.append(doc) |
|
|
return top_docs |
|
|
|
|
|
def create_optimized_parent_child_chunks(all_docs): |
|
|
if not all_docs: |
|
|
print("β CHUNKING: No input documents provided!") |
|
|
return [], [], [] |
|
|
|
|
|
parent_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=900, chunk_overlap=200, separators=["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""] |
|
|
) |
|
|
child_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=350, chunk_overlap=80, separators=["\n", ". ", "! ", "? ", "; ", ", ", " ", ""] |
|
|
) |
|
|
parent_docs = parent_splitter.split_documents(all_docs) |
|
|
doc_ids = [str(uuid.uuid4()) for _ in parent_docs] |
|
|
child_docs = [] |
|
|
|
|
|
for i, parent_doc in enumerate(parent_docs): |
|
|
parent_id = doc_ids[i] |
|
|
children = child_splitter.split_documents([parent_doc]) |
|
|
for j, child in enumerate(children): |
|
|
child.metadata.update({ |
|
|
"doc_id": parent_id, "chunk_index": j, "total_chunks": len(children), |
|
|
"is_first_chunk": j == 0, "is_last_chunk": j == len(children) - 1, |
|
|
}) |
|
|
if len(children) > 1: |
|
|
if j == 0: child.page_content = "[Beginning] " + child.page_content |
|
|
elif j == len(children) - 1: child.page_content = "[Continues...] " + child.page_content |
|
|
child_docs.append(child) |
|
|
|
|
|
print(f"β
CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks.") |
|
|
return parent_docs, child_docs, doc_ids |
|
|
|
|
|
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]: |
|
|
if not docs: return [] |
|
|
parent_scores, child_content_by_parent = {}, {} |
|
|
for doc in docs: |
|
|
parent_id = doc.metadata.get("doc_id") |
|
|
if parent_id: |
|
|
parent_scores[parent_id] = parent_scores.get(parent_id, 0) + 1 |
|
|
if parent_id not in child_content_by_parent: child_content_by_parent[parent_id] = [] |
|
|
child_content_by_parent[parent_id].append(doc.page_content) |
|
|
|
|
|
parent_ids = list(parent_scores.keys()) |
|
|
parents = store.mget(parent_ids) |
|
|
enhanced_parents = [] |
|
|
|
|
|
for i, parent in enumerate(parents): |
|
|
if parent is not None: |
|
|
parent_id = parent_ids[i] |
|
|
if parent_id in child_content_by_parent: |
|
|
child_excerpts = "\n".join(child_content_by_parent[parent_id][:3]) |
|
|
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}" |
|
|
enhanced_parent = Document( |
|
|
page_content=enhanced_content, |
|
|
metadata={**parent.metadata, "child_relevance_score": parent_scores[parent_id], "matching_children": len(child_content_by_parent[parent_id])} |
|
|
) |
|
|
enhanced_parents.append(enhanced_parent) |
|
|
else: |
|
|
print(f"β PARENT_FETCH: Parent {parent_ids[i]} not found in store!") |
|
|
|
|
|
enhanced_parents.sort(key=lambda p: p.metadata.get("child_relevance_score", 0), reverse=True) |
|
|
return enhanced_parents |
|
|
|
|
|
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU")) |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads' |
|
|
|
|
|
try: |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
print(f"π Upload folder ready: {app.config['UPLOAD_FOLDER']}") |
|
|
except Exception as e: |
|
|
print(f"β Failed to create upload folder, falling back to /tmp: {e}") |
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' |
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
|
|
|
session_data = {} |
|
|
message_histories = {} |
|
|
|
|
|
print("π Loading embedding model...") |
|
|
try: |
|
|
EMBEDDING_MODEL = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
|
model_kwargs={'device': 'cpu'}, |
|
|
encode_kwargs={'normalize_embeddings': True} |
|
|
) |
|
|
print("β
Embedding model loaded.") |
|
|
except Exception as e: print(f"β FATAL: Could not load embedding model. Error: {e}"); raise e |
|
|
|
|
|
print("π Loading reranker model...") |
|
|
try: |
|
|
RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu') |
|
|
print("β
Reranker model loaded.") |
|
|
except Exception as e: print(f"β FATAL: Could not load reranker model. Error: {e}"); raise e |
|
|
|
|
|
def load_pdf_with_fallback(filepath): |
|
|
try: |
|
|
docs = [] |
|
|
with fitz.open(filepath) as pdf_doc: |
|
|
for page_num, page in enumerate(pdf_doc): |
|
|
text = page.get_text() |
|
|
if text.strip(): docs.append(Document(page_content=text, metadata={"source": os.path.basename(filepath), "page": page_num + 1})) |
|
|
if docs: |
|
|
print(f"β
Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages") |
|
|
return docs |
|
|
else: raise ValueError("No text content found in PDF.") |
|
|
except Exception as e: print(f"β PyMuPDF failed for {filepath}: {e}"); raise |
|
|
|
|
|
LOADER_MAPPING = {".txt": TextLoader, ".pdf": load_pdf_with_fallback, ".docx": Docx2txtLoader} |
|
|
|
|
|
def get_session_history(session_id: str) -> ChatMessageHistory: |
|
|
if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory() |
|
|
return message_histories[session_id] |
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
|
def health_check(): return jsonify({'status': 'healthy'}), 200 |
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
|
def index(): return render_template('index.html') |
|
|
|
|
|
@app.route('/upload', methods=['POST']) |
|
|
def upload_files(): |
|
|
files = request.files.getlist('file') |
|
|
|
|
|
|
|
|
temperature_str = request.form.get('temperature', '0.2') |
|
|
temperature = float(temperature_str) |
|
|
model_name = request.form.get('model_name', 'moonshotai/kimi-k2-instruct') |
|
|
print(f"βοΈ UPLOAD: Model: {model_name}, Temp: {temperature}") |
|
|
|
|
|
if not files or all(f.filename == '' for f in files): |
|
|
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400 |
|
|
|
|
|
all_docs, processed_files, failed_files = [], [], [] |
|
|
print(f"π Processing {len(files)} file(s)...") |
|
|
for file in files: |
|
|
if file and file.filename: |
|
|
filename = secure_filename(file.filename) |
|
|
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
|
|
try: |
|
|
file.save(filepath) |
|
|
file_ext = os.path.splitext(filename)[1].lower() |
|
|
if file_ext not in LOADER_MAPPING: raise ValueError("Unsupported file format.") |
|
|
loader_func = LOADER_MAPPING[file_ext] |
|
|
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load() |
|
|
if not docs: raise ValueError("No content extracted.") |
|
|
all_docs.extend(docs) |
|
|
processed_files.append(filename) |
|
|
except Exception as e: |
|
|
print(f"β Error processing {filename}: {e}") |
|
|
failed_files.append(f"{filename} ({e})") |
|
|
|
|
|
if not all_docs: |
|
|
return jsonify({'status': 'error', 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"}), 400 |
|
|
|
|
|
print(f"β
UPLOAD: Processed {len(processed_files)} files.") |
|
|
try: |
|
|
print("π Starting RAG pipeline setup...") |
|
|
parent_docs, child_docs, doc_ids = create_optimized_parent_child_chunks(all_docs) |
|
|
if not child_docs: raise ValueError("No child documents created during chunking.") |
|
|
|
|
|
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL) |
|
|
store = InMemoryStore(); store.mset(list(zip(doc_ids, parent_docs))) |
|
|
print(f"β
Indexed {len(child_docs)} document chunks.") |
|
|
|
|
|
bm25_retriever = BM25Retriever.from_documents(child_docs); bm25_retriever.k = 12 |
|
|
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 12}) |
|
|
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4]) |
|
|
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5) |
|
|
def get_parents(docs: List[Document]) -> List[Document]: return get_context_aware_parents(docs, store) |
|
|
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever) |
|
|
final_retriever = compression_retriever | get_parents |
|
|
|
|
|
session_id = str(uuid.uuid4()) |
|
|
rag_chain, api_key_manager = create_rag_chain( |
|
|
retriever=final_retriever, get_session_history_func=get_session_history, |
|
|
model_name=model_name, temperature=temperature |
|
|
) |
|
|
|
|
|
session_data[session_id] = {'chain': rag_chain, 'model_name': model_name, 'temperature': temperature, 'api_key_manager': api_key_manager} |
|
|
|
|
|
success_msg = f"Processed: {', '.join(processed_files)}" |
|
|
if failed_files: success_msg += f". Failed: {', '.join(failed_files)}" |
|
|
|
|
|
|
|
|
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) |
|
|
|
|
|
print(f"β
UPLOAD COMPLETE: Session {session_id} is ready.") |
|
|
|
|
|
return jsonify({ |
|
|
'status': 'success', |
|
|
'filename': success_msg, |
|
|
'session_id': session_id, |
|
|
'model_name': model_name, |
|
|
'mode': mode_label |
|
|
}) |
|
|
except Exception as e: |
|
|
import traceback; traceback.print_exc() |
|
|
return jsonify({'status': 'error', 'message': f'RAG setup failed: {e}'}), 500 |
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
data = request.get_json() |
|
|
question, session_id = data.get('question'), data.get('session_id') or session.get('session_id') |
|
|
|
|
|
if not question: return jsonify({'status': 'error', 'message': 'No question provided.'}), 400 |
|
|
if not session_id or session_id not in session_data: |
|
|
print(f"β CHAT: Invalid session {session_id}.") |
|
|
return jsonify({'status': 'error', 'message': 'Invalid session. Please upload documents first.'}), 400 |
|
|
|
|
|
try: |
|
|
session_info = session_data[session_id] |
|
|
rag_chain = session_info['chain'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = session_info['model_name'] |
|
|
|
|
|
|
|
|
temperature_float = session_info['temperature'] |
|
|
temperature_str = str(temperature_float) |
|
|
|
|
|
|
|
|
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str) |
|
|
|
|
|
|
|
|
|
|
|
print(f"π¬ CHAT: Invoking chain for session {session_id}...") |
|
|
answer = rag_chain.invoke({"question": question}, config={"configurable": {"session_id": session_id}}) |
|
|
print(f"β
CHAT: Answer generated.") |
|
|
|
|
|
|
|
|
return jsonify({ |
|
|
'answer': answer, |
|
|
'model_name': model_name, |
|
|
'mode': mode_label |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback; traceback.print_exc() |
|
|
return jsonify({'status': 'error', 'message': f'Error during chat: {e}'}), 500 |
|
|
|
|
|
def clean_markdown_for_tts(text: str) -> str: |
|
|
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text); text = re.sub(r'\_(.*?)\_', r'\1', text) |
|
|
text = re.sub(r'`(.*?)`', r'\1', text); text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE) |
|
|
text = re.sub(r'\n+', ' ', text) |
|
|
return text.strip() |
|
|
|
|
|
@app.route('/tts', methods=['POST']) |
|
|
def text_to_speech(): |
|
|
data = request.get_json(); text = data.get('text') |
|
|
if not text: return jsonify({'status': 'error', 'message': 'No text provided.'}), 400 |
|
|
try: |
|
|
clean_text = clean_markdown_for_tts(text); tts = gTTS(clean_text, lang='en') |
|
|
mp3_fp = io.BytesIO(); tts.write_to_fp(mp3_fp); mp3_fp.seek(0) |
|
|
return Response(mp3_fp, mimetype='audio/mpeg') |
|
|
except Exception as e: |
|
|
print(f"β TTS Error: {e}") |
|
|
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
print(f"π Starting Flask app on port {port}") |
|
|
app.run(host="0.0.0.0", port=port, debug=False, threaded=False) |