File size: 10,924 Bytes
becc8f7
 
 
 
 
 
 
 
 
 
 
 
57bb94b
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57bb94b
becc8f7
2c5dd57
57bb94b
 
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a743656
 
 
dbb87ba
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af77b5d
becc8f7
af77b5d
becc8f7
 
 
af77b5d
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27481ac
becc8f7
 
27481ac
becc8f7
 
27481ac
 
 
 
becc8f7
 
27481ac
 
 
becc8f7
27481ac
 
 
becc8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import time
import uuid
from flask import Flask, request, render_template, session, jsonify, Response
from werkzeug.utils import secure_filename
from rag_processor import create_rag_chain
from typing import Sequence, Any, List
import fitz
import re
import io
from gtts import gTTS
from langchain_core.documents import Document

from langchain_community.document_loaders import (
    TextLoader,
    Docx2txtLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.storage import InMemoryStore


app = Flask(__name__)
app.config['SECRET_KEY'] = os.urandom(24)

is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
if is_hf_spaces:
    app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
else:
    app.config['UPLOAD_FOLDER'] = 'uploads'

try:
    os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
    print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
except Exception as e:
    print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
    app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
    os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
    print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")

rag_chains = {}
message_histories = {}
doc_stores = {} # To hold the InMemoryStore for each session

print("Loading embedding model...")
try:
    hf_token = os.getenv("HF_TOKEN")
    EMBEDDING_MODEL = HuggingFaceEmbeddings(
        model_name="google/embeddinggemma-300m",
        model_kwargs={'device': 'cpu'},
        encode_kwargs={'normalize_embeddings': True},
    )
    print("Embedding model loaded successfully.")
except Exception as e:
    print(f"FATAL: Could not load embedding model. Error: {e}")
    raise

def load_pdf_with_fallback(filepath):
    try:
        docs = []
        with fitz.open(filepath) as pdf_doc:
            for page_num, page in enumerate(pdf_doc):
                text = page.get_text()
                if text.strip():
                    docs.append(Document(
                        page_content=text,
                        metadata={
                            "source": os.path.basename(filepath),
                            "page": page_num + 1,
                        }
                    ))
        if docs:
            print(f"Successfully loaded PDF with PyMuPDF: {filepath}")
            return docs
        else:
            raise ValueError("No text content found in PDF.")
    except Exception as e:
        print(f"PyMuPDF failed for {filepath}: {e}")
        raise

LOADER_MAPPING = {
    ".txt": TextLoader,
    ".pdf": load_pdf_with_fallback,
    ".docx": Docx2txtLoader,
}

def get_session_history(session_id: str) -> ChatMessageHistory:
    if session_id not in message_histories:
        message_histories[session_id] = ChatMessageHistory()
    return message_histories[session_id]

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({'status': 'healthy'}), 200

@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/upload', methods=['POST'])
def upload_files():
    files = request.files.getlist('file')
    if not files or all(f.filename == '' for f in files):
        return jsonify({'status': 'error', 'message': 'No selected files.'}), 400

    all_docs = []
    processed_files, failed_files = [], []

    for file in files:
        if file and file.filename:
            filename = secure_filename(file.filename)
            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            try:
                file.save(filepath)
                file_ext = os.path.splitext(filename)[1].lower()
                if file_ext not in LOADER_MAPPING:
                    raise ValueError("Unsupported file format.")
                
                loader_func = LOADER_MAPPING[file_ext]
                docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
                
                if not docs:
                    raise ValueError("No content extracted.")

                all_docs.extend(docs)
                processed_files.append(filename)
                print(f"✓ Successfully processed: {filename}")
            except Exception as e:
                error_msg = str(e)
                print(f"✗ Error processing {filename}: {error_msg}")
                failed_files.append(f"{filename} ({error_msg})")

    if not all_docs:
        error_summary = "Failed to process all files."
        if failed_files:
            error_summary += " Reasons: " + ", ".join(failed_files)
        return jsonify({'status': 'error', 'message': error_summary}), 400

    try:
        print("Starting RAG pipeline setup...")
        
        parent_splitter =RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=300,
                                                        separators=["\n\n", "\n", ". ", " ", ""],  # Prioritize natural breaks
                                                        length_function=len)
        child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100)

        parent_docs = parent_splitter.split_documents(all_docs)
        doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
        
        child_docs = []
        for i, doc in enumerate(parent_docs):
            _id = doc_ids[i]
            sub_docs = child_splitter.split_documents([doc])
            for child in sub_docs:
                child.metadata["doc_id"] = _id
            child_docs.extend(sub_docs)
        
        store = InMemoryStore()
        store.mset(list(zip(doc_ids, parent_docs)))
        
        vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
        
        print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")

        bm25_retriever = BM25Retriever.from_documents(child_docs)
        bm25_retriever.k = 3

        faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

        ensemble_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, faiss_retriever],
            weights=[0.4, 0.6]
        )
        print("Created Hybrid Retriever for child documents.")
        
        session_id = str(uuid.uuid4())
        
        doc_stores[session_id] = store
        
        rag_chain_components = create_rag_chain(ensemble_retriever, get_session_history, EMBEDDING_MODEL, store)
        
        rag_chains[session_id] = rag_chain_components
        session['session_id'] = session_id

        success_msg = f"Successfully processed: {', '.join(processed_files)}"
        if failed_files:
            success_msg += f"\nFailed to process: {', '.join(failed_files)}"

        return jsonify({
            'status': 'success', 
            'filename': success_msg,
            'session_id': session_id
        })

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500

@app.route('/chat', methods=['POST'])
def chat():
    data = request.get_json()
    question = data.get('question')
    session_id = session.get('session_id') or data.get('session_id')

    if not question or not session_id or session_id not in rag_chains:
        return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400

    try:
        chain_components = rag_chains[session_id]
        config = {"configurable": {"session_id": session_id}}
        
        print("\n" + "="*50)
        print("--- STARTING DIAGNOSTIC RUN ---")
        print(f"Original Question: {question}")
        print("="*50 + "\n")

        rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages})
        #print(f"--- 1. Rewritten Query ---\n{rewritten_query}\n")

        hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query})
        #print(f"--- 2. HyDE Document ---\n{hyde_doc}\n")
        
        final_retrieved_docs = chain_components["base_retriever"].get_relevant_documents(hyde_doc)
        #print(f"--- 3. Retrieved Top {len(final_retrieved_docs)} Child Docs ---")
        #for i, doc in enumerate(final_retrieved_docs):
            #print(f"  Doc {i+1}: {doc.page_content[:150]}... (Source: {doc.metadata.get('source')})")
        #print("\n")

        final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs)
        #print(f"--- 4. Final {len(final_context_docs)} Parent Docs for LLM ---")
        #for i, doc in enumerate(final_context_docs):
            #print(f"  Final Doc {i+1} (Source: {doc.metadata.get('source')}, Page: {doc.metadata.get('page')}):\n  '{doc.page_content[:300]}...'\n---")
        
        #print("="*50)
        #print("--- INVOKING FINAL CHAIN ---")
        #print("="*50 + "\n")
        
        answer_string = chain_components["final_chain"].invoke({"question": question}, config=config)
        
        return jsonify({'answer': answer_string})

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500

def clean_markdown_for_tts(text: str) -> str:
    text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
    text = re.sub(r'\_(.*?)\_', r'\1', text)
    text = re.sub(r'`(.*?)`', r'\1', text)
    text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
    text = re.sub(r'\n+', ' ', text)
    return text.strip()

@app.route('/tts', methods=['POST'])
def text_to_speech():
    data = request.get_json()
    text = data.get('text')

    if not text:
        return jsonify({'status': 'error', 'message': 'No text provided.'}), 400

    try:
        clean_text = clean_markdown_for_tts(text)
        tts = gTTS(clean_text, lang='en')
        mp3_fp = io.BytesIO()
        tts.write_to_fp(mp3_fp)
        mp3_fp.seek(0)
        return Response(mp3_fp, mimetype='audio/mpeg')
    except Exception as e:
        print(f"Error in TTS generation: {e}")
        return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500

if __name__ == '__main__':
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)