riteshraut commited on
Commit
f7d42c1
·
1 Parent(s): 2f22e27
Files changed (8) hide show
  1. - Copy.gitattributes +0 -35
  2. - Copy.gitignore +0 -6
  3. .env - Copy.example +0 -2
  4. app.py +207 -172
  5. evaluate.py +205 -0
  6. query_expansion.py +525 -0
  7. rag_processor.py +446 -60
  8. templates/index.html +230 -345
- Copy.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
- Copy.gitignore DELETED
@@ -1,6 +0,0 @@
1
- .env
2
- /uploads/
3
- /vectorstores/
4
- /.cache/
5
- __pycache__/
6
- *.pyc
 
 
 
 
 
 
 
.env - Copy.example DELETED
@@ -1,2 +0,0 @@
1
- # Copy this file to .env and fill in your API key
2
- GROQ_API_KEY=your_groq_api_key_here
 
 
 
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import time
3
  import uuid
4
  from flask import Flask, request, render_template, session, jsonify, Response
5
  from werkzeug.utils import secure_filename
@@ -10,55 +9,143 @@ import re
10
  import io
11
  from gtts import gTTS
12
  from langchain_core.documents import Document
13
-
14
- from langchain_community.document_loaders import (
15
- TextLoader,
16
- Docx2txtLoader,
17
- )
18
  from langchain.text_splitter import RecursiveCharacterTextSplitter
19
- from langchain_experimental.text_splitter import SemanticChunker
20
  from langchain_huggingface import HuggingFaceEmbeddings
21
  from langchain_community.vectorstores import FAISS
22
- from langchain.retrievers import EnsembleRetriever
 
23
  from langchain_community.retrievers import BM25Retriever
24
  from langchain_community.chat_message_histories import ChatMessageHistory
25
  from langchain.storage import InMemoryStore
26
-
27
 
28
  app = Flask(__name__)
29
  app.config['SECRET_KEY'] = os.urandom(24)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
32
- if is_hf_spaces:
33
- app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
34
- else:
35
- app.config['UPLOAD_FOLDER'] = 'uploads'
36
 
37
  try:
38
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
39
- print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
40
  except Exception as e:
41
- print(f"Failed to create upload folder {app.config['UPLOAD_FOLDER']}: {e}")
42
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
43
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
44
- print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
45
 
46
- rag_chains = {}
47
  message_histories = {}
48
- doc_stores = {} # To hold the InMemoryStore for each session
49
 
50
- print("Loading embedding model...")
51
  try:
52
- hf_token = os.getenv("HF_TOKEN")
53
  EMBEDDING_MODEL = HuggingFaceEmbeddings(
54
- model_name="google/embeddinggemma-300m",
55
  model_kwargs={'device': 'cpu'},
56
- encode_kwargs={'normalize_embeddings': True},
57
  )
58
- print("Embedding model loaded successfully.")
59
- except Exception as e:
60
- print(f"FATAL: Could not load embedding model. Error: {e}")
61
- raise
 
 
 
 
62
 
63
  def load_pdf_with_fallback(filepath):
64
  try:
@@ -66,51 +153,40 @@ def load_pdf_with_fallback(filepath):
66
  with fitz.open(filepath) as pdf_doc:
67
  for page_num, page in enumerate(pdf_doc):
68
  text = page.get_text()
69
- if text.strip():
70
- docs.append(Document(
71
- page_content=text,
72
- metadata={
73
- "source": os.path.basename(filepath),
74
- "page": page_num + 1,
75
- }
76
- ))
77
  if docs:
78
- print(f"Successfully loaded PDF with PyMuPDF: {filepath}")
79
  return docs
80
- else:
81
- raise ValueError("No text content found in PDF.")
82
- except Exception as e:
83
- print(f"PyMuPDF failed for {filepath}: {e}")
84
- raise
85
 
86
- LOADER_MAPPING = {
87
- ".txt": TextLoader,
88
- ".pdf": load_pdf_with_fallback,
89
- ".docx": Docx2txtLoader,
90
- }
91
 
92
  def get_session_history(session_id: str) -> ChatMessageHistory:
93
- if session_id not in message_histories:
94
- message_histories[session_id] = ChatMessageHistory()
95
  return message_histories[session_id]
96
 
97
  @app.route('/health', methods=['GET'])
98
- def health_check():
99
- return jsonify({'status': 'healthy'}), 200
100
 
101
  @app.route('/', methods=['GET'])
102
- def index():
103
- return render_template('index.html')
104
 
105
  @app.route('/upload', methods=['POST'])
106
  def upload_files():
107
  files = request.files.getlist('file')
 
 
 
 
 
 
 
108
  if not files or all(f.filename == '' for f in files):
109
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
110
 
111
- all_docs = []
112
- processed_files, failed_files = [], []
113
-
114
  for file in files:
115
  if file and file.filename:
116
  filename = secure_filename(file.filename)
@@ -118,169 +194,128 @@ def upload_files():
118
  try:
119
  file.save(filepath)
120
  file_ext = os.path.splitext(filename)[1].lower()
121
- if file_ext not in LOADER_MAPPING:
122
- raise ValueError("Unsupported file format.")
123
-
124
  loader_func = LOADER_MAPPING[file_ext]
125
  docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
126
-
127
- if not docs:
128
- raise ValueError("No content extracted.")
129
-
130
  all_docs.extend(docs)
131
  processed_files.append(filename)
132
- print(f"✓ Successfully processed: {filename}")
133
  except Exception as e:
134
- error_msg = str(e)
135
- print(f"✗ Error processing {filename}: {error_msg}")
136
- failed_files.append(f"{filename} ({error_msg})")
137
 
138
  if not all_docs:
139
- error_summary = "Failed to process all files."
140
- if failed_files:
141
- error_summary += " Reasons: " + ", ".join(failed_files)
142
- return jsonify({'status': 'error', 'message': error_summary}), 400
143
 
 
144
  try:
145
- print("Starting RAG pipeline setup...")
146
-
147
- parent_splitter =RecursiveCharacterTextSplitter(chunk_size=1000,chunk_overlap=300,
148
- separators=["\n\n", "\n", ". ", " ", ""], # Prioritize natural breaks
149
- length_function=len)
150
- child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100)
151
-
152
- parent_docs = parent_splitter.split_documents(all_docs)
153
- doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
154
-
155
- child_docs = []
156
- for i, doc in enumerate(parent_docs):
157
- _id = doc_ids[i]
158
- sub_docs = child_splitter.split_documents([doc])
159
- for child in sub_docs:
160
- child.metadata["doc_id"] = _id
161
- child_docs.extend(sub_docs)
162
-
163
- store = InMemoryStore()
164
- store.mset(list(zip(doc_ids, parent_docs)))
165
 
166
  vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
167
-
168
- print(f"Stored {len(parent_docs)} parent docs and indexed {len(child_docs)} child docs.")
169
-
170
- bm25_retriever = BM25Retriever.from_documents(child_docs)
171
- bm25_retriever.k = 3
172
 
173
- faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
 
 
 
 
 
 
174
 
175
- ensemble_retriever = EnsembleRetriever(
176
- retrievers=[bm25_retriever, faiss_retriever],
177
- weights=[0.4, 0.6]
178
- )
179
- print("Created Hybrid Retriever for child documents.")
180
-
181
  session_id = str(uuid.uuid4())
 
 
 
 
 
 
182
 
183
- doc_stores[session_id] = store
 
184
 
185
- rag_chain_components = create_rag_chain(ensemble_retriever, get_session_history, EMBEDDING_MODEL, store)
 
186
 
187
- rag_chains[session_id] = rag_chain_components
188
- session['session_id'] = session_id
189
-
190
- success_msg = f"Successfully processed: {', '.join(processed_files)}"
191
- if failed_files:
192
- success_msg += f"\nFailed to process: {', '.join(failed_files)}"
193
-
194
  return jsonify({
195
  'status': 'success',
196
- 'filename': success_msg,
197
- 'session_id': session_id
 
 
198
  })
199
-
200
  except Exception as e:
201
- import traceback
202
- traceback.print_exc()
203
- return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
204
 
205
  @app.route('/chat', methods=['POST'])
206
  def chat():
207
  data = request.get_json()
208
- question = data.get('question')
209
- session_id = session.get('session_id') or data.get('session_id')
210
 
211
- if not question or not session_id or session_id not in rag_chains:
212
- return jsonify({'status': 'error', 'message': 'Invalid session or no question provided.'}), 400
 
 
213
 
214
  try:
215
- chain_components = rag_chains[session_id]
216
- config = {"configurable": {"session_id": session_id}}
217
 
218
- print("\n" + "="*50)
219
- print("--- STARTING DIAGNOSTIC RUN ---")
220
- print(f"Original Question: {question}")
221
- print("="*50 + "\n")
222
-
223
- rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages})
224
- #print(f"--- 1. Rewritten Query ---\n{rewritten_query}\n")
225
-
226
- hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query})
227
- #print(f"--- 2. HyDE Document ---\n{hyde_doc}\n")
228
 
229
- final_retrieved_docs = chain_components["base_retriever"].get_relevant_documents(hyde_doc)
230
- #print(f"--- 3. Retrieved Top {len(final_retrieved_docs)} Child Docs ---")
231
- #for i, doc in enumerate(final_retrieved_docs):
232
- #print(f" Doc {i+1}: {doc.page_content[:150]}... (Source: {doc.metadata.get('source')})")
233
- #print("\n")
234
-
235
- final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs)
236
- #print(f"--- 4. Final {len(final_context_docs)} Parent Docs for LLM ---")
237
- #for i, doc in enumerate(final_context_docs):
238
- #print(f" Final Doc {i+1} (Source: {doc.metadata.get('source')}, Page: {doc.metadata.get('page')}):\n '{doc.page_content[:300]}...'\n---")
239
 
240
- #print("="*50)
241
- #print("--- INVOKING FINAL CHAIN ---")
242
- #print("="*50 + "\n")
243
 
244
- answer_string = chain_components["final_chain"].invoke({"question": question}, config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- return jsonify({'answer': answer_string})
247
-
248
  except Exception as e:
249
- import traceback
250
- traceback.print_exc()
251
- return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
252
 
253
  def clean_markdown_for_tts(text: str) -> str:
254
- text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
255
- text = re.sub(r'\_(.*?)\_', r'\1', text)
256
- text = re.sub(r'`(.*?)`', r'\1', text)
257
- text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
258
- text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
259
- text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
260
- text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
261
- text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
262
  text = re.sub(r'\n+', ' ', text)
263
  return text.strip()
264
 
265
  @app.route('/tts', methods=['POST'])
266
  def text_to_speech():
267
- data = request.get_json()
268
- text = data.get('text')
269
-
270
- if not text:
271
- return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
272
-
273
  try:
274
- clean_text = clean_markdown_for_tts(text)
275
- tts = gTTS(clean_text, lang='en')
276
- mp3_fp = io.BytesIO()
277
- tts.write_to_fp(mp3_fp)
278
- mp3_fp.seek(0)
279
  return Response(mp3_fp, mimetype='audio/mpeg')
280
  except Exception as e:
281
- print(f"Error in TTS generation: {e}")
282
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
283
 
284
  if __name__ == '__main__':
285
  port = int(os.environ.get("PORT", 7860))
286
- app.run(host="0.0.0.0", port=port, debug=False)
 
 
1
  import os
 
2
  import uuid
3
  from flask import Flask, request, render_template, session, jsonify, Response
4
  from werkzeug.utils import secure_filename
 
9
  import io
10
  from gtts import gTTS
11
  from langchain_core.documents import Document
12
+ from langchain_community.document_loaders import TextLoader, Docx2txtLoader
 
 
 
 
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
  from langchain_community.vectorstores import FAISS
16
+ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
17
+ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
18
  from langchain_community.retrievers import BM25Retriever
19
  from langchain_community.chat_message_histories import ChatMessageHistory
20
  from langchain.storage import InMemoryStore
21
+ from sentence_transformers.cross_encoder import CrossEncoder
22
 
23
  app = Flask(__name__)
24
  app.config['SECRET_KEY'] = os.urandom(24)
25
 
26
+ # --- FIX: Use STRING keys for the dictionary ---
27
+ # Maps temperature strings (from the form) to the mode labels
28
+ TEMPERATURE_LABELS = {
29
+ "0.2": "Precise",
30
+ "0.4": "Confident",
31
+ "0.6": "Balanced",
32
+ "0.8": "Flexible",
33
+ "1.0": "Creative"
34
+ }
35
+
36
+ class LocalReranker(BaseDocumentCompressor):
37
+ model: Any
38
+ top_n: int = 5
39
+
40
+ class Config:
41
+ arbitrary_types_allowed = True
42
+
43
+ def compress_documents(
44
+ self, documents: Sequence[Document], query: str, callbacks=None
45
+ ) -> Sequence[Document]:
46
+ if not documents:
47
+ return []
48
+ pairs = [[query, doc.page_content] for doc in documents]
49
+ scores = self.model.predict(pairs, show_progress_bar=False)
50
+ doc_scores = list(zip(documents, scores))
51
+ sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
52
+ top_docs = []
53
+ for doc, score in sorted_doc_scores[: self.top_n]:
54
+ doc.metadata["rerank_score"] = float(score)
55
+ top_docs.append(doc)
56
+ return top_docs
57
+
58
+ def create_optimized_parent_child_chunks(all_docs):
59
+ if not all_docs:
60
+ print("❌ CHUNKING: No input documents provided!")
61
+ return [], [], []
62
+
63
+ parent_splitter = RecursiveCharacterTextSplitter(
64
+ chunk_size=900, chunk_overlap=200, separators=["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
65
+ )
66
+ child_splitter = RecursiveCharacterTextSplitter(
67
+ chunk_size=350, chunk_overlap=80, separators=["\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
68
+ )
69
+ parent_docs = parent_splitter.split_documents(all_docs)
70
+ doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
71
+ child_docs = []
72
+
73
+ for i, parent_doc in enumerate(parent_docs):
74
+ parent_id = doc_ids[i]
75
+ children = child_splitter.split_documents([parent_doc])
76
+ for j, child in enumerate(children):
77
+ child.metadata.update({
78
+ "doc_id": parent_id, "chunk_index": j, "total_chunks": len(children),
79
+ "is_first_chunk": j == 0, "is_last_chunk": j == len(children) - 1,
80
+ })
81
+ if len(children) > 1:
82
+ if j == 0: child.page_content = "[Beginning] " + child.page_content
83
+ elif j == len(children) - 1: child.page_content = "[Continues...] " + child.page_content
84
+ child_docs.append(child)
85
+
86
+ print(f"✅ CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks.")
87
+ return parent_docs, child_docs, doc_ids
88
+
89
+ def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
90
+ if not docs: return []
91
+ parent_scores, child_content_by_parent = {}, {}
92
+ for doc in docs:
93
+ parent_id = doc.metadata.get("doc_id")
94
+ if parent_id:
95
+ parent_scores[parent_id] = parent_scores.get(parent_id, 0) + 1
96
+ if parent_id not in child_content_by_parent: child_content_by_parent[parent_id] = []
97
+ child_content_by_parent[parent_id].append(doc.page_content)
98
+
99
+ parent_ids = list(parent_scores.keys())
100
+ parents = store.mget(parent_ids)
101
+ enhanced_parents = []
102
+
103
+ for i, parent in enumerate(parents):
104
+ if parent is not None:
105
+ parent_id = parent_ids[i]
106
+ if parent_id in child_content_by_parent:
107
+ child_excerpts = "\n".join(child_content_by_parent[parent_id][:3])
108
+ enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
109
+ enhanced_parent = Document(
110
+ page_content=enhanced_content,
111
+ metadata={**parent.metadata, "child_relevance_score": parent_scores[parent_id], "matching_children": len(child_content_by_parent[parent_id])}
112
+ )
113
+ enhanced_parents.append(enhanced_parent)
114
+ else:
115
+ print(f"❌ PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
116
+
117
+ enhanced_parents.sort(key=lambda p: p.metadata.get("child_relevance_score", 0), reverse=True)
118
+ return enhanced_parents
119
+
120
  is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
121
+ app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
 
 
 
122
 
123
  try:
124
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
125
+ print(f"📁 Upload folder ready: {app.config['UPLOAD_FOLDER']}")
126
  except Exception as e:
127
+ print(f"Failed to create upload folder, falling back to /tmp: {e}")
128
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
129
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
 
130
 
131
+ session_data = {}
132
  message_histories = {}
 
133
 
134
+ print("🔄 Loading embedding model...")
135
  try:
 
136
  EMBEDDING_MODEL = HuggingFaceEmbeddings(
137
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
138
  model_kwargs={'device': 'cpu'},
139
+ encode_kwargs={'normalize_embeddings': True}
140
  )
141
+ print("Embedding model loaded.")
142
+ except Exception as e: print(f"❌ FATAL: Could not load embedding model. Error: {e}"); raise e
143
+
144
+ print("🔄 Loading reranker model...")
145
+ try:
146
+ RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
147
+ print("✅ Reranker model loaded.")
148
+ except Exception as e: print(f"❌ FATAL: Could not load reranker model. Error: {e}"); raise e
149
 
150
  def load_pdf_with_fallback(filepath):
151
  try:
 
153
  with fitz.open(filepath) as pdf_doc:
154
  for page_num, page in enumerate(pdf_doc):
155
  text = page.get_text()
156
+ if text.strip(): docs.append(Document(page_content=text, metadata={"source": os.path.basename(filepath), "page": page_num + 1}))
 
 
 
 
 
 
 
157
  if docs:
158
+ print(f" Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages")
159
  return docs
160
+ else: raise ValueError("No text content found in PDF.")
161
+ except Exception as e: print(f" PyMuPDF failed for {filepath}: {e}"); raise
 
 
 
162
 
163
+ LOADER_MAPPING = {".txt": TextLoader, ".pdf": load_pdf_with_fallback, ".docx": Docx2txtLoader}
 
 
 
 
164
 
165
  def get_session_history(session_id: str) -> ChatMessageHistory:
166
+ if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory()
 
167
  return message_histories[session_id]
168
 
169
  @app.route('/health', methods=['GET'])
170
+ def health_check(): return jsonify({'status': 'healthy'}), 200
 
171
 
172
  @app.route('/', methods=['GET'])
173
+ def index(): return render_template('index.html')
 
174
 
175
  @app.route('/upload', methods=['POST'])
176
  def upload_files():
177
  files = request.files.getlist('file')
178
+
179
+ # Get temperature as a string for the dictionary key
180
+ temperature_str = request.form.get('temperature', '0.2')
181
+ temperature = float(temperature_str) # Convert to float for the LLM
182
+ model_name = request.form.get('model_name', 'moonshotai/kimi-k2-instruct')
183
+ print(f"⚙️ UPLOAD: Model: {model_name}, Temp: {temperature}")
184
+
185
  if not files or all(f.filename == '' for f in files):
186
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
187
 
188
+ all_docs, processed_files, failed_files = [], [], []
189
+ print(f"📁 Processing {len(files)} file(s)...")
 
190
  for file in files:
191
  if file and file.filename:
192
  filename = secure_filename(file.filename)
 
194
  try:
195
  file.save(filepath)
196
  file_ext = os.path.splitext(filename)[1].lower()
197
+ if file_ext not in LOADER_MAPPING: raise ValueError("Unsupported file format.")
 
 
198
  loader_func = LOADER_MAPPING[file_ext]
199
  docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
200
+ if not docs: raise ValueError("No content extracted.")
 
 
 
201
  all_docs.extend(docs)
202
  processed_files.append(filename)
 
203
  except Exception as e:
204
+ print(f"✗ Error processing {filename}: {e}")
205
+ failed_files.append(f"{filename} ({e})")
 
206
 
207
  if not all_docs:
208
+ return jsonify({'status': 'error', 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"}), 400
 
 
 
209
 
210
+ print(f"✅ UPLOAD: Processed {len(processed_files)} files.")
211
  try:
212
+ print("🔄 Starting RAG pipeline setup...")
213
+ parent_docs, child_docs, doc_ids = create_optimized_parent_child_chunks(all_docs)
214
+ if not child_docs: raise ValueError("No child documents created during chunking.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
217
+ store = InMemoryStore(); store.mset(list(zip(doc_ids, parent_docs)))
218
+ print(f" Indexed {len(child_docs)} document chunks.")
 
 
 
219
 
220
+ bm25_retriever = BM25Retriever.from_documents(child_docs); bm25_retriever.k = 12
221
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 12})
222
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4])
223
+ reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
224
+ def get_parents(docs: List[Document]) -> List[Document]: return get_context_aware_parents(docs, store)
225
+ compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
226
+ final_retriever = compression_retriever | get_parents
227
 
 
 
 
 
 
 
228
  session_id = str(uuid.uuid4())
229
+ rag_chain, api_key_manager = create_rag_chain(
230
+ retriever=final_retriever, get_session_history_func=get_session_history,
231
+ model_name=model_name, temperature=temperature
232
+ )
233
+ # Store the float temperature
234
+ session_data[session_id] = {'chain': rag_chain, 'model_name': model_name, 'temperature': temperature, 'api_key_manager': api_key_manager}
235
 
236
+ success_msg = f"Processed: {', '.join(processed_files)}"
237
+ if failed_files: success_msg += f". Failed: {', '.join(failed_files)}"
238
 
239
+ # Get the mode label using the STRING key
240
+ mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
241
 
242
+ print(f"✅ UPLOAD COMPLETE: Session {session_id} is ready.")
243
+ # Return all info needed by the frontend
 
 
 
 
 
244
  return jsonify({
245
  'status': 'success',
246
+ 'filename': success_msg,
247
+ 'session_id': session_id,
248
+ 'model_name': model_name,
249
+ 'mode': mode_label
250
  })
 
251
  except Exception as e:
252
+ import traceback; traceback.print_exc()
253
+ return jsonify({'status': 'error', 'message': f'RAG setup failed: {e}'}), 500
 
254
 
255
  @app.route('/chat', methods=['POST'])
256
  def chat():
257
  data = request.get_json()
258
+ question, session_id = data.get('question'), data.get('session_id') or session.get('session_id')
 
259
 
260
+ if not question: return jsonify({'status': 'error', 'message': 'No question provided.'}), 400
261
+ if not session_id or session_id not in session_data:
262
+ print(f"❌ CHAT: Invalid session {session_id}.")
263
+ return jsonify({'status': 'error', 'message': 'Invalid session. Please upload documents first.'}), 400
264
 
265
  try:
266
+ session_info = session_data[session_id]
267
+ rag_chain = session_info['chain']
268
 
269
+ # --- START: BUGFIX & FEATURE UPDATE ---
 
 
 
 
 
 
 
 
 
270
 
271
+ # 1. Get model name from session
272
+ model_name = session_info['model_name']
 
 
 
 
 
 
 
 
273
 
274
+ # 2. Get temperature (float) and convert to string for lookup
275
+ temperature_float = session_info['temperature']
276
+ temperature_str = str(temperature_float)
277
 
278
+ # 3. Get the correct mode label
279
+ mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
280
+
281
+ # --- END: BUGFIX & FEATURE UPDATE ---
282
+
283
+ print(f"💬 CHAT: Invoking chain for session {session_id}...")
284
+ answer = rag_chain.invoke({"question": question}, config={"configurable": {"session_id": session_id}})
285
+ print(f"✅ CHAT: Answer generated.")
286
+
287
+ # Return all info needed by the frontend
288
+ return jsonify({
289
+ 'answer': answer,
290
+ 'model_name': model_name,
291
+ 'mode': mode_label
292
+ })
293
 
 
 
294
  except Exception as e:
295
+ import traceback; traceback.print_exc()
296
+ return jsonify({'status': 'error', 'message': f'Error during chat: {e}'}), 500
 
297
 
298
  def clean_markdown_for_tts(text: str) -> str:
299
+ text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text); text = re.sub(r'\_(.*?)\_', r'\1', text)
300
+ text = re.sub(r'`(.*?)`', r'\1', text); text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
301
+ text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
302
+ text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
 
 
 
 
303
  text = re.sub(r'\n+', ' ', text)
304
  return text.strip()
305
 
306
  @app.route('/tts', methods=['POST'])
307
  def text_to_speech():
308
+ data = request.get_json(); text = data.get('text')
309
+ if not text: return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
 
 
 
 
310
  try:
311
+ clean_text = clean_markdown_for_tts(text); tts = gTTS(clean_text, lang='en')
312
+ mp3_fp = io.BytesIO(); tts.write_to_fp(mp3_fp); mp3_fp.seek(0)
 
 
 
313
  return Response(mp3_fp, mimetype='audio/mpeg')
314
  except Exception as e:
315
+ print(f" TTS Error: {e}")
316
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
317
 
318
  if __name__ == '__main__':
319
  port = int(os.environ.get("PORT", 7860))
320
+ print(f"🚀 Starting Flask app on port {port}")
321
+ app.run(host="0.0.0.0", port=port, debug=False, threaded=False)
evaluate.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import uuid
4
+ from dotenv import load_dotenv
5
+ from datasets import Dataset
6
+ import pandas as pd
7
+ from typing import Sequence, Any, List
8
+
9
+ # Ragas and LangChain components
10
+ from ragas import evaluate
11
+ from ragas.metrics import (
12
+ faithfulness,
13
+ answer_relevancy,
14
+ context_recall,
15
+ context_precision,
16
+ )
17
+ from ragas.testset import TestsetGenerator
18
+ # NOTE: The 'evolutions' import has been completely removed.
19
+
20
+ # Your specific RAG components from app.py
21
+ from langchain_groq import ChatGroq
22
+ from langchain_community.document_loaders import PyMuPDFLoader
23
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
24
+ from langchain_huggingface import HuggingFaceEmbeddings
25
+ from langchain_community.vectorstores import FAISS
26
+ from langchain.storage import InMemoryStore
27
+ from langchain_community.retrievers import BM25Retriever
28
+ from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
29
+ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
30
+ from langchain_core.documents import Document
31
+ from sentence_transformers.cross_encoder import CrossEncoder
32
+ from rag_processor import create_rag_chain
33
+ from langchain_community.chat_message_histories import ChatMessageHistory
34
+ import fitz
35
+
36
+ # Load environment variables
37
+ load_dotenv()
38
+
39
+ # --- Re-implementing LocalReranker from app.py ---
40
+ class LocalReranker(BaseDocumentCompressor):
41
+ model: Any
42
+ top_n: int = 3
43
+ class Config:
44
+ arbitrary_types_allowed = True
45
+ def compress_documents(self, documents: Sequence[Document], query: str, callbacks=None) -> Sequence[Document]:
46
+ if not documents: return []
47
+ pairs = [[query, doc.page_content] for doc in documents]
48
+ scores = self.model.predict(pairs, show_progress_bar=False)
49
+ doc_scores = list(zip(documents, scores))
50
+ sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
51
+ top_docs = []
52
+ for doc, score in sorted_doc_scores[:self.top_n]:
53
+ doc.metadata['rerank_score'] = float(score)
54
+ top_docs.append(doc)
55
+ return top_docs
56
+
57
+ # --- Helper Functions ---
58
+ def load_pdf_with_fallback(filepath):
59
+ """Load PDF using PyMuPDF"""
60
+ try:
61
+ docs = []
62
+ with fitz.open(filepath) as pdf_doc:
63
+ for page_num, page in enumerate(pdf_doc):
64
+ text = page.get_text()
65
+ if text.strip():
66
+ docs.append(Document(
67
+ page_content=text,
68
+ metadata={"source": os.path.basename(filepath), "page": page_num + 1}
69
+ ))
70
+ if docs:
71
+ print(f"✓ Successfully loaded PDF: {filepath}")
72
+ return docs
73
+ else:
74
+ raise ValueError("No text content found in PDF.")
75
+ except Exception as e:
76
+ print(f"✗ PyMuPDF failed for {filepath}: {e}")
77
+ raise
78
+
79
+ async def main():
80
+ """Main execution function"""
81
+ print("\n" + "="*60 + "\nSTARTING RAGAS EVALUATION\n" + "="*60)
82
+
83
+ pdf_path = "uploads/Unit_-_1_Introduction.pdf"
84
+ if not os.path.exists(pdf_path):
85
+ print(f"✗ Error: PDF not found at {pdf_path}")
86
+ return
87
+
88
+ try:
89
+ # --- 1. Setup Models ---
90
+ print("\n--- 1. Initializing Models ---")
91
+ groq_api_key = os.getenv("GROQ_API_KEY")
92
+ if not groq_api_key or groq_api_key == "your_groq_api_key_here":
93
+ raise ValueError("GROQ_API_KEY not found or is a placeholder.")
94
+
95
+ generator_llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_api_key)
96
+ critic_llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
97
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
98
+ reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
99
+ print("✓ Models initialized.")
100
+
101
+ # --- 2. Setup RAG Pipeline ---
102
+ print("\n--- 2. Setting up RAG Pipeline ---")
103
+ documents = load_pdf_with_fallback(pdf_path)
104
+
105
+ # Split documents
106
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=400)
107
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50)
108
+ parent_docs = parent_splitter.split_documents(documents)
109
+ doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
110
+
111
+ child_docs = []
112
+ for i, doc in enumerate(parent_docs):
113
+ _id = doc_ids[i]
114
+ sub_docs = child_splitter.split_documents([doc])
115
+ for child in sub_docs:
116
+ child.metadata["doc_id"] = _id
117
+ child_docs.extend(sub_docs)
118
+
119
+ store = InMemoryStore()
120
+ store.mset(list(zip(doc_ids, parent_docs)))
121
+ vectorstore = FAISS.from_documents(child_docs, embedding_model)
122
+
123
+ bm25_retriever = BM25Retriever.from_documents(child_docs, k=10)
124
+ faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
125
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.4, 0.6])
126
+
127
+ reranker = LocalReranker(model=reranker_model, top_n=5)
128
+ compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
129
+
130
+ def get_parents(docs: List[Document]) -> List[Document]:
131
+ parent_ids = {d.metadata["doc_id"] for d in docs}
132
+ return store.mget(list(parent_ids))
133
+
134
+ final_retriever = compression_retriever | get_parents
135
+
136
+ message_histories = {}
137
+ def get_session_history(session_id: str):
138
+ if session_id not in message_histories:
139
+ message_histories[session_id] = ChatMessageHistory()
140
+ return message_histories[session_id]
141
+
142
+ rag_chain = create_rag_chain(final_retriever, get_session_history)
143
+ print("✓ RAG chain created successfully.")
144
+
145
+ # --- 3. Generate Testset ---
146
+ print("\n--- 3. Generating Test Questions ---")
147
+ generator = TestsetGenerator.from_langchain(generator_llm, critic_llm, embedding_model)
148
+
149
+ # Generate a simple test set without complex distributions
150
+ testset = generator.generate_with_langchain_docs(documents, testset_size=5)
151
+ print("✓ Testset generated.")
152
+
153
+ # --- 4. Run RAG Chain on Testset ---
154
+ print("\n--- 4. Running RAG Chain to Generate Answers ---")
155
+ test_questions = [item['question'] for item in testset.to_pandas().to_dict('records')]
156
+ ground_truths = [item['ground_truth'] for item in testset.to_pandas().to_dict('records')]
157
+
158
+ answers = []
159
+ contexts = []
160
+
161
+ for i, question in enumerate(test_questions):
162
+ print(f" Processing question {i+1}/{len(test_questions)}...")
163
+ # Retrieve contexts
164
+ retrieved_docs = final_retriever.invoke(question)
165
+ contexts.append([doc.page_content for doc in retrieved_docs])
166
+ # Get answer from chain
167
+ config = {"configurable": {"session_id": str(uuid.uuid4())}}
168
+ answer = await rag_chain.ainvoke({"question": question}, config=config)
169
+ answers.append(answer)
170
+
171
+ # --- 5. Evaluate with Ragas ---
172
+ print("\n--- 5. Evaluating Results with Ragas ---")
173
+ eval_data = {
174
+ 'question': test_questions,
175
+ 'answer': answers,
176
+ 'contexts': contexts,
177
+ 'ground_truth': ground_truths
178
+ }
179
+ eval_dataset = Dataset.from_dict(eval_data)
180
+
181
+ result = evaluate(
182
+ eval_dataset,
183
+ metrics=[faithfulness, answer_relevancy, context_precision, context_recall],
184
+ llm=critic_llm,
185
+ embeddings=embedding_model
186
+ )
187
+
188
+ print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
189
+ print(result)
190
+
191
+ # --- 6. Save Results ---
192
+ print("\n--- 6. Saving Results ---")
193
+ results_df = result.to_pandas()
194
+ results_df.to_csv("evaluation_results.csv", index=False)
195
+ print("✓ Evaluation results saved to evaluation_results.csv")
196
+
197
+ print("\n" + "="*60 + "\nEVALUATION COMPLETE!\n" + "="*60)
198
+
199
+ except Exception as e:
200
+ print(f"\n✗ An error occurred during the process: {e}")
201
+ import traceback
202
+ traceback.print_exc()
203
+
204
+ if __name__ == "__main__":
205
+ asyncio.run(main())
query_expansion.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/query_expansion.py
2
+
3
+ """
4
+ Query Expansion System for CogniChat RAG Application
5
+
6
+ This module implements advanced query expansion techniques to improve retrieval quality:
7
+ - QueryAnalyzer: Extracts intent, entities, and keywords
8
+ - QueryRephraser: Generates natural language variations
9
+ - MultiQueryExpander: Creates diverse query formulations
10
+ - MultiHopReasoner: Connects concepts across documents
11
+ - FallbackStrategies: Handles edge cases gracefully
12
+
13
+ Author: CogniChat Team
14
+ Date: October 19, 2025
15
+ """
16
+
17
+ import re
18
+ from typing import List, Dict, Any, Optional
19
+ from dataclasses import dataclass
20
+ from enum import Enum
21
+
22
+
23
+ class QueryStrategy(Enum):
24
+ """Query expansion strategies with different complexity levels."""
25
+ QUICK = "quick" # 2 queries - fast, minimal expansion
26
+ BALANCED = "balanced" # 3-4 queries - good balance
27
+ COMPREHENSIVE = "comprehensive" # 5-6 queries - maximum coverage
28
+
29
+
30
+ @dataclass
31
+ class QueryAnalysis:
32
+ """Results from query analysis."""
33
+ intent: str # question, definition, comparison, explanation, etc.
34
+ entities: List[str] # Named entities extracted
35
+ keywords: List[str] # Important keywords
36
+ complexity: str # simple, medium, complex
37
+ domain: Optional[str] = None # Technical domain if detected
38
+
39
+
40
+ @dataclass
41
+ class ExpandedQuery:
42
+ """Container for expanded query variations."""
43
+ original: str
44
+ variations: List[str]
45
+ strategy_used: QueryStrategy
46
+ analysis: QueryAnalysis
47
+
48
+
49
+ class QueryAnalyzer:
50
+ """
51
+ Analyzes queries to extract intent, entities, and key information.
52
+ Uses LLM-based analysis for intelligent query understanding.
53
+ """
54
+
55
+ def __init__(self, llm=None):
56
+ """
57
+ Initialize QueryAnalyzer.
58
+
59
+ Args:
60
+ llm: Optional LangChain LLM for advanced analysis
61
+ """
62
+ self.llm = llm
63
+ self.intent_patterns = {
64
+ 'definition': r'\b(what is|define|meaning of|definition)\b',
65
+ 'how_to': r'\b(how to|how do|how can|steps to)\b',
66
+ 'comparison': r'\b(compare|difference|versus|vs|better than)\b',
67
+ 'explanation': r'\b(why|explain|reason|cause)\b',
68
+ 'listing': r'\b(list|enumerate|what are|types of)\b',
69
+ 'example': r'\b(example|instance|sample|case)\b',
70
+ }
71
+
72
+ def analyze(self, query: str) -> QueryAnalysis:
73
+ """
74
+ Analyze query to extract intent, entities, and keywords.
75
+
76
+ Args:
77
+ query: User's original query
78
+
79
+ Returns:
80
+ QueryAnalysis object with extracted information
81
+ """
82
+ query_lower = query.lower()
83
+
84
+ # Detect intent
85
+ intent = self._detect_intent(query_lower)
86
+
87
+ # Extract entities (simplified - can be enhanced with NER)
88
+ entities = self._extract_entities(query)
89
+
90
+ # Extract keywords
91
+ keywords = self._extract_keywords(query)
92
+
93
+ # Assess complexity
94
+ complexity = self._assess_complexity(query, entities, keywords)
95
+
96
+ # Detect domain
97
+ domain = self._detect_domain(query_lower)
98
+
99
+ return QueryAnalysis(
100
+ intent=intent,
101
+ entities=entities,
102
+ keywords=keywords,
103
+ complexity=complexity,
104
+ domain=domain
105
+ )
106
+
107
+ def _detect_intent(self, query_lower: str) -> str:
108
+ """Detect query intent using pattern matching."""
109
+ for intent, pattern in self.intent_patterns.items():
110
+ if re.search(pattern, query_lower):
111
+ return intent
112
+ return 'general'
113
+
114
+ def _extract_entities(self, query: str) -> List[str]:
115
+ """Extract named entities (simplified version)."""
116
+ # Look for capitalized words (potential entities)
117
+ words = query.split()
118
+ entities = []
119
+
120
+ for word in words:
121
+ # Skip common words at sentence start
122
+ if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']:
123
+ entities.append(word)
124
+
125
+ # Look for quoted terms
126
+ quoted = re.findall(r'"([^"]+)"', query)
127
+ entities.extend(quoted)
128
+
129
+ return list(set(entities))
130
+
131
+ def _extract_keywords(self, query: str) -> List[str]:
132
+ """Extract important keywords from query."""
133
+ # Remove stop words (simplified list)
134
+ stop_words = {
135
+ 'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
136
+ 'what', 'how', 'why', 'when', 'where', 'which', 'who',
137
+ 'do', 'does', 'did', 'can', 'could', 'should', 'would',
138
+ 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'
139
+ }
140
+
141
+ # Split and filter
142
+ words = re.findall(r'\b\w+\b', query.lower())
143
+ keywords = [w for w in words if w not in stop_words and len(w) > 2]
144
+
145
+ return keywords[:10] # Limit to top 10
146
+
147
+ def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str:
148
+ """Assess query complexity."""
149
+ word_count = len(query.split())
150
+ entity_count = len(entities)
151
+ keyword_count = len(keywords)
152
+
153
+ # Simple scoring
154
+ score = word_count + (entity_count * 2) + (keyword_count * 1.5)
155
+
156
+ if score < 15:
157
+ return 'simple'
158
+ elif score < 30:
159
+ return 'medium'
160
+ else:
161
+ return 'complex'
162
+
163
+ def _detect_domain(self, query_lower: str) -> Optional[str]:
164
+ """Detect technical domain if present."""
165
+ domains = {
166
+ 'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'],
167
+ 'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'],
168
+ 'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'],
169
+ 'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'],
170
+ 'database': ['sql', 'query', 'database', 'table', 'index'],
171
+ 'security': ['encryption', 'authentication', 'vulnerability', 'attack'],
172
+ }
173
+
174
+ for domain, keywords in domains.items():
175
+ if any(kw in query_lower for kw in keywords):
176
+ return domain
177
+
178
+ return None
179
+
180
+
181
+ class QueryRephraser:
182
+ """
183
+ Generates natural language variations of queries using multiple strategies.
184
+ """
185
+
186
+ def __init__(self, llm=None):
187
+ """
188
+ Initialize QueryRephraser.
189
+
190
+ Args:
191
+ llm: LangChain LLM for generating variations
192
+ """
193
+ self.llm = llm
194
+
195
+ def generate_variations(
196
+ self,
197
+ query: str,
198
+ analysis: QueryAnalysis,
199
+ strategy: QueryStrategy = QueryStrategy.BALANCED
200
+ ) -> List[str]:
201
+ """
202
+ Generate query variations based on strategy.
203
+
204
+ Args:
205
+ query: Original query
206
+ analysis: Query analysis results
207
+ strategy: Expansion strategy to use
208
+
209
+ Returns:
210
+ List of query variations
211
+ """
212
+ variations = [query] # Always include original
213
+
214
+ if strategy == QueryStrategy.QUICK:
215
+ # Just add synonym variation
216
+ variations.append(self._synonym_variation(query, analysis))
217
+
218
+ elif strategy == QueryStrategy.BALANCED:
219
+ # Add synonym, expanded, and simplified versions
220
+ variations.append(self._synonym_variation(query, analysis))
221
+ variations.append(self._expanded_variation(query, analysis))
222
+ variations.append(self._simplified_variation(query, analysis))
223
+
224
+ elif strategy == QueryStrategy.COMPREHENSIVE:
225
+ # Add all variations
226
+ variations.append(self._synonym_variation(query, analysis))
227
+ variations.append(self._expanded_variation(query, analysis))
228
+ variations.append(self._simplified_variation(query, analysis))
229
+ variations.append(self._keyword_focused(query, analysis))
230
+ variations.append(self._context_variation(query, analysis))
231
+ # Add one more: alternate phrasing
232
+ if analysis.intent in ['how_to', 'explanation']:
233
+ variations.append(f"Guide to {' '.join(analysis.keywords[:3])}")
234
+
235
+ # Remove duplicates and None values
236
+ variations = [v for v in variations if v]
237
+ return list(dict.fromkeys(variations)) # Preserve order, remove dupes
238
+
239
+ def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str:
240
+ """Generate variation using synonyms."""
241
+ # Common synonym replacements
242
+ synonyms = {
243
+ 'error': 'issue',
244
+ 'problem': 'issue',
245
+ 'fix': 'resolve',
246
+ 'use': 'utilize',
247
+ 'create': 'generate',
248
+ 'make': 'create',
249
+ 'get': 'retrieve',
250
+ 'show': 'display',
251
+ 'find': 'locate',
252
+ 'explain': 'describe',
253
+ }
254
+
255
+ words = query.lower().split()
256
+ for i, word in enumerate(words):
257
+ if word in synonyms:
258
+ words[i] = synonyms[word]
259
+ break # Only replace one word to keep natural
260
+
261
+ return ' '.join(words).capitalize()
262
+
263
+ def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str:
264
+ """Generate expanded version with more detail."""
265
+ if analysis.intent == 'definition':
266
+ return f"Detailed explanation and definition of {' '.join(analysis.keywords)}"
267
+ elif analysis.intent == 'how_to':
268
+ return f"Step-by-step guide on {query.lower()}"
269
+ elif analysis.intent == 'comparison':
270
+ return f"Comprehensive comparison: {query}"
271
+ else:
272
+ # Add qualifying words
273
+ return f"Detailed information about {query.lower()}"
274
+
275
+ def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str:
276
+ """Generate simplified version focusing on core concepts."""
277
+ # Use just the keywords
278
+ if len(analysis.keywords) >= 2:
279
+ return ' '.join(analysis.keywords[:3])
280
+ return query
281
+
282
+ def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str:
283
+ """Create keyword-focused variation for BM25."""
284
+ keywords = analysis.keywords + analysis.entities
285
+ return ' '.join(keywords[:5])
286
+
287
+ def _context_variation(self, query: str, analysis: QueryAnalysis) -> str:
288
+ """Add contextual information if domain detected."""
289
+ if analysis.domain:
290
+ return f"{query} in {analysis.domain} context"
291
+ return query
292
+
293
+
294
+ class MultiQueryExpander:
295
+ """
296
+ Main query expansion orchestrator that combines analysis and rephrasing.
297
+ """
298
+
299
+ def __init__(self, llm=None):
300
+ """
301
+ Initialize MultiQueryExpander.
302
+
303
+ Args:
304
+ llm: LangChain LLM for advanced expansions
305
+ """
306
+ self.analyzer = QueryAnalyzer(llm)
307
+ self.rephraser = QueryRephraser(llm)
308
+
309
+ def expand(
310
+ self,
311
+ query: str,
312
+ strategy: QueryStrategy = QueryStrategy.BALANCED,
313
+ max_queries: int = 6
314
+ ) -> ExpandedQuery:
315
+ """
316
+ Expand query into multiple variations.
317
+
318
+ Args:
319
+ query: Original user query
320
+ strategy: Expansion strategy
321
+ max_queries: Maximum number of queries to generate
322
+
323
+ Returns:
324
+ ExpandedQuery object with all variations
325
+ """
326
+ # Analyze query
327
+ analysis = self.analyzer.analyze(query)
328
+
329
+ # Generate variations
330
+ variations = self.rephraser.generate_variations(query, analysis, strategy)
331
+
332
+ # Limit to max_queries
333
+ variations = variations[:max_queries]
334
+
335
+ return ExpandedQuery(
336
+ original=query,
337
+ variations=variations,
338
+ strategy_used=strategy,
339
+ analysis=analysis
340
+ )
341
+
342
+
343
+ class MultiHopReasoner:
344
+ """
345
+ Implements multi-hop reasoning to connect concepts across documents.
346
+ Useful for complex queries that require information from multiple sources.
347
+ """
348
+
349
+ def __init__(self, llm=None):
350
+ """
351
+ Initialize MultiHopReasoner.
352
+
353
+ Args:
354
+ llm: LangChain LLM for reasoning
355
+ """
356
+ self.llm = llm
357
+
358
+ def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]:
359
+ """
360
+ Break complex query into sub-queries for multi-hop reasoning.
361
+
362
+ Args:
363
+ query: Original complex query
364
+ analysis: Query analysis
365
+
366
+ Returns:
367
+ List of sub-queries
368
+ """
369
+ sub_queries = [query]
370
+
371
+ # For comparison queries, create separate queries for each entity
372
+ if analysis.intent == 'comparison' and len(analysis.entities) >= 2:
373
+ for entity in analysis.entities[:2]:
374
+ sub_queries.append(f"Information about {entity}")
375
+ elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2:
376
+ # Fallback: use keywords if no entities found
377
+ for keyword in analysis.keywords[:2]:
378
+ sub_queries.append(f"Information about {keyword}")
379
+
380
+ # For how-to queries, break into steps
381
+ if analysis.intent == 'how_to' and len(analysis.keywords) >= 2:
382
+ main_topic = ' '.join(analysis.keywords[:2])
383
+ sub_queries.append(f"Prerequisites for {main_topic}")
384
+ sub_queries.append(f"Steps to {main_topic}")
385
+
386
+ # For complex questions, create focused sub-queries
387
+ if analysis.complexity == 'complex' and len(analysis.keywords) > 3:
388
+ # Create queries focusing on different keyword groups
389
+ mid = len(analysis.keywords) // 2
390
+ sub_queries.append(' '.join(analysis.keywords[:mid]))
391
+ sub_queries.append(' '.join(analysis.keywords[mid:]))
392
+
393
+ return sub_queries[:5] # Limit to 5 sub-queries
394
+
395
+
396
+ class FallbackStrategies:
397
+ """
398
+ Implements fallback strategies for queries that don't retrieve good results.
399
+ """
400
+
401
+ @staticmethod
402
+ def simplify_query(query: str) -> str:
403
+ """Simplify query by removing modifiers and focusing on core terms."""
404
+ # Remove question words
405
+ query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE)
406
+
407
+ # Remove common phrases
408
+ query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE)
409
+
410
+ # Clean up extra spaces
411
+ query = re.sub(r'\s+', ' ', query).strip()
412
+
413
+ return query
414
+
415
+ @staticmethod
416
+ def broaden_query(query: str, analysis: QueryAnalysis) -> str:
417
+ """Broaden query to increase recall."""
418
+ # Remove specific constraints
419
+ query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE)
420
+
421
+ # Add general terms
422
+ if analysis.keywords:
423
+ return f"{analysis.keywords[0]} overview"
424
+
425
+ return query
426
+
427
+ @staticmethod
428
+ def focus_entities(analysis: QueryAnalysis) -> str:
429
+ """Create entity-focused query as fallback."""
430
+ if analysis.entities:
431
+ return ' '.join(analysis.entities)
432
+ elif analysis.keywords:
433
+ return ' '.join(analysis.keywords[:3])
434
+ return ""
435
+
436
+
437
+ # Convenience function for easy integration
438
+ def expand_query_simple(
439
+ query: str,
440
+ strategy: str = "balanced",
441
+ llm=None
442
+ ) -> List[str]:
443
+ """
444
+ Simple function to expand a query without dealing with classes.
445
+
446
+ Args:
447
+ query: User's query to expand
448
+ strategy: "quick", "balanced", or "comprehensive"
449
+ llm: Optional LangChain LLM
450
+
451
+ Returns:
452
+ List of expanded query variations
453
+
454
+ Example:
455
+ >>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced")
456
+ >>> print(queries)
457
+ ['How do I debug Python code?', 'How do I resolve Python code?', ...]
458
+ """
459
+ expander = MultiQueryExpander(llm=llm)
460
+ strategy_enum = QueryStrategy(strategy)
461
+ expanded = expander.expand(query, strategy=strategy_enum)
462
+ return expanded.variations
463
+
464
+
465
+ # Example usage and testing
466
+ if __name__ == "__main__":
467
+ # Example 1: Simple query expansion
468
+ print("=" * 60)
469
+ print("Example 1: Simple Query Expansion")
470
+ print("=" * 60)
471
+
472
+ query = "What is machine learning?"
473
+ queries = expand_query_simple(query, strategy="balanced")
474
+
475
+ print(f"\nOriginal: {query}")
476
+ print(f"\nExpanded queries ({len(queries)}):")
477
+ for i, q in enumerate(queries, 1):
478
+ print(f" {i}. {q}")
479
+
480
+ # Example 2: Complex query with full analysis
481
+ print("\n" + "=" * 60)
482
+ print("Example 2: Complex Query with Analysis")
483
+ print("=" * 60)
484
+
485
+ expander = MultiQueryExpander()
486
+ query = "How do I compare the performance of different neural network architectures?"
487
+ result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE)
488
+
489
+ print(f"\nOriginal: {result.original}")
490
+ print(f"\nAnalysis:")
491
+ print(f" Intent: {result.analysis.intent}")
492
+ print(f" Entities: {result.analysis.entities}")
493
+ print(f" Keywords: {result.analysis.keywords}")
494
+ print(f" Complexity: {result.analysis.complexity}")
495
+ print(f" Domain: {result.analysis.domain}")
496
+ print(f"\nExpanded queries ({len(result.variations)}):")
497
+ for i, q in enumerate(result.variations, 1):
498
+ print(f" {i}. {q}")
499
+
500
+ # Example 3: Multi-hop reasoning
501
+ print("\n" + "=" * 60)
502
+ print("Example 3: Multi-Hop Reasoning")
503
+ print("=" * 60)
504
+
505
+ reasoner = MultiHopReasoner()
506
+ analyzer = QueryAnalyzer()
507
+
508
+ query = "Compare Python and Java for web development"
509
+ analysis = analyzer.analyze(query)
510
+ sub_queries = reasoner.generate_sub_queries(query, analysis)
511
+
512
+ print(f"\nOriginal: {query}")
513
+ print(f"\nSub-queries for multi-hop reasoning:")
514
+ for i, sq in enumerate(sub_queries, 1):
515
+ print(f" {i}. {sq}")
516
+
517
+ # Example 4: Fallback strategies
518
+ print("\n" + "=" * 60)
519
+ print("Example 4: Fallback Strategies")
520
+ print("=" * 60)
521
+
522
+ query = "What is the specific difference between supervised and unsupervised learning?"
523
+ analysis = analyzer.analyze(query)
524
+
525
+
rag_processor.py CHANGED
@@ -3,93 +3,479 @@ from dotenv import load_dotenv
3
  from operator import itemgetter
4
  from langchain_groq import ChatGroq
5
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
- from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
 
 
 
9
 
10
- def create_rag_chain(base_retriever, get_session_history_func, embedding_model, store):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
- Creates a dictionary of RAG chain components for inspection and a final runnable chain.
 
 
 
 
 
 
 
 
 
13
  """
14
- load_dotenv()
15
- api_key = os.getenv("GROQ_API_KEY")
16
- if not api_key or api_key == "your_groq_api_key_here":
17
- raise ValueError("GROQ_API_KEY not found or not configured properly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- llm = ChatGroq(model_name="moonshotai/kimi-k2-instruct-0905", api_key=api_key, temperature=0.1)
20
 
21
- # 1. HyDE-like Document Generation Chain
22
- hyde_template = """As a document expert, write a concise, fact-based paragraph that directly answers the user's question. This will be used for a database search.
23
- Question: {question}
24
- Hypothetical Answer:"""
25
- hyde_prompt = ChatPromptTemplate.from_template(hyde_template)
26
- hyde_chain = hyde_prompt | llm | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # 2. Query Rewriting Chain
29
- rewrite_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question that is optimized for a vector database.
30
 
31
- **Chat History:**
32
- {chat_history}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- **Follow-up Question:**
35
- {question}
36
 
37
- **Standalone Question:**"""
38
- rewrite_prompt = ChatPromptTemplate.from_messages([
39
- ("system", rewrite_template),
40
- MessagesPlaceholder(variable_name="chat_history"),
41
- ("human", "Reformulate this question as a standalone query: {question}")
42
- ])
43
- query_rewriter_chain = rewrite_prompt | llm | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # 3. Parent Document Fetching Chain
46
- def get_parents(docs):
47
- parent_ids = {d.metadata.get("doc_id") for d in docs}
48
- return store.mget(list(parent_ids))
49
 
50
- parent_fetcher_chain = RunnableLambda(get_parents)
51
 
52
- # 4. Main Conversational RAG Chain
53
- rag_template = """You are CogniChat, an expert document analysis assistant. Your task is to answer the user's question based *only* on the provided context.
54
 
55
- **Instructions:**
56
- 1. Read the context carefully.
57
- 2. If the answer is in the context, provide a clear and concise answer.
58
- 3. If the answer is not in the context, you *must* state that you cannot find the information in the provided documents. Do not use any external knowledge.
59
- 4. Where appropriate, use formatting like lists or bold text to improve readability.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- **Context:**
62
- {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  rag_prompt = ChatPromptTemplate.from_messages([
65
  ("system", rag_template),
66
  MessagesPlaceholder(variable_name="chat_history"),
67
  ("human", "{question}"),
68
  ])
69
 
70
- conversational_rag_chain = (
71
- RunnablePassthrough.assign(
72
- context=query_rewriter_chain | hyde_chain | base_retriever | parent_fetcher_chain
73
- )
74
- | rag_prompt
75
- | llm
76
- | StrOutputParser()
77
- )
78
-
79
- # 5. Final Chain with History (Simplified)
80
- final_chain = RunnableWithMessageHistory(
 
 
 
 
 
 
81
  conversational_rag_chain,
82
  get_session_history_func,
83
  input_messages_key="question",
84
  history_messages_key="chat_history",
85
  )
86
 
87
- print("\n✅ RAG chain and components successfully built.")
 
88
 
89
- return {
90
- "rewriter": query_rewriter_chain,
91
- "hyde": hyde_chain,
92
- "base_retriever": base_retriever,
93
- "parent_fetcher": parent_fetcher_chain,
94
- "final_chain": final_chain
95
- }
 
3
  from operator import itemgetter
4
  from langchain_groq import ChatGroq
5
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ from langchain_core.documents import Document
10
+ from query_expansion import expand_query_simple
11
+ from typing import List, Optional
12
+ import time
13
 
14
+ class GroqAPIKeyManager:
15
+ """Manages multiple Groq API keys with automatic rotation and fallback."""
16
+
17
+ def __init__(self, api_keys: List[str]):
18
+ """
19
+ Initialize with a list of API keys.
20
+
21
+ Args:
22
+ api_keys: List of Groq API keys to use
23
+ """
24
+ self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"]
25
+ if not self.api_keys:
26
+ raise ValueError("No valid API keys provided!")
27
+
28
+ self.current_index = 0
29
+ self.failed_keys = set()
30
+ self.success_count = {key: 0 for key in self.api_keys}
31
+ self.failure_count = {key: 0 for key in self.api_keys}
32
+
33
+ print(f"🔑 API Key Manager: Loaded {len(self.api_keys)} API keys")
34
+
35
+ def get_current_key(self) -> str:
36
+ """Get the current API key."""
37
+ return self.api_keys[self.current_index]
38
+
39
+ def mark_success(self, api_key: str):
40
+ """Mark an API key as successful."""
41
+ if api_key in self.success_count:
42
+ self.success_count[api_key] += 1
43
+ # Remove from failed keys if it was there
44
+ if api_key in self.failed_keys:
45
+ self.failed_keys.remove(api_key)
46
+ print(f" ✅ API Key #{self.api_keys.index(api_key) + 1} recovered!")
47
+
48
+ def mark_failure(self, api_key: str):
49
+ """Mark an API key as failed."""
50
+ if api_key in self.failure_count:
51
+ self.failure_count[api_key] += 1
52
+ self.failed_keys.add(api_key)
53
+
54
+ def rotate_to_next_key(self) -> bool:
55
+ """
56
+ Rotate to the next available API key.
57
+
58
+ Returns:
59
+ True if a new key is available, False if all keys failed
60
+ """
61
+ initial_index = self.current_index
62
+ attempts = 0
63
+
64
+ while attempts < len(self.api_keys):
65
+ self.current_index = (self.current_index + 1) % len(self.api_keys)
66
+ attempts += 1
67
+
68
+ current_key = self.api_keys[self.current_index]
69
+
70
+ # If we've tried all keys, allow retry even failed ones
71
+ if attempts >= len(self.api_keys):
72
+ print(f" ⚠️ All keys attempted, retrying with key #{self.current_index + 1}")
73
+ return True
74
+
75
+ # Skip recently failed keys unless it's been a while
76
+ if current_key not in self.failed_keys:
77
+ print(f" 🔄 Switching to API Key #{self.current_index + 1}")
78
+ return True
79
+
80
+ return False
81
+
82
+ def get_statistics(self) -> str:
83
+ """Get statistics about API key usage."""
84
+ stats = []
85
+ for i, key in enumerate(self.api_keys):
86
+ success = self.success_count[key]
87
+ failure = self.failure_count[key]
88
+ status = "❌ FAILED" if key in self.failed_keys else "✅ ACTIVE"
89
+ masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***"
90
+ stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]")
91
+ return "\n".join(stats)
92
+
93
+
94
+ def load_api_keys_from_hf_secrets() -> List[str]:
95
  """
96
+ Load API keys from Hugging Face Spaces Secrets.
97
+
98
+ In your Hugging Face Space settings, add these secrets:
99
+ - GROQ_API_KEY_1
100
+ - GROQ_API_KEY_2
101
+ - GROQ_API_KEY_3
102
+ - GROQ_API_KEY_4
103
+
104
+ Returns:
105
+ List of API keys retrieved from HF secrets
106
  """
107
+ api_keys = []
108
+ secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4']
109
+
110
+ print("🔐 Loading API keys from Hugging Face Secrets...")
111
+
112
+ for secret_name in secret_names:
113
+ try:
114
+ # HF Spaces secrets are available as environment variables
115
+ api_key = os.getenv(secret_name)
116
+
117
+ if api_key and api_key.strip() and api_key != "your_groq_api_key_here":
118
+ api_keys.append(api_key.strip())
119
+ print(f" ✅ Loaded: {secret_name}")
120
+ else:
121
+ print(f" ⚠️ Not found or empty: {secret_name}")
122
+ except Exception as e:
123
+ print(f" ❌ Error loading {secret_name}: {str(e)}")
124
+
125
+ # ADD THIS RETURN STATEMENT - this was missing!
126
+ return api_keys
127
 
 
128
 
129
+ def create_llm_with_fallback(
130
+ api_key_manager: GroqAPIKeyManager,
131
+ model_name: str,
132
+ temperature: float,
133
+ max_retries: int = 3
134
+ ) -> ChatGroq:
135
+ """
136
+ Create a ChatGroq LLM with automatic API key fallback.
137
+
138
+ Args:
139
+ api_key_manager: Manager handling multiple API keys
140
+ model_name: Name of the model to use
141
+ temperature: Temperature setting
142
+ max_retries: Maximum number of retry attempts
143
+
144
+ Returns:
145
+ ChatGroq instance
146
+ """
147
+ for attempt in range(max_retries):
148
+ current_key = api_key_manager.get_current_key()
149
+
150
+ try:
151
+ llm = ChatGroq(
152
+ model_name=model_name,
153
+ api_key=current_key,
154
+ temperature=temperature
155
+ )
156
+ # Test the connection with a simple call
157
+ test_result = llm.invoke("test")
158
+ api_key_manager.mark_success(current_key)
159
+ return llm
160
+
161
+ except Exception as e:
162
+ error_msg = str(e).lower()
163
+ api_key_manager.mark_failure(current_key)
164
+
165
+ # Check if it's a rate limit or auth error
166
+ if "rate" in error_msg or "limit" in error_msg:
167
+ print(f" ⚠️ Rate limit hit on API Key #{api_key_manager.current_index + 1}")
168
+ elif "auth" in error_msg or "api" in error_msg:
169
+ print(f" ❌ Authentication failed on API Key #{api_key_manager.current_index + 1}")
170
+ else:
171
+ print(f" ❌ Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}")
172
+
173
+ # Try next key if available
174
+ if attempt < max_retries - 1:
175
+ if api_key_manager.rotate_to_next_key():
176
+ print(f" 🔄 Retrying with next API key (Attempt {attempt + 2}/{max_retries})...")
177
+ time.sleep(1) # Brief pause before retry
178
+ else:
179
+ raise ValueError("All API keys failed!")
180
+ else:
181
+ raise ValueError(f"Failed to initialize LLM after {max_retries} attempts")
182
+
183
+ raise ValueError("Failed to create LLM with any available API key")
184
 
 
 
185
 
186
+ def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"):
187
+ """Wraps a base retriever with query expansion capabilities."""
188
+ def multi_query_retrieve(query: str) -> List[Document]:
189
+ """Retrieves documents using expanded query variations."""
190
+ query_variations = expand_query_simple(query, strategy=strategy, llm=llm)
191
+ all_docs = []
192
+ seen_content = set()
193
+ for i, query_var in enumerate(query_variations):
194
+ try:
195
+ docs = base_retriever.invoke(query_var)
196
+ for doc in docs:
197
+ content_hash = hash(doc.page_content)
198
+ if content_hash not in seen_content:
199
+ seen_content.add(content_hash)
200
+ all_docs.append(doc)
201
+ except Exception as e:
202
+ print(f" ✗ Query Expansion Error (Query {i+1}): {str(e)[:50]}")
203
+ continue
204
+ print(f" 📊 Query Expansion: Retrieved {len(all_docs)} unique documents.")
205
+ return all_docs
206
+ return multi_query_retrieve
207
 
 
 
208
 
209
+ def get_system_prompt(temperature: float) -> str:
210
+ """
211
+ Returns a system prompt dynamically based on temperature setting.
212
+
213
+ Temperature ranges:
214
+ - 0.0-0.4: Highly factual, structured, conservative
215
+ - 0.4-0.8: Balanced approach with moderate creativity
216
+ - 0.8-1.0: Creative, engaging, storytelling mode
217
+ """
218
+
219
+ if temperature <= 0.4:
220
+ # Conservative, structured prompt
221
+ return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers.
222
+
223
+ RESPONSE GUIDELINES:
224
+
225
+ **Structure & Formatting:**
226
+ - Start with a direct answer to the question
227
+ - Use **bold** for key terms, important concepts, and technical terminology
228
+ - Use bullet points (•) for lists, features, or multiple items
229
+ - Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information
230
+ - Use ### Headers to organize different sections or topics
231
+ - Add blank lines between sections for readability
232
+
233
+ **Source Citation:**
234
+ - Always cite information using: [Source: filename, Page: X]
235
+ - Place citations at the end of your final answer only
236
+ - Do not cite sources within the body of your answer
237
+ - Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7]
238
+
239
+ **Completeness:**
240
+ - Provide thorough, detailed answers using ALL relevant information from context
241
+ - Summarize and properly elaborate each point for increased clarity
242
+ - If the question has multiple parts, address each part clearly
243
+
244
+ **Accuracy:**
245
+ - ONLY use information from the provided context documents below
246
+ - If information is incomplete, state what IS available and what ISN'T
247
+ - If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents"
248
+ - Never make assumptions or add information not in the context
249
+
250
+ ---
251
+
252
+ {context}
253
+
254
+ ---
255
+
256
+ Now answer the following question comprehensively using the context above:"""
257
+
258
+ elif temperature <= 0.8:
259
+ # Balanced prompt
260
+ return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication.
261
+
262
+ RESPONSE GUIDELINES:
263
+
264
+ **Communication Style:**
265
+ - Present information in a clear, engaging manner
266
+ - Use **bold** for emphasis on important concepts
267
+ - Balance structure with natural flow
268
+ - Make complex topics accessible and interesting
269
+
270
+ **Content Approach:**
271
+ - Ground your response firmly in the provided context
272
+ - Add helpful explanations and connections between concepts
273
+ - Use analogies or examples when they help clarify ideas (but keep them brief)
274
+ - Organize information logically with headers (###) and lists where appropriate
275
+
276
+ **Source Attribution:**
277
+ - Cite sources at the end: [Source: filename, Page: X]
278
+ - Be transparent about what the documents do and don't contain
279
+
280
+ **Accuracy:**
281
+ - Base your answer on the context documents provided
282
+ - If information is partial, explain what's available
283
+ - Acknowledge gaps: "The documents don't cover this aspect"
284
 
285
+ ---
 
 
 
286
 
287
+ {context}
288
 
289
+ ---
 
290
 
291
+ Now answer the following question in an engaging yet accurate way:"""
292
+
293
+ else: # temperature > 0.8
294
+ # Creative, engaging prompt
295
+ return """You are CogniChat, a creative and insightful document analyst who transforms information into engaging, memorable experiences.
296
+
297
+ 🎨 CREATIVE RESPONSE GUIDELINES:
298
+
299
+ **Your Mission:**
300
+ Transform the document content into compelling, creative responses while staying true to the facts. Think of yourself as a skilled storyteller who brings information to life!
301
+
302
+ **Creative Techniques - Use Liberally:**
303
+ - **Vivid Language**: Use descriptive, evocative language that paints mental pictures
304
+ - **Analogies & Metaphors**: Create memorable comparisons that illuminate concepts
305
+ - **Narrative Flow**: Tell a story when appropriate - build tension, reveal insights progressively
306
+ - **Engaging Hooks**: Start with something intriguing that captures attention
307
+ - **Real-World Connections**: Bridge abstract concepts to tangible, relatable scenarios
308
+ - **Thought-Provoking Questions**: Pose rhetorical questions that spark curiosity
309
+ - **Dynamic Formatting**: Use varied structures - not just bullet points. Try prose paragraphs, short punchy sentences, strategic emphasis
310
+
311
+ **Creative Freedom:**
312
+ - Interpret and synthesize information creatively
313
+ - Make insightful connections between different pieces of information
314
+ - Present the same facts in novel, interesting ways
315
+ - Use formatting creatively: emojis (when appropriate), varied paragraph lengths, strategic **emphasis**
316
+ - Vary your tone based on content: enthusiastic for exciting topics, contemplative for complex ones
317
+
318
+ **Boundaries of Creativity:**
319
+ - ✅ Creative presentation, interpretation, and synthesis of facts
320
+ - ✅ Memorable analogies and explanatory examples
321
+ - ✅ Engaging narrative structure and compelling language
322
+ - ❌ Never invent facts not in the documents
323
+ - ❌ Don't contradict the source material
324
+ - ❌ Acknowledge when information isn't available (but do so creatively!)
325
 
326
+ **Source Attribution:**
327
+ - Weave citations naturally into your narrative
328
+ - End with: [Source: filename, Page: X]
329
+
330
+ ---
331
+
332
+ {context}
333
+
334
+ ---
335
+
336
+ Now, using your creative prowess and the context above, craft an engaging and memorable answer to this question:"""
337
+
338
+
339
+ def create_rag_chain(
340
+ retriever,
341
+ get_session_history_func,
342
+ enable_query_expansion=True,
343
+ expansion_strategy="balanced",
344
+ model_name: str = "moonshotai/kimi-k2-instruct",
345
+ temperature: float = 0.2,
346
+ api_keys: Optional[List[str]] = None
347
+ ):
348
+ """
349
+ Creates an advanced RAG chain with temperature-adaptive prompting and API key rotation.
350
+
351
+ Args:
352
+ retriever: Document retriever
353
+ get_session_history_func: Function to get session history
354
+ enable_query_expansion: Whether to enable query expansion
355
+ expansion_strategy: Strategy for query expansion
356
+ model_name: Name of the LLM model
357
+ temperature: Temperature setting (0.0-1.0)
358
+ api_keys: Optional list of API keys. If None, loads from environment
359
  """
360
+
361
+ # Load API keys from HF Secrets
362
+ if api_keys is None:
363
+ api_keys = load_api_keys_from_hf_secrets()
364
+
365
+ if not api_keys:
366
+ raise ValueError(
367
+ "No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, "
368
+ "GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file"
369
+ )
370
+
371
+ # Initialize API key manager
372
+ api_key_manager = GroqAPIKeyManager(api_keys)
373
+
374
+ print(f"⚙️ RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}")
375
+
376
+ # Display creativity mode based on temperature
377
+ if temperature <= 0.4:
378
+ creativity_mode = "FACTUAL & STRUCTURED"
379
+ elif temperature <= 0.8:
380
+ creativity_mode = "BALANCED & ENGAGING"
381
+ else:
382
+ creativity_mode = "CREATIVE & STORYTELLING"
383
+ print(f"🎭 Creativity Mode: {creativity_mode}")
384
+
385
+ # Create LLM with fallback
386
+ llm = create_llm_with_fallback(api_key_manager, model_name, temperature)
387
+ print(f"✅ LLM initialized with API Key #{api_key_manager.current_index + 1}")
388
+
389
+ if enable_query_expansion:
390
+ print(f"✨ RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})")
391
+ enhanced_retriever = create_multi_query_retriever(
392
+ base_retriever=retriever,
393
+ llm=llm,
394
+ strategy=expansion_strategy
395
+ )
396
+ else:
397
+ enhanced_retriever = retriever
398
+
399
+ rewrite_template = """You are an expert at optimizing search queries for document retrieval.
400
+
401
+ Given the conversation history and a follow-up question, create a comprehensive standalone question that:
402
+ 1. Incorporates all relevant context from the chat history
403
+ 2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.)
404
+ 3. Includes key technical terms and concepts that would help find relevant documents
405
+ 4. Maintains the original intent, specificity, and detail level
406
+ 5. If the question asks for comparison or multiple items, ensure all items are in the query
407
+
408
+ Chat History:
409
+ {chat_history}
410
+
411
+ Follow-up Question: {question}
412
+
413
+ Optimized Standalone Question:"""
414
+ rewrite_prompt = ChatPromptTemplate.from_messages([
415
+ ("system", rewrite_template),
416
+ MessagesPlaceholder(variable_name="chat_history"),
417
+ ("human", "{question}")
418
+ ])
419
+ query_rewriter = rewrite_prompt | llm | StrOutputParser()
420
+
421
+ def format_docs(docs):
422
+ """Format retrieved documents with clear structure and metadata."""
423
+ if not docs:
424
+ return "No relevant documents found in the knowledge base."
425
+
426
+ formatted_parts = []
427
+ for i, doc in enumerate(docs, 1):
428
+ source = doc.metadata.get('source', 'Unknown Document')
429
+ page = doc.metadata.get('page', 'N/A')
430
+ rerank_score = doc.metadata.get('rerank_score')
431
+ content = doc.page_content.strip()
432
+
433
+ doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}"
434
+ metadata_line = f"Source: {source} | Page: {page}"
435
+ if rerank_score:
436
+ metadata_line += f" | Relevance: {rerank_score:.3f}"
437
+
438
+ formatted_parts.append(
439
+ f"{doc_header}\n"
440
+ f"{metadata_line}\n"
441
+ f"{'-'*60}\n"
442
+ f"{content}\n"
443
+ )
444
+ return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts)
445
+
446
+ # Get temperature-adaptive system prompt
447
+ rag_template = get_system_prompt(temperature)
448
+
449
  rag_prompt = ChatPromptTemplate.from_messages([
450
  ("system", rag_template),
451
  MessagesPlaceholder(variable_name="chat_history"),
452
  ("human", "{question}"),
453
  ])
454
 
455
+ # Rewriter input construction
456
+ rewriter_input = RunnableParallel({
457
+ "question": itemgetter("question"),
458
+ "chat_history": itemgetter("chat_history"),
459
+ })
460
+
461
+ # Main retrieval pipeline
462
+ retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs
463
+
464
+ # Final conversational RAG chain
465
+ conversational_rag_chain = RunnableParallel({
466
+ "context": retrieval_chain,
467
+ "question": itemgetter("question"),
468
+ "chat_history": itemgetter("chat_history"),
469
+ }) | rag_prompt | llm | StrOutputParser()
470
+
471
+ chain_with_memory = RunnableWithMessageHistory(
472
  conversational_rag_chain,
473
  get_session_history_func,
474
  input_messages_key="question",
475
  history_messages_key="chat_history",
476
  )
477
 
478
+ print("✅ RAG: Chain created successfully.")
479
+ print("\n" + api_key_manager.get_statistics())
480
 
481
+ return chain_with_memory, api_key_manager # Return manager for statistics
 
 
 
 
 
 
templates/index.html CHANGED
@@ -7,7 +7,7 @@
7
  <script src="https://cdn.tailwindcss.com"></script>
8
  <link rel="preconnect" href="https://fonts.googleapis.com">
9
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
- <link href="https://fonts.googleapis.com/css2?family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
11
  <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
12
  <style>
13
  :root {
@@ -20,23 +20,28 @@
20
  --input-bg: #e8f0fe;
21
  --user-bubble: #d9e7ff;
22
  --bot-bubble: #f1f3f4;
 
 
 
23
  }
24
 
25
- /* Dark mode styles */
26
  .dark {
27
- --background: #202124;
28
- --foreground: #e8eaed;
29
- --primary: #8ab4f8;
30
- --primary-hover: #99bdfa;
31
- --card: #303134;
32
- --card-border: #5f6368;
33
- --input-bg: #303134;
34
- --user-bubble: #3c4043;
35
- --bot-bubble: #3c4043;
 
 
 
36
  }
37
 
38
  body {
39
- font-family: 'Google Sans', 'Roboto', sans-serif;
40
  background-color: var(--background);
41
  color: var(--foreground);
42
  overflow: hidden;
@@ -44,15 +49,14 @@
44
 
45
  #chat-window::-webkit-scrollbar { width: 8px; }
46
  #chat-window::-webkit-scrollbar-track { background: transparent; }
47
- #chat-window::-webkit-scrollbar-thumb { background-color: #bdc1c6; border-radius: 20px; }
48
  .dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
49
 
50
  .drop-zone--over {
51
  border-color: var(--primary);
52
- box-shadow: 0 0 15px rgba(26, 115, 232, 0.3);
53
  }
54
 
55
- /* Loading Spinner */
56
  .loader {
57
  width: 48px;
58
  height: 48px;
@@ -82,7 +86,6 @@
82
  100% { transform: rotate(360deg); }
83
  }
84
 
85
- /* Typing Indicator Animation */
86
  .typing-indicator span {
87
  height: 10px;
88
  width: 10px;
@@ -98,173 +101,136 @@
98
  40% { transform: scale(1.0); }
99
  }
100
 
101
- /* Enhanced Markdown Styling for better readability and aesthetics */
102
- .markdown-content p {
103
- margin-bottom: 1rem;
104
- line-height: 1.75;
105
- }
106
- .markdown-content h1, .markdown-content h2, .markdown-content h3, .markdown-content h4 {
107
- font-family: 'Google Sans', sans-serif;
108
- font-weight: 700;
109
- margin-top: 1.75rem;
110
- margin-bottom: 1rem;
111
- line-height: 1.3;
112
- }
113
- .markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
114
- .markdown-content h2 { font-size: 1.5em; }
115
- .markdown-content h3 { font-size: 1.25em; }
116
- .markdown-content h4 { font-size: 1.1em; }
117
- .markdown-content ul, .markdown-content ol {
118
- padding-left: 1.75rem;
119
- margin-bottom: 1rem;
120
- }
121
- .markdown-content li {
122
- margin-bottom: 0.5rem;
123
- }
124
- .dark .markdown-content ul > li::marker { color: var(--primary); }
125
- .markdown-content ul > li::marker { color: var(--primary); }
126
- .markdown-content a {
127
- color: var(--primary);
128
- text-decoration: none;
129
- font-weight: 500;
130
- border-bottom: 1px solid transparent;
131
- transition: all 0.2s ease-in-out;
132
- }
133
- .markdown-content a:hover {
134
- border-bottom-color: var(--primary-hover);
135
- }
136
- .markdown-content blockquote {
137
- margin: 1.5rem 0;
138
- padding-left: 1.5rem;
139
- border-left: 4px solid var(--card-border);
140
- color: #6c757d;
141
- font-style: italic;
142
- }
143
- .dark .markdown-content blockquote {
144
- color: #adb5bd;
145
- }
146
- .markdown-content hr {
147
- border: none;
148
- border-top: 1px solid var(--card-border);
149
- margin: 2rem 0;
150
- }
151
- .markdown-content table {
152
- width: 100%;
153
- border-collapse: collapse;
154
- margin: 1.5rem 0;
155
- font-size: 0.9em;
156
- box-shadow: 0 1px 3px rgba(0,0,0,0.05);
157
- border-radius: 8px;
158
- overflow: hidden;
159
- }
160
- .markdown-content th, .markdown-content td {
161
- border: 1px solid var(--card-border);
162
- padding: 0.75rem 1rem;
163
- text-align: left;
164
- }
165
- .markdown-content th {
166
- background-color: var(--bot-bubble);
167
- font-weight: 500;
168
- }
169
- .markdown-content code {
170
- background-color: rgba(0,0,0,0.05);
171
- padding: 0.2rem 0.4rem;
172
- border-radius: 0.25rem;
173
- font-family: 'Roboto Mono', monospace;
174
- font-size: 0.9em;
175
- }
176
- .dark .markdown-content code {
177
- background-color: rgba(255,255,255,0.1);
178
- }
179
- .markdown-content pre {
180
- position: relative;
181
- background-color: #f8f9fa;
182
- border: 1px solid var(--card-border);
183
- border-radius: 0.5rem;
184
- margin-bottom: 1rem;
185
- }
186
- .dark .markdown-content pre {
187
- background-color: #2e2f32;
188
- }
189
- .markdown-content pre code {
190
- background: none;
191
- padding: 1rem;
192
- display: block;
193
- overflow-x: auto;
194
- }
195
- .markdown-content pre .copy-code-btn {
196
- position: absolute;
197
- top: 0.5rem;
198
- right: 0.5rem;
199
- background-color: #e8eaed;
200
- border: 1px solid #dadce0;
201
- color: #5f6368;
202
- padding: 0.3rem 0.6rem;
203
- border-radius: 0.25rem;
204
- cursor: pointer;
205
- opacity: 0;
206
- transition: opacity 0.2s;
207
- font-size: 0.8em;
208
- }
209
- .dark .markdown-content pre .copy-code-btn {
210
- background-color: #3c4043;
211
- border-color: #5f6368;
212
- color: #e8eaed;
213
- }
214
- .markdown-content pre:hover .copy-code-btn {
215
- opacity: 1;
216
- }
217
 
218
- /* Spinner for the TTS button */
219
  .tts-button-loader {
220
  width: 16px;
221
  height: 16px;
222
- border: 2px solid currentColor; /* Use button's text color */
223
  border-radius: 50%;
224
  display: inline-block;
225
  box-sizing: border-box;
226
  animation: rotation 0.8s linear infinite;
227
- border-bottom-color: transparent; /* Makes it a half circle spinner */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  }
229
  </style>
230
  </head>
231
  <body class="w-screen h-screen dark">
232
  <main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
233
  <div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
234
- <header class="text-center p-4 border-b border-[var(--card-border)] flex-shrink-0">
235
- <h1 class="text-xl font-medium">Chat with your Docs</h1>
236
- <p id="chat-filename" class="text-xs text-gray-500 dark:text-gray-400 mt-1"></p>
 
 
 
 
 
 
 
 
237
  </header>
 
 
238
  <div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
239
- <div id="chat-content" class="max-w-4xl mx-auto space-y-8">
240
- </div>
241
  </div>
242
- <div class="p-4 flex-shrink-0 bg-[var(--background)] border-t border-[var(--card-border)]">
243
- <form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-sm border border-transparent focus-within:border-[var(--primary)] transition-colors">
244
  <input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
245
- <button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed disabled:bg-gray-500" title="Send">
246
- <svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M3.49941 11.5556L11.555 3.5L12.4438 4.38889L6.27721 10.5556H21.9994V11.5556H6.27721L12.4438 17.7222L11.555 18.6111L3.49941 10.5556V11.5556Z" transform="rotate(180, 12.7497, 11.0556)" fill="currentColor"></path></svg>
247
  </button>
248
  </form>
249
  </div>
250
  </div>
251
 
252
  <div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
253
- <div class="text-center">
254
- <h1 class="text-5xl font-medium mb-4">Upload docs to chat</h1>
255
- <div id="drop-zone" class="w-full max-w-lg text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer bg-[var(--card)] hover:border-[var(--primary)]">
256
- <input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx,.jpg,.jpeg,.png" multiple title="input">
257
- <svg class="mx-auto h-12 w-12 text-gray-400" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" ><path stroke-linecap="round" stroke-linejoin="round" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
258
- <p class="mt-4 text-sm font-medium">Drag & drop files or click to upload</p>
259
- <p id="file-name" class="mt-2 text-xs text-gray-500"></p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  </div>
261
  </div>
262
  </div>
263
 
264
- <div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50 text-center p-4">
265
  <div class="loader"></div>
266
  <p id="loading-text" class="mt-6 text-sm font-medium"></p>
267
- <p id="loading-subtext" class="mt-2 text-xs text-gray-500 dark:text-gray-400"></p>
268
  </div>
269
  </main>
270
 
@@ -278,94 +244,89 @@
278
  const loadingOverlay = document.getElementById('loading-overlay');
279
  const loadingText = document.getElementById('loading-text');
280
  const loadingSubtext = document.getElementById('loading-subtext');
281
-
282
  const chatForm = document.getElementById('chat-form');
283
  const chatInput = document.getElementById('chat-input');
284
  const chatSubmitBtn = document.getElementById('chat-submit-btn');
285
  const chatWindow = document.getElementById('chat-window');
286
  const chatContent = document.getElementById('chat-content');
 
 
287
  const chatFilename = document.getElementById('chat-filename');
 
288
 
289
- let sessionId = null;
290
- const storedSessionId = sessionStorage.getItem('cognichat_session_id');
291
- if (storedSessionId) {
292
- sessionId = storedSessionId;
293
- console.debug('Restored session ID from storage:', sessionId);
294
- }
295
 
296
- // --- File Upload Logic ---
297
  dropZone.addEventListener('click', () => fileUploadInput.click());
298
 
299
  ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
300
- dropZone.addEventListener(eventName, preventDefaults, false);
301
- document.body.addEventListener(eventName, preventDefaults, false);
302
- });
303
-
304
- ['dragenter', 'dragover'].forEach(eventName => {
305
- dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over'));
306
  });
307
- ['dragleave', 'drop'].forEach(eventName => {
308
- dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over'));
309
- });
310
-
311
  dropZone.addEventListener('drop', (e) => {
312
- const files = e.dataTransfer.files;
313
- if (files.length > 0) handleFiles(files);
314
  });
315
-
316
  fileUploadInput.addEventListener('change', (e) => {
317
  if (e.target.files.length > 0) handleFiles(e.target.files);
318
  });
319
 
320
- function preventDefaults(e) { e.preventDefault(); e.stopPropagation(); }
321
-
322
  async function handleFiles(files) {
323
  const formData = new FormData();
324
- let fileNames = [];
325
- for (const file of files) {
326
- formData.append('file', file);
327
- fileNames.push(file.name);
328
- }
329
 
330
  fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
331
- await uploadAndProcessFiles(formData, fileNames);
332
  }
333
 
334
- async function uploadAndProcessFiles(formData, fileNames) {
335
  loadingOverlay.classList.remove('hidden');
336
- loadingText.textContent = `Processing ${fileNames.length} document(s)...`;
337
- loadingSubtext.textContent = "🤓Creating a knowledge base may take a minute or two. So please hold on tight";
338
 
339
  try {
340
  const response = await fetch('/upload', { method: 'POST', body: formData });
341
  const result = await response.json();
342
-
343
  if (!response.ok) throw new Error(result.message || 'Unknown error occurred.');
344
- if (result.session_id) {
345
- sessionId = result.session_id;
346
- sessionStorage.setItem('cognichat_session_id', sessionId);
347
- console.debug('Stored session ID from upload:', sessionId);
348
- } else {
349
- console.warn('Upload response missing session_id field.');
350
- }
351
 
352
- chatFilename.textContent = `Chatting with: ${result.filename}`;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  uploadContainer.classList.add('hidden');
354
  chatContainer.classList.remove('hidden');
355
- appendMessage("I've analyzed your documents. What would you like to know?", "bot");
356
 
357
  } catch (error) {
358
  console.error('Upload error:', error);
359
  alert(`Error: ${error.message}`);
360
  } finally {
361
  loadingOverlay.classList.add('hidden');
362
- loadingSubtext.textContent = '';
363
  fileNameSpan.textContent = '';
364
  fileUploadInput.value = '';
365
  }
366
  }
367
 
368
- // --- Chat Logic ---
369
  chatForm.addEventListener('submit', async (e) => {
370
  e.preventDefault();
371
  const question = chatInput.value.trim();
@@ -377,53 +338,36 @@
377
  chatSubmitBtn.disabled = true;
378
 
379
  const typingIndicator = showTypingIndicator();
380
- let botMessageContainer = null;
381
- let contentDiv = null;
382
-
383
  try {
384
- const requestBody = { question: question };
385
- if (sessionId) {
386
- requestBody.session_id = sessionId;
387
- }
388
-
389
  const response = await fetch('/chat', {
390
  method: 'POST',
391
  headers: { 'Content-Type': 'application/json' },
392
- body: JSON.stringify(requestBody),
393
  });
394
-
395
  if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
396
-
397
- // ============================ MODIFICATION START ==============================
398
- // Parse the JSON response instead of reading a stream
399
- const result = await response.json();
400
- const answer = result.answer; // Extract the 'answer' field
401
-
402
- if (!answer) {
403
- throw new Error("Received an empty or invalid response from the server.");
404
- }
405
 
 
 
 
 
 
 
 
 
 
 
406
  typingIndicator.remove();
407
- botMessageContainer = appendMessage('', 'bot');
408
- contentDiv = botMessageContainer.querySelector('.markdown-content');
409
-
410
- // Use the extracted answer for rendering
411
- contentDiv.innerHTML = marked.parse(answer);
412
  contentDiv.querySelectorAll('pre').forEach(addCopyButton);
413
- scrollToBottom(); // Scroll after content is added
414
-
415
- // Use the extracted answer for TTS
416
- addTextToSpeechControls(botMessageContainer, answer);
417
- // ============================ MODIFICATION END ==============================
418
-
419
  } catch (error) {
420
  console.error('Chat error:', error);
421
- if (typingIndicator) typingIndicator.remove();
422
- if (contentDiv) {
423
- contentDiv.innerHTML = `<p class="text-red-500">Error: ${error.message}</p>`;
424
- } else {
425
- appendMessage(`Error: ${error.message}`, 'bot');
426
- }
427
  } finally {
428
  chatInput.disabled = false;
429
  chatSubmitBtn.disabled = false;
@@ -431,166 +375,108 @@
431
  }
432
  });
433
 
434
- // --- UI Helper Functions ---
435
-
436
- function appendMessage(text, sender) {
437
  const messageWrapper = document.createElement('div');
438
- messageWrapper.className = `flex items-start gap-4`;
439
-
440
- const iconSVG = sender === 'user'
441
- ? `<div class="bg-blue-100 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1"><svg class="w-5 h-5 text-blue-600 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
442
  : `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
443
 
444
- const messageBubble = document.createElement('div');
445
- messageBubble.className = `flex-1 pt-1`;
446
-
447
- const senderName = document.createElement('p');
448
- senderName.className = 'font-medium text-sm mb-1';
449
- senderName.textContent = sender === 'user' ? 'You' : 'CogniChat';
450
-
451
- const contentDiv = document.createElement('div');
452
- contentDiv.className = 'text-base markdown-content';
453
- // Only parse if text is not empty
454
- if (text) {
455
- contentDiv.innerHTML = marked.parse(text);
 
456
  }
457
 
458
- const controlsContainer = document.createElement('div');
459
- controlsContainer.className = 'tts-controls mt-2';
460
-
461
- messageBubble.appendChild(senderName);
462
- messageBubble.appendChild(contentDiv);
463
- messageBubble.appendChild(controlsContainer);
464
- messageWrapper.innerHTML = iconSVG;
465
- messageWrapper.appendChild(messageBubble);
466
-
467
  chatContent.appendChild(messageWrapper);
468
  scrollToBottom();
469
-
470
- return messageBubble;
471
  }
472
 
473
  function showTypingIndicator() {
474
- const indicatorWrapper = document.createElement('div');
475
- indicatorWrapper.className = `flex items-start gap-4`;
476
- indicatorWrapper.id = 'typing-indicator';
477
-
478
- const iconSVG = `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
479
-
480
- const messageBubble = document.createElement('div');
481
- messageBubble.className = 'flex-1 pt-1';
482
-
483
- const senderName = document.createElement('p');
484
- senderName.className = 'font-medium text-sm mb-1';
485
- senderName.textContent = 'CogniChat is thinking...';
486
-
487
  const indicator = document.createElement('div');
488
- indicator.className = 'typing-indicator';
489
- indicator.innerHTML = '<span></span><span></span><span></span>';
490
-
491
- messageBubble.appendChild(senderName);
492
- messageBubble.appendChild(indicator);
493
- indicatorWrapper.innerHTML = iconSVG;
494
- indicatorWrapper.appendChild(messageBubble);
495
-
496
- chatContent.appendChild(indicatorWrapper);
497
  scrollToBottom();
498
-
499
- return indicatorWrapper;
500
  }
501
 
502
- function scrollToBottom() {
503
- chatWindow.scrollTo({
504
- top: chatWindow.scrollHeight,
505
- behavior: 'smooth'
506
- });
507
- }
508
-
509
  function addCopyButton(pre) {
510
  const button = document.createElement('button');
511
  button.className = 'copy-code-btn';
512
  button.textContent = 'Copy';
513
  pre.appendChild(button);
514
-
515
  button.addEventListener('click', () => {
516
- const code = pre.querySelector('code').innerText;
517
- navigator.clipboard.writeText(code).then(() => {
518
- button.textContent = 'Copied!';
519
- setTimeout(() => button.textContent = 'Copy', 2000);
520
- });
521
  });
522
  }
523
 
524
- // --- Text-to-Speech Logic ---
525
- let currentAudio = null;
526
- let currentPlayingButton = null;
527
-
528
  const playIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M8 5v14l11-7z"/></svg>`;
529
  const pauseIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M6 19h4V5H6v14zm8-14v14h4V5h-4z"/></svg>`;
530
-
531
-
532
  function addTextToSpeechControls(messageBubble, text) {
533
- const ttsControls = messageBubble.querySelector('.tts-controls');
534
- if (text.trim().length > 0) {
535
- const speakButton = document.createElement('button');
536
- speakButton.className = 'speak-btn px-4 py-2 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50 disabled:cursor-not-allowed';
537
- speakButton.title = 'Listen to this message';
538
- speakButton.setAttribute('data-state', 'play');
539
- speakButton.innerHTML = `${playIconSVG} <span>Play</span>`;
540
- ttsControls.appendChild(speakButton);
541
- speakButton.addEventListener('click', () => handleTTS(text, speakButton));
542
- }
543
  }
544
 
545
  async function handleTTS(text, button) {
546
  if (button === currentPlayingButton) {
547
  if (currentAudio && !currentAudio.paused) {
548
  currentAudio.pause();
549
- button.setAttribute('data-state', 'paused');
550
- button.innerHTML = `${playIconSVG} <span>Play</span>`;
551
  } else if (currentAudio && currentAudio.paused) {
552
  currentAudio.play();
553
- button.setAttribute('data-state', 'playing');
554
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
555
  }
556
  return;
557
  }
558
-
559
  resetAllSpeakButtons();
560
-
561
  currentPlayingButton = button;
562
- button.setAttribute('data-state', 'loading');
563
  button.innerHTML = `<div class="tts-button-loader"></div> <span>Loading...</span>`;
564
  button.disabled = true;
565
 
566
  try {
567
- const response = await fetch('/tts', {
568
- method: 'POST',
569
- headers: { 'Content-Type': 'application/json' },
570
- body: JSON.stringify({ text: text })
571
- });
572
  if (!response.ok) throw new Error('Failed to generate audio.');
573
-
574
  const blob = await response.blob();
575
- const audioUrl = URL.createObjectURL(blob);
576
- currentAudio = new Audio(audioUrl);
577
  currentAudio.play();
578
-
579
- button.setAttribute('data-state', 'playing');
580
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
581
-
582
- currentAudio.onended = () => {
583
- button.setAttribute('data-state', 'play');
584
- button.innerHTML = `${playIconSVG} <span>Play</span>`;
585
- currentAudio = null;
586
- currentPlayingButton = null;
587
- };
588
-
589
  } catch (error) {
590
  console.error('TTS Error:', error);
591
- button.setAttribute('data-state', 'error');
592
- button.innerHTML = `${playIconSVG} <span>Error</span>`;
593
- alert('Failed to play audio. Please try again.');
594
  resetAllSpeakButtons();
595
  } finally {
596
  button.disabled = false;
@@ -599,8 +485,7 @@
599
 
600
  function resetAllSpeakButtons() {
601
  document.querySelectorAll('.speak-btn').forEach(btn => {
602
- btn.setAttribute('data-state', 'play');
603
- btn.innerHTML = `${playIconSVG} <span>Play</span>`;
604
  btn.disabled = false;
605
  });
606
  if (currentAudio) {
 
7
  <script src="https://cdn.tailwindcss.com"></script>
8
  <link rel="preconnect" href="https://fonts.googleapis.com">
9
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
11
  <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
12
  <style>
13
  :root {
 
20
  --input-bg: #e8f0fe;
21
  --user-bubble: #d9e7ff;
22
  --bot-bubble: #f1f3f4;
23
+ --select-bg: #ffffff;
24
+ --select-border: #dadce0;
25
+ --select-text: #1f1f1f;
26
  }
27
 
 
28
  .dark {
29
+ --background: #111827;
30
+ --foreground: #e5e7eb;
31
+ --primary: #3b82f6;
32
+ --primary-hover: #60a5fa;
33
+ --card: #1f2937;
34
+ --card-border: #4b5563;
35
+ --input-bg: #374151;
36
+ --user-bubble: #374151;
37
+ --bot-bubble: #374151;
38
+ --select-bg: #374151;
39
+ --select-border: #6b7280;
40
+ --select-text: #f3f4f6;
41
  }
42
 
43
  body {
44
+ font-family: 'Inter', 'Google Sans', 'Roboto', sans-serif;
45
  background-color: var(--background);
46
  color: var(--foreground);
47
  overflow: hidden;
 
49
 
50
  #chat-window::-webkit-scrollbar { width: 8px; }
51
  #chat-window::-webkit-scrollbar-track { background: transparent; }
52
+ #chat-window::-webkit-scrollbar-thumb { background-color: #4b5563; border-radius: 20px; }
53
  .dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
54
 
55
  .drop-zone--over {
56
  border-color: var(--primary);
57
+ box-shadow: 0 0 20px rgba(59, 130, 246, 0.4);
58
  }
59
 
 
60
  .loader {
61
  width: 48px;
62
  height: 48px;
 
86
  100% { transform: rotate(360deg); }
87
  }
88
 
 
89
  .typing-indicator span {
90
  height: 10px;
91
  width: 10px;
 
101
  40% { transform: scale(1.0); }
102
  }
103
 
104
+ .markdown-content p { margin-bottom: 1rem; line-height: 1.75; }
105
+ .markdown-content h1, .markdown-content h2, .markdown-content h3 { font-weight: 600; margin-top: 1.5rem; margin-bottom: 0.75rem; line-height: 1.3; }
106
+ .markdown-content h1 { font-size: 1.5em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3rem;}
107
+ .markdown-content h2 { font-size: 1.25em; }
108
+ .markdown-content h3 { font-size: 1.1em; }
109
+ .markdown-content ul, .markdown-content ol { padding-left: 1.75rem; margin-bottom: 1rem; }
110
+ .markdown-content li { margin-bottom: 0.5rem; }
111
+ .markdown-content a { color: var(--primary); text-decoration: none; font-weight: 500; }
112
+ .markdown-content pre { position: relative; background-color: #2e2f32; border: 1px solid var(--card-border); border-radius: 0.5rem; margin-bottom: 1rem; font-size: 0.9em;}
113
+ .markdown-content pre code { background: none; padding: 1rem; display: block; overflow-x: auto; }
114
+ .markdown-content pre .copy-code-btn { position: absolute; top: 0.5rem; right: 0.5rem; background-color: #3c4043; border: 1px solid #5f6368; color: #e8eaed; padding: 0.3rem 0.6rem; border-radius: 0.25rem; cursor: pointer; opacity: 0; transition: opacity 0.2s; font-size: 0.8em;}
115
+ .markdown-content pre:hover .copy-code-btn { opacity: 1; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
117
  .tts-button-loader {
118
  width: 16px;
119
  height: 16px;
120
+ border: 2px solid currentColor;
121
  border-radius: 50%;
122
  display: inline-block;
123
  box-sizing: border-box;
124
  animation: rotation 0.8s linear infinite;
125
+ border-bottom-color: transparent;
126
+ }
127
+
128
+ .select-wrapper {
129
+ position: relative;
130
+ }
131
+ .select-wrapper select {
132
+ background-color: var(--select-bg);
133
+ border: 1px solid var(--select-border);
134
+ color: var(--select-text);
135
+ padding: 0.75rem 2.5rem 0.75rem 1rem;
136
+ border-radius: 0.75rem;
137
+ font-size: 0.875rem;
138
+ width: 100%;
139
+ appearance: none;
140
+ -webkit-appearance: none;
141
+ transition: all 0.2s ease-in-out;
142
+ cursor: pointer;
143
+ background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
144
+ background-position: right 0.75rem center;
145
+ background-repeat: no-repeat;
146
+ background-size: 1.25em 1.25em;
147
  }
148
  </style>
149
  </head>
150
  <body class="w-screen h-screen dark">
151
  <main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
152
  <div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
153
+
154
+ <!-- --- CORRECT HEADER (Center/Right Layout) --- -->
155
+ <header class="p-4 border-b border-[var(--card-border)] flex-shrink-0 flex justify-between items-center w-full">
156
+ <div class="w-1/4"></div>
157
+ <div class="w-1/2 text-center">
158
+ <h1 class="text-xl font-medium tracking-wide">CogniChat</h1>
159
+ <p id="chat-filename" class="text-xs text-gray-400 mt-1 truncate"></p>
160
+ </div>
161
+ <div id="chat-session-info" class="w-1/4 text-right text-xs">
162
+ <!-- This will be populated by JavaScript -->
163
+ </div>
164
  </header>
165
+ <!-- --- END HEADER --- -->
166
+
167
  <div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
168
+ <div id="chat-content" class="max-w-4xl mx-auto space-y-8"></div>
 
169
  </div>
170
+ <div class="p-4 flex-shrink-0 bg-opacity-50 backdrop-blur-md border-t border-[var(--card-border)]">
171
+ <form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-lg border border-[var(--card-border)] focus-within:ring-2 focus-within:ring-[var(--primary)] transition-all">
172
  <input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
173
+ <button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2.5 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed" title="Send">
174
+ <svg class="w-5 h-5" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-8.707l-3-3a1 1 0 00-1.414 1.414L10.586 9H7a1 1 0 100 2h3.586l-1.293 1.293a1 1 0 101.414 1.414l3-3a1 1 0 000-1.414z" clip-rule="evenodd"></path></svg>
175
  </button>
176
  </form>
177
  </div>
178
  </div>
179
 
180
  <div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
181
+ <div class="text-center max-w-xl w-full">
182
+ <h1 class="text-5xl font-bold mb-3 tracking-tight">CogniChat</h1>
183
+ <p class="text-lg text-gray-400 mb-8">Upload your documents to start a conversation.</p>
184
+ <div class="mb-8 p-5 bg-[var(--card)] rounded-2xl border border-[var(--card-border)] shadow-lg">
185
+ <div class="flex flex-col sm:flex-row items-center gap-6">
186
+ <div class="w-full sm:w-1/2">
187
+ <div class="flex items-center gap-2 mb-2">
188
+ <svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path d="M7 3a1 1 0 000 2h6a1 1 0 100-2H7zM4 7a1 1 0 011-1h10a1 1 0 110 2H5a1 1 0 01-1-1zM2 11a2 2 0 012-2h12a2 2 0 012 2v4a2 2 0 01-2 2H4a2 2 0 01-2-2v-4z" /></svg>
189
+ <label for="model-select" class="block text-sm font-medium text-gray-300">Model</label>
190
+ </div>
191
+ <div class="select-wrapper">
192
+ <select id="model-select" name="model_name">
193
+ <option value="moonshotai/kimi-k2-instruct" selected>Kimi Instruct</option>
194
+ <option value="openai/gpt-oss-20b">GPT OSS 20b</option>
195
+ <option value="llama-3.3-70b-versatile">Llama 3.3 70b</option>
196
+ <option value="llama-3.1-8b-instant">Llama 3.1 8b Instant</option>
197
+ </select>
198
+ </div>
199
+ </div>
200
+ <div class="w-full sm:w-1/2">
201
+ <div class="flex items-center gap-2 mb-2">
202
+ <svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M5.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7zM12 5.5a3.5 3.5 0 11-7 0 3.5 3.5 0 017 0zM14.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7z" clip-rule="evenodd" /></svg>
203
+ <label for="temperature-select" class="block text-sm font-medium text-gray-300">Mode</label>
204
+ </div>
205
+ <div class="select-wrapper">
206
+ <select id="temperature-select" name="temperature">
207
+ <option value="0.2" selected>0.2 - Precise</option>
208
+ <option value="0.4">0.4 - Confident</option>
209
+ <option value="0.6">0.6 - Balanced</option>
210
+ <option value="0.8">0.8 - Flexible</option>
211
+ <option value="1.0">1.0 - Creative</option>
212
+ </select>
213
+ </div>
214
+ </div>
215
+ </div>
216
+ <p class="text-xs text-gray-500 mt-4 text-center">Higher creativity modes may reduce factual accuracy.</p>
217
+ </div>
218
+ <div id="drop-zone" class="w-full text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer hover:bg-[var(--card)] hover:border-[var(--primary)]">
219
+ <div class="flex flex-col items-center justify-center pointer-events-none">
220
+ <svg class="mx-auto h-12 w-12 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
221
+ <p class="mt-4 text-sm font-medium text-gray-400">Drag & drop files or <span class="text-[var(--primary)] font-semibold">click to upload</span></p>
222
+ <p class="text-xs text-gray-400 mt-1">Supports PDF, DOCX, TXT</p>
223
+ <p id="file-name" class="mt-2 text-xs text-gray-500"></p>
224
+ </div>
225
+ <input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx" multiple>
226
  </div>
227
  </div>
228
  </div>
229
 
230
+ <div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50">
231
  <div class="loader"></div>
232
  <p id="loading-text" class="mt-6 text-sm font-medium"></p>
233
+ <p id="loading-subtext" class="mt-2 text-xs text-gray-400"></p>
234
  </div>
235
  </main>
236
 
 
244
  const loadingOverlay = document.getElementById('loading-overlay');
245
  const loadingText = document.getElementById('loading-text');
246
  const loadingSubtext = document.getElementById('loading-subtext');
 
247
  const chatForm = document.getElementById('chat-form');
248
  const chatInput = document.getElementById('chat-input');
249
  const chatSubmitBtn = document.getElementById('chat-submit-btn');
250
  const chatWindow = document.getElementById('chat-window');
251
  const chatContent = document.getElementById('chat-content');
252
+ const modelSelect = document.getElementById('model-select');
253
+ const temperatureSelect = document.getElementById('temperature-select');
254
  const chatFilename = document.getElementById('chat-filename');
255
+ const chatSessionInfo = document.getElementById('chat-session-info');
256
 
257
+ let sessionId = sessionStorage.getItem('cognichat_session_id');
 
 
 
 
 
258
 
 
259
  dropZone.addEventListener('click', () => fileUploadInput.click());
260
 
261
  ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
262
+ dropZone.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();});
 
 
 
 
 
263
  });
264
+ ['dragenter', 'dragover'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over')));
265
+ ['dragleave', 'drop'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over')));
266
+
 
267
  dropZone.addEventListener('drop', (e) => {
268
+ if (e.dataTransfer.files.length > 0) handleFiles(e.dataTransfer.files);
 
269
  });
 
270
  fileUploadInput.addEventListener('change', (e) => {
271
  if (e.target.files.length > 0) handleFiles(e.target.files);
272
  });
273
 
 
 
274
  async function handleFiles(files) {
275
  const formData = new FormData();
276
+ let fileNames = Array.from(files).map(f => f.name);
277
+ for (const file of files) { formData.append('file', file); }
278
+
279
+ formData.append('model_name', modelSelect.value);
280
+ formData.append('temperature', temperatureSelect.value);
281
 
282
  fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
283
+ await uploadAndProcessFiles(formData);
284
  }
285
 
286
+ async function uploadAndProcessFiles(formData) {
287
  loadingOverlay.classList.remove('hidden');
288
+ loadingText.textContent = `Processing document(s)...`;
289
+ loadingSubtext.textContent = "Creating a knowledge base may take a minute. Please hold on tight!";
290
 
291
  try {
292
  const response = await fetch('/upload', { method: 'POST', body: formData });
293
  const result = await response.json();
 
294
  if (!response.ok) throw new Error(result.message || 'Unknown error occurred.');
 
 
 
 
 
 
 
295
 
296
+ sessionId = result.session_id;
297
+ sessionStorage.setItem('cognichat_session_id', sessionId);
298
+
299
+ chatFilename.innerHTML = `Chatting with: <strong>${result.filename}</strong>`;
300
+ chatFilename.title = result.filename;
301
+
302
+ chatSessionInfo.innerHTML = `
303
+ <span class="text-gray-500 dark:text-gray-500 italic block hover:text-gray-300 transition-colors cursor-pointer" onclick="location.reload()">
304
+ Refresh to change settings
305
+ </span>`;
306
+
307
+ const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
308
+ const simpleModelName = modelOption ? modelOption.textContent : result.model_name;
309
+
310
+ const modelInfo = {
311
+ model: result.model_name,
312
+ mode: result.mode,
313
+ simpleModelName: simpleModelName
314
+ };
315
+
316
  uploadContainer.classList.add('hidden');
317
  chatContainer.classList.remove('hidden');
318
+ appendMessage("I've analyzed your documents. What would you like to know?", "bot", modelInfo);
319
 
320
  } catch (error) {
321
  console.error('Upload error:', error);
322
  alert(`Error: ${error.message}`);
323
  } finally {
324
  loadingOverlay.classList.add('hidden');
 
325
  fileNameSpan.textContent = '';
326
  fileUploadInput.value = '';
327
  }
328
  }
329
 
 
330
  chatForm.addEventListener('submit', async (e) => {
331
  e.preventDefault();
332
  const question = chatInput.value.trim();
 
338
  chatSubmitBtn.disabled = true;
339
 
340
  const typingIndicator = showTypingIndicator();
341
+
 
 
342
  try {
 
 
 
 
 
343
  const response = await fetch('/chat', {
344
  method: 'POST',
345
  headers: { 'Content-Type': 'application/json' },
346
+ body: JSON.stringify({ question, session_id: sessionId }),
347
  });
 
348
  if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
 
 
 
 
 
 
 
 
 
349
 
350
+ const result = await response.json();
351
+
352
+ const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
353
+ const simpleModelName = modelOption ? modelOption.textContent : result.model_name;
354
+ const modelInfo = {
355
+ model: result.model_name,
356
+ mode: result.mode,
357
+ simpleModelName: simpleModelName
358
+ };
359
+
360
  typingIndicator.remove();
361
+ const botMessageContainer = appendMessage('', 'bot', modelInfo);
362
+ const contentDiv = botMessageContainer.querySelector('.markdown-content');
363
+ contentDiv.innerHTML = marked.parse(result.answer);
 
 
364
  contentDiv.querySelectorAll('pre').forEach(addCopyButton);
365
+ scrollToBottom();
366
+ addTextToSpeechControls(botMessageContainer, result.answer);
 
 
 
 
367
  } catch (error) {
368
  console.error('Chat error:', error);
369
+ typingIndicator.remove();
370
+ appendMessage(`Error: ${error.message}`, 'bot');
 
 
 
 
371
  } finally {
372
  chatInput.disabled = false;
373
  chatSubmitBtn.disabled = false;
 
375
  }
376
  });
377
 
378
+ // --- FINAL, CORRECT appendMessage function ---
379
+ function appendMessage(text, sender, modelInfo = null) {
 
380
  const messageWrapper = document.createElement('div');
381
+ const iconSVG = sender === 'user'
382
+ ? `<div class="bg-blue-200 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1"><svg class="w-5 h-5 text-blue-700 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
 
 
383
  : `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
384
 
385
+ let senderHTML;
386
+ if (sender === 'user') {
387
+ senderHTML = '<p class="font-medium text-sm mb-1">You</p>';
388
+ } else {
389
+ let modelInfoHTML = '';
390
+ if (modelInfo && modelInfo.simpleModelName) {
391
+ modelInfoHTML = `
392
+ <span class="ml-2 text-xs font-normal text-gray-400">
393
+ (Model: ${modelInfo.simpleModelName} Mode: ${modelInfo.mode})
394
+ </span>
395
+ `;
396
+ }
397
+ senderHTML = `<div class="font-medium text-sm mb-1 flex items-center">CogniChat ${modelInfoHTML}</div>`;
398
  }
399
 
400
+ messageWrapper.className = `flex items-start gap-4`;
401
+ messageWrapper.innerHTML = `
402
+ ${iconSVG}
403
+ <div class="flex-1 pt-1">
404
+ ${senderHTML}
405
+ <div class="text-base markdown-content">${text ? marked.parse(text) : ''}</div>
406
+ <div class="tts-controls mt-2"></div>
407
+ </div>
408
+ `;
409
  chatContent.appendChild(messageWrapper);
410
  scrollToBottom();
411
+ return messageWrapper.querySelector('.flex-1');
 
412
  }
413
 
414
  function showTypingIndicator() {
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  const indicator = document.createElement('div');
416
+ indicator.id = 'typing-indicator';
417
+ indicator.className = `flex items-start gap-4`;
418
+ indicator.innerHTML = `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div><div class="flex-1 pt-1"><p class="font-medium text-sm mb-1">CogniChat is thinking...</p><div class="typing-indicator"><span></span><span></span><span></span></div></div>`;
419
+ chatContent.appendChild(indicator);
 
 
 
 
 
420
  scrollToBottom();
421
+ return indicator;
 
422
  }
423
 
424
+ function scrollToBottom() { chatWindow.scrollTo({ top: chatWindow.scrollHeight, behavior: 'smooth' }); }
425
+
 
 
 
 
 
426
  function addCopyButton(pre) {
427
  const button = document.createElement('button');
428
  button.className = 'copy-code-btn';
429
  button.textContent = 'Copy';
430
  pre.appendChild(button);
 
431
  button.addEventListener('click', () => {
432
+ navigator.clipboard.writeText(pre.querySelector('code').innerText)
433
+ .then(() => {
434
+ button.textContent = 'Copied!';
435
+ setTimeout(() => button.textContent = 'Copy', 2000);
436
+ });
437
  });
438
  }
439
 
440
+ // (TTS functions remain unchanged)
441
+ let currentAudio, currentPlayingButton;
 
 
442
  const playIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M8 5v14l11-7z"/></svg>`;
443
  const pauseIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M6 19h4V5H6v14zm8-14v14h4V5h-4z"/></svg>`;
 
 
444
  function addTextToSpeechControls(messageBubble, text) {
445
+ if (!text.trim()) return;
446
+ const speakButton = document.createElement('button');
447
+ speakButton.className = 'speak-btn mt-2 px-3 py-1.5 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50';
448
+ speakButton.title = 'Listen to this message';
449
+ speakButton.innerHTML = `${playIconSVG} <span>Listen</span>`;
450
+ messageBubble.querySelector('.tts-controls').appendChild(speakButton);
451
+ speakButton.addEventListener('click', () => handleTTS(text, speakButton));
 
 
 
452
  }
453
 
454
  async function handleTTS(text, button) {
455
  if (button === currentPlayingButton) {
456
  if (currentAudio && !currentAudio.paused) {
457
  currentAudio.pause();
458
+ button.innerHTML = `${playIconSVG} <span>Listen</span>`;
 
459
  } else if (currentAudio && currentAudio.paused) {
460
  currentAudio.play();
 
461
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
462
  }
463
  return;
464
  }
 
465
  resetAllSpeakButtons();
 
466
  currentPlayingButton = button;
 
467
  button.innerHTML = `<div class="tts-button-loader"></div> <span>Loading...</span>`;
468
  button.disabled = true;
469
 
470
  try {
471
+ const response = await fetch('/tts', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ text }) });
 
 
 
 
472
  if (!response.ok) throw new Error('Failed to generate audio.');
 
473
  const blob = await response.blob();
474
+ currentAudio = new Audio(URL.createObjectURL(blob));
 
475
  currentAudio.play();
 
 
476
  button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
477
+ currentAudio.onended = resetAllSpeakButtons;
 
 
 
 
 
 
 
478
  } catch (error) {
479
  console.error('TTS Error:', error);
 
 
 
480
  resetAllSpeakButtons();
481
  } finally {
482
  button.disabled = false;
 
485
 
486
  function resetAllSpeakButtons() {
487
  document.querySelectorAll('.speak-btn').forEach(btn => {
488
+ btn.innerHTML = `${playIconSVG} <span>Listen</span>`;
 
489
  btn.disabled = false;
490
  });
491
  if (currentAudio) {