Spaces:
Running
Running
| import os | |
| import sys | |
| import json | |
| import streamlit as st | |
| import warnings | |
| import traceback | |
| import logs | |
| import chromadb | |
| import hashlib | |
| import sqlite3 | |
| import regex as re | |
| from pinecone import Pinecone | |
| from typing import Optional, Dict, Any | |
| from sentence_transformers import util | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| warnings.filterwarnings("ignore") | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))) | |
| from sentence_transformers import SentenceTransformer | |
| from configuration import Configuration | |
| from rag_scripts.rag_pipeline import RAGPipeline | |
| from rag_scripts.documents_processing.chunking import PyMuPDFChunker | |
| from rag_scripts.embedding.embedder import SentenceTransformerEmbedder | |
| from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB | |
| from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB | |
| from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB | |
| from rag_scripts.llm.llmResponse import GROQLLM | |
| from rag_scripts.evaluation.evaluator import RAGEvaluator | |
| class RAGOperations: | |
| VALID_VECTOR_DB = {'chroma', 'faiss', 'pinecone'} | |
| def check_db(vector_db_type: str, db_path: str, collection_name: str) -> bool: | |
| try: | |
| if vector_db_type not in RAGOperations.VALID_VECTOR_DB: | |
| logs.logger.info(f"Invalid Vector DB: {vector_db_type}") | |
| raise | |
| if vector_db_type.lower() == 'pinecone': | |
| pc = Pinecone(api_key=Configuration.PINECONE_API_KEY) | |
| return collection_name in pc.list_indexes().names() | |
| elif vector_db_type.lower() == 'chroma': | |
| if not os.path.exists(db_path): | |
| return False | |
| client = chromadb.PersistentClient(path=db_path) | |
| try: | |
| client.get_collection(collection_name) | |
| return True | |
| except: | |
| return False | |
| elif vector_db_type.lower() == "faiss": | |
| faiss_index_file = os.path.join(db_path, f"{collection_name}.faiss") | |
| faiss_doc_store_file = os.path.join(db_path, f"{collection_name}_docs.pkl") | |
| return os.path.exists(faiss_index_file) and os.path.exists(faiss_doc_store_file) | |
| except Exception as ex: | |
| traceback.print_exc() | |
| logs.logger.info(f"Exception in checking {vector_db_type} existence") | |
| return False | |
| def get_pipeline_params(chunk_size: Optional[int] =None, | |
| chunk_overlap: Optional[int]=None, | |
| embedding_model: Optional[str]=None, | |
| vector_db_type: Optional[str]=None, | |
| llm_model: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| re_ranker_model: Optional[str] = None, | |
| use_tuned: bool = False) -> Dict[str, Any]: | |
| try: | |
| best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json') | |
| params = { | |
| 'document_path': Configuration.FULL_PDF_PATH, | |
| 'chunk_size': chunk_size if chunk_size is not None else Configuration.DEFAULT_CHUNK_SIZE, | |
| 'chunk_overlap': chunk_overlap if chunk_overlap is not None else Configuration.DEFAULT_CHUNK_OVERLAP, | |
| 'embedding_model_name': embedding_model if embedding_model is not None else Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL, | |
| 'vector_db_type': vector_db_type if vector_db_type is not None else "chroma", | |
| 'llm_model_name': llm_model if llm_model is not None else llm_model, | |
| 'db_path': None, | |
| 'collection_name': Configuration.COLLECTION_NAME, | |
| 'vector_db': None, | |
| 'temperature': temperature if temperature is not None else 0.1, | |
| 'top_p': top_p if top_p is not None else .95, | |
| 'max_tokens': max_tokens if max_tokens is not None else 1500, | |
| 're_ranker_model': re_ranker_model if re_ranker_model is not None else Configuration.DEFAULT_RERANKER, | |
| } | |
| if use_tuned and os.path.exists(best_param_path): | |
| with open(best_param_path, 'rb') as f: | |
| best_params = json.load(f) | |
| logs.logger.info(f"Best params: {best_params} from the file {best_param_path}") | |
| params.update({ | |
| 'vector_db_type': best_params.get('vector_db_type', params['vector_db_type']), | |
| 'embedding_model_name': best_params.get('embedding_model', params['embedding_model_name']), | |
| 'chunk_overlap': best_params.get('chunk_overlap', params['chunk_overlap']), | |
| 'chunk_size': best_params.get('chunk_size', params['chunk_size']), | |
| 're_ranker_model': best_params.get('re_ranker_model', params['re_ranker_model'])}) | |
| use_tuned = True | |
| if use_tuned: | |
| tuned_db_type = params['vector_db_type'] | |
| params['db_path'] = os.path.join(Configuration.DATA_DIR, 'TunedDB', | |
| tuned_db_type) if tuned_db_type != 'pinecone' else "" | |
| params['collection_name'] = 'tuned-' + Configuration.COLLECTION_NAME | |
| if tuned_db_type in ['chroma', 'faiss']: | |
| os.makedirs(params['db_path'], exist_ok=True) | |
| logs.logger.info(f"Tuned db path: {params['db_path']}") | |
| else: | |
| params['db_path'] = (Configuration.CHROMA_DB_PATH if params['vector_db_type'] == 'chroma' | |
| else Configuration.FAISS_DB_PATH if params['vector_db_type'] == 'faiss' | |
| else "") | |
| if params['vector_db_type'] in ['chroma', 'faiss']: | |
| os.makedirs(params['db_path'], exist_ok=True) | |
| logs.logger.info(f"Created directory for {params['vector_db_type']} at {params['db_path']}") | |
| return params | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in get_pipeline_params: {ex}") | |
| traceback.print_exc() | |
| def check_embedding_dimension(vector_db_type: str, db_path: str, | |
| collection_name: str, embedding_model: str) -> bool: | |
| if vector_db_type != 'chroma': | |
| return True | |
| try: | |
| client = chromadb.PersistentClient(path=db_path) | |
| collection = client.get_collection(collection_name) | |
| model = SentenceTransformer(embedding_model) | |
| sample_embedding = model.encode(["test"])[0] | |
| try: | |
| expected_dim = collection._embedding_function.dim | |
| except AttributeError: | |
| peek_result = collection.peek(limit=1) | |
| if 'embedding' in peek_result and peek_result['embedding']: | |
| expected_dim = len(peek_result['embedding'][0]) | |
| else: | |
| return False | |
| actual_dim = len(sample_embedding) | |
| logs.logger.info(f"Expected dimension: {expected_dim} Actual dimension: {actual_dim}") | |
| return expected_dim == actual_dim | |
| except Exception as ex: | |
| logs.logger.info(f"Error checking embedding dimension: {ex}") | |
| return False | |
| def initialize_pipeline(params: dict[str, Any]) -> RAGPipeline: | |
| try: | |
| embedder = SentenceTransformerEmbedder(model_name=params['embedding_model_name']) | |
| chunkerObj = PyMuPDFChunker( | |
| pdf_path=params['document_path'], | |
| chunk_size=params['chunk_size'], | |
| chunk_overlap=params['chunk_overlap']) | |
| llm_model = params['llm_model_name'] | |
| vector_db = None | |
| if params['vector_db_type'] == 'chroma': | |
| vector_db = chromaDBVectorDB(embedder=embedder, | |
| db_path=params['db_path'], | |
| collection_name=params['collection_name']) | |
| elif params['vector_db_type'] == 'faiss': | |
| vector_db = FAISSVectorDB(embedder=embedder, | |
| db_path=params['db_path'], | |
| collection_name=params['collection_name']) | |
| elif params['vector_db_type'] == 'pinecone': | |
| vector_db = PineconeVectorDB(embedder=embedder, | |
| db_path=params['db_path'], | |
| collection_name=params['collection_name']) | |
| else: | |
| raise ValueError(f"Unknown vector_db_type: {params['vector_db_type']}") | |
| return RAGPipeline(document_path=params['document_path'], | |
| chunker=chunkerObj, embedder=embedder, | |
| vector_db=vector_db, | |
| llm=GROQLLM(model_name=llm_model), | |
| re_ranker_model_name=params['re_ranker_model'] if params[ | |
| 're_ranker_model'] else Configuration.DEFAULT_RERANKER, ) | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in pipeline initialize: {ex}") | |
| traceback.print_exc() | |
| sys.exit(1) | |
| def run_build_job(chunk_size: Optional[int] = None, | |
| chunk_overlap: Optional[int] = None, | |
| embedding_model: Optional[str] = None, | |
| vector_db_type: Optional[str]= None, | |
| llm_model: Optional[str]= None, | |
| temperature: Optional[float]= None, | |
| top_p: Optional[float]= None, | |
| max_tokens: Optional[int]= None, | |
| re_ranker_model: Optional[str] =None, | |
| use_tuned: bool = False) -> None: | |
| try: | |
| params = RAGOperations.get_pipeline_params(chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| embedding_model=embedding_model, | |
| vector_db_type=vector_db_type, | |
| llm_model=llm_model, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| re_ranker_model=re_ranker_model, | |
| use_tuned=use_tuned) | |
| pipeline = RAGOperations.initialize_pipeline(params) | |
| pipeline.build_index() | |
| logs.logger.info(f"RAG Build JOB completed") | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in run build job: {ex}") | |
| traceback.print_exc() | |
| raise | |
| def run_search_job(query: Optional[str] = None, | |
| k: int = 5, raw: bool = False, | |
| use_tuned: bool = False, | |
| llm_model: Optional[str]= None, | |
| user_context: Optional[Dict[str,str]] = None, | |
| temperature: Optional[float]= None, | |
| top_p: Optional[float]= None, | |
| max_tokens: Optional[int]= None, | |
| chunk_size: Optional[int]= None, | |
| chunk_overlap: Optional[int]= None, | |
| embedding_model: Optional[str]= None, | |
| vector_db_type: Optional[str]= None, | |
| re_ranker_model: Optional[str]= None, | |
| use_rag:bool = True) -> Dict[str, Any]: | |
| try: | |
| params = RAGOperations.get_pipeline_params(chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| embedding_model=embedding_model, | |
| vector_db_type=vector_db_type, | |
| llm_model=llm_model, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| re_ranker_model=re_ranker_model, | |
| use_tuned=use_tuned) | |
| vector_db_type = params['vector_db_type'] | |
| db_path = params['db_path'] | |
| collection_name = params['collection_name'] | |
| pipeline = RAGOperations.initialize_pipeline(params) | |
| db_exists = RAGOperations.check_db(vector_db_type, db_path, collection_name) | |
| if use_rag: | |
| if not db_exists: | |
| pipeline.build_index() | |
| elif pipeline.vector_db.count_documents() == 0: | |
| pipeline.build_index() | |
| elif not RAGOperations.check_embedding_dimension(vector_db_type, db_path, | |
| collection_name, params['embedding_model_name']): | |
| logs.logger.info(f"Embedding dimension mismatch. rebuilding the index") | |
| pipeline.vector_db.delete_collection(collection_name) | |
| pipeline.build_index() | |
| else: | |
| logs.logger.info(f"Using existing {vector_db_type} database with collection: {collection_name}") | |
| if pipeline.vector_db.count_documents() == 0: | |
| logs.logger.info(f"No Documents found in vector database after re-build") | |
| sys.exit(1) | |
| evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH, | |
| pdf_path=Configuration.FULL_PDF_PATH) | |
| user_query = query if query else ( | |
| input("Enter your Query: ")) | |
| if user_query.lower() == 'exit': | |
| return | |
| expected_answers = None | |
| expected_keywords = [] | |
| query_found = False | |
| try: | |
| with open(Configuration.EVAL_DATA_PATH, 'r') as f: | |
| eval_data = json.load(f) | |
| for item in eval_data: | |
| if item.get('query').strip().lower() == user_query.strip().lower(): | |
| expected_keywords = item.get('expected_keywords', []) | |
| expected_answers = item.get('expected_answer_snippet', "") | |
| query_found = True | |
| break | |
| if not expected_keywords and not expected_answers: | |
| logs.logger.info(f"No evaluation data found for query in json") | |
| except Exception as ex: | |
| logs.logger.info(f"No json file : {ex}") | |
| retrieved_documents = [] | |
| if raw: | |
| retrieved_documents = pipeline.retrieve_raw_documents( | |
| user_query, k=k * 2) | |
| logs.logger.info("Raw documents retrieved") | |
| logs.logger.info(json.dumps(retrieved_documents, indent=4)) | |
| if not retrieved_documents: | |
| response = {"summary": "No relevant documents found", | |
| "sources": []} | |
| else: | |
| query_embedding = evaluator.embedder.encode(user_query, | |
| convert_to_tensor=True, normalize_embeddings=True) | |
| similarities = [(doc, util.cos_sim(query_embedding, | |
| evaluator.embedder.encode(doc['content'], | |
| convert_to_tensor=True, | |
| normalize_embeddings=True)).item()) | |
| for doc in retrieved_documents] | |
| similarities.sort(key=lambda x: x[1], reverse=True) | |
| top_docs = similarities[:min(3, len(similarities))] | |
| truncated_content = [] | |
| for doc, sim in top_docs: | |
| content_paragraphs = re.split(r'\n\s*\n', doc['content'].strip()) | |
| para_sims = [(para, util.cos_sim(query_embedding, | |
| evaluator.embedder.encode(para.strip(), convert_to_tensor=True, | |
| normalize_embeddings=True)).item()) | |
| for para in content_paragraphs if para.strip()] | |
| para_sims.sort(key=lambda x: x[1], reverse=True) | |
| top_paras = [para for para, para_sim in para_sims[:2] if para_sim >= 0.3] | |
| if len(top_paras) < 1: # Fallback to at least one paragraph | |
| top_paras = [para for para, _ in para_sims[:1]] | |
| truncated_content.append('\n\n'.join(top_paras)) | |
| response = { | |
| "summary": "\n".join(truncated_content), | |
| "sources": [{"document_id": f"DOC {idx + 1}", | |
| "page": str(doc['metadata'].get("page_number", "NA")), | |
| "section": doc['metadata'].get("section", "NA"), | |
| "clause": doc['metadata'].get("clause", "NA")} | |
| for idx, (doc, _) in enumerate(top_docs)]} | |
| else: | |
| logs.logger.info("LLM+RAG") | |
| response = pipeline.query(user_query, k=k, | |
| include_metadata=True, | |
| user_context=user_context | |
| ) | |
| retrieved_documents = pipeline.retrieve_raw_documents( | |
| user_query, k=k) | |
| final_expected_answer = expected_answers if expected_answers is not None else "" | |
| additional_eval_metrices = {} | |
| if not query_found: | |
| logs.logger.info(f"No query found in eval_Data.json: {user_query}") | |
| raw_reference_for_score = evaluator._syntesize_raw_reference(retrieved_documents) | |
| if not final_expected_answer.strip(): | |
| final_expected_answer = raw_reference_for_score | |
| retrieved_documents_content = [doc.get('content', '') for doc in retrieved_documents] | |
| llm_as_judge = evaluator._evaluate_with_llm(user_query, | |
| response.get('summary', ''), | |
| retrieved_documents_content) | |
| if llm_as_judge: | |
| additional_eval_metrices.update(llm_as_judge) | |
| output = {"query": user_query, "response": response, "evaluation": llm_as_judge} | |
| logs.logger.info(json.dumps(output, indent=4)) | |
| return output | |
| else: | |
| output = {"query": user_query, "response": response, "evaluation": llm_as_judge} | |
| logs.logger.info(json.dumps(output, indent=4)) | |
| return output | |
| else: | |
| eval_result = evaluator.evaluate_response(user_query, response, retrieved_documents, | |
| expected_keywords, expected_answers) | |
| output = {"query": user_query, "response": response, "evaluation": eval_result} | |
| logs.logger.info(json.dumps(output, indent=2, ensure_ascii=False)) | |
| return output | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in run search job {ex}") | |
| traceback.print_exc() | |
| def run_hypertune_job(llm_model: Optional[str] = None, | |
| search_type: str = "random", | |
| n_iter: int = 3) -> Dict[str,Any]: | |
| try: | |
| evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH, | |
| pdf_path=Configuration.FULL_PDF_PATH) | |
| result = evaluator.evaluate_combined_params_grid( | |
| chunk_size_to_test=[512, 1024, 2048], | |
| chunk_overlap_to_test=[100, 200, 400], | |
| embedding_models_to_test=["all-MiniLM-L6-v2", | |
| "all-mpnet-base-v2", | |
| "paraphrase-MiniLM-L3-v2", | |
| "multi-qa-mpnet-base-dot-v1"], | |
| vector_db_types_to_test=['pinecone'], | |
| llm_model_name=llm_model, | |
| re_ranker_model=["cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| "cross-encoder/ms-marco-TinyBERT-L-2"], | |
| search_type=search_type, | |
| n_iter=n_iter) | |
| best_parameter = result['best_params'] | |
| best_score = result['best_score'] | |
| pkl_file = result['pkl_file'] | |
| best_metrics = result['best_metrics'] | |
| best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json') | |
| with open(best_param_path, 'w') as f: | |
| json.dump(best_parameter, f, indent=4) | |
| tuned_db = best_parameter['vector_db_type'] | |
| tuned_path = os.path.join(Configuration.DATA_DIR, 'TunedDB', tuned_db) | |
| if tuned_db != 'pinecone': | |
| os.makedirs(tuned_path, exist_ok=True) | |
| tuned_collection_name = "tuned-" + Configuration.COLLECTION_NAME | |
| tuned_params = { | |
| 'document_path': Configuration.FULL_PDF_PATH, | |
| 'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE), | |
| 'chunk_overlap': best_parameter.get('chunk_overlap', Configuration.DEFAULT_CHUNK_OVERLAP), | |
| 'embedding_model_name': best_parameter.get('embedding_model', | |
| Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL), | |
| 'vector_db_type': tuned_db, | |
| 'llm_model_name': llm_model, | |
| 'db_path': tuned_path if tuned_db != 'pinecone' else "", | |
| 'collection_name': tuned_collection_name, | |
| 'vector_db': None, | |
| 're_ranker_model': best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER) | |
| } | |
| if 're_ranker_model' in best_parameter: | |
| tuned_params['re_ranker_model'] = best_parameter['re_ranker_model'] | |
| else: | |
| tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER | |
| tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params) | |
| tuned_pipeline.build_index() | |
| return result | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in hypertune: {ex} ") | |
| traceback.print_exc() | |
| def run_llm_with_prompt(run_type: str, | |
| temperature: float=0.1, | |
| top_p: float=0.95, | |
| max_tokens=1500) -> None: | |
| try: | |
| params = RAGOperations.get_pipeline_params() | |
| pipeline = RAGOperations.initialize_pipeline(params) | |
| evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH, | |
| pdf_path=Configuration.FULL_PDF_PATH) | |
| system_message = ( | |
| "You are an expert assistant for Flykite Airlines HR Policy Queries." | |
| "Provide concise, accurate and policy-specific answers based solely on the the provided context." | |
| "Structured your response clearly, using bullet points, newlines if applicable. " | |
| "If the context lacks information, state that clearly and speculation." | |
| ) if run_type == 'prompting' else None | |
| user_query = input("Enter your query: ") | |
| expected_answer = None | |
| expected_keywords = [] | |
| try: | |
| with open(Configuration.EVAL_DATA_PATH, 'r') as f: | |
| eval_data = json.load(f) | |
| for item in eval_data: | |
| expected_answer = item.get('expected_answer_snippet', "") | |
| expected_keywords = item.get('expected_keywords', []) | |
| break | |
| except Exception as ex: | |
| logs.logger.info(f"Error loading eval_data.json for query {user_query}: {ex}") | |
| if run_type == 'prompting': | |
| prompt = ( | |
| f"You are an expert assistant for Flykite Airlines HR Policy Queries." | |
| f"Answer the following question with a structured response, using bullet points or sections where applicable" | |
| f"Base your answer solely on the query and avoid hallucination" | |
| f"Question: \n {user_query} \n" | |
| f"Answer: ") | |
| else: | |
| prompt = user_query | |
| response = pipeline.llm.generate_response( | |
| prompt=prompt, | |
| system_message=system_message, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens | |
| ) | |
| retreived_documents = [] | |
| eval_result = evaluator.evaluate_response(user_query, | |
| response, | |
| retreived_documents, | |
| expected_keywords, | |
| expected_answer) | |
| output = {"query": user_query, | |
| "response": { | |
| "summary: ": response.strip(), | |
| "source: ": ["LLM Response Not RAG loaded"]}, | |
| "evaluation": eval_result} | |
| logs.logger.info(json.dumps(output, indent=2)) | |
| return output | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in LLm_prompting response: {ex}") | |
| traceback.print_exc() | |
| return {"error": str(ex)} | |
| def login() -> Dict[str, str]: | |
| username = input("Enter your username: ") | |
| password = input("Enter your password: ") | |
| hashed_password = hashlib.sha256(password.encode()).hexdigest() | |
| try: | |
| conn = sqlite3.connect('users.db') | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT username,jobrole,department,location FROM users WHERE username = ? AND password = ?", | |
| (username, hashed_password) | |
| ) | |
| user = cursor.fetchone() | |
| logs.logger.info(f"{user}") | |
| conn.close() | |
| if user: | |
| return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]} | |
| else: | |
| logs.logger.info("Invalid username or password") | |
| sys.exit(1) | |
| except sqlite3.Error as ex: | |
| return False | |
| def authenticate_user(username, password) -> Optional[Dict[str, str]]: | |
| hashed_password = hashlib.sha256(password.encode()).hexdigest() | |
| conn = sqlite3.connect('users.db') | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT username, jobrole, department, location FROM users WHERE username = ? AND password = ?", | |
| (username, hashed_password) | |
| ) | |
| user = cursor.fetchone() | |
| conn.close() | |
| if user: | |
| return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]} | |
| return None | |
| def home_page(): | |
| st.title("Welcome to Flykite RAG System") | |
| if 'logged_in' not in st.session_state: | |
| st.session_state.logged_in = False | |
| if 'user_info' not in st.session_state: | |
| st.session_state.user_info = None | |
| if not st.session_state.logged_in: | |
| st.subheader("Login") | |
| with st.form("login_form"): | |
| username = st.text_input("Username") | |
| password = st.text_input("Password", type="password") | |
| login_button = st.form_submit_button("Login") | |
| if login_button: | |
| user_data = RAGOperations.authenticate_user(username, password) | |
| if user_data: | |
| st.session_state.logged_in = True | |
| st.session_state.user_info = user_data | |
| st.session_state.user_context = { | |
| "role": user_data['role'], | |
| "department": user_data['department'], | |
| "location": user_data['location'] | |
| } | |
| st.success(f"Logged in as {user_data['username']} ({user_data['role']})") | |
| # No rerun needed here, the main_app will handle navigation | |
| st.session_state.page = "User" if user_data['role'] != 'admin' else "Admin" | |
| st.rerun() | |
| else: | |
| st.error("Invalid username or password.") | |
| else: | |
| st.write( | |
| f"You are logged in as **{st.session_state.user_info['username']}** (Role: **{st.session_state.user_info['role']}**)") | |
| if st.button("Logout"): | |
| st.session_state.logged_in = False | |
| st.session_state.user_info = None | |
| st.session_state.user_context = None | |
| st.session_state.page = "Home" # Redirect to home on logout | |
| st.rerun() | |
| def admin_page(): | |
| st.title("Admin Dashboard") | |
| st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})") | |
| if st.session_state.user_info and st.session_state.user_info['role'] == 'admin': | |
| st.header("RAG Hypertuning") | |
| st.info("Run hyperparameter tuning to find the best RAG configuration and build a tuned index.") | |
| with st.form("hypertune_form"): | |
| st.write("Hypertuning parameters:") | |
| llm_model_ht = st.selectbox("LLM Model for Hypertuning Evaluation", | |
| options=["llama-3.3-70b-versatile", "llama-3.1-8b-instant"], | |
| index=["llama-3.3-70b-versatile", "llama-3.1-8b-instant"].index( | |
| Configuration.DEFAULT_GROQ_LLM_MODEL) if Configuration.DEFAULT_GROQ_LLM_MODEL in [ | |
| "llama-3.3-70b-versatile", "llama-3.1-8b-instant"] else 0, | |
| key="llm_model_ht_select") | |
| # New inputs for hyperparameter tuning | |
| st.subheader("Hyperparameter Ranges/Options:") | |
| chunk_sizes = st.multiselect("Chunk Sizes to Test (e.g., 256, 512, 1024)", | |
| options=[512, 1024,2048], | |
| default=[512], | |
| key="chunk_sizes_ht") | |
| chunk_overlaps = st.multiselect("Chunk Overlaps to Test (e.g., 50, 100, 200)", | |
| options=[150,200,400], | |
| default=[150], | |
| key="chunk_overlaps_ht") | |
| embedding_models = st.multiselect("Embedding Models to Test", | |
| options=["all-MiniLM-L6-v2", "all-mpnet-base-v2", | |
| "paraphrase-MiniLM-L3-v2", "multi-qa-mpnet-base-dot-v1"], | |
| default=["all-MiniLM-L6-v2", "all-mpnet-base-v2"], | |
| key="embedding_models_ht") | |
| re_ranker_models = st.multiselect("Re-ranker Models to Test", | |
| options=["cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| "cross-encoder/ms-marco-TinyBERT-L-2", "None"], | |
| default=["cross-encoder/ms-marco-MiniLM-L-6-v2"], | |
| key="re_ranker_models_ht") | |
| vector_db_types = st.multiselect("Vector DB Types to Test", | |
| options=['chroma', 'faiss', 'pinecone'], | |
| default=['chroma'], | |
| key="vector_db_types_ht") | |
| search_type = st.radio("Hypertuning Search Type", | |
| options=["random", "grid"], | |
| index=0, # Default to random | |
| key="search_type_ht") | |
| n_iter = st.number_input("Number of Hyper-tuning Iterations (for Random Search)", | |
| min_value=1, value=3, step=1, | |
| help="Only applicable for 'Random' search type.", | |
| key="n_iter_ht") | |
| hypertune_button = st.form_submit_button("Run Hypertune Job") | |
| if hypertune_button: | |
| if not chunk_sizes or not chunk_overlaps or not embedding_models or not re_ranker_models or not vector_db_types: | |
| st.error("Please select at least one option for all hyperparameter categories.") | |
| else: | |
| # Handle 'None' for re-ranker model: remove "None" string and pass None object if needed | |
| final_re_ranker_models = [ | |
| None if model == "None" else model for model in re_ranker_models | |
| ] | |
| st.write("Starting RAG Hypertuning. This may take a while...") | |
| with st.spinner("Running hypertuning..."): | |
| try: | |
| result = RAGOperations.run_hypertune_job( | |
| llm_model=llm_model_ht, | |
| chunk_size_to_test=chunk_sizes, | |
| chunk_overlap_to_test=chunk_overlaps, | |
| embedding_models_to_test=embedding_models, | |
| re_ranker_model=final_re_ranker_models, | |
| vector_db_types_to_test=vector_db_types, | |
| search_type=search_type, | |
| n_iter=n_iter if search_type == "random" else None # n_iter only for random search | |
| ) | |
| if result and "error" not in result: | |
| st.success("Hypertuning completed and tuned index built!") | |
| st.subheader("Best Parameters Found:") | |
| st.json(result.get('best_params', {})) | |
| if 'best_score' in result: | |
| st.write(f"Best Score: {result['best_score']:.4f}") | |
| if 'best_metrics' in result: | |
| st.subheader("Best Metrics:") | |
| st.json(result['best_metrics']) | |
| else: | |
| st.error(f"Hypertuning failed: {result.get('error', 'Unknown error')}") | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred during hypertuning: {e}") | |
| st.exception(e) # Display full traceback in Streamlit | |
| st.header("RAG Testing") | |
| st.info("Test the RAG pipeline with a specific query, optionally using the tuned database.") | |
| with st.form("rag_test_form"): | |
| test_query = st.text_area("Enter a test query for the RAG system:", | |
| value="What is the policy on annual leave?", | |
| key="test_query_input") | |
| use_tuned_db = st.checkbox("Use Tuned RAG Database (if hypertuned previously)", value=True, | |
| key="use_tuned_db_checkbox") | |
| display_raw = st.checkbox("Display Raw Retrieved Documents only (no LLM)", | |
| key="display_raw_docs_checkbox") | |
| k_value = st.slider("Number of documents to retrieve (k)", min_value=1, max_value=10, value=5, | |
| key="k_value_slider") | |
| test_rag_button = st.form_submit_button("Run RAG Test Query") | |
| if test_rag_button: | |
| st.write("Running RAG test query...") | |
| with st.spinner("Getting RAG response..."): | |
| try: | |
| result = RAGOperations.run_search_job( | |
| query=test_query, | |
| k=k_value, | |
| raw=display_raw, | |
| use_tuned=use_tuned_db, | |
| llm_model=st.session_state.get('llm_model_ht_select', | |
| Configuration.DEFAULT_GROQ_LLM_MODEL), | |
| user_context=st.session_state.user_context | |
| ) | |
| if result and "error" not in result: | |
| st.success("RAG Test Query Completed!") | |
| st.subheader("RAG Response:") | |
| if display_raw: | |
| st.json(result.get('response', {})) | |
| else: | |
| response_data = result.get('response', {}) | |
| if 'summary' in response_data: | |
| st.write(response_data['summary']) | |
| if 'sources' in response_data and response_data['sources']: | |
| st.subheader("Sources:") | |
| for source in response_data['sources']: | |
| if isinstance(source, dict): | |
| st.markdown( | |
| f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}") | |
| else: | |
| st.markdown(f"- {source}") | |
| else: | |
| st.json(response_data) | |
| if 'evaluation' in result: | |
| st.subheader("Evaluation Results:") | |
| st.json(result['evaluation']) | |
| else: | |
| st.error(f"RAG test query failed: {result.get('error', 'Unknown error')}") | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred during RAG test: {e}") | |
| st.exception(e) | |
| else: | |
| st.warning("You do not have administrative privileges to view this page.") | |
| if st.button("Go to User Page"): | |
| st.session_state.page = "User" | |
| st.rerun() | |
| def run_hypertune_job(llm_model: Optional[str] = None, | |
| chunk_size_to_test: Optional[list[int]] = None, # Added parameter | |
| chunk_overlap_to_test: Optional[list[int]] = None, # Added parameter | |
| embedding_models_to_test: Optional[list[str]] = None, # Added parameter | |
| vector_db_types_to_test: Optional[list[str]] = None, # Added parameter | |
| re_ranker_model: Optional[list[str]] = None, # Added parameter | |
| search_type: str = "random", | |
| n_iter: Optional[int] = 3) -> Dict[str, Any]: | |
| try: | |
| evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH, | |
| pdf_path=Configuration.FULL_PDF_PATH) | |
| result = evaluator.evaluate_combined_params_grid( | |
| chunk_size_to_test=chunk_size_to_test if chunk_size_to_test is not None else [512, 1024, 2048], | |
| chunk_overlap_to_test=chunk_overlap_to_test if chunk_overlap_to_test is not None else [100, 200, 400], | |
| embedding_models_to_test=embedding_models_to_test if embedding_models_to_test is not None else [ | |
| "all-MiniLM-L6-v2", | |
| "all-mpnet-base-v2", | |
| "paraphrase-MiniLM-L3-v2", | |
| "multi-qa-mpnet-base-dot-v1"], | |
| vector_db_types_to_test=vector_db_types_to_test if vector_db_types_to_test is not None else ['chroma'], | |
| llm_model_name=llm_model, | |
| re_ranker_model=re_ranker_model if re_ranker_model is not None else [ | |
| "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| "cross-encoder/ms-marco-TinyBERT-L-2"], | |
| search_type=search_type, | |
| n_iter=n_iter) | |
| best_parameter = result['best_params'] | |
| best_score = result['best_score'] | |
| pkl_file = result['pkl_file'] | |
| best_metrics = result['best_metrics'] | |
| best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json') | |
| with open(best_param_path, 'w') as f: | |
| json.dump(best_parameter, f, indent=4) | |
| tuned_db = best_parameter['vector_db_type'] | |
| tuned_path = os.path.join(Configuration.DATA_DIR, 'TunedDB', tuned_db) | |
| if tuned_db != 'pinecone': | |
| os.makedirs(tuned_path, exist_ok=True) | |
| tuned_collection_name = "tuned-" + Configuration.COLLECTION_NAME | |
| tuned_params = { | |
| 'document_path': Configuration.FULL_PDF_PATH, | |
| 'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE), | |
| 'chunk_overlap': best_parameter.get('chunk_overlap', Configuration.DEFAULT_CHUNK_OVERLAP), | |
| 'embedding_model_name': best_parameter.get('embedding_model', | |
| Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL), | |
| 'vector_db_type': tuned_db, | |
| 'llm_model_name': llm_model, | |
| 'db_path': tuned_path if tuned_db != 'pinecone' else "", | |
| 'collection_name': tuned_collection_name, | |
| 'vector_db': None, | |
| 're_ranker_model': best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER) | |
| } | |
| if 're_ranker_model' in best_parameter: | |
| tuned_params['re_ranker_model'] = best_parameter['re_ranker_model'] | |
| else: | |
| tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER | |
| tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params) | |
| tuned_pipeline.build_index() | |
| return result | |
| except Exception as ex: | |
| logs.logger.info(f"Exception in hypertune: {ex} ") | |
| traceback.print_exc() | |
| return {"error": str(ex)} # Return error for Streamlit to display | |
| def user_page(): | |
| st.title("Flykite HR Policy Query") | |
| st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})") | |
| st.info("Ask any question about the Flykite Airlines HR policy document.") | |
| with st.form("user_query_form"): | |
| user_query = st.text_area("Your Query:", height=100, key="user_query_input") | |
| response_type = st.radio("Choose Response Type:", | |
| options=["LLM Tuned Response (RAG + LLM)", | |
| "RAG Raw Response (Retrieved Docs Only)"], | |
| index=0, key="response_type_radio") | |
| k_value_user = st.slider("Number of documents to consider (k)", min_value=1, max_value=10, value=5, | |
| key="k_value_user_slider") | |
| submit_query_button = st.form_submit_button("Get Answer") | |
| if submit_query_button and user_query: | |
| st.subheader("Response:") | |
| with st.spinner("Fetching answer..."): | |
| try: | |
| display_raw = (response_type == "RAG Raw Response (Retrieved Docs Only)") | |
| # Direct call to RAGOperations.run_search_job | |
| result = RAGOperations.run_search_job( | |
| query=user_query, | |
| raw=display_raw, | |
| k=k_value_user, | |
| use_tuned=True, # User page always uses tuned if available | |
| user_context=st.session_state.user_context # Pass user context | |
| ) | |
| if result and "error" not in result: | |
| response_data = result.get('response', {}) | |
| evaluation = result.get('evaluation',{}) | |
| if display_raw: | |
| st.json(response_data) # Raw output from main.py is already formatted | |
| else: | |
| if 'summary' in response_data: | |
| st.markdown(response_data['summary']) | |
| if 'sources' in response_data and response_data['sources']: | |
| st.subheader("Sources:") | |
| for source in response_data['sources']: | |
| if isinstance(source, dict): | |
| st.markdown( | |
| f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}") | |
| else: # Fallback for raw string sources | |
| st.markdown(f"- {source}") | |
| else: | |
| st.json(response_data) | |
| if evaluation: | |
| #st.markdown(f"**Evaluation Results:** **Groundedness Score** {evaluation.get('Groundedness score', 'N/A')}, **Relevance Score:** {evaluation.get('Relevance score', 'N/A')}, **Reasoning** {evaluation.get('Reasoning', 'N/A')}") | |
| st.json(evaluation) | |
| else: | |
| st.error( | |
| f"Failed to get a response: {result.get('error', 'Unknown error')}. Please try again.") | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred during user query: {e}") | |
| st.error(traceback.format_exc()) | |
| elif submit_query_button and not user_query: | |
| st.warning("Please enter a query.") | |
| def main_app(): | |
| st.sidebar.title("Navigation") | |
| if 'logged_in' not in st.session_state: | |
| st.session_state.logged_in = False | |
| if 'page' not in st.session_state: | |
| st.session_state.page = "Home" | |
| if not st.session_state.logged_in: | |
| st.session_state.page = "Home" | |
| RAGOperations.home_page() | |
| else: | |
| st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home")) | |
| if st.session_state.user_info and st.session_state.user_info['role'] == 'admin': | |
| st.sidebar.button("Admin Dashboard", on_click=lambda: st.session_state.update(page="Admin")) | |
| st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User")) | |
| else: | |
| st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User")) | |
| if st.session_state.page == "Home": | |
| RAGOperations.home_page() | |
| elif st.session_state.page == "Admin": | |
| RAGOperations.admin_page() | |
| elif st.session_state.page == "User": | |
| RAGOperations.user_page() | |
| if __name__ == "__main__": | |
| main_app() | |