riteshraut
commited on
Commit
·
f7d42c1
1
Parent(s):
2f22e27
fix
Browse files- - Copy.gitattributes +0 -35
- - Copy.gitignore +0 -6
- .env - Copy.example +0 -2
- app.py +207 -172
- evaluate.py +205 -0
- query_expansion.py +525 -0
- rag_processor.py +446 -60
- templates/index.html +230 -345
- Copy.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- Copy.gitignore
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
.env
|
| 2 |
-
/uploads/
|
| 3 |
-
/vectorstores/
|
| 4 |
-
/.cache/
|
| 5 |
-
__pycache__/
|
| 6 |
-
*.pyc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.env - Copy.example
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
# Copy this file to .env and fill in your API key
|
| 2 |
-
GROQ_API_KEY=your_groq_api_key_here
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import time
|
| 3 |
import uuid
|
| 4 |
from flask import Flask, request, render_template, session, jsonify, Response
|
| 5 |
from werkzeug.utils import secure_filename
|
|
@@ -10,55 +9,143 @@ import re
|
|
| 10 |
import io
|
| 11 |
from gtts import gTTS
|
| 12 |
from langchain_core.documents import Document
|
| 13 |
-
|
| 14 |
-
from langchain_community.document_loaders import (
|
| 15 |
-
TextLoader,
|
| 16 |
-
Docx2txtLoader,
|
| 17 |
-
)
|
| 18 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 19 |
-
from langchain_experimental.text_splitter import SemanticChunker
|
| 20 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 21 |
from langchain_community.vectorstores import FAISS
|
| 22 |
-
from langchain.retrievers import EnsembleRetriever
|
|
|
|
| 23 |
from langchain_community.retrievers import BM25Retriever
|
| 24 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 25 |
from langchain.storage import InMemoryStore
|
| 26 |
-
|
| 27 |
|
| 28 |
app = Flask(__name__)
|
| 29 |
app.config['SECRET_KEY'] = os.urandom(24)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 32 |
-
if is_hf_spaces
|
| 33 |
-
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 34 |
-
else:
|
| 35 |
-
app.config['UPLOAD_FOLDER'] = 'uploads'
|
| 36 |
|
| 37 |
try:
|
| 38 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 39 |
-
print(f"Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 40 |
except Exception as e:
|
| 41 |
-
print(f"Failed to create upload folder
|
| 42 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 43 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 44 |
-
print(f"Using fallback upload folder: {app.config['UPLOAD_FOLDER']}")
|
| 45 |
|
| 46 |
-
|
| 47 |
message_histories = {}
|
| 48 |
-
doc_stores = {} # To hold the InMemoryStore for each session
|
| 49 |
|
| 50 |
-
print("Loading embedding model...")
|
| 51 |
try:
|
| 52 |
-
hf_token = os.getenv("HF_TOKEN")
|
| 53 |
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 54 |
-
model_name="
|
| 55 |
model_kwargs={'device': 'cpu'},
|
| 56 |
-
encode_kwargs={'normalize_embeddings': True}
|
| 57 |
)
|
| 58 |
-
print("Embedding model loaded
|
| 59 |
-
except Exception as e:
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def load_pdf_with_fallback(filepath):
|
| 64 |
try:
|
|
@@ -66,51 +153,40 @@ def load_pdf_with_fallback(filepath):
|
|
| 66 |
with fitz.open(filepath) as pdf_doc:
|
| 67 |
for page_num, page in enumerate(pdf_doc):
|
| 68 |
text = page.get_text()
|
| 69 |
-
if text.strip():
|
| 70 |
-
docs.append(Document(
|
| 71 |
-
page_content=text,
|
| 72 |
-
metadata={
|
| 73 |
-
"source": os.path.basename(filepath),
|
| 74 |
-
"page": page_num + 1,
|
| 75 |
-
}
|
| 76 |
-
))
|
| 77 |
if docs:
|
| 78 |
-
print(f"
|
| 79 |
return docs
|
| 80 |
-
else:
|
| 81 |
-
|
| 82 |
-
except Exception as e:
|
| 83 |
-
print(f"PyMuPDF failed for {filepath}: {e}")
|
| 84 |
-
raise
|
| 85 |
|
| 86 |
-
LOADER_MAPPING = {
|
| 87 |
-
".txt": TextLoader,
|
| 88 |
-
".pdf": load_pdf_with_fallback,
|
| 89 |
-
".docx": Docx2txtLoader,
|
| 90 |
-
}
|
| 91 |
|
| 92 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
| 93 |
-
if session_id not in message_histories:
|
| 94 |
-
message_histories[session_id] = ChatMessageHistory()
|
| 95 |
return message_histories[session_id]
|
| 96 |
|
| 97 |
@app.route('/health', methods=['GET'])
|
| 98 |
-
def health_check():
|
| 99 |
-
return jsonify({'status': 'healthy'}), 200
|
| 100 |
|
| 101 |
@app.route('/', methods=['GET'])
|
| 102 |
-
def index():
|
| 103 |
-
return render_template('index.html')
|
| 104 |
|
| 105 |
@app.route('/upload', methods=['POST'])
|
| 106 |
def upload_files():
|
| 107 |
files = request.files.getlist('file')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if not files or all(f.filename == '' for f in files):
|
| 109 |
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
|
| 110 |
|
| 111 |
-
all_docs = []
|
| 112 |
-
|
| 113 |
-
|
| 114 |
for file in files:
|
| 115 |
if file and file.filename:
|
| 116 |
filename = secure_filename(file.filename)
|
|
@@ -118,169 +194,128 @@ def upload_files():
|
|
| 118 |
try:
|
| 119 |
file.save(filepath)
|
| 120 |
file_ext = os.path.splitext(filename)[1].lower()
|
| 121 |
-
if file_ext not in LOADER_MAPPING:
|
| 122 |
-
raise ValueError("Unsupported file format.")
|
| 123 |
-
|
| 124 |
loader_func = LOADER_MAPPING[file_ext]
|
| 125 |
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
|
| 126 |
-
|
| 127 |
-
if not docs:
|
| 128 |
-
raise ValueError("No content extracted.")
|
| 129 |
-
|
| 130 |
all_docs.extend(docs)
|
| 131 |
processed_files.append(filename)
|
| 132 |
-
print(f"✓ Successfully processed: {filename}")
|
| 133 |
except Exception as e:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
failed_files.append(f"{filename} ({error_msg})")
|
| 137 |
|
| 138 |
if not all_docs:
|
| 139 |
-
|
| 140 |
-
if failed_files:
|
| 141 |
-
error_summary += " Reasons: " + ", ".join(failed_files)
|
| 142 |
-
return jsonify({'status': 'error', 'message': error_summary}), 400
|
| 143 |
|
|
|
|
| 144 |
try:
|
| 145 |
-
print("Starting RAG pipeline setup...")
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
separators=["\n\n", "\n", ". ", " ", ""], # Prioritize natural breaks
|
| 149 |
-
length_function=len)
|
| 150 |
-
child_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=100)
|
| 151 |
-
|
| 152 |
-
parent_docs = parent_splitter.split_documents(all_docs)
|
| 153 |
-
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 154 |
-
|
| 155 |
-
child_docs = []
|
| 156 |
-
for i, doc in enumerate(parent_docs):
|
| 157 |
-
_id = doc_ids[i]
|
| 158 |
-
sub_docs = child_splitter.split_documents([doc])
|
| 159 |
-
for child in sub_docs:
|
| 160 |
-
child.metadata["doc_id"] = _id
|
| 161 |
-
child_docs.extend(sub_docs)
|
| 162 |
-
|
| 163 |
-
store = InMemoryStore()
|
| 164 |
-
store.mset(list(zip(doc_ids, parent_docs)))
|
| 165 |
|
| 166 |
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
|
| 167 |
-
|
| 168 |
-
print(f"
|
| 169 |
-
|
| 170 |
-
bm25_retriever = BM25Retriever.from_documents(child_docs)
|
| 171 |
-
bm25_retriever.k = 3
|
| 172 |
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
ensemble_retriever = EnsembleRetriever(
|
| 176 |
-
retrievers=[bm25_retriever, faiss_retriever],
|
| 177 |
-
weights=[0.4, 0.6]
|
| 178 |
-
)
|
| 179 |
-
print("Created Hybrid Retriever for child documents.")
|
| 180 |
-
|
| 181 |
session_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
|
|
|
| 184 |
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
success_msg = f"Successfully processed: {', '.join(processed_files)}"
|
| 191 |
-
if failed_files:
|
| 192 |
-
success_msg += f"\nFailed to process: {', '.join(failed_files)}"
|
| 193 |
-
|
| 194 |
return jsonify({
|
| 195 |
'status': 'success',
|
| 196 |
-
'filename': success_msg,
|
| 197 |
-
'session_id': session_id
|
|
|
|
|
|
|
| 198 |
})
|
| 199 |
-
|
| 200 |
except Exception as e:
|
| 201 |
-
import traceback
|
| 202 |
-
|
| 203 |
-
return jsonify({'status': 'error', 'message': f'Failed during RAG setup: {e}'}), 500
|
| 204 |
|
| 205 |
@app.route('/chat', methods=['POST'])
|
| 206 |
def chat():
|
| 207 |
data = request.get_json()
|
| 208 |
-
question = data.get('question')
|
| 209 |
-
session_id = session.get('session_id') or data.get('session_id')
|
| 210 |
|
| 211 |
-
if not question
|
| 212 |
-
|
|
|
|
|
|
|
| 213 |
|
| 214 |
try:
|
| 215 |
-
|
| 216 |
-
|
| 217 |
|
| 218 |
-
|
| 219 |
-
print("--- STARTING DIAGNOSTIC RUN ---")
|
| 220 |
-
print(f"Original Question: {question}")
|
| 221 |
-
print("="*50 + "\n")
|
| 222 |
-
|
| 223 |
-
rewritten_query = chain_components["rewriter"].invoke({"question": question, "chat_history": get_session_history(session_id).messages})
|
| 224 |
-
#print(f"--- 1. Rewritten Query ---\n{rewritten_query}\n")
|
| 225 |
-
|
| 226 |
-
hyde_doc = chain_components["hyde"].invoke({"question": rewritten_query})
|
| 227 |
-
#print(f"--- 2. HyDE Document ---\n{hyde_doc}\n")
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
#for i, doc in enumerate(final_retrieved_docs):
|
| 232 |
-
#print(f" Doc {i+1}: {doc.page_content[:150]}... (Source: {doc.metadata.get('source')})")
|
| 233 |
-
#print("\n")
|
| 234 |
-
|
| 235 |
-
final_context_docs = chain_components["parent_fetcher"].invoke(final_retrieved_docs)
|
| 236 |
-
#print(f"--- 4. Final {len(final_context_docs)} Parent Docs for LLM ---")
|
| 237 |
-
#for i, doc in enumerate(final_context_docs):
|
| 238 |
-
#print(f" Final Doc {i+1} (Source: {doc.metadata.get('source')}, Page: {doc.metadata.get('page')}):\n '{doc.page_content[:300]}...'\n---")
|
| 239 |
|
| 240 |
-
#
|
| 241 |
-
|
| 242 |
-
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
return jsonify({'answer': answer_string})
|
| 247 |
-
|
| 248 |
except Exception as e:
|
| 249 |
-
import traceback
|
| 250 |
-
|
| 251 |
-
return jsonify({'status': 'error', 'message': 'An error occurred while getting the answer.'}), 500
|
| 252 |
|
| 253 |
def clean_markdown_for_tts(text: str) -> str:
|
| 254 |
-
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
|
| 255 |
-
text = re.sub(r'
|
| 256 |
-
text = re.sub(r'
|
| 257 |
-
text = re.sub(r'^\s
|
| 258 |
-
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
|
| 259 |
-
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
| 260 |
-
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
|
| 261 |
-
text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
|
| 262 |
text = re.sub(r'\n+', ' ', text)
|
| 263 |
return text.strip()
|
| 264 |
|
| 265 |
@app.route('/tts', methods=['POST'])
|
| 266 |
def text_to_speech():
|
| 267 |
-
data = request.get_json()
|
| 268 |
-
text
|
| 269 |
-
|
| 270 |
-
if not text:
|
| 271 |
-
return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
|
| 272 |
-
|
| 273 |
try:
|
| 274 |
-
clean_text = clean_markdown_for_tts(text)
|
| 275 |
-
|
| 276 |
-
mp3_fp = io.BytesIO()
|
| 277 |
-
tts.write_to_fp(mp3_fp)
|
| 278 |
-
mp3_fp.seek(0)
|
| 279 |
return Response(mp3_fp, mimetype='audio/mpeg')
|
| 280 |
except Exception as e:
|
| 281 |
-
print(f"
|
| 282 |
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
|
| 283 |
|
| 284 |
if __name__ == '__main__':
|
| 285 |
port = int(os.environ.get("PORT", 7860))
|
| 286 |
-
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import uuid
|
| 3 |
from flask import Flask, request, render_template, session, jsonify, Response
|
| 4 |
from werkzeug.utils import secure_filename
|
|
|
|
| 9 |
import io
|
| 10 |
from gtts import gTTS
|
| 11 |
from langchain_core.documents import Document
|
| 12 |
+
from langchain_community.document_loaders import TextLoader, Docx2txtLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
| 14 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 15 |
from langchain_community.vectorstores import FAISS
|
| 16 |
+
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
|
| 17 |
+
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
| 18 |
from langchain_community.retrievers import BM25Retriever
|
| 19 |
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 20 |
from langchain.storage import InMemoryStore
|
| 21 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
| 22 |
|
| 23 |
app = Flask(__name__)
|
| 24 |
app.config['SECRET_KEY'] = os.urandom(24)
|
| 25 |
|
| 26 |
+
# --- FIX: Use STRING keys for the dictionary ---
|
| 27 |
+
# Maps temperature strings (from the form) to the mode labels
|
| 28 |
+
TEMPERATURE_LABELS = {
|
| 29 |
+
"0.2": "Precise",
|
| 30 |
+
"0.4": "Confident",
|
| 31 |
+
"0.6": "Balanced",
|
| 32 |
+
"0.8": "Flexible",
|
| 33 |
+
"1.0": "Creative"
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
class LocalReranker(BaseDocumentCompressor):
|
| 37 |
+
model: Any
|
| 38 |
+
top_n: int = 5
|
| 39 |
+
|
| 40 |
+
class Config:
|
| 41 |
+
arbitrary_types_allowed = True
|
| 42 |
+
|
| 43 |
+
def compress_documents(
|
| 44 |
+
self, documents: Sequence[Document], query: str, callbacks=None
|
| 45 |
+
) -> Sequence[Document]:
|
| 46 |
+
if not documents:
|
| 47 |
+
return []
|
| 48 |
+
pairs = [[query, doc.page_content] for doc in documents]
|
| 49 |
+
scores = self.model.predict(pairs, show_progress_bar=False)
|
| 50 |
+
doc_scores = list(zip(documents, scores))
|
| 51 |
+
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 52 |
+
top_docs = []
|
| 53 |
+
for doc, score in sorted_doc_scores[: self.top_n]:
|
| 54 |
+
doc.metadata["rerank_score"] = float(score)
|
| 55 |
+
top_docs.append(doc)
|
| 56 |
+
return top_docs
|
| 57 |
+
|
| 58 |
+
def create_optimized_parent_child_chunks(all_docs):
|
| 59 |
+
if not all_docs:
|
| 60 |
+
print("❌ CHUNKING: No input documents provided!")
|
| 61 |
+
return [], [], []
|
| 62 |
+
|
| 63 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
| 64 |
+
chunk_size=900, chunk_overlap=200, separators=["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
|
| 65 |
+
)
|
| 66 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
| 67 |
+
chunk_size=350, chunk_overlap=80, separators=["\n", ". ", "! ", "? ", "; ", ", ", " ", ""]
|
| 68 |
+
)
|
| 69 |
+
parent_docs = parent_splitter.split_documents(all_docs)
|
| 70 |
+
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 71 |
+
child_docs = []
|
| 72 |
+
|
| 73 |
+
for i, parent_doc in enumerate(parent_docs):
|
| 74 |
+
parent_id = doc_ids[i]
|
| 75 |
+
children = child_splitter.split_documents([parent_doc])
|
| 76 |
+
for j, child in enumerate(children):
|
| 77 |
+
child.metadata.update({
|
| 78 |
+
"doc_id": parent_id, "chunk_index": j, "total_chunks": len(children),
|
| 79 |
+
"is_first_chunk": j == 0, "is_last_chunk": j == len(children) - 1,
|
| 80 |
+
})
|
| 81 |
+
if len(children) > 1:
|
| 82 |
+
if j == 0: child.page_content = "[Beginning] " + child.page_content
|
| 83 |
+
elif j == len(children) - 1: child.page_content = "[Continues...] " + child.page_content
|
| 84 |
+
child_docs.append(child)
|
| 85 |
+
|
| 86 |
+
print(f"✅ CHUNKING: Created {len(parent_docs)} parent and {len(child_docs)} child chunks.")
|
| 87 |
+
return parent_docs, child_docs, doc_ids
|
| 88 |
+
|
| 89 |
+
def get_context_aware_parents(docs: List[Document], store: InMemoryStore) -> List[Document]:
|
| 90 |
+
if not docs: return []
|
| 91 |
+
parent_scores, child_content_by_parent = {}, {}
|
| 92 |
+
for doc in docs:
|
| 93 |
+
parent_id = doc.metadata.get("doc_id")
|
| 94 |
+
if parent_id:
|
| 95 |
+
parent_scores[parent_id] = parent_scores.get(parent_id, 0) + 1
|
| 96 |
+
if parent_id not in child_content_by_parent: child_content_by_parent[parent_id] = []
|
| 97 |
+
child_content_by_parent[parent_id].append(doc.page_content)
|
| 98 |
+
|
| 99 |
+
parent_ids = list(parent_scores.keys())
|
| 100 |
+
parents = store.mget(parent_ids)
|
| 101 |
+
enhanced_parents = []
|
| 102 |
+
|
| 103 |
+
for i, parent in enumerate(parents):
|
| 104 |
+
if parent is not None:
|
| 105 |
+
parent_id = parent_ids[i]
|
| 106 |
+
if parent_id in child_content_by_parent:
|
| 107 |
+
child_excerpts = "\n".join(child_content_by_parent[parent_id][:3])
|
| 108 |
+
enhanced_content = f"{parent.page_content}\n\nRelevant excerpts:\n{child_excerpts}"
|
| 109 |
+
enhanced_parent = Document(
|
| 110 |
+
page_content=enhanced_content,
|
| 111 |
+
metadata={**parent.metadata, "child_relevance_score": parent_scores[parent_id], "matching_children": len(child_content_by_parent[parent_id])}
|
| 112 |
+
)
|
| 113 |
+
enhanced_parents.append(enhanced_parent)
|
| 114 |
+
else:
|
| 115 |
+
print(f"❌ PARENT_FETCH: Parent {parent_ids[i]} not found in store!")
|
| 116 |
+
|
| 117 |
+
enhanced_parents.sort(key=lambda p: p.metadata.get("child_relevance_score", 0), reverse=True)
|
| 118 |
+
return enhanced_parents
|
| 119 |
+
|
| 120 |
is_hf_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"))
|
| 121 |
+
app.config['UPLOAD_FOLDER'] = '/tmp/uploads' if is_hf_spaces else 'uploads'
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
try:
|
| 124 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 125 |
+
print(f"📁 Upload folder ready: {app.config['UPLOAD_FOLDER']}")
|
| 126 |
except Exception as e:
|
| 127 |
+
print(f"❌ Failed to create upload folder, falling back to /tmp: {e}")
|
| 128 |
app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
|
| 129 |
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
|
|
| 130 |
|
| 131 |
+
session_data = {}
|
| 132 |
message_histories = {}
|
|
|
|
| 133 |
|
| 134 |
+
print("🔄 Loading embedding model...")
|
| 135 |
try:
|
|
|
|
| 136 |
EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
| 137 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 138 |
model_kwargs={'device': 'cpu'},
|
| 139 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 140 |
)
|
| 141 |
+
print("✅ Embedding model loaded.")
|
| 142 |
+
except Exception as e: print(f"❌ FATAL: Could not load embedding model. Error: {e}"); raise e
|
| 143 |
+
|
| 144 |
+
print("🔄 Loading reranker model...")
|
| 145 |
+
try:
|
| 146 |
+
RERANKER_MODEL = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
|
| 147 |
+
print("✅ Reranker model loaded.")
|
| 148 |
+
except Exception as e: print(f"❌ FATAL: Could not load reranker model. Error: {e}"); raise e
|
| 149 |
|
| 150 |
def load_pdf_with_fallback(filepath):
|
| 151 |
try:
|
|
|
|
| 153 |
with fitz.open(filepath) as pdf_doc:
|
| 154 |
for page_num, page in enumerate(pdf_doc):
|
| 155 |
text = page.get_text()
|
| 156 |
+
if text.strip(): docs.append(Document(page_content=text, metadata={"source": os.path.basename(filepath), "page": page_num + 1}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
if docs:
|
| 158 |
+
print(f"✅ Loaded PDF: {os.path.basename(filepath)} - {len(docs)} pages")
|
| 159 |
return docs
|
| 160 |
+
else: raise ValueError("No text content found in PDF.")
|
| 161 |
+
except Exception as e: print(f"❌ PyMuPDF failed for {filepath}: {e}"); raise
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
LOADER_MAPPING = {".txt": TextLoader, ".pdf": load_pdf_with_fallback, ".docx": Docx2txtLoader}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def get_session_history(session_id: str) -> ChatMessageHistory:
|
| 166 |
+
if session_id not in message_histories: message_histories[session_id] = ChatMessageHistory()
|
|
|
|
| 167 |
return message_histories[session_id]
|
| 168 |
|
| 169 |
@app.route('/health', methods=['GET'])
|
| 170 |
+
def health_check(): return jsonify({'status': 'healthy'}), 200
|
|
|
|
| 171 |
|
| 172 |
@app.route('/', methods=['GET'])
|
| 173 |
+
def index(): return render_template('index.html')
|
|
|
|
| 174 |
|
| 175 |
@app.route('/upload', methods=['POST'])
|
| 176 |
def upload_files():
|
| 177 |
files = request.files.getlist('file')
|
| 178 |
+
|
| 179 |
+
# Get temperature as a string for the dictionary key
|
| 180 |
+
temperature_str = request.form.get('temperature', '0.2')
|
| 181 |
+
temperature = float(temperature_str) # Convert to float for the LLM
|
| 182 |
+
model_name = request.form.get('model_name', 'moonshotai/kimi-k2-instruct')
|
| 183 |
+
print(f"⚙️ UPLOAD: Model: {model_name}, Temp: {temperature}")
|
| 184 |
+
|
| 185 |
if not files or all(f.filename == '' for f in files):
|
| 186 |
return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
|
| 187 |
|
| 188 |
+
all_docs, processed_files, failed_files = [], [], []
|
| 189 |
+
print(f"📁 Processing {len(files)} file(s)...")
|
|
|
|
| 190 |
for file in files:
|
| 191 |
if file and file.filename:
|
| 192 |
filename = secure_filename(file.filename)
|
|
|
|
| 194 |
try:
|
| 195 |
file.save(filepath)
|
| 196 |
file_ext = os.path.splitext(filename)[1].lower()
|
| 197 |
+
if file_ext not in LOADER_MAPPING: raise ValueError("Unsupported file format.")
|
|
|
|
|
|
|
| 198 |
loader_func = LOADER_MAPPING[file_ext]
|
| 199 |
docs = loader_func(filepath) if file_ext == ".pdf" else loader_func(filepath).load()
|
| 200 |
+
if not docs: raise ValueError("No content extracted.")
|
|
|
|
|
|
|
|
|
|
| 201 |
all_docs.extend(docs)
|
| 202 |
processed_files.append(filename)
|
|
|
|
| 203 |
except Exception as e:
|
| 204 |
+
print(f"✗ Error processing {filename}: {e}")
|
| 205 |
+
failed_files.append(f"{filename} ({e})")
|
|
|
|
| 206 |
|
| 207 |
if not all_docs:
|
| 208 |
+
return jsonify({'status': 'error', 'message': f"Failed to process all files. Reasons: {', '.join(failed_files)}"}), 400
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
+
print(f"✅ UPLOAD: Processed {len(processed_files)} files.")
|
| 211 |
try:
|
| 212 |
+
print("🔄 Starting RAG pipeline setup...")
|
| 213 |
+
parent_docs, child_docs, doc_ids = create_optimized_parent_child_chunks(all_docs)
|
| 214 |
+
if not child_docs: raise ValueError("No child documents created during chunking.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
vectorstore = FAISS.from_documents(child_docs, EMBEDDING_MODEL)
|
| 217 |
+
store = InMemoryStore(); store.mset(list(zip(doc_ids, parent_docs)))
|
| 218 |
+
print(f"✅ Indexed {len(child_docs)} document chunks.")
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
+
bm25_retriever = BM25Retriever.from_documents(child_docs); bm25_retriever.k = 12
|
| 221 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 12})
|
| 222 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4])
|
| 223 |
+
reranker = LocalReranker(model=RERANKER_MODEL, top_n=5)
|
| 224 |
+
def get_parents(docs: List[Document]) -> List[Document]: return get_context_aware_parents(docs, store)
|
| 225 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
|
| 226 |
+
final_retriever = compression_retriever | get_parents
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
session_id = str(uuid.uuid4())
|
| 229 |
+
rag_chain, api_key_manager = create_rag_chain(
|
| 230 |
+
retriever=final_retriever, get_session_history_func=get_session_history,
|
| 231 |
+
model_name=model_name, temperature=temperature
|
| 232 |
+
)
|
| 233 |
+
# Store the float temperature
|
| 234 |
+
session_data[session_id] = {'chain': rag_chain, 'model_name': model_name, 'temperature': temperature, 'api_key_manager': api_key_manager}
|
| 235 |
|
| 236 |
+
success_msg = f"Processed: {', '.join(processed_files)}"
|
| 237 |
+
if failed_files: success_msg += f". Failed: {', '.join(failed_files)}"
|
| 238 |
|
| 239 |
+
# Get the mode label using the STRING key
|
| 240 |
+
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
|
| 241 |
|
| 242 |
+
print(f"✅ UPLOAD COMPLETE: Session {session_id} is ready.")
|
| 243 |
+
# Return all info needed by the frontend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
return jsonify({
|
| 245 |
'status': 'success',
|
| 246 |
+
'filename': success_msg,
|
| 247 |
+
'session_id': session_id,
|
| 248 |
+
'model_name': model_name,
|
| 249 |
+
'mode': mode_label
|
| 250 |
})
|
|
|
|
| 251 |
except Exception as e:
|
| 252 |
+
import traceback; traceback.print_exc()
|
| 253 |
+
return jsonify({'status': 'error', 'message': f'RAG setup failed: {e}'}), 500
|
|
|
|
| 254 |
|
| 255 |
@app.route('/chat', methods=['POST'])
|
| 256 |
def chat():
|
| 257 |
data = request.get_json()
|
| 258 |
+
question, session_id = data.get('question'), data.get('session_id') or session.get('session_id')
|
|
|
|
| 259 |
|
| 260 |
+
if not question: return jsonify({'status': 'error', 'message': 'No question provided.'}), 400
|
| 261 |
+
if not session_id or session_id not in session_data:
|
| 262 |
+
print(f"❌ CHAT: Invalid session {session_id}.")
|
| 263 |
+
return jsonify({'status': 'error', 'message': 'Invalid session. Please upload documents first.'}), 400
|
| 264 |
|
| 265 |
try:
|
| 266 |
+
session_info = session_data[session_id]
|
| 267 |
+
rag_chain = session_info['chain']
|
| 268 |
|
| 269 |
+
# --- START: BUGFIX & FEATURE UPDATE ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
# 1. Get model name from session
|
| 272 |
+
model_name = session_info['model_name']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
+
# 2. Get temperature (float) and convert to string for lookup
|
| 275 |
+
temperature_float = session_info['temperature']
|
| 276 |
+
temperature_str = str(temperature_float)
|
| 277 |
|
| 278 |
+
# 3. Get the correct mode label
|
| 279 |
+
mode_label = TEMPERATURE_LABELS.get(temperature_str, temperature_str)
|
| 280 |
+
|
| 281 |
+
# --- END: BUGFIX & FEATURE UPDATE ---
|
| 282 |
+
|
| 283 |
+
print(f"💬 CHAT: Invoking chain for session {session_id}...")
|
| 284 |
+
answer = rag_chain.invoke({"question": question}, config={"configurable": {"session_id": session_id}})
|
| 285 |
+
print(f"✅ CHAT: Answer generated.")
|
| 286 |
+
|
| 287 |
+
# Return all info needed by the frontend
|
| 288 |
+
return jsonify({
|
| 289 |
+
'answer': answer,
|
| 290 |
+
'model_name': model_name,
|
| 291 |
+
'mode': mode_label
|
| 292 |
+
})
|
| 293 |
|
|
|
|
|
|
|
| 294 |
except Exception as e:
|
| 295 |
+
import traceback; traceback.print_exc()
|
| 296 |
+
return jsonify({'status': 'error', 'message': f'Error during chat: {e}'}), 500
|
|
|
|
| 297 |
|
| 298 |
def clean_markdown_for_tts(text: str) -> str:
|
| 299 |
+
text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text); text = re.sub(r'\_(.*?)\_', r'\1', text)
|
| 300 |
+
text = re.sub(r'`(.*?)`', r'\1', text); text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
|
| 301 |
+
text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
|
| 302 |
+
text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE); text = re.sub(r'^\s*[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
text = re.sub(r'\n+', ' ', text)
|
| 304 |
return text.strip()
|
| 305 |
|
| 306 |
@app.route('/tts', methods=['POST'])
|
| 307 |
def text_to_speech():
|
| 308 |
+
data = request.get_json(); text = data.get('text')
|
| 309 |
+
if not text: return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
try:
|
| 311 |
+
clean_text = clean_markdown_for_tts(text); tts = gTTS(clean_text, lang='en')
|
| 312 |
+
mp3_fp = io.BytesIO(); tts.write_to_fp(mp3_fp); mp3_fp.seek(0)
|
|
|
|
|
|
|
|
|
|
| 313 |
return Response(mp3_fp, mimetype='audio/mpeg')
|
| 314 |
except Exception as e:
|
| 315 |
+
print(f"❌ TTS Error: {e}")
|
| 316 |
return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
|
| 317 |
|
| 318 |
if __name__ == '__main__':
|
| 319 |
port = int(os.environ.get("PORT", 7860))
|
| 320 |
+
print(f"🚀 Starting Flask app on port {port}")
|
| 321 |
+
app.run(host="0.0.0.0", port=port, debug=False, threaded=False)
|
evaluate.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import uuid
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from datasets import Dataset
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Sequence, Any, List
|
| 8 |
+
|
| 9 |
+
# Ragas and LangChain components
|
| 10 |
+
from ragas import evaluate
|
| 11 |
+
from ragas.metrics import (
|
| 12 |
+
faithfulness,
|
| 13 |
+
answer_relevancy,
|
| 14 |
+
context_recall,
|
| 15 |
+
context_precision,
|
| 16 |
+
)
|
| 17 |
+
from ragas.testset import TestsetGenerator
|
| 18 |
+
# NOTE: The 'evolutions' import has been completely removed.
|
| 19 |
+
|
| 20 |
+
# Your specific RAG components from app.py
|
| 21 |
+
from langchain_groq import ChatGroq
|
| 22 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 23 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 24 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 25 |
+
from langchain_community.vectorstores import FAISS
|
| 26 |
+
from langchain.storage import InMemoryStore
|
| 27 |
+
from langchain_community.retrievers import BM25Retriever
|
| 28 |
+
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
|
| 29 |
+
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
|
| 30 |
+
from langchain_core.documents import Document
|
| 31 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
| 32 |
+
from rag_processor import create_rag_chain
|
| 33 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 34 |
+
import fitz
|
| 35 |
+
|
| 36 |
+
# Load environment variables
|
| 37 |
+
load_dotenv()
|
| 38 |
+
|
| 39 |
+
# --- Re-implementing LocalReranker from app.py ---
|
| 40 |
+
class LocalReranker(BaseDocumentCompressor):
|
| 41 |
+
model: Any
|
| 42 |
+
top_n: int = 3
|
| 43 |
+
class Config:
|
| 44 |
+
arbitrary_types_allowed = True
|
| 45 |
+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks=None) -> Sequence[Document]:
|
| 46 |
+
if not documents: return []
|
| 47 |
+
pairs = [[query, doc.page_content] for doc in documents]
|
| 48 |
+
scores = self.model.predict(pairs, show_progress_bar=False)
|
| 49 |
+
doc_scores = list(zip(documents, scores))
|
| 50 |
+
sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True)
|
| 51 |
+
top_docs = []
|
| 52 |
+
for doc, score in sorted_doc_scores[:self.top_n]:
|
| 53 |
+
doc.metadata['rerank_score'] = float(score)
|
| 54 |
+
top_docs.append(doc)
|
| 55 |
+
return top_docs
|
| 56 |
+
|
| 57 |
+
# --- Helper Functions ---
|
| 58 |
+
def load_pdf_with_fallback(filepath):
|
| 59 |
+
"""Load PDF using PyMuPDF"""
|
| 60 |
+
try:
|
| 61 |
+
docs = []
|
| 62 |
+
with fitz.open(filepath) as pdf_doc:
|
| 63 |
+
for page_num, page in enumerate(pdf_doc):
|
| 64 |
+
text = page.get_text()
|
| 65 |
+
if text.strip():
|
| 66 |
+
docs.append(Document(
|
| 67 |
+
page_content=text,
|
| 68 |
+
metadata={"source": os.path.basename(filepath), "page": page_num + 1}
|
| 69 |
+
))
|
| 70 |
+
if docs:
|
| 71 |
+
print(f"✓ Successfully loaded PDF: {filepath}")
|
| 72 |
+
return docs
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("No text content found in PDF.")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"✗ PyMuPDF failed for {filepath}: {e}")
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
async def main():
|
| 80 |
+
"""Main execution function"""
|
| 81 |
+
print("\n" + "="*60 + "\nSTARTING RAGAS EVALUATION\n" + "="*60)
|
| 82 |
+
|
| 83 |
+
pdf_path = "uploads/Unit_-_1_Introduction.pdf"
|
| 84 |
+
if not os.path.exists(pdf_path):
|
| 85 |
+
print(f"✗ Error: PDF not found at {pdf_path}")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# --- 1. Setup Models ---
|
| 90 |
+
print("\n--- 1. Initializing Models ---")
|
| 91 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 92 |
+
if not groq_api_key or groq_api_key == "your_groq_api_key_here":
|
| 93 |
+
raise ValueError("GROQ_API_KEY not found or is a placeholder.")
|
| 94 |
+
|
| 95 |
+
generator_llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_api_key)
|
| 96 |
+
critic_llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
|
| 97 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 98 |
+
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device='cpu')
|
| 99 |
+
print("✓ Models initialized.")
|
| 100 |
+
|
| 101 |
+
# --- 2. Setup RAG Pipeline ---
|
| 102 |
+
print("\n--- 2. Setting up RAG Pipeline ---")
|
| 103 |
+
documents = load_pdf_with_fallback(pdf_path)
|
| 104 |
+
|
| 105 |
+
# Split documents
|
| 106 |
+
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=400)
|
| 107 |
+
child_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=50)
|
| 108 |
+
parent_docs = parent_splitter.split_documents(documents)
|
| 109 |
+
doc_ids = [str(uuid.uuid4()) for _ in parent_docs]
|
| 110 |
+
|
| 111 |
+
child_docs = []
|
| 112 |
+
for i, doc in enumerate(parent_docs):
|
| 113 |
+
_id = doc_ids[i]
|
| 114 |
+
sub_docs = child_splitter.split_documents([doc])
|
| 115 |
+
for child in sub_docs:
|
| 116 |
+
child.metadata["doc_id"] = _id
|
| 117 |
+
child_docs.extend(sub_docs)
|
| 118 |
+
|
| 119 |
+
store = InMemoryStore()
|
| 120 |
+
store.mset(list(zip(doc_ids, parent_docs)))
|
| 121 |
+
vectorstore = FAISS.from_documents(child_docs, embedding_model)
|
| 122 |
+
|
| 123 |
+
bm25_retriever = BM25Retriever.from_documents(child_docs, k=10)
|
| 124 |
+
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
|
| 125 |
+
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.4, 0.6])
|
| 126 |
+
|
| 127 |
+
reranker = LocalReranker(model=reranker_model, top_n=5)
|
| 128 |
+
compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=ensemble_retriever)
|
| 129 |
+
|
| 130 |
+
def get_parents(docs: List[Document]) -> List[Document]:
|
| 131 |
+
parent_ids = {d.metadata["doc_id"] for d in docs}
|
| 132 |
+
return store.mget(list(parent_ids))
|
| 133 |
+
|
| 134 |
+
final_retriever = compression_retriever | get_parents
|
| 135 |
+
|
| 136 |
+
message_histories = {}
|
| 137 |
+
def get_session_history(session_id: str):
|
| 138 |
+
if session_id not in message_histories:
|
| 139 |
+
message_histories[session_id] = ChatMessageHistory()
|
| 140 |
+
return message_histories[session_id]
|
| 141 |
+
|
| 142 |
+
rag_chain = create_rag_chain(final_retriever, get_session_history)
|
| 143 |
+
print("✓ RAG chain created successfully.")
|
| 144 |
+
|
| 145 |
+
# --- 3. Generate Testset ---
|
| 146 |
+
print("\n--- 3. Generating Test Questions ---")
|
| 147 |
+
generator = TestsetGenerator.from_langchain(generator_llm, critic_llm, embedding_model)
|
| 148 |
+
|
| 149 |
+
# Generate a simple test set without complex distributions
|
| 150 |
+
testset = generator.generate_with_langchain_docs(documents, testset_size=5)
|
| 151 |
+
print("✓ Testset generated.")
|
| 152 |
+
|
| 153 |
+
# --- 4. Run RAG Chain on Testset ---
|
| 154 |
+
print("\n--- 4. Running RAG Chain to Generate Answers ---")
|
| 155 |
+
test_questions = [item['question'] for item in testset.to_pandas().to_dict('records')]
|
| 156 |
+
ground_truths = [item['ground_truth'] for item in testset.to_pandas().to_dict('records')]
|
| 157 |
+
|
| 158 |
+
answers = []
|
| 159 |
+
contexts = []
|
| 160 |
+
|
| 161 |
+
for i, question in enumerate(test_questions):
|
| 162 |
+
print(f" Processing question {i+1}/{len(test_questions)}...")
|
| 163 |
+
# Retrieve contexts
|
| 164 |
+
retrieved_docs = final_retriever.invoke(question)
|
| 165 |
+
contexts.append([doc.page_content for doc in retrieved_docs])
|
| 166 |
+
# Get answer from chain
|
| 167 |
+
config = {"configurable": {"session_id": str(uuid.uuid4())}}
|
| 168 |
+
answer = await rag_chain.ainvoke({"question": question}, config=config)
|
| 169 |
+
answers.append(answer)
|
| 170 |
+
|
| 171 |
+
# --- 5. Evaluate with Ragas ---
|
| 172 |
+
print("\n--- 5. Evaluating Results with Ragas ---")
|
| 173 |
+
eval_data = {
|
| 174 |
+
'question': test_questions,
|
| 175 |
+
'answer': answers,
|
| 176 |
+
'contexts': contexts,
|
| 177 |
+
'ground_truth': ground_truths
|
| 178 |
+
}
|
| 179 |
+
eval_dataset = Dataset.from_dict(eval_data)
|
| 180 |
+
|
| 181 |
+
result = evaluate(
|
| 182 |
+
eval_dataset,
|
| 183 |
+
metrics=[faithfulness, answer_relevancy, context_precision, context_recall],
|
| 184 |
+
llm=critic_llm,
|
| 185 |
+
embeddings=embedding_model
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
|
| 189 |
+
print(result)
|
| 190 |
+
|
| 191 |
+
# --- 6. Save Results ---
|
| 192 |
+
print("\n--- 6. Saving Results ---")
|
| 193 |
+
results_df = result.to_pandas()
|
| 194 |
+
results_df.to_csv("evaluation_results.csv", index=False)
|
| 195 |
+
print("✓ Evaluation results saved to evaluation_results.csv")
|
| 196 |
+
|
| 197 |
+
print("\n" + "="*60 + "\nEVALUATION COMPLETE!\n" + "="*60)
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"\n✗ An error occurred during the process: {e}")
|
| 201 |
+
import traceback
|
| 202 |
+
traceback.print_exc()
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
asyncio.run(main())
|
query_expansion.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/query_expansion.py
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Query Expansion System for CogniChat RAG Application
|
| 5 |
+
|
| 6 |
+
This module implements advanced query expansion techniques to improve retrieval quality:
|
| 7 |
+
- QueryAnalyzer: Extracts intent, entities, and keywords
|
| 8 |
+
- QueryRephraser: Generates natural language variations
|
| 9 |
+
- MultiQueryExpander: Creates diverse query formulations
|
| 10 |
+
- MultiHopReasoner: Connects concepts across documents
|
| 11 |
+
- FallbackStrategies: Handles edge cases gracefully
|
| 12 |
+
|
| 13 |
+
Author: CogniChat Team
|
| 14 |
+
Date: October 19, 2025
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
from typing import List, Dict, Any, Optional
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from enum import Enum
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class QueryStrategy(Enum):
|
| 24 |
+
"""Query expansion strategies with different complexity levels."""
|
| 25 |
+
QUICK = "quick" # 2 queries - fast, minimal expansion
|
| 26 |
+
BALANCED = "balanced" # 3-4 queries - good balance
|
| 27 |
+
COMPREHENSIVE = "comprehensive" # 5-6 queries - maximum coverage
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class QueryAnalysis:
|
| 32 |
+
"""Results from query analysis."""
|
| 33 |
+
intent: str # question, definition, comparison, explanation, etc.
|
| 34 |
+
entities: List[str] # Named entities extracted
|
| 35 |
+
keywords: List[str] # Important keywords
|
| 36 |
+
complexity: str # simple, medium, complex
|
| 37 |
+
domain: Optional[str] = None # Technical domain if detected
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ExpandedQuery:
|
| 42 |
+
"""Container for expanded query variations."""
|
| 43 |
+
original: str
|
| 44 |
+
variations: List[str]
|
| 45 |
+
strategy_used: QueryStrategy
|
| 46 |
+
analysis: QueryAnalysis
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class QueryAnalyzer:
|
| 50 |
+
"""
|
| 51 |
+
Analyzes queries to extract intent, entities, and key information.
|
| 52 |
+
Uses LLM-based analysis for intelligent query understanding.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, llm=None):
|
| 56 |
+
"""
|
| 57 |
+
Initialize QueryAnalyzer.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
llm: Optional LangChain LLM for advanced analysis
|
| 61 |
+
"""
|
| 62 |
+
self.llm = llm
|
| 63 |
+
self.intent_patterns = {
|
| 64 |
+
'definition': r'\b(what is|define|meaning of|definition)\b',
|
| 65 |
+
'how_to': r'\b(how to|how do|how can|steps to)\b',
|
| 66 |
+
'comparison': r'\b(compare|difference|versus|vs|better than)\b',
|
| 67 |
+
'explanation': r'\b(why|explain|reason|cause)\b',
|
| 68 |
+
'listing': r'\b(list|enumerate|what are|types of)\b',
|
| 69 |
+
'example': r'\b(example|instance|sample|case)\b',
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def analyze(self, query: str) -> QueryAnalysis:
|
| 73 |
+
"""
|
| 74 |
+
Analyze query to extract intent, entities, and keywords.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
query: User's original query
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
QueryAnalysis object with extracted information
|
| 81 |
+
"""
|
| 82 |
+
query_lower = query.lower()
|
| 83 |
+
|
| 84 |
+
# Detect intent
|
| 85 |
+
intent = self._detect_intent(query_lower)
|
| 86 |
+
|
| 87 |
+
# Extract entities (simplified - can be enhanced with NER)
|
| 88 |
+
entities = self._extract_entities(query)
|
| 89 |
+
|
| 90 |
+
# Extract keywords
|
| 91 |
+
keywords = self._extract_keywords(query)
|
| 92 |
+
|
| 93 |
+
# Assess complexity
|
| 94 |
+
complexity = self._assess_complexity(query, entities, keywords)
|
| 95 |
+
|
| 96 |
+
# Detect domain
|
| 97 |
+
domain = self._detect_domain(query_lower)
|
| 98 |
+
|
| 99 |
+
return QueryAnalysis(
|
| 100 |
+
intent=intent,
|
| 101 |
+
entities=entities,
|
| 102 |
+
keywords=keywords,
|
| 103 |
+
complexity=complexity,
|
| 104 |
+
domain=domain
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _detect_intent(self, query_lower: str) -> str:
|
| 108 |
+
"""Detect query intent using pattern matching."""
|
| 109 |
+
for intent, pattern in self.intent_patterns.items():
|
| 110 |
+
if re.search(pattern, query_lower):
|
| 111 |
+
return intent
|
| 112 |
+
return 'general'
|
| 113 |
+
|
| 114 |
+
def _extract_entities(self, query: str) -> List[str]:
|
| 115 |
+
"""Extract named entities (simplified version)."""
|
| 116 |
+
# Look for capitalized words (potential entities)
|
| 117 |
+
words = query.split()
|
| 118 |
+
entities = []
|
| 119 |
+
|
| 120 |
+
for word in words:
|
| 121 |
+
# Skip common words at sentence start
|
| 122 |
+
if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']:
|
| 123 |
+
entities.append(word)
|
| 124 |
+
|
| 125 |
+
# Look for quoted terms
|
| 126 |
+
quoted = re.findall(r'"([^"]+)"', query)
|
| 127 |
+
entities.extend(quoted)
|
| 128 |
+
|
| 129 |
+
return list(set(entities))
|
| 130 |
+
|
| 131 |
+
def _extract_keywords(self, query: str) -> List[str]:
|
| 132 |
+
"""Extract important keywords from query."""
|
| 133 |
+
# Remove stop words (simplified list)
|
| 134 |
+
stop_words = {
|
| 135 |
+
'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
|
| 136 |
+
'what', 'how', 'why', 'when', 'where', 'which', 'who',
|
| 137 |
+
'do', 'does', 'did', 'can', 'could', 'should', 'would',
|
| 138 |
+
'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# Split and filter
|
| 142 |
+
words = re.findall(r'\b\w+\b', query.lower())
|
| 143 |
+
keywords = [w for w in words if w not in stop_words and len(w) > 2]
|
| 144 |
+
|
| 145 |
+
return keywords[:10] # Limit to top 10
|
| 146 |
+
|
| 147 |
+
def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str:
|
| 148 |
+
"""Assess query complexity."""
|
| 149 |
+
word_count = len(query.split())
|
| 150 |
+
entity_count = len(entities)
|
| 151 |
+
keyword_count = len(keywords)
|
| 152 |
+
|
| 153 |
+
# Simple scoring
|
| 154 |
+
score = word_count + (entity_count * 2) + (keyword_count * 1.5)
|
| 155 |
+
|
| 156 |
+
if score < 15:
|
| 157 |
+
return 'simple'
|
| 158 |
+
elif score < 30:
|
| 159 |
+
return 'medium'
|
| 160 |
+
else:
|
| 161 |
+
return 'complex'
|
| 162 |
+
|
| 163 |
+
def _detect_domain(self, query_lower: str) -> Optional[str]:
|
| 164 |
+
"""Detect technical domain if present."""
|
| 165 |
+
domains = {
|
| 166 |
+
'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'],
|
| 167 |
+
'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'],
|
| 168 |
+
'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'],
|
| 169 |
+
'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'],
|
| 170 |
+
'database': ['sql', 'query', 'database', 'table', 'index'],
|
| 171 |
+
'security': ['encryption', 'authentication', 'vulnerability', 'attack'],
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
for domain, keywords in domains.items():
|
| 175 |
+
if any(kw in query_lower for kw in keywords):
|
| 176 |
+
return domain
|
| 177 |
+
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class QueryRephraser:
|
| 182 |
+
"""
|
| 183 |
+
Generates natural language variations of queries using multiple strategies.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, llm=None):
|
| 187 |
+
"""
|
| 188 |
+
Initialize QueryRephraser.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
llm: LangChain LLM for generating variations
|
| 192 |
+
"""
|
| 193 |
+
self.llm = llm
|
| 194 |
+
|
| 195 |
+
def generate_variations(
|
| 196 |
+
self,
|
| 197 |
+
query: str,
|
| 198 |
+
analysis: QueryAnalysis,
|
| 199 |
+
strategy: QueryStrategy = QueryStrategy.BALANCED
|
| 200 |
+
) -> List[str]:
|
| 201 |
+
"""
|
| 202 |
+
Generate query variations based on strategy.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
query: Original query
|
| 206 |
+
analysis: Query analysis results
|
| 207 |
+
strategy: Expansion strategy to use
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
List of query variations
|
| 211 |
+
"""
|
| 212 |
+
variations = [query] # Always include original
|
| 213 |
+
|
| 214 |
+
if strategy == QueryStrategy.QUICK:
|
| 215 |
+
# Just add synonym variation
|
| 216 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 217 |
+
|
| 218 |
+
elif strategy == QueryStrategy.BALANCED:
|
| 219 |
+
# Add synonym, expanded, and simplified versions
|
| 220 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 221 |
+
variations.append(self._expanded_variation(query, analysis))
|
| 222 |
+
variations.append(self._simplified_variation(query, analysis))
|
| 223 |
+
|
| 224 |
+
elif strategy == QueryStrategy.COMPREHENSIVE:
|
| 225 |
+
# Add all variations
|
| 226 |
+
variations.append(self._synonym_variation(query, analysis))
|
| 227 |
+
variations.append(self._expanded_variation(query, analysis))
|
| 228 |
+
variations.append(self._simplified_variation(query, analysis))
|
| 229 |
+
variations.append(self._keyword_focused(query, analysis))
|
| 230 |
+
variations.append(self._context_variation(query, analysis))
|
| 231 |
+
# Add one more: alternate phrasing
|
| 232 |
+
if analysis.intent in ['how_to', 'explanation']:
|
| 233 |
+
variations.append(f"Guide to {' '.join(analysis.keywords[:3])}")
|
| 234 |
+
|
| 235 |
+
# Remove duplicates and None values
|
| 236 |
+
variations = [v for v in variations if v]
|
| 237 |
+
return list(dict.fromkeys(variations)) # Preserve order, remove dupes
|
| 238 |
+
|
| 239 |
+
def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 240 |
+
"""Generate variation using synonyms."""
|
| 241 |
+
# Common synonym replacements
|
| 242 |
+
synonyms = {
|
| 243 |
+
'error': 'issue',
|
| 244 |
+
'problem': 'issue',
|
| 245 |
+
'fix': 'resolve',
|
| 246 |
+
'use': 'utilize',
|
| 247 |
+
'create': 'generate',
|
| 248 |
+
'make': 'create',
|
| 249 |
+
'get': 'retrieve',
|
| 250 |
+
'show': 'display',
|
| 251 |
+
'find': 'locate',
|
| 252 |
+
'explain': 'describe',
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
words = query.lower().split()
|
| 256 |
+
for i, word in enumerate(words):
|
| 257 |
+
if word in synonyms:
|
| 258 |
+
words[i] = synonyms[word]
|
| 259 |
+
break # Only replace one word to keep natural
|
| 260 |
+
|
| 261 |
+
return ' '.join(words).capitalize()
|
| 262 |
+
|
| 263 |
+
def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 264 |
+
"""Generate expanded version with more detail."""
|
| 265 |
+
if analysis.intent == 'definition':
|
| 266 |
+
return f"Detailed explanation and definition of {' '.join(analysis.keywords)}"
|
| 267 |
+
elif analysis.intent == 'how_to':
|
| 268 |
+
return f"Step-by-step guide on {query.lower()}"
|
| 269 |
+
elif analysis.intent == 'comparison':
|
| 270 |
+
return f"Comprehensive comparison: {query}"
|
| 271 |
+
else:
|
| 272 |
+
# Add qualifying words
|
| 273 |
+
return f"Detailed information about {query.lower()}"
|
| 274 |
+
|
| 275 |
+
def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 276 |
+
"""Generate simplified version focusing on core concepts."""
|
| 277 |
+
# Use just the keywords
|
| 278 |
+
if len(analysis.keywords) >= 2:
|
| 279 |
+
return ' '.join(analysis.keywords[:3])
|
| 280 |
+
return query
|
| 281 |
+
|
| 282 |
+
def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str:
|
| 283 |
+
"""Create keyword-focused variation for BM25."""
|
| 284 |
+
keywords = analysis.keywords + analysis.entities
|
| 285 |
+
return ' '.join(keywords[:5])
|
| 286 |
+
|
| 287 |
+
def _context_variation(self, query: str, analysis: QueryAnalysis) -> str:
|
| 288 |
+
"""Add contextual information if domain detected."""
|
| 289 |
+
if analysis.domain:
|
| 290 |
+
return f"{query} in {analysis.domain} context"
|
| 291 |
+
return query
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MultiQueryExpander:
|
| 295 |
+
"""
|
| 296 |
+
Main query expansion orchestrator that combines analysis and rephrasing.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(self, llm=None):
|
| 300 |
+
"""
|
| 301 |
+
Initialize MultiQueryExpander.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
llm: LangChain LLM for advanced expansions
|
| 305 |
+
"""
|
| 306 |
+
self.analyzer = QueryAnalyzer(llm)
|
| 307 |
+
self.rephraser = QueryRephraser(llm)
|
| 308 |
+
|
| 309 |
+
def expand(
|
| 310 |
+
self,
|
| 311 |
+
query: str,
|
| 312 |
+
strategy: QueryStrategy = QueryStrategy.BALANCED,
|
| 313 |
+
max_queries: int = 6
|
| 314 |
+
) -> ExpandedQuery:
|
| 315 |
+
"""
|
| 316 |
+
Expand query into multiple variations.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
query: Original user query
|
| 320 |
+
strategy: Expansion strategy
|
| 321 |
+
max_queries: Maximum number of queries to generate
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
ExpandedQuery object with all variations
|
| 325 |
+
"""
|
| 326 |
+
# Analyze query
|
| 327 |
+
analysis = self.analyzer.analyze(query)
|
| 328 |
+
|
| 329 |
+
# Generate variations
|
| 330 |
+
variations = self.rephraser.generate_variations(query, analysis, strategy)
|
| 331 |
+
|
| 332 |
+
# Limit to max_queries
|
| 333 |
+
variations = variations[:max_queries]
|
| 334 |
+
|
| 335 |
+
return ExpandedQuery(
|
| 336 |
+
original=query,
|
| 337 |
+
variations=variations,
|
| 338 |
+
strategy_used=strategy,
|
| 339 |
+
analysis=analysis
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class MultiHopReasoner:
|
| 344 |
+
"""
|
| 345 |
+
Implements multi-hop reasoning to connect concepts across documents.
|
| 346 |
+
Useful for complex queries that require information from multiple sources.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
def __init__(self, llm=None):
|
| 350 |
+
"""
|
| 351 |
+
Initialize MultiHopReasoner.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
llm: LangChain LLM for reasoning
|
| 355 |
+
"""
|
| 356 |
+
self.llm = llm
|
| 357 |
+
|
| 358 |
+
def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]:
|
| 359 |
+
"""
|
| 360 |
+
Break complex query into sub-queries for multi-hop reasoning.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
query: Original complex query
|
| 364 |
+
analysis: Query analysis
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
List of sub-queries
|
| 368 |
+
"""
|
| 369 |
+
sub_queries = [query]
|
| 370 |
+
|
| 371 |
+
# For comparison queries, create separate queries for each entity
|
| 372 |
+
if analysis.intent == 'comparison' and len(analysis.entities) >= 2:
|
| 373 |
+
for entity in analysis.entities[:2]:
|
| 374 |
+
sub_queries.append(f"Information about {entity}")
|
| 375 |
+
elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2:
|
| 376 |
+
# Fallback: use keywords if no entities found
|
| 377 |
+
for keyword in analysis.keywords[:2]:
|
| 378 |
+
sub_queries.append(f"Information about {keyword}")
|
| 379 |
+
|
| 380 |
+
# For how-to queries, break into steps
|
| 381 |
+
if analysis.intent == 'how_to' and len(analysis.keywords) >= 2:
|
| 382 |
+
main_topic = ' '.join(analysis.keywords[:2])
|
| 383 |
+
sub_queries.append(f"Prerequisites for {main_topic}")
|
| 384 |
+
sub_queries.append(f"Steps to {main_topic}")
|
| 385 |
+
|
| 386 |
+
# For complex questions, create focused sub-queries
|
| 387 |
+
if analysis.complexity == 'complex' and len(analysis.keywords) > 3:
|
| 388 |
+
# Create queries focusing on different keyword groups
|
| 389 |
+
mid = len(analysis.keywords) // 2
|
| 390 |
+
sub_queries.append(' '.join(analysis.keywords[:mid]))
|
| 391 |
+
sub_queries.append(' '.join(analysis.keywords[mid:]))
|
| 392 |
+
|
| 393 |
+
return sub_queries[:5] # Limit to 5 sub-queries
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class FallbackStrategies:
|
| 397 |
+
"""
|
| 398 |
+
Implements fallback strategies for queries that don't retrieve good results.
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
@staticmethod
|
| 402 |
+
def simplify_query(query: str) -> str:
|
| 403 |
+
"""Simplify query by removing modifiers and focusing on core terms."""
|
| 404 |
+
# Remove question words
|
| 405 |
+
query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE)
|
| 406 |
+
|
| 407 |
+
# Remove common phrases
|
| 408 |
+
query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE)
|
| 409 |
+
|
| 410 |
+
# Clean up extra spaces
|
| 411 |
+
query = re.sub(r'\s+', ' ', query).strip()
|
| 412 |
+
|
| 413 |
+
return query
|
| 414 |
+
|
| 415 |
+
@staticmethod
|
| 416 |
+
def broaden_query(query: str, analysis: QueryAnalysis) -> str:
|
| 417 |
+
"""Broaden query to increase recall."""
|
| 418 |
+
# Remove specific constraints
|
| 419 |
+
query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE)
|
| 420 |
+
|
| 421 |
+
# Add general terms
|
| 422 |
+
if analysis.keywords:
|
| 423 |
+
return f"{analysis.keywords[0]} overview"
|
| 424 |
+
|
| 425 |
+
return query
|
| 426 |
+
|
| 427 |
+
@staticmethod
|
| 428 |
+
def focus_entities(analysis: QueryAnalysis) -> str:
|
| 429 |
+
"""Create entity-focused query as fallback."""
|
| 430 |
+
if analysis.entities:
|
| 431 |
+
return ' '.join(analysis.entities)
|
| 432 |
+
elif analysis.keywords:
|
| 433 |
+
return ' '.join(analysis.keywords[:3])
|
| 434 |
+
return ""
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# Convenience function for easy integration
|
| 438 |
+
def expand_query_simple(
|
| 439 |
+
query: str,
|
| 440 |
+
strategy: str = "balanced",
|
| 441 |
+
llm=None
|
| 442 |
+
) -> List[str]:
|
| 443 |
+
"""
|
| 444 |
+
Simple function to expand a query without dealing with classes.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
query: User's query to expand
|
| 448 |
+
strategy: "quick", "balanced", or "comprehensive"
|
| 449 |
+
llm: Optional LangChain LLM
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
List of expanded query variations
|
| 453 |
+
|
| 454 |
+
Example:
|
| 455 |
+
>>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced")
|
| 456 |
+
>>> print(queries)
|
| 457 |
+
['How do I debug Python code?', 'How do I resolve Python code?', ...]
|
| 458 |
+
"""
|
| 459 |
+
expander = MultiQueryExpander(llm=llm)
|
| 460 |
+
strategy_enum = QueryStrategy(strategy)
|
| 461 |
+
expanded = expander.expand(query, strategy=strategy_enum)
|
| 462 |
+
return expanded.variations
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# Example usage and testing
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
# Example 1: Simple query expansion
|
| 468 |
+
print("=" * 60)
|
| 469 |
+
print("Example 1: Simple Query Expansion")
|
| 470 |
+
print("=" * 60)
|
| 471 |
+
|
| 472 |
+
query = "What is machine learning?"
|
| 473 |
+
queries = expand_query_simple(query, strategy="balanced")
|
| 474 |
+
|
| 475 |
+
print(f"\nOriginal: {query}")
|
| 476 |
+
print(f"\nExpanded queries ({len(queries)}):")
|
| 477 |
+
for i, q in enumerate(queries, 1):
|
| 478 |
+
print(f" {i}. {q}")
|
| 479 |
+
|
| 480 |
+
# Example 2: Complex query with full analysis
|
| 481 |
+
print("\n" + "=" * 60)
|
| 482 |
+
print("Example 2: Complex Query with Analysis")
|
| 483 |
+
print("=" * 60)
|
| 484 |
+
|
| 485 |
+
expander = MultiQueryExpander()
|
| 486 |
+
query = "How do I compare the performance of different neural network architectures?"
|
| 487 |
+
result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE)
|
| 488 |
+
|
| 489 |
+
print(f"\nOriginal: {result.original}")
|
| 490 |
+
print(f"\nAnalysis:")
|
| 491 |
+
print(f" Intent: {result.analysis.intent}")
|
| 492 |
+
print(f" Entities: {result.analysis.entities}")
|
| 493 |
+
print(f" Keywords: {result.analysis.keywords}")
|
| 494 |
+
print(f" Complexity: {result.analysis.complexity}")
|
| 495 |
+
print(f" Domain: {result.analysis.domain}")
|
| 496 |
+
print(f"\nExpanded queries ({len(result.variations)}):")
|
| 497 |
+
for i, q in enumerate(result.variations, 1):
|
| 498 |
+
print(f" {i}. {q}")
|
| 499 |
+
|
| 500 |
+
# Example 3: Multi-hop reasoning
|
| 501 |
+
print("\n" + "=" * 60)
|
| 502 |
+
print("Example 3: Multi-Hop Reasoning")
|
| 503 |
+
print("=" * 60)
|
| 504 |
+
|
| 505 |
+
reasoner = MultiHopReasoner()
|
| 506 |
+
analyzer = QueryAnalyzer()
|
| 507 |
+
|
| 508 |
+
query = "Compare Python and Java for web development"
|
| 509 |
+
analysis = analyzer.analyze(query)
|
| 510 |
+
sub_queries = reasoner.generate_sub_queries(query, analysis)
|
| 511 |
+
|
| 512 |
+
print(f"\nOriginal: {query}")
|
| 513 |
+
print(f"\nSub-queries for multi-hop reasoning:")
|
| 514 |
+
for i, sq in enumerate(sub_queries, 1):
|
| 515 |
+
print(f" {i}. {sq}")
|
| 516 |
+
|
| 517 |
+
# Example 4: Fallback strategies
|
| 518 |
+
print("\n" + "=" * 60)
|
| 519 |
+
print("Example 4: Fallback Strategies")
|
| 520 |
+
print("=" * 60)
|
| 521 |
+
|
| 522 |
+
query = "What is the specific difference between supervised and unsupervised learning?"
|
| 523 |
+
analysis = analyzer.analyze(query)
|
| 524 |
+
|
| 525 |
+
|
rag_processor.py
CHANGED
|
@@ -3,93 +3,479 @@ from dotenv import load_dotenv
|
|
| 3 |
from operator import itemgetter
|
| 4 |
from langchain_groq import ChatGroq
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 6 |
-
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
llm = ChatGroq(model_name="moonshotai/kimi-k2-instruct-0905", api_key=api_key, temperature=0.1)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
# 2. Query Rewriting Chain
|
| 29 |
-
rewrite_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question that is optimized for a vector database.
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
**Follow-up Question:**
|
| 35 |
-
{question}
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
def get_parents(docs):
|
| 47 |
-
parent_ids = {d.metadata.get("doc_id") for d in docs}
|
| 48 |
-
return store.mget(list(parent_ids))
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
rag_template = """You are CogniChat, an expert document analysis assistant. Your task is to answer the user's question based *only* on the provided context.
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
rag_prompt = ChatPromptTemplate.from_messages([
|
| 65 |
("system", rag_template),
|
| 66 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 67 |
("human", "{question}"),
|
| 68 |
])
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
#
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
conversational_rag_chain,
|
| 82 |
get_session_history_func,
|
| 83 |
input_messages_key="question",
|
| 84 |
history_messages_key="chat_history",
|
| 85 |
)
|
| 86 |
|
| 87 |
-
print("
|
|
|
|
| 88 |
|
| 89 |
-
return
|
| 90 |
-
"rewriter": query_rewriter_chain,
|
| 91 |
-
"hyde": hyde_chain,
|
| 92 |
-
"base_retriever": base_retriever,
|
| 93 |
-
"parent_fetcher": parent_fetcher_chain,
|
| 94 |
-
"final_chain": final_chain
|
| 95 |
-
}
|
|
|
|
| 3 |
from operator import itemgetter
|
| 4 |
from langchain_groq import ChatGroq
|
| 5 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 6 |
+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 9 |
+
from langchain_core.documents import Document
|
| 10 |
+
from query_expansion import expand_query_simple
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
import time
|
| 13 |
|
| 14 |
+
class GroqAPIKeyManager:
|
| 15 |
+
"""Manages multiple Groq API keys with automatic rotation and fallback."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, api_keys: List[str]):
|
| 18 |
+
"""
|
| 19 |
+
Initialize with a list of API keys.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
api_keys: List of Groq API keys to use
|
| 23 |
+
"""
|
| 24 |
+
self.api_keys = [key for key in api_keys if key and key != "your_groq_api_key_here"]
|
| 25 |
+
if not self.api_keys:
|
| 26 |
+
raise ValueError("No valid API keys provided!")
|
| 27 |
+
|
| 28 |
+
self.current_index = 0
|
| 29 |
+
self.failed_keys = set()
|
| 30 |
+
self.success_count = {key: 0 for key in self.api_keys}
|
| 31 |
+
self.failure_count = {key: 0 for key in self.api_keys}
|
| 32 |
+
|
| 33 |
+
print(f"🔑 API Key Manager: Loaded {len(self.api_keys)} API keys")
|
| 34 |
+
|
| 35 |
+
def get_current_key(self) -> str:
|
| 36 |
+
"""Get the current API key."""
|
| 37 |
+
return self.api_keys[self.current_index]
|
| 38 |
+
|
| 39 |
+
def mark_success(self, api_key: str):
|
| 40 |
+
"""Mark an API key as successful."""
|
| 41 |
+
if api_key in self.success_count:
|
| 42 |
+
self.success_count[api_key] += 1
|
| 43 |
+
# Remove from failed keys if it was there
|
| 44 |
+
if api_key in self.failed_keys:
|
| 45 |
+
self.failed_keys.remove(api_key)
|
| 46 |
+
print(f" ✅ API Key #{self.api_keys.index(api_key) + 1} recovered!")
|
| 47 |
+
|
| 48 |
+
def mark_failure(self, api_key: str):
|
| 49 |
+
"""Mark an API key as failed."""
|
| 50 |
+
if api_key in self.failure_count:
|
| 51 |
+
self.failure_count[api_key] += 1
|
| 52 |
+
self.failed_keys.add(api_key)
|
| 53 |
+
|
| 54 |
+
def rotate_to_next_key(self) -> bool:
|
| 55 |
+
"""
|
| 56 |
+
Rotate to the next available API key.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
True if a new key is available, False if all keys failed
|
| 60 |
+
"""
|
| 61 |
+
initial_index = self.current_index
|
| 62 |
+
attempts = 0
|
| 63 |
+
|
| 64 |
+
while attempts < len(self.api_keys):
|
| 65 |
+
self.current_index = (self.current_index + 1) % len(self.api_keys)
|
| 66 |
+
attempts += 1
|
| 67 |
+
|
| 68 |
+
current_key = self.api_keys[self.current_index]
|
| 69 |
+
|
| 70 |
+
# If we've tried all keys, allow retry even failed ones
|
| 71 |
+
if attempts >= len(self.api_keys):
|
| 72 |
+
print(f" ⚠️ All keys attempted, retrying with key #{self.current_index + 1}")
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
# Skip recently failed keys unless it's been a while
|
| 76 |
+
if current_key not in self.failed_keys:
|
| 77 |
+
print(f" 🔄 Switching to API Key #{self.current_index + 1}")
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
def get_statistics(self) -> str:
|
| 83 |
+
"""Get statistics about API key usage."""
|
| 84 |
+
stats = []
|
| 85 |
+
for i, key in enumerate(self.api_keys):
|
| 86 |
+
success = self.success_count[key]
|
| 87 |
+
failure = self.failure_count[key]
|
| 88 |
+
status = "❌ FAILED" if key in self.failed_keys else "✅ ACTIVE"
|
| 89 |
+
masked_key = key[:8] + "..." + key[-4:] if len(key) > 12 else "***"
|
| 90 |
+
stats.append(f" Key #{i+1} ({masked_key}): {success} success, {failure} failures [{status}]")
|
| 91 |
+
return "\n".join(stats)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_api_keys_from_hf_secrets() -> List[str]:
|
| 95 |
"""
|
| 96 |
+
Load API keys from Hugging Face Spaces Secrets.
|
| 97 |
+
|
| 98 |
+
In your Hugging Face Space settings, add these secrets:
|
| 99 |
+
- GROQ_API_KEY_1
|
| 100 |
+
- GROQ_API_KEY_2
|
| 101 |
+
- GROQ_API_KEY_3
|
| 102 |
+
- GROQ_API_KEY_4
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
List of API keys retrieved from HF secrets
|
| 106 |
"""
|
| 107 |
+
api_keys = []
|
| 108 |
+
secret_names = ['GROQ_API_KEY_1', 'GROQ_API_KEY_2', 'GROQ_API_KEY_3', 'GROQ_API_KEY_4']
|
| 109 |
+
|
| 110 |
+
print("🔐 Loading API keys from Hugging Face Secrets...")
|
| 111 |
+
|
| 112 |
+
for secret_name in secret_names:
|
| 113 |
+
try:
|
| 114 |
+
# HF Spaces secrets are available as environment variables
|
| 115 |
+
api_key = os.getenv(secret_name)
|
| 116 |
+
|
| 117 |
+
if api_key and api_key.strip() and api_key != "your_groq_api_key_here":
|
| 118 |
+
api_keys.append(api_key.strip())
|
| 119 |
+
print(f" ✅ Loaded: {secret_name}")
|
| 120 |
+
else:
|
| 121 |
+
print(f" ⚠️ Not found or empty: {secret_name}")
|
| 122 |
+
except Exception as e:
|
| 123 |
+
print(f" ❌ Error loading {secret_name}: {str(e)}")
|
| 124 |
+
|
| 125 |
+
# ADD THIS RETURN STATEMENT - this was missing!
|
| 126 |
+
return api_keys
|
| 127 |
|
|
|
|
| 128 |
|
| 129 |
+
def create_llm_with_fallback(
|
| 130 |
+
api_key_manager: GroqAPIKeyManager,
|
| 131 |
+
model_name: str,
|
| 132 |
+
temperature: float,
|
| 133 |
+
max_retries: int = 3
|
| 134 |
+
) -> ChatGroq:
|
| 135 |
+
"""
|
| 136 |
+
Create a ChatGroq LLM with automatic API key fallback.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
api_key_manager: Manager handling multiple API keys
|
| 140 |
+
model_name: Name of the model to use
|
| 141 |
+
temperature: Temperature setting
|
| 142 |
+
max_retries: Maximum number of retry attempts
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
ChatGroq instance
|
| 146 |
+
"""
|
| 147 |
+
for attempt in range(max_retries):
|
| 148 |
+
current_key = api_key_manager.get_current_key()
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
llm = ChatGroq(
|
| 152 |
+
model_name=model_name,
|
| 153 |
+
api_key=current_key,
|
| 154 |
+
temperature=temperature
|
| 155 |
+
)
|
| 156 |
+
# Test the connection with a simple call
|
| 157 |
+
test_result = llm.invoke("test")
|
| 158 |
+
api_key_manager.mark_success(current_key)
|
| 159 |
+
return llm
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
error_msg = str(e).lower()
|
| 163 |
+
api_key_manager.mark_failure(current_key)
|
| 164 |
+
|
| 165 |
+
# Check if it's a rate limit or auth error
|
| 166 |
+
if "rate" in error_msg or "limit" in error_msg:
|
| 167 |
+
print(f" ⚠️ Rate limit hit on API Key #{api_key_manager.current_index + 1}")
|
| 168 |
+
elif "auth" in error_msg or "api" in error_msg:
|
| 169 |
+
print(f" ❌ Authentication failed on API Key #{api_key_manager.current_index + 1}")
|
| 170 |
+
else:
|
| 171 |
+
print(f" ❌ Error with API Key #{api_key_manager.current_index + 1}: {str(e)[:50]}")
|
| 172 |
+
|
| 173 |
+
# Try next key if available
|
| 174 |
+
if attempt < max_retries - 1:
|
| 175 |
+
if api_key_manager.rotate_to_next_key():
|
| 176 |
+
print(f" 🔄 Retrying with next API key (Attempt {attempt + 2}/{max_retries})...")
|
| 177 |
+
time.sleep(1) # Brief pause before retry
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError("All API keys failed!")
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError(f"Failed to initialize LLM after {max_retries} attempts")
|
| 182 |
+
|
| 183 |
+
raise ValueError("Failed to create LLM with any available API key")
|
| 184 |
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
def create_multi_query_retriever(base_retriever, llm, strategy: str = "balanced"):
|
| 187 |
+
"""Wraps a base retriever with query expansion capabilities."""
|
| 188 |
+
def multi_query_retrieve(query: str) -> List[Document]:
|
| 189 |
+
"""Retrieves documents using expanded query variations."""
|
| 190 |
+
query_variations = expand_query_simple(query, strategy=strategy, llm=llm)
|
| 191 |
+
all_docs = []
|
| 192 |
+
seen_content = set()
|
| 193 |
+
for i, query_var in enumerate(query_variations):
|
| 194 |
+
try:
|
| 195 |
+
docs = base_retriever.invoke(query_var)
|
| 196 |
+
for doc in docs:
|
| 197 |
+
content_hash = hash(doc.page_content)
|
| 198 |
+
if content_hash not in seen_content:
|
| 199 |
+
seen_content.add(content_hash)
|
| 200 |
+
all_docs.append(doc)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f" ✗ Query Expansion Error (Query {i+1}): {str(e)[:50]}")
|
| 203 |
+
continue
|
| 204 |
+
print(f" 📊 Query Expansion: Retrieved {len(all_docs)} unique documents.")
|
| 205 |
+
return all_docs
|
| 206 |
+
return multi_query_retrieve
|
| 207 |
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
def get_system_prompt(temperature: float) -> str:
|
| 210 |
+
"""
|
| 211 |
+
Returns a system prompt dynamically based on temperature setting.
|
| 212 |
+
|
| 213 |
+
Temperature ranges:
|
| 214 |
+
- 0.0-0.4: Highly factual, structured, conservative
|
| 215 |
+
- 0.4-0.8: Balanced approach with moderate creativity
|
| 216 |
+
- 0.8-1.0: Creative, engaging, storytelling mode
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
if temperature <= 0.4:
|
| 220 |
+
# Conservative, structured prompt
|
| 221 |
+
return """You are CogniChat, an expert document analysis assistant specializing in comprehensive and well-structured answers.
|
| 222 |
+
|
| 223 |
+
RESPONSE GUIDELINES:
|
| 224 |
+
|
| 225 |
+
**Structure & Formatting:**
|
| 226 |
+
- Start with a direct answer to the question
|
| 227 |
+
- Use **bold** for key terms, important concepts, and technical terminology
|
| 228 |
+
- Use bullet points (•) for lists, features, or multiple items
|
| 229 |
+
- Use numbered lists (1., 2., 3.) for steps, procedures, or sequential information
|
| 230 |
+
- Use ### Headers to organize different sections or topics
|
| 231 |
+
- Add blank lines between sections for readability
|
| 232 |
+
|
| 233 |
+
**Source Citation:**
|
| 234 |
+
- Always cite information using: [Source: filename, Page: X]
|
| 235 |
+
- Place citations at the end of your final answer only
|
| 236 |
+
- Do not cite sources within the body of your answer
|
| 237 |
+
- Multiple sources: [Source: doc1.pdf, Page: 3; doc2.pdf, Page: 7]
|
| 238 |
+
|
| 239 |
+
**Completeness:**
|
| 240 |
+
- Provide thorough, detailed answers using ALL relevant information from context
|
| 241 |
+
- Summarize and properly elaborate each point for increased clarity
|
| 242 |
+
- If the question has multiple parts, address each part clearly
|
| 243 |
+
|
| 244 |
+
**Accuracy:**
|
| 245 |
+
- ONLY use information from the provided context documents below
|
| 246 |
+
- If information is incomplete, state what IS available and what ISN'T
|
| 247 |
+
- If the answer isn't in the context, clearly state: "I cannot find this information in the uploaded documents"
|
| 248 |
+
- Never make assumptions or add information not in the context
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
{context}
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
Now answer the following question comprehensively using the context above:"""
|
| 257 |
+
|
| 258 |
+
elif temperature <= 0.8:
|
| 259 |
+
# Balanced prompt
|
| 260 |
+
return """You are CogniChat, an intelligent document analysis assistant that combines accuracy with engaging communication.
|
| 261 |
+
|
| 262 |
+
RESPONSE GUIDELINES:
|
| 263 |
+
|
| 264 |
+
**Communication Style:**
|
| 265 |
+
- Present information in a clear, engaging manner
|
| 266 |
+
- Use **bold** for emphasis on important concepts
|
| 267 |
+
- Balance structure with natural flow
|
| 268 |
+
- Make complex topics accessible and interesting
|
| 269 |
+
|
| 270 |
+
**Content Approach:**
|
| 271 |
+
- Ground your response firmly in the provided context
|
| 272 |
+
- Add helpful explanations and connections between concepts
|
| 273 |
+
- Use analogies or examples when they help clarify ideas (but keep them brief)
|
| 274 |
+
- Organize information logically with headers (###) and lists where appropriate
|
| 275 |
+
|
| 276 |
+
**Source Attribution:**
|
| 277 |
+
- Cite sources at the end: [Source: filename, Page: X]
|
| 278 |
+
- Be transparent about what the documents do and don't contain
|
| 279 |
+
|
| 280 |
+
**Accuracy:**
|
| 281 |
+
- Base your answer on the context documents provided
|
| 282 |
+
- If information is partial, explain what's available
|
| 283 |
+
- Acknowledge gaps: "The documents don't cover this aspect"
|
| 284 |
|
| 285 |
+
---
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
{context}
|
| 288 |
|
| 289 |
+
---
|
|
|
|
| 290 |
|
| 291 |
+
Now answer the following question in an engaging yet accurate way:"""
|
| 292 |
+
|
| 293 |
+
else: # temperature > 0.8
|
| 294 |
+
# Creative, engaging prompt
|
| 295 |
+
return """You are CogniChat, a creative and insightful document analyst who transforms information into engaging, memorable experiences.
|
| 296 |
+
|
| 297 |
+
🎨 CREATIVE RESPONSE GUIDELINES:
|
| 298 |
+
|
| 299 |
+
**Your Mission:**
|
| 300 |
+
Transform the document content into compelling, creative responses while staying true to the facts. Think of yourself as a skilled storyteller who brings information to life!
|
| 301 |
+
|
| 302 |
+
**Creative Techniques - Use Liberally:**
|
| 303 |
+
- **Vivid Language**: Use descriptive, evocative language that paints mental pictures
|
| 304 |
+
- **Analogies & Metaphors**: Create memorable comparisons that illuminate concepts
|
| 305 |
+
- **Narrative Flow**: Tell a story when appropriate - build tension, reveal insights progressively
|
| 306 |
+
- **Engaging Hooks**: Start with something intriguing that captures attention
|
| 307 |
+
- **Real-World Connections**: Bridge abstract concepts to tangible, relatable scenarios
|
| 308 |
+
- **Thought-Provoking Questions**: Pose rhetorical questions that spark curiosity
|
| 309 |
+
- **Dynamic Formatting**: Use varied structures - not just bullet points. Try prose paragraphs, short punchy sentences, strategic emphasis
|
| 310 |
+
|
| 311 |
+
**Creative Freedom:**
|
| 312 |
+
- Interpret and synthesize information creatively
|
| 313 |
+
- Make insightful connections between different pieces of information
|
| 314 |
+
- Present the same facts in novel, interesting ways
|
| 315 |
+
- Use formatting creatively: emojis (when appropriate), varied paragraph lengths, strategic **emphasis**
|
| 316 |
+
- Vary your tone based on content: enthusiastic for exciting topics, contemplative for complex ones
|
| 317 |
+
|
| 318 |
+
**Boundaries of Creativity:**
|
| 319 |
+
- ✅ Creative presentation, interpretation, and synthesis of facts
|
| 320 |
+
- ✅ Memorable analogies and explanatory examples
|
| 321 |
+
- ✅ Engaging narrative structure and compelling language
|
| 322 |
+
- ❌ Never invent facts not in the documents
|
| 323 |
+
- ❌ Don't contradict the source material
|
| 324 |
+
- ❌ Acknowledge when information isn't available (but do so creatively!)
|
| 325 |
|
| 326 |
+
**Source Attribution:**
|
| 327 |
+
- Weave citations naturally into your narrative
|
| 328 |
+
- End with: [Source: filename, Page: X]
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
{context}
|
| 333 |
+
|
| 334 |
+
---
|
| 335 |
+
|
| 336 |
+
Now, using your creative prowess and the context above, craft an engaging and memorable answer to this question:"""
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def create_rag_chain(
|
| 340 |
+
retriever,
|
| 341 |
+
get_session_history_func,
|
| 342 |
+
enable_query_expansion=True,
|
| 343 |
+
expansion_strategy="balanced",
|
| 344 |
+
model_name: str = "moonshotai/kimi-k2-instruct",
|
| 345 |
+
temperature: float = 0.2,
|
| 346 |
+
api_keys: Optional[List[str]] = None
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
Creates an advanced RAG chain with temperature-adaptive prompting and API key rotation.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
retriever: Document retriever
|
| 353 |
+
get_session_history_func: Function to get session history
|
| 354 |
+
enable_query_expansion: Whether to enable query expansion
|
| 355 |
+
expansion_strategy: Strategy for query expansion
|
| 356 |
+
model_name: Name of the LLM model
|
| 357 |
+
temperature: Temperature setting (0.0-1.0)
|
| 358 |
+
api_keys: Optional list of API keys. If None, loads from environment
|
| 359 |
"""
|
| 360 |
+
|
| 361 |
+
# Load API keys from HF Secrets
|
| 362 |
+
if api_keys is None:
|
| 363 |
+
api_keys = load_api_keys_from_hf_secrets()
|
| 364 |
+
|
| 365 |
+
if not api_keys:
|
| 366 |
+
raise ValueError(
|
| 367 |
+
"No valid API keys found! Please set GROQ_API_KEY or GROQ_API_KEY_1, "
|
| 368 |
+
"GROQ_API_KEY_2, GROQ_API_KEY_3, GROQ_API_KEY_4 in your .env file"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Initialize API key manager
|
| 372 |
+
api_key_manager = GroqAPIKeyManager(api_keys)
|
| 373 |
+
|
| 374 |
+
print(f"⚙️ RAG: Initializing LLM - Model: {model_name}, Temp: {temperature}")
|
| 375 |
+
|
| 376 |
+
# Display creativity mode based on temperature
|
| 377 |
+
if temperature <= 0.4:
|
| 378 |
+
creativity_mode = "FACTUAL & STRUCTURED"
|
| 379 |
+
elif temperature <= 0.8:
|
| 380 |
+
creativity_mode = "BALANCED & ENGAGING"
|
| 381 |
+
else:
|
| 382 |
+
creativity_mode = "CREATIVE & STORYTELLING"
|
| 383 |
+
print(f"🎭 Creativity Mode: {creativity_mode}")
|
| 384 |
+
|
| 385 |
+
# Create LLM with fallback
|
| 386 |
+
llm = create_llm_with_fallback(api_key_manager, model_name, temperature)
|
| 387 |
+
print(f"✅ LLM initialized with API Key #{api_key_manager.current_index + 1}")
|
| 388 |
+
|
| 389 |
+
if enable_query_expansion:
|
| 390 |
+
print(f"✨ RAG: Query Expansion ENABLED (Strategy: {expansion_strategy})")
|
| 391 |
+
enhanced_retriever = create_multi_query_retriever(
|
| 392 |
+
base_retriever=retriever,
|
| 393 |
+
llm=llm,
|
| 394 |
+
strategy=expansion_strategy
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
enhanced_retriever = retriever
|
| 398 |
+
|
| 399 |
+
rewrite_template = """You are an expert at optimizing search queries for document retrieval.
|
| 400 |
+
|
| 401 |
+
Given the conversation history and a follow-up question, create a comprehensive standalone question that:
|
| 402 |
+
1. Incorporates all relevant context from the chat history
|
| 403 |
+
2. Expands abbreviations and resolves all pronouns (it, they, this, that, etc.)
|
| 404 |
+
3. Includes key technical terms and concepts that would help find relevant documents
|
| 405 |
+
4. Maintains the original intent, specificity, and detail level
|
| 406 |
+
5. If the question asks for comparison or multiple items, ensure all items are in the query
|
| 407 |
+
|
| 408 |
+
Chat History:
|
| 409 |
+
{chat_history}
|
| 410 |
+
|
| 411 |
+
Follow-up Question: {question}
|
| 412 |
+
|
| 413 |
+
Optimized Standalone Question:"""
|
| 414 |
+
rewrite_prompt = ChatPromptTemplate.from_messages([
|
| 415 |
+
("system", rewrite_template),
|
| 416 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 417 |
+
("human", "{question}")
|
| 418 |
+
])
|
| 419 |
+
query_rewriter = rewrite_prompt | llm | StrOutputParser()
|
| 420 |
+
|
| 421 |
+
def format_docs(docs):
|
| 422 |
+
"""Format retrieved documents with clear structure and metadata."""
|
| 423 |
+
if not docs:
|
| 424 |
+
return "No relevant documents found in the knowledge base."
|
| 425 |
+
|
| 426 |
+
formatted_parts = []
|
| 427 |
+
for i, doc in enumerate(docs, 1):
|
| 428 |
+
source = doc.metadata.get('source', 'Unknown Document')
|
| 429 |
+
page = doc.metadata.get('page', 'N/A')
|
| 430 |
+
rerank_score = doc.metadata.get('rerank_score')
|
| 431 |
+
content = doc.page_content.strip()
|
| 432 |
+
|
| 433 |
+
doc_header = f"{'='*60}\nDOCUMENT {i}\n{'='*60}"
|
| 434 |
+
metadata_line = f"Source: {source} | Page: {page}"
|
| 435 |
+
if rerank_score:
|
| 436 |
+
metadata_line += f" | Relevance: {rerank_score:.3f}"
|
| 437 |
+
|
| 438 |
+
formatted_parts.append(
|
| 439 |
+
f"{doc_header}\n"
|
| 440 |
+
f"{metadata_line}\n"
|
| 441 |
+
f"{'-'*60}\n"
|
| 442 |
+
f"{content}\n"
|
| 443 |
+
)
|
| 444 |
+
return f"RETRIEVED CONTEXT ({len(docs)} documents):\n\n" + "\n".join(formatted_parts)
|
| 445 |
+
|
| 446 |
+
# Get temperature-adaptive system prompt
|
| 447 |
+
rag_template = get_system_prompt(temperature)
|
| 448 |
+
|
| 449 |
rag_prompt = ChatPromptTemplate.from_messages([
|
| 450 |
("system", rag_template),
|
| 451 |
MessagesPlaceholder(variable_name="chat_history"),
|
| 452 |
("human", "{question}"),
|
| 453 |
])
|
| 454 |
|
| 455 |
+
# Rewriter input construction
|
| 456 |
+
rewriter_input = RunnableParallel({
|
| 457 |
+
"question": itemgetter("question"),
|
| 458 |
+
"chat_history": itemgetter("chat_history"),
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
# Main retrieval pipeline
|
| 462 |
+
retrieval_chain = rewriter_input | query_rewriter | enhanced_retriever | format_docs
|
| 463 |
+
|
| 464 |
+
# Final conversational RAG chain
|
| 465 |
+
conversational_rag_chain = RunnableParallel({
|
| 466 |
+
"context": retrieval_chain,
|
| 467 |
+
"question": itemgetter("question"),
|
| 468 |
+
"chat_history": itemgetter("chat_history"),
|
| 469 |
+
}) | rag_prompt | llm | StrOutputParser()
|
| 470 |
+
|
| 471 |
+
chain_with_memory = RunnableWithMessageHistory(
|
| 472 |
conversational_rag_chain,
|
| 473 |
get_session_history_func,
|
| 474 |
input_messages_key="question",
|
| 475 |
history_messages_key="chat_history",
|
| 476 |
)
|
| 477 |
|
| 478 |
+
print("✅ RAG: Chain created successfully.")
|
| 479 |
+
print("\n" + api_key_manager.get_statistics())
|
| 480 |
|
| 481 |
+
return chain_with_memory, api_key_manager # Return manager for statistics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates/index.html
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
-
<link href="https://fonts.googleapis.com/css2?family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
|
| 11 |
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
| 12 |
<style>
|
| 13 |
:root {
|
|
@@ -20,23 +20,28 @@
|
|
| 20 |
--input-bg: #e8f0fe;
|
| 21 |
--user-bubble: #d9e7ff;
|
| 22 |
--bot-bubble: #f1f3f4;
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
-
/* Dark mode styles */
|
| 26 |
.dark {
|
| 27 |
-
--background: #
|
| 28 |
-
--foreground: #
|
| 29 |
-
--primary: #
|
| 30 |
-
--primary-hover: #
|
| 31 |
-
--card: #
|
| 32 |
-
--card-border: #
|
| 33 |
-
--input-bg: #
|
| 34 |
-
--user-bubble: #
|
| 35 |
-
--bot-bubble: #
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
|
| 38 |
body {
|
| 39 |
-
font-family: 'Google Sans', 'Roboto', sans-serif;
|
| 40 |
background-color: var(--background);
|
| 41 |
color: var(--foreground);
|
| 42 |
overflow: hidden;
|
|
@@ -44,15 +49,14 @@
|
|
| 44 |
|
| 45 |
#chat-window::-webkit-scrollbar { width: 8px; }
|
| 46 |
#chat-window::-webkit-scrollbar-track { background: transparent; }
|
| 47 |
-
#chat-window::-webkit-scrollbar-thumb { background-color: #
|
| 48 |
.dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
|
| 49 |
|
| 50 |
.drop-zone--over {
|
| 51 |
border-color: var(--primary);
|
| 52 |
-
box-shadow: 0 0
|
| 53 |
}
|
| 54 |
|
| 55 |
-
/* Loading Spinner */
|
| 56 |
.loader {
|
| 57 |
width: 48px;
|
| 58 |
height: 48px;
|
|
@@ -82,7 +86,6 @@
|
|
| 82 |
100% { transform: rotate(360deg); }
|
| 83 |
}
|
| 84 |
|
| 85 |
-
/* Typing Indicator Animation */
|
| 86 |
.typing-indicator span {
|
| 87 |
height: 10px;
|
| 88 |
width: 10px;
|
|
@@ -98,173 +101,136 @@
|
|
| 98 |
40% { transform: scale(1.0); }
|
| 99 |
}
|
| 100 |
|
| 101 |
-
|
| 102 |
-
.markdown-content
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
}
|
| 106 |
-
.markdown-content
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
}
|
| 113 |
-
.markdown-content h1 { font-size: 1.75em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.5rem; }
|
| 114 |
-
.markdown-content h2 { font-size: 1.5em; }
|
| 115 |
-
.markdown-content h3 { font-size: 1.25em; }
|
| 116 |
-
.markdown-content h4 { font-size: 1.1em; }
|
| 117 |
-
.markdown-content ul, .markdown-content ol {
|
| 118 |
-
padding-left: 1.75rem;
|
| 119 |
-
margin-bottom: 1rem;
|
| 120 |
-
}
|
| 121 |
-
.markdown-content li {
|
| 122 |
-
margin-bottom: 0.5rem;
|
| 123 |
-
}
|
| 124 |
-
.dark .markdown-content ul > li::marker { color: var(--primary); }
|
| 125 |
-
.markdown-content ul > li::marker { color: var(--primary); }
|
| 126 |
-
.markdown-content a {
|
| 127 |
-
color: var(--primary);
|
| 128 |
-
text-decoration: none;
|
| 129 |
-
font-weight: 500;
|
| 130 |
-
border-bottom: 1px solid transparent;
|
| 131 |
-
transition: all 0.2s ease-in-out;
|
| 132 |
-
}
|
| 133 |
-
.markdown-content a:hover {
|
| 134 |
-
border-bottom-color: var(--primary-hover);
|
| 135 |
-
}
|
| 136 |
-
.markdown-content blockquote {
|
| 137 |
-
margin: 1.5rem 0;
|
| 138 |
-
padding-left: 1.5rem;
|
| 139 |
-
border-left: 4px solid var(--card-border);
|
| 140 |
-
color: #6c757d;
|
| 141 |
-
font-style: italic;
|
| 142 |
-
}
|
| 143 |
-
.dark .markdown-content blockquote {
|
| 144 |
-
color: #adb5bd;
|
| 145 |
-
}
|
| 146 |
-
.markdown-content hr {
|
| 147 |
-
border: none;
|
| 148 |
-
border-top: 1px solid var(--card-border);
|
| 149 |
-
margin: 2rem 0;
|
| 150 |
-
}
|
| 151 |
-
.markdown-content table {
|
| 152 |
-
width: 100%;
|
| 153 |
-
border-collapse: collapse;
|
| 154 |
-
margin: 1.5rem 0;
|
| 155 |
-
font-size: 0.9em;
|
| 156 |
-
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
| 157 |
-
border-radius: 8px;
|
| 158 |
-
overflow: hidden;
|
| 159 |
-
}
|
| 160 |
-
.markdown-content th, .markdown-content td {
|
| 161 |
-
border: 1px solid var(--card-border);
|
| 162 |
-
padding: 0.75rem 1rem;
|
| 163 |
-
text-align: left;
|
| 164 |
-
}
|
| 165 |
-
.markdown-content th {
|
| 166 |
-
background-color: var(--bot-bubble);
|
| 167 |
-
font-weight: 500;
|
| 168 |
-
}
|
| 169 |
-
.markdown-content code {
|
| 170 |
-
background-color: rgba(0,0,0,0.05);
|
| 171 |
-
padding: 0.2rem 0.4rem;
|
| 172 |
-
border-radius: 0.25rem;
|
| 173 |
-
font-family: 'Roboto Mono', monospace;
|
| 174 |
-
font-size: 0.9em;
|
| 175 |
-
}
|
| 176 |
-
.dark .markdown-content code {
|
| 177 |
-
background-color: rgba(255,255,255,0.1);
|
| 178 |
-
}
|
| 179 |
-
.markdown-content pre {
|
| 180 |
-
position: relative;
|
| 181 |
-
background-color: #f8f9fa;
|
| 182 |
-
border: 1px solid var(--card-border);
|
| 183 |
-
border-radius: 0.5rem;
|
| 184 |
-
margin-bottom: 1rem;
|
| 185 |
-
}
|
| 186 |
-
.dark .markdown-content pre {
|
| 187 |
-
background-color: #2e2f32;
|
| 188 |
-
}
|
| 189 |
-
.markdown-content pre code {
|
| 190 |
-
background: none;
|
| 191 |
-
padding: 1rem;
|
| 192 |
-
display: block;
|
| 193 |
-
overflow-x: auto;
|
| 194 |
-
}
|
| 195 |
-
.markdown-content pre .copy-code-btn {
|
| 196 |
-
position: absolute;
|
| 197 |
-
top: 0.5rem;
|
| 198 |
-
right: 0.5rem;
|
| 199 |
-
background-color: #e8eaed;
|
| 200 |
-
border: 1px solid #dadce0;
|
| 201 |
-
color: #5f6368;
|
| 202 |
-
padding: 0.3rem 0.6rem;
|
| 203 |
-
border-radius: 0.25rem;
|
| 204 |
-
cursor: pointer;
|
| 205 |
-
opacity: 0;
|
| 206 |
-
transition: opacity 0.2s;
|
| 207 |
-
font-size: 0.8em;
|
| 208 |
-
}
|
| 209 |
-
.dark .markdown-content pre .copy-code-btn {
|
| 210 |
-
background-color: #3c4043;
|
| 211 |
-
border-color: #5f6368;
|
| 212 |
-
color: #e8eaed;
|
| 213 |
-
}
|
| 214 |
-
.markdown-content pre:hover .copy-code-btn {
|
| 215 |
-
opacity: 1;
|
| 216 |
-
}
|
| 217 |
|
| 218 |
-
/* Spinner for the TTS button */
|
| 219 |
.tts-button-loader {
|
| 220 |
width: 16px;
|
| 221 |
height: 16px;
|
| 222 |
-
border: 2px solid currentColor;
|
| 223 |
border-radius: 50%;
|
| 224 |
display: inline-block;
|
| 225 |
box-sizing: border-box;
|
| 226 |
animation: rotation 0.8s linear infinite;
|
| 227 |
-
border-bottom-color: transparent;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
}
|
| 229 |
</style>
|
| 230 |
</head>
|
| 231 |
<body class="w-screen h-screen dark">
|
| 232 |
<main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
|
| 233 |
<div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
</header>
|
|
|
|
|
|
|
| 238 |
<div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
|
| 239 |
-
<div id="chat-content" class="max-w-4xl mx-auto space-y-8">
|
| 240 |
-
</div>
|
| 241 |
</div>
|
| 242 |
-
<div class="p-4 flex-shrink-0 bg-
|
| 243 |
-
<form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-
|
| 244 |
<input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
|
| 245 |
-
<button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed
|
| 246 |
-
<svg class="w-5 h-5" viewBox="0 0
|
| 247 |
</button>
|
| 248 |
</form>
|
| 249 |
</div>
|
| 250 |
</div>
|
| 251 |
|
| 252 |
<div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
|
| 253 |
-
<div class="text-center">
|
| 254 |
-
<h1 class="text-5xl font-
|
| 255 |
-
<
|
| 256 |
-
|
| 257 |
-
<
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
</div>
|
| 261 |
</div>
|
| 262 |
</div>
|
| 263 |
|
| 264 |
-
<div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50
|
| 265 |
<div class="loader"></div>
|
| 266 |
<p id="loading-text" class="mt-6 text-sm font-medium"></p>
|
| 267 |
-
<p id="loading-subtext" class="mt-2 text-xs text-gray-
|
| 268 |
</div>
|
| 269 |
</main>
|
| 270 |
|
|
@@ -278,94 +244,89 @@
|
|
| 278 |
const loadingOverlay = document.getElementById('loading-overlay');
|
| 279 |
const loadingText = document.getElementById('loading-text');
|
| 280 |
const loadingSubtext = document.getElementById('loading-subtext');
|
| 281 |
-
|
| 282 |
const chatForm = document.getElementById('chat-form');
|
| 283 |
const chatInput = document.getElementById('chat-input');
|
| 284 |
const chatSubmitBtn = document.getElementById('chat-submit-btn');
|
| 285 |
const chatWindow = document.getElementById('chat-window');
|
| 286 |
const chatContent = document.getElementById('chat-content');
|
|
|
|
|
|
|
| 287 |
const chatFilename = document.getElementById('chat-filename');
|
|
|
|
| 288 |
|
| 289 |
-
let sessionId =
|
| 290 |
-
const storedSessionId = sessionStorage.getItem('cognichat_session_id');
|
| 291 |
-
if (storedSessionId) {
|
| 292 |
-
sessionId = storedSessionId;
|
| 293 |
-
console.debug('Restored session ID from storage:', sessionId);
|
| 294 |
-
}
|
| 295 |
|
| 296 |
-
// --- File Upload Logic ---
|
| 297 |
dropZone.addEventListener('click', () => fileUploadInput.click());
|
| 298 |
|
| 299 |
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
|
| 300 |
-
dropZone.addEventListener(eventName,
|
| 301 |
-
document.body.addEventListener(eventName, preventDefaults, false);
|
| 302 |
-
});
|
| 303 |
-
|
| 304 |
-
['dragenter', 'dragover'].forEach(eventName => {
|
| 305 |
-
dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over'));
|
| 306 |
});
|
| 307 |
-
['
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
dropZone.addEventListener('drop', (e) => {
|
| 312 |
-
|
| 313 |
-
if (files.length > 0) handleFiles(files);
|
| 314 |
});
|
| 315 |
-
|
| 316 |
fileUploadInput.addEventListener('change', (e) => {
|
| 317 |
if (e.target.files.length > 0) handleFiles(e.target.files);
|
| 318 |
});
|
| 319 |
|
| 320 |
-
function preventDefaults(e) { e.preventDefault(); e.stopPropagation(); }
|
| 321 |
-
|
| 322 |
async function handleFiles(files) {
|
| 323 |
const formData = new FormData();
|
| 324 |
-
let fileNames =
|
| 325 |
-
for (const file of files) {
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
|
| 330 |
fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
|
| 331 |
-
await uploadAndProcessFiles(formData
|
| 332 |
}
|
| 333 |
|
| 334 |
-
async function uploadAndProcessFiles(formData
|
| 335 |
loadingOverlay.classList.remove('hidden');
|
| 336 |
-
loadingText.textContent = `Processing
|
| 337 |
-
loadingSubtext.textContent = "
|
| 338 |
|
| 339 |
try {
|
| 340 |
const response = await fetch('/upload', { method: 'POST', body: formData });
|
| 341 |
const result = await response.json();
|
| 342 |
-
|
| 343 |
if (!response.ok) throw new Error(result.message || 'Unknown error occurred.');
|
| 344 |
-
if (result.session_id) {
|
| 345 |
-
sessionId = result.session_id;
|
| 346 |
-
sessionStorage.setItem('cognichat_session_id', sessionId);
|
| 347 |
-
console.debug('Stored session ID from upload:', sessionId);
|
| 348 |
-
} else {
|
| 349 |
-
console.warn('Upload response missing session_id field.');
|
| 350 |
-
}
|
| 351 |
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
uploadContainer.classList.add('hidden');
|
| 354 |
chatContainer.classList.remove('hidden');
|
| 355 |
-
appendMessage("I've analyzed your documents. What would you like to know?", "bot");
|
| 356 |
|
| 357 |
} catch (error) {
|
| 358 |
console.error('Upload error:', error);
|
| 359 |
alert(`Error: ${error.message}`);
|
| 360 |
} finally {
|
| 361 |
loadingOverlay.classList.add('hidden');
|
| 362 |
-
loadingSubtext.textContent = '';
|
| 363 |
fileNameSpan.textContent = '';
|
| 364 |
fileUploadInput.value = '';
|
| 365 |
}
|
| 366 |
}
|
| 367 |
|
| 368 |
-
// --- Chat Logic ---
|
| 369 |
chatForm.addEventListener('submit', async (e) => {
|
| 370 |
e.preventDefault();
|
| 371 |
const question = chatInput.value.trim();
|
|
@@ -377,53 +338,36 @@
|
|
| 377 |
chatSubmitBtn.disabled = true;
|
| 378 |
|
| 379 |
const typingIndicator = showTypingIndicator();
|
| 380 |
-
|
| 381 |
-
let contentDiv = null;
|
| 382 |
-
|
| 383 |
try {
|
| 384 |
-
const requestBody = { question: question };
|
| 385 |
-
if (sessionId) {
|
| 386 |
-
requestBody.session_id = sessionId;
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
const response = await fetch('/chat', {
|
| 390 |
method: 'POST',
|
| 391 |
headers: { 'Content-Type': 'application/json' },
|
| 392 |
-
body: JSON.stringify(
|
| 393 |
});
|
| 394 |
-
|
| 395 |
if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
|
| 396 |
-
|
| 397 |
-
// ============================ MODIFICATION START ==============================
|
| 398 |
-
// Parse the JSON response instead of reading a stream
|
| 399 |
-
const result = await response.json();
|
| 400 |
-
const answer = result.answer; // Extract the 'answer' field
|
| 401 |
-
|
| 402 |
-
if (!answer) {
|
| 403 |
-
throw new Error("Received an empty or invalid response from the server.");
|
| 404 |
-
}
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
typingIndicator.remove();
|
| 407 |
-
botMessageContainer = appendMessage('', 'bot');
|
| 408 |
-
contentDiv = botMessageContainer.querySelector('.markdown-content');
|
| 409 |
-
|
| 410 |
-
// Use the extracted answer for rendering
|
| 411 |
-
contentDiv.innerHTML = marked.parse(answer);
|
| 412 |
contentDiv.querySelectorAll('pre').forEach(addCopyButton);
|
| 413 |
-
scrollToBottom();
|
| 414 |
-
|
| 415 |
-
// Use the extracted answer for TTS
|
| 416 |
-
addTextToSpeechControls(botMessageContainer, answer);
|
| 417 |
-
// ============================ MODIFICATION END ==============================
|
| 418 |
-
|
| 419 |
} catch (error) {
|
| 420 |
console.error('Chat error:', error);
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
contentDiv.innerHTML = `<p class="text-red-500">Error: ${error.message}</p>`;
|
| 424 |
-
} else {
|
| 425 |
-
appendMessage(`Error: ${error.message}`, 'bot');
|
| 426 |
-
}
|
| 427 |
} finally {
|
| 428 |
chatInput.disabled = false;
|
| 429 |
chatSubmitBtn.disabled = false;
|
|
@@ -431,166 +375,108 @@
|
|
| 431 |
}
|
| 432 |
});
|
| 433 |
|
| 434 |
-
// ---
|
| 435 |
-
|
| 436 |
-
function appendMessage(text, sender) {
|
| 437 |
const messageWrapper = document.createElement('div');
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
const iconSVG = sender === 'user'
|
| 441 |
-
? `<div class="bg-blue-100 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1"><svg class="w-5 h-5 text-blue-600 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
|
| 442 |
: `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
|
|
|
| 456 |
}
|
| 457 |
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
chatContent.appendChild(messageWrapper);
|
| 468 |
scrollToBottom();
|
| 469 |
-
|
| 470 |
-
return messageBubble;
|
| 471 |
}
|
| 472 |
|
| 473 |
function showTypingIndicator() {
|
| 474 |
-
const indicatorWrapper = document.createElement('div');
|
| 475 |
-
indicatorWrapper.className = `flex items-start gap-4`;
|
| 476 |
-
indicatorWrapper.id = 'typing-indicator';
|
| 477 |
-
|
| 478 |
-
const iconSVG = `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
|
| 479 |
-
|
| 480 |
-
const messageBubble = document.createElement('div');
|
| 481 |
-
messageBubble.className = 'flex-1 pt-1';
|
| 482 |
-
|
| 483 |
-
const senderName = document.createElement('p');
|
| 484 |
-
senderName.className = 'font-medium text-sm mb-1';
|
| 485 |
-
senderName.textContent = 'CogniChat is thinking...';
|
| 486 |
-
|
| 487 |
const indicator = document.createElement('div');
|
| 488 |
-
indicator.
|
| 489 |
-
indicator.
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
messageBubble.appendChild(indicator);
|
| 493 |
-
indicatorWrapper.innerHTML = iconSVG;
|
| 494 |
-
indicatorWrapper.appendChild(messageBubble);
|
| 495 |
-
|
| 496 |
-
chatContent.appendChild(indicatorWrapper);
|
| 497 |
scrollToBottom();
|
| 498 |
-
|
| 499 |
-
return indicatorWrapper;
|
| 500 |
}
|
| 501 |
|
| 502 |
-
function scrollToBottom() {
|
| 503 |
-
|
| 504 |
-
top: chatWindow.scrollHeight,
|
| 505 |
-
behavior: 'smooth'
|
| 506 |
-
});
|
| 507 |
-
}
|
| 508 |
-
|
| 509 |
function addCopyButton(pre) {
|
| 510 |
const button = document.createElement('button');
|
| 511 |
button.className = 'copy-code-btn';
|
| 512 |
button.textContent = 'Copy';
|
| 513 |
pre.appendChild(button);
|
| 514 |
-
|
| 515 |
button.addEventListener('click', () => {
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
});
|
| 522 |
}
|
| 523 |
|
| 524 |
-
//
|
| 525 |
-
let currentAudio
|
| 526 |
-
let currentPlayingButton = null;
|
| 527 |
-
|
| 528 |
const playIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M8 5v14l11-7z"/></svg>`;
|
| 529 |
const pauseIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M6 19h4V5H6v14zm8-14v14h4V5h-4z"/></svg>`;
|
| 530 |
-
|
| 531 |
-
|
| 532 |
function addTextToSpeechControls(messageBubble, text) {
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
ttsControls.appendChild(speakButton);
|
| 541 |
-
speakButton.addEventListener('click', () => handleTTS(text, speakButton));
|
| 542 |
-
}
|
| 543 |
}
|
| 544 |
|
| 545 |
async function handleTTS(text, button) {
|
| 546 |
if (button === currentPlayingButton) {
|
| 547 |
if (currentAudio && !currentAudio.paused) {
|
| 548 |
currentAudio.pause();
|
| 549 |
-
button.
|
| 550 |
-
button.innerHTML = `${playIconSVG} <span>Play</span>`;
|
| 551 |
} else if (currentAudio && currentAudio.paused) {
|
| 552 |
currentAudio.play();
|
| 553 |
-
button.setAttribute('data-state', 'playing');
|
| 554 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 555 |
}
|
| 556 |
return;
|
| 557 |
}
|
| 558 |
-
|
| 559 |
resetAllSpeakButtons();
|
| 560 |
-
|
| 561 |
currentPlayingButton = button;
|
| 562 |
-
button.setAttribute('data-state', 'loading');
|
| 563 |
button.innerHTML = `<div class="tts-button-loader"></div> <span>Loading...</span>`;
|
| 564 |
button.disabled = true;
|
| 565 |
|
| 566 |
try {
|
| 567 |
-
const response = await fetch('/tts', {
|
| 568 |
-
method: 'POST',
|
| 569 |
-
headers: { 'Content-Type': 'application/json' },
|
| 570 |
-
body: JSON.stringify({ text: text })
|
| 571 |
-
});
|
| 572 |
if (!response.ok) throw new Error('Failed to generate audio.');
|
| 573 |
-
|
| 574 |
const blob = await response.blob();
|
| 575 |
-
|
| 576 |
-
currentAudio = new Audio(audioUrl);
|
| 577 |
currentAudio.play();
|
| 578 |
-
|
| 579 |
-
button.setAttribute('data-state', 'playing');
|
| 580 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 581 |
-
|
| 582 |
-
currentAudio.onended = () => {
|
| 583 |
-
button.setAttribute('data-state', 'play');
|
| 584 |
-
button.innerHTML = `${playIconSVG} <span>Play</span>`;
|
| 585 |
-
currentAudio = null;
|
| 586 |
-
currentPlayingButton = null;
|
| 587 |
-
};
|
| 588 |
-
|
| 589 |
} catch (error) {
|
| 590 |
console.error('TTS Error:', error);
|
| 591 |
-
button.setAttribute('data-state', 'error');
|
| 592 |
-
button.innerHTML = `${playIconSVG} <span>Error</span>`;
|
| 593 |
-
alert('Failed to play audio. Please try again.');
|
| 594 |
resetAllSpeakButtons();
|
| 595 |
} finally {
|
| 596 |
button.disabled = false;
|
|
@@ -599,8 +485,7 @@
|
|
| 599 |
|
| 600 |
function resetAllSpeakButtons() {
|
| 601 |
document.querySelectorAll('.speak-btn').forEach(btn => {
|
| 602 |
-
btn.
|
| 603 |
-
btn.innerHTML = `${playIconSVG} <span>Play</span>`;
|
| 604 |
btn.disabled = false;
|
| 605 |
});
|
| 606 |
if (currentAudio) {
|
|
|
|
| 7 |
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Google+Sans:wght@400;500;700&family=Roboto:wght@400;500&display=swap" rel="stylesheet">
|
| 11 |
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
| 12 |
<style>
|
| 13 |
:root {
|
|
|
|
| 20 |
--input-bg: #e8f0fe;
|
| 21 |
--user-bubble: #d9e7ff;
|
| 22 |
--bot-bubble: #f1f3f4;
|
| 23 |
+
--select-bg: #ffffff;
|
| 24 |
+
--select-border: #dadce0;
|
| 25 |
+
--select-text: #1f1f1f;
|
| 26 |
}
|
| 27 |
|
|
|
|
| 28 |
.dark {
|
| 29 |
+
--background: #111827;
|
| 30 |
+
--foreground: #e5e7eb;
|
| 31 |
+
--primary: #3b82f6;
|
| 32 |
+
--primary-hover: #60a5fa;
|
| 33 |
+
--card: #1f2937;
|
| 34 |
+
--card-border: #4b5563;
|
| 35 |
+
--input-bg: #374151;
|
| 36 |
+
--user-bubble: #374151;
|
| 37 |
+
--bot-bubble: #374151;
|
| 38 |
+
--select-bg: #374151;
|
| 39 |
+
--select-border: #6b7280;
|
| 40 |
+
--select-text: #f3f4f6;
|
| 41 |
}
|
| 42 |
|
| 43 |
body {
|
| 44 |
+
font-family: 'Inter', 'Google Sans', 'Roboto', sans-serif;
|
| 45 |
background-color: var(--background);
|
| 46 |
color: var(--foreground);
|
| 47 |
overflow: hidden;
|
|
|
|
| 49 |
|
| 50 |
#chat-window::-webkit-scrollbar { width: 8px; }
|
| 51 |
#chat-window::-webkit-scrollbar-track { background: transparent; }
|
| 52 |
+
#chat-window::-webkit-scrollbar-thumb { background-color: #4b5563; border-radius: 20px; }
|
| 53 |
.dark #chat-window::-webkit-scrollbar-thumb { background-color: #5f6368; }
|
| 54 |
|
| 55 |
.drop-zone--over {
|
| 56 |
border-color: var(--primary);
|
| 57 |
+
box-shadow: 0 0 20px rgba(59, 130, 246, 0.4);
|
| 58 |
}
|
| 59 |
|
|
|
|
| 60 |
.loader {
|
| 61 |
width: 48px;
|
| 62 |
height: 48px;
|
|
|
|
| 86 |
100% { transform: rotate(360deg); }
|
| 87 |
}
|
| 88 |
|
|
|
|
| 89 |
.typing-indicator span {
|
| 90 |
height: 10px;
|
| 91 |
width: 10px;
|
|
|
|
| 101 |
40% { transform: scale(1.0); }
|
| 102 |
}
|
| 103 |
|
| 104 |
+
.markdown-content p { margin-bottom: 1rem; line-height: 1.75; }
|
| 105 |
+
.markdown-content h1, .markdown-content h2, .markdown-content h3 { font-weight: 600; margin-top: 1.5rem; margin-bottom: 0.75rem; line-height: 1.3; }
|
| 106 |
+
.markdown-content h1 { font-size: 1.5em; border-bottom: 1px solid var(--card-border); padding-bottom: 0.3rem;}
|
| 107 |
+
.markdown-content h2 { font-size: 1.25em; }
|
| 108 |
+
.markdown-content h3 { font-size: 1.1em; }
|
| 109 |
+
.markdown-content ul, .markdown-content ol { padding-left: 1.75rem; margin-bottom: 1rem; }
|
| 110 |
+
.markdown-content li { margin-bottom: 0.5rem; }
|
| 111 |
+
.markdown-content a { color: var(--primary); text-decoration: none; font-weight: 500; }
|
| 112 |
+
.markdown-content pre { position: relative; background-color: #2e2f32; border: 1px solid var(--card-border); border-radius: 0.5rem; margin-bottom: 1rem; font-size: 0.9em;}
|
| 113 |
+
.markdown-content pre code { background: none; padding: 1rem; display: block; overflow-x: auto; }
|
| 114 |
+
.markdown-content pre .copy-code-btn { position: absolute; top: 0.5rem; right: 0.5rem; background-color: #3c4043; border: 1px solid #5f6368; color: #e8eaed; padding: 0.3rem 0.6rem; border-radius: 0.25rem; cursor: pointer; opacity: 0; transition: opacity 0.2s; font-size: 0.8em;}
|
| 115 |
+
.markdown-content pre:hover .copy-code-btn { opacity: 1; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
|
|
|
| 117 |
.tts-button-loader {
|
| 118 |
width: 16px;
|
| 119 |
height: 16px;
|
| 120 |
+
border: 2px solid currentColor;
|
| 121 |
border-radius: 50%;
|
| 122 |
display: inline-block;
|
| 123 |
box-sizing: border-box;
|
| 124 |
animation: rotation 0.8s linear infinite;
|
| 125 |
+
border-bottom-color: transparent;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
.select-wrapper {
|
| 129 |
+
position: relative;
|
| 130 |
+
}
|
| 131 |
+
.select-wrapper select {
|
| 132 |
+
background-color: var(--select-bg);
|
| 133 |
+
border: 1px solid var(--select-border);
|
| 134 |
+
color: var(--select-text);
|
| 135 |
+
padding: 0.75rem 2.5rem 0.75rem 1rem;
|
| 136 |
+
border-radius: 0.75rem;
|
| 137 |
+
font-size: 0.875rem;
|
| 138 |
+
width: 100%;
|
| 139 |
+
appearance: none;
|
| 140 |
+
-webkit-appearance: none;
|
| 141 |
+
transition: all 0.2s ease-in-out;
|
| 142 |
+
cursor: pointer;
|
| 143 |
+
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 20 20'%3e%3cpath stroke='%239ca3af' stroke-linecap='round' stroke-linejoin='round' stroke-width='1.5' d='M6 8l4 4 4-4'/%3e%3c/svg%3e");
|
| 144 |
+
background-position: right 0.75rem center;
|
| 145 |
+
background-repeat: no-repeat;
|
| 146 |
+
background-size: 1.25em 1.25em;
|
| 147 |
}
|
| 148 |
</style>
|
| 149 |
</head>
|
| 150 |
<body class="w-screen h-screen dark">
|
| 151 |
<main id="main-content" class="h-full flex flex-col transition-opacity duration-500">
|
| 152 |
<div id="chat-container" class="hidden flex-1 flex flex-col w-full mx-auto overflow-hidden">
|
| 153 |
+
|
| 154 |
+
<!-- --- CORRECT HEADER (Center/Right Layout) --- -->
|
| 155 |
+
<header class="p-4 border-b border-[var(--card-border)] flex-shrink-0 flex justify-between items-center w-full">
|
| 156 |
+
<div class="w-1/4"></div>
|
| 157 |
+
<div class="w-1/2 text-center">
|
| 158 |
+
<h1 class="text-xl font-medium tracking-wide">CogniChat</h1>
|
| 159 |
+
<p id="chat-filename" class="text-xs text-gray-400 mt-1 truncate"></p>
|
| 160 |
+
</div>
|
| 161 |
+
<div id="chat-session-info" class="w-1/4 text-right text-xs">
|
| 162 |
+
<!-- This will be populated by JavaScript -->
|
| 163 |
+
</div>
|
| 164 |
</header>
|
| 165 |
+
<!-- --- END HEADER --- -->
|
| 166 |
+
|
| 167 |
<div id="chat-window" class="flex-1 overflow-y-auto p-4 md:p-6 lg:p-10">
|
| 168 |
+
<div id="chat-content" class="max-w-4xl mx-auto space-y-8"></div>
|
|
|
|
| 169 |
</div>
|
| 170 |
+
<div class="p-4 flex-shrink-0 bg-opacity-50 backdrop-blur-md border-t border-[var(--card-border)]">
|
| 171 |
+
<form id="chat-form" class="max-w-4xl mx-auto bg-[var(--card)] rounded-full p-2 flex items-center shadow-lg border border-[var(--card-border)] focus-within:ring-2 focus-within:ring-[var(--primary)] transition-all">
|
| 172 |
<input type="text" id="chat-input" placeholder="Ask a question about your documents..." class="flex-grow bg-transparent focus:outline-none px-4 text-sm" autocomplete="off">
|
| 173 |
+
<button type="submit" id="chat-submit-btn" class="bg-[var(--primary)] hover:bg-[var(--primary-hover)] text-white p-2.5 rounded-full transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed" title="Send">
|
| 174 |
+
<svg class="w-5 h-5" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-8.707l-3-3a1 1 0 00-1.414 1.414L10.586 9H7a1 1 0 100 2h3.586l-1.293 1.293a1 1 0 101.414 1.414l3-3a1 1 0 000-1.414z" clip-rule="evenodd"></path></svg>
|
| 175 |
</button>
|
| 176 |
</form>
|
| 177 |
</div>
|
| 178 |
</div>
|
| 179 |
|
| 180 |
<div id="upload-container" class="flex-1 flex flex-col items-center justify-center p-8 transition-opacity duration-300">
|
| 181 |
+
<div class="text-center max-w-xl w-full">
|
| 182 |
+
<h1 class="text-5xl font-bold mb-3 tracking-tight">CogniChat</h1>
|
| 183 |
+
<p class="text-lg text-gray-400 mb-8">Upload your documents to start a conversation.</p>
|
| 184 |
+
<div class="mb-8 p-5 bg-[var(--card)] rounded-2xl border border-[var(--card-border)] shadow-lg">
|
| 185 |
+
<div class="flex flex-col sm:flex-row items-center gap-6">
|
| 186 |
+
<div class="w-full sm:w-1/2">
|
| 187 |
+
<div class="flex items-center gap-2 mb-2">
|
| 188 |
+
<svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path d="M7 3a1 1 0 000 2h6a1 1 0 100-2H7zM4 7a1 1 0 011-1h10a1 1 0 110 2H5a1 1 0 01-1-1zM2 11a2 2 0 012-2h12a2 2 0 012 2v4a2 2 0 01-2 2H4a2 2 0 01-2-2v-4z" /></svg>
|
| 189 |
+
<label for="model-select" class="block text-sm font-medium text-gray-300">Model</label>
|
| 190 |
+
</div>
|
| 191 |
+
<div class="select-wrapper">
|
| 192 |
+
<select id="model-select" name="model_name">
|
| 193 |
+
<option value="moonshotai/kimi-k2-instruct" selected>Kimi Instruct</option>
|
| 194 |
+
<option value="openai/gpt-oss-20b">GPT OSS 20b</option>
|
| 195 |
+
<option value="llama-3.3-70b-versatile">Llama 3.3 70b</option>
|
| 196 |
+
<option value="llama-3.1-8b-instant">Llama 3.1 8b Instant</option>
|
| 197 |
+
</select>
|
| 198 |
+
</div>
|
| 199 |
+
</div>
|
| 200 |
+
<div class="w-full sm:w-1/2">
|
| 201 |
+
<div class="flex items-center gap-2 mb-2">
|
| 202 |
+
<svg class="w-5 h-5 text-gray-400" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20" fill="currentColor"><path fill-rule="evenodd" d="M5.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7zM12 5.5a3.5 3.5 0 11-7 0 3.5 3.5 0 017 0zM14.5 16a3.5 3.5 0 100-7 3.5 3.5 0 000 7z" clip-rule="evenodd" /></svg>
|
| 203 |
+
<label for="temperature-select" class="block text-sm font-medium text-gray-300">Mode</label>
|
| 204 |
+
</div>
|
| 205 |
+
<div class="select-wrapper">
|
| 206 |
+
<select id="temperature-select" name="temperature">
|
| 207 |
+
<option value="0.2" selected>0.2 - Precise</option>
|
| 208 |
+
<option value="0.4">0.4 - Confident</option>
|
| 209 |
+
<option value="0.6">0.6 - Balanced</option>
|
| 210 |
+
<option value="0.8">0.8 - Flexible</option>
|
| 211 |
+
<option value="1.0">1.0 - Creative</option>
|
| 212 |
+
</select>
|
| 213 |
+
</div>
|
| 214 |
+
</div>
|
| 215 |
+
</div>
|
| 216 |
+
<p class="text-xs text-gray-500 mt-4 text-center">Higher creativity modes may reduce factual accuracy.</p>
|
| 217 |
+
</div>
|
| 218 |
+
<div id="drop-zone" class="w-full text-center border-2 border-dashed border-[var(--card-border)] rounded-2xl p-10 transition-all duration-300 cursor-pointer hover:bg-[var(--card)] hover:border-[var(--primary)]">
|
| 219 |
+
<div class="flex flex-col items-center justify-center pointer-events-none">
|
| 220 |
+
<svg class="mx-auto h-12 w-12 text-gray-500" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="M12 16.5V9.75m0 0l3-3m-3 3l-3 3M6.75 19.5a4.5 4.5 0 01-1.41-8.775 5.25 5.25 0 0110.233-2.33 3 3 0 013.758 3.848A3.752 3.752 0 0118 19.5H6.75z"></path></svg>
|
| 221 |
+
<p class="mt-4 text-sm font-medium text-gray-400">Drag & drop files or <span class="text-[var(--primary)] font-semibold">click to upload</span></p>
|
| 222 |
+
<p class="text-xs text-gray-400 mt-1">Supports PDF, DOCX, TXT</p>
|
| 223 |
+
<p id="file-name" class="mt-2 text-xs text-gray-500"></p>
|
| 224 |
+
</div>
|
| 225 |
+
<input id="file-upload" type="file" class="hidden" accept=".pdf,.txt,.docx" multiple>
|
| 226 |
</div>
|
| 227 |
</div>
|
| 228 |
</div>
|
| 229 |
|
| 230 |
+
<div id="loading-overlay" class="hidden fixed inset-0 bg-[var(--background)] bg-opacity-80 backdrop-blur-sm flex flex-col items-center justify-center z-50">
|
| 231 |
<div class="loader"></div>
|
| 232 |
<p id="loading-text" class="mt-6 text-sm font-medium"></p>
|
| 233 |
+
<p id="loading-subtext" class="mt-2 text-xs text-gray-400"></p>
|
| 234 |
</div>
|
| 235 |
</main>
|
| 236 |
|
|
|
|
| 244 |
const loadingOverlay = document.getElementById('loading-overlay');
|
| 245 |
const loadingText = document.getElementById('loading-text');
|
| 246 |
const loadingSubtext = document.getElementById('loading-subtext');
|
|
|
|
| 247 |
const chatForm = document.getElementById('chat-form');
|
| 248 |
const chatInput = document.getElementById('chat-input');
|
| 249 |
const chatSubmitBtn = document.getElementById('chat-submit-btn');
|
| 250 |
const chatWindow = document.getElementById('chat-window');
|
| 251 |
const chatContent = document.getElementById('chat-content');
|
| 252 |
+
const modelSelect = document.getElementById('model-select');
|
| 253 |
+
const temperatureSelect = document.getElementById('temperature-select');
|
| 254 |
const chatFilename = document.getElementById('chat-filename');
|
| 255 |
+
const chatSessionInfo = document.getElementById('chat-session-info');
|
| 256 |
|
| 257 |
+
let sessionId = sessionStorage.getItem('cognichat_session_id');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
|
|
|
| 259 |
dropZone.addEventListener('click', () => fileUploadInput.click());
|
| 260 |
|
| 261 |
['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
|
| 262 |
+
dropZone.addEventListener(eventName, e => {e.preventDefault(); e.stopPropagation();});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
});
|
| 264 |
+
['dragenter', 'dragover'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.add('drop-zone--over')));
|
| 265 |
+
['dragleave', 'drop'].forEach(eventName => dropZone.addEventListener(eventName, () => dropZone.classList.remove('drop-zone--over')));
|
| 266 |
+
|
|
|
|
| 267 |
dropZone.addEventListener('drop', (e) => {
|
| 268 |
+
if (e.dataTransfer.files.length > 0) handleFiles(e.dataTransfer.files);
|
|
|
|
| 269 |
});
|
|
|
|
| 270 |
fileUploadInput.addEventListener('change', (e) => {
|
| 271 |
if (e.target.files.length > 0) handleFiles(e.target.files);
|
| 272 |
});
|
| 273 |
|
|
|
|
|
|
|
| 274 |
async function handleFiles(files) {
|
| 275 |
const formData = new FormData();
|
| 276 |
+
let fileNames = Array.from(files).map(f => f.name);
|
| 277 |
+
for (const file of files) { formData.append('file', file); }
|
| 278 |
+
|
| 279 |
+
formData.append('model_name', modelSelect.value);
|
| 280 |
+
formData.append('temperature', temperatureSelect.value);
|
| 281 |
|
| 282 |
fileNameSpan.textContent = `Selected: ${fileNames.join(', ')}`;
|
| 283 |
+
await uploadAndProcessFiles(formData);
|
| 284 |
}
|
| 285 |
|
| 286 |
+
async function uploadAndProcessFiles(formData) {
|
| 287 |
loadingOverlay.classList.remove('hidden');
|
| 288 |
+
loadingText.textContent = `Processing document(s)...`;
|
| 289 |
+
loadingSubtext.textContent = "Creating a knowledge base may take a minute. Please hold on tight!";
|
| 290 |
|
| 291 |
try {
|
| 292 |
const response = await fetch('/upload', { method: 'POST', body: formData });
|
| 293 |
const result = await response.json();
|
|
|
|
| 294 |
if (!response.ok) throw new Error(result.message || 'Unknown error occurred.');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
+
sessionId = result.session_id;
|
| 297 |
+
sessionStorage.setItem('cognichat_session_id', sessionId);
|
| 298 |
+
|
| 299 |
+
chatFilename.innerHTML = `Chatting with: <strong>${result.filename}</strong>`;
|
| 300 |
+
chatFilename.title = result.filename;
|
| 301 |
+
|
| 302 |
+
chatSessionInfo.innerHTML = `
|
| 303 |
+
<span class="text-gray-500 dark:text-gray-500 italic block hover:text-gray-300 transition-colors cursor-pointer" onclick="location.reload()">
|
| 304 |
+
Refresh to change settings
|
| 305 |
+
</span>`;
|
| 306 |
+
|
| 307 |
+
const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
|
| 308 |
+
const simpleModelName = modelOption ? modelOption.textContent : result.model_name;
|
| 309 |
+
|
| 310 |
+
const modelInfo = {
|
| 311 |
+
model: result.model_name,
|
| 312 |
+
mode: result.mode,
|
| 313 |
+
simpleModelName: simpleModelName
|
| 314 |
+
};
|
| 315 |
+
|
| 316 |
uploadContainer.classList.add('hidden');
|
| 317 |
chatContainer.classList.remove('hidden');
|
| 318 |
+
appendMessage("I've analyzed your documents. What would you like to know?", "bot", modelInfo);
|
| 319 |
|
| 320 |
} catch (error) {
|
| 321 |
console.error('Upload error:', error);
|
| 322 |
alert(`Error: ${error.message}`);
|
| 323 |
} finally {
|
| 324 |
loadingOverlay.classList.add('hidden');
|
|
|
|
| 325 |
fileNameSpan.textContent = '';
|
| 326 |
fileUploadInput.value = '';
|
| 327 |
}
|
| 328 |
}
|
| 329 |
|
|
|
|
| 330 |
chatForm.addEventListener('submit', async (e) => {
|
| 331 |
e.preventDefault();
|
| 332 |
const question = chatInput.value.trim();
|
|
|
|
| 338 |
chatSubmitBtn.disabled = true;
|
| 339 |
|
| 340 |
const typingIndicator = showTypingIndicator();
|
| 341 |
+
|
|
|
|
|
|
|
| 342 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
const response = await fetch('/chat', {
|
| 344 |
method: 'POST',
|
| 345 |
headers: { 'Content-Type': 'application/json' },
|
| 346 |
+
body: JSON.stringify({ question, session_id: sessionId }),
|
| 347 |
});
|
|
|
|
| 348 |
if (!response.ok) throw new Error(`Server error: ${response.statusText}`);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
const result = await response.json();
|
| 351 |
+
|
| 352 |
+
const modelOption = modelSelect.querySelector(`option[value="${result.model_name}"]`);
|
| 353 |
+
const simpleModelName = modelOption ? modelOption.textContent : result.model_name;
|
| 354 |
+
const modelInfo = {
|
| 355 |
+
model: result.model_name,
|
| 356 |
+
mode: result.mode,
|
| 357 |
+
simpleModelName: simpleModelName
|
| 358 |
+
};
|
| 359 |
+
|
| 360 |
typingIndicator.remove();
|
| 361 |
+
const botMessageContainer = appendMessage('', 'bot', modelInfo);
|
| 362 |
+
const contentDiv = botMessageContainer.querySelector('.markdown-content');
|
| 363 |
+
contentDiv.innerHTML = marked.parse(result.answer);
|
|
|
|
|
|
|
| 364 |
contentDiv.querySelectorAll('pre').forEach(addCopyButton);
|
| 365 |
+
scrollToBottom();
|
| 366 |
+
addTextToSpeechControls(botMessageContainer, result.answer);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
} catch (error) {
|
| 368 |
console.error('Chat error:', error);
|
| 369 |
+
typingIndicator.remove();
|
| 370 |
+
appendMessage(`Error: ${error.message}`, 'bot');
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
} finally {
|
| 372 |
chatInput.disabled = false;
|
| 373 |
chatSubmitBtn.disabled = false;
|
|
|
|
| 375 |
}
|
| 376 |
});
|
| 377 |
|
| 378 |
+
// --- FINAL, CORRECT appendMessage function ---
|
| 379 |
+
function appendMessage(text, sender, modelInfo = null) {
|
|
|
|
| 380 |
const messageWrapper = document.createElement('div');
|
| 381 |
+
const iconSVG = sender === 'user'
|
| 382 |
+
? `<div class="bg-blue-200 dark:bg-gray-700 p-2.5 rounded-full flex-shrink-0 mt-1"><svg class="w-5 h-5 text-blue-700 dark:text-blue-300" viewBox="0 0 24 24"><path fill="currentColor" d="M12 12c2.21 0 4-1.79 4-4s-1.79-4-4-4-4 1.79-4 4 1.79 4 4 4zm0 2c-2.67 0-8 1.34-8 4v2h16v-2c0-2.66-5.33-4-8-4z"></path></svg></div>`
|
|
|
|
|
|
|
| 383 |
: `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div>`;
|
| 384 |
|
| 385 |
+
let senderHTML;
|
| 386 |
+
if (sender === 'user') {
|
| 387 |
+
senderHTML = '<p class="font-medium text-sm mb-1">You</p>';
|
| 388 |
+
} else {
|
| 389 |
+
let modelInfoHTML = '';
|
| 390 |
+
if (modelInfo && modelInfo.simpleModelName) {
|
| 391 |
+
modelInfoHTML = `
|
| 392 |
+
<span class="ml-2 text-xs font-normal text-gray-400">
|
| 393 |
+
(Model: ${modelInfo.simpleModelName} • Mode: ${modelInfo.mode})
|
| 394 |
+
</span>
|
| 395 |
+
`;
|
| 396 |
+
}
|
| 397 |
+
senderHTML = `<div class="font-medium text-sm mb-1 flex items-center">CogniChat ${modelInfoHTML}</div>`;
|
| 398 |
}
|
| 399 |
|
| 400 |
+
messageWrapper.className = `flex items-start gap-4`;
|
| 401 |
+
messageWrapper.innerHTML = `
|
| 402 |
+
${iconSVG}
|
| 403 |
+
<div class="flex-1 pt-1">
|
| 404 |
+
${senderHTML}
|
| 405 |
+
<div class="text-base markdown-content">${text ? marked.parse(text) : ''}</div>
|
| 406 |
+
<div class="tts-controls mt-2"></div>
|
| 407 |
+
</div>
|
| 408 |
+
`;
|
| 409 |
chatContent.appendChild(messageWrapper);
|
| 410 |
scrollToBottom();
|
| 411 |
+
return messageWrapper.querySelector('.flex-1');
|
|
|
|
| 412 |
}
|
| 413 |
|
| 414 |
function showTypingIndicator() {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
const indicator = document.createElement('div');
|
| 416 |
+
indicator.id = 'typing-indicator';
|
| 417 |
+
indicator.className = `flex items-start gap-4`;
|
| 418 |
+
indicator.innerHTML = `<div class="bg-gray-200 dark:bg-gray-700 rounded-full flex-shrink-0 mt-1 text-xl flex items-center justify-center w-10 h-10">✨</div><div class="flex-1 pt-1"><p class="font-medium text-sm mb-1">CogniChat is thinking...</p><div class="typing-indicator"><span></span><span></span><span></span></div></div>`;
|
| 419 |
+
chatContent.appendChild(indicator);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
scrollToBottom();
|
| 421 |
+
return indicator;
|
|
|
|
| 422 |
}
|
| 423 |
|
| 424 |
+
function scrollToBottom() { chatWindow.scrollTo({ top: chatWindow.scrollHeight, behavior: 'smooth' }); }
|
| 425 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
function addCopyButton(pre) {
|
| 427 |
const button = document.createElement('button');
|
| 428 |
button.className = 'copy-code-btn';
|
| 429 |
button.textContent = 'Copy';
|
| 430 |
pre.appendChild(button);
|
|
|
|
| 431 |
button.addEventListener('click', () => {
|
| 432 |
+
navigator.clipboard.writeText(pre.querySelector('code').innerText)
|
| 433 |
+
.then(() => {
|
| 434 |
+
button.textContent = 'Copied!';
|
| 435 |
+
setTimeout(() => button.textContent = 'Copy', 2000);
|
| 436 |
+
});
|
| 437 |
});
|
| 438 |
}
|
| 439 |
|
| 440 |
+
// (TTS functions remain unchanged)
|
| 441 |
+
let currentAudio, currentPlayingButton;
|
|
|
|
|
|
|
| 442 |
const playIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M8 5v14l11-7z"/></svg>`;
|
| 443 |
const pauseIconSVG = `<svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"><path d="M6 19h4V5H6v14zm8-14v14h4V5h-4z"/></svg>`;
|
|
|
|
|
|
|
| 444 |
function addTextToSpeechControls(messageBubble, text) {
|
| 445 |
+
if (!text.trim()) return;
|
| 446 |
+
const speakButton = document.createElement('button');
|
| 447 |
+
speakButton.className = 'speak-btn mt-2 px-3 py-1.5 bg-blue-700 text-white rounded-full text-sm font-medium hover:bg-blue-800 transition-colors flex items-center gap-2 disabled:opacity-50';
|
| 448 |
+
speakButton.title = 'Listen to this message';
|
| 449 |
+
speakButton.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
| 450 |
+
messageBubble.querySelector('.tts-controls').appendChild(speakButton);
|
| 451 |
+
speakButton.addEventListener('click', () => handleTTS(text, speakButton));
|
|
|
|
|
|
|
|
|
|
| 452 |
}
|
| 453 |
|
| 454 |
async function handleTTS(text, button) {
|
| 455 |
if (button === currentPlayingButton) {
|
| 456 |
if (currentAudio && !currentAudio.paused) {
|
| 457 |
currentAudio.pause();
|
| 458 |
+
button.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
|
|
|
| 459 |
} else if (currentAudio && currentAudio.paused) {
|
| 460 |
currentAudio.play();
|
|
|
|
| 461 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 462 |
}
|
| 463 |
return;
|
| 464 |
}
|
|
|
|
| 465 |
resetAllSpeakButtons();
|
|
|
|
| 466 |
currentPlayingButton = button;
|
|
|
|
| 467 |
button.innerHTML = `<div class="tts-button-loader"></div> <span>Loading...</span>`;
|
| 468 |
button.disabled = true;
|
| 469 |
|
| 470 |
try {
|
| 471 |
+
const response = await fetch('/tts', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ text }) });
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
if (!response.ok) throw new Error('Failed to generate audio.');
|
|
|
|
| 473 |
const blob = await response.blob();
|
| 474 |
+
currentAudio = new Audio(URL.createObjectURL(blob));
|
|
|
|
| 475 |
currentAudio.play();
|
|
|
|
|
|
|
| 476 |
button.innerHTML = `${pauseIconSVG} <span>Pause</span>`;
|
| 477 |
+
currentAudio.onended = resetAllSpeakButtons;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
} catch (error) {
|
| 479 |
console.error('TTS Error:', error);
|
|
|
|
|
|
|
|
|
|
| 480 |
resetAllSpeakButtons();
|
| 481 |
} finally {
|
| 482 |
button.disabled = false;
|
|
|
|
| 485 |
|
| 486 |
function resetAllSpeakButtons() {
|
| 487 |
document.querySelectorAll('.speak-btn').forEach(btn => {
|
| 488 |
+
btn.innerHTML = `${playIconSVG} <span>Listen</span>`;
|
|
|
|
| 489 |
btn.disabled = false;
|
| 490 |
});
|
| 491 |
if (currentAudio) {
|