Spaces:
Running
Running
| from flask import Flask, render_template, request, redirect, url_for, session, flash | |
| import os | |
| from werkzeug.utils import secure_filename | |
| from retrival import generate_data_store,update_data_store,approximate_bpe_token_counter | |
| from langchain_community.vectorstores import Chroma | |
| import chromadb | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_core.prompts import PromptTemplate, ChatPromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from huggingface_hub import InferenceClient | |
| from langchain.schema import Document | |
| from langchain_core.documents import Document | |
| from dotenv import load_dotenv | |
| import re | |
| import numpy as np | |
| import glob | |
| import shutil | |
| from werkzeug.utils import secure_filename | |
| import asyncio | |
| import nltk | |
| nltk.download('punkt_tab') | |
| import nltk | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| app = Flask(__name__) | |
| # Set the secret key for session management | |
| app.secret_key = os.urandom(24) | |
| # Configurations | |
| UPLOAD_FOLDER = "uploads/" | |
| VECTOR_DB_FOLDER = "VectorDB/" | |
| TABLE_DB_FOLDER = "TableDB/" | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| app.config['DEBUG'] = True | |
| app.config['ENV'] = 'development' | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(VECTOR_DB_FOLDER, exist_ok=True) | |
| os.makedirs(TABLE_DB_FOLDER, exist_ok=True) | |
| # Global variables | |
| CHROMA_PATH = None | |
| TABLE_PATH = None | |
| ######################################################################################################################################################## | |
| ####----------------------------------------------------------------- Prompt Templates ------------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| # prompt if the simple document | |
| PROMPT_TEMPLATE_DOC = """ | |
| <s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions: | |
| - Use only the provided context; do not add external information. | |
| - The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones. | |
| - If available, use the provided source information to support the response. | |
| - Answer concisely and factually. | |
| Context: | |
| {context} | |
| --- | |
| Question: | |
| {question} | |
| Response: | |
| [/INST] | |
| """ | |
| # prompt if the document having the tables | |
| PROMPT_TEMPLATE_TAB = """ | |
| <s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions: | |
| - Use only the provided context; do not add external information. | |
| - The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones. | |
| - If available, use the provided source information to support the response. | |
| - If a table is provided as html, incorporate its relevant details into the response while maintaining a structured format. | |
| - Answer concisely and factually. | |
| Context: | |
| {context} | |
| --- | |
| Table: | |
| {table} | |
| --- | |
| Question: | |
| {question} | |
| Response: | |
| [/INST] | |
| """ | |
| ######################################################################################################################################################## | |
| ####--------------------------------------------------------------- Flask APP ROUTES --------------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def home(): | |
| return render_template('home.html') | |
| ######################################################################################################################################################## | |
| ####---------------------------------------------------------------- routes for chat --------------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def chat(): | |
| try: | |
| if 'history' not in session: | |
| session['history'] = [] | |
| print("sessionhist1",session['history']) | |
| global CHROMA_PATH | |
| global TABLE_PATH | |
| old_db = session.get('old_db', None) | |
| print(f"Selected DB: {CHROMA_PATH}") | |
| # if old_db != None: | |
| # if CHROMA_PATH != old_db: | |
| # session['history'] = [] | |
| #print("sessionhist1",session['history']) | |
| if request.method == 'POST': | |
| query_text = request.form['query_text'] | |
| if CHROMA_PATH is None: | |
| flash("Please select a database first!", "error") | |
| return redirect(url_for('list_dbs')) | |
| # Load the selected Document Database | |
| embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| #embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
| db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) | |
| # Convert the query to its embedding vector | |
| query_embedding = embedding_function.embed_query(query_text) | |
| if isinstance(query_embedding, float): | |
| query_embedding = [query_embedding] | |
| results_document = db.similarity_search_by_vector_with_relevance_scores( | |
| embedding=query_embedding, # Pass the query embedding | |
| k=3, | |
| #filter=filter_condition # Pass the filter condition | |
| ) | |
| print("results------------------->",results_document) | |
| print("============================================") | |
| print("============================================") | |
| context_text_document = " \n\n###\n\n ".join( | |
| [f"Source: {doc.metadata.get('source', '')} Page_content:{doc.page_content}\n" for doc, _score in results_document] | |
| ) | |
| # Loading Table Database only if available | |
| if TABLE_PATH is not None: | |
| #embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
| tdb = Chroma(persist_directory=TABLE_PATH, embedding_function=embedding_function) | |
| results_table = tdb.similarity_search_by_vector_with_relevance_scores( | |
| embedding=query_embedding, # Pass the query embedding | |
| k=2 | |
| #filter=filter_condition # Pass the filter condition | |
| ) | |
| print("results------------------->",results_table) | |
| context_text_table = "\n\n---\n\n".join([doc.page_content for doc, _score in results_table]) | |
| # Prepare the prompt and query the model | |
| prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_TAB) | |
| prompt = prompt_template.format(context=context_text_document,table=context_text_table,question=query_text) | |
| #prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text) | |
| print("results------------------->",prompt) | |
| else: | |
| # Prepare the prompt and query the model | |
| prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_DOC) | |
| prompt = prompt_template.format(context=context_text_document,question=query_text) | |
| #prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text) | |
| print("results------------------->",prompt) | |
| #Model Defining and its use | |
| repo_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| HFT = os.environ["HF_TOKEN"] | |
| llm = HuggingFaceEndpoint( | |
| repo_id=repo_id, | |
| #max_tokens=3000, | |
| max_new_tokens=2000, | |
| task = "text-generation", | |
| temperature=0.8, | |
| huggingfacehub_api_token=HFT, | |
| ) | |
| data= llm.invoke(prompt) | |
| #data= llm(prompt) | |
| #data = response.choices[0].message.content | |
| # filtering the uneccessary context. | |
| if re.search(r'\bmention\b|\bnot mention\b|\bnot mentioned\b|\bnot contain\b|\bnot include\b|\bnot provide\b|\bdoes not\b|\bnot explicitly\b|\bnot explicitly mentioned\b', data, re.IGNORECASE): | |
| data = "We do not have information related to your query on our end." | |
| # Save the query and answer to the session history | |
| session['history'].append((query_text, data)) | |
| # Mark the session as modified to ensure it gets saved | |
| session.modified = True | |
| print("sessionhist2",session['history']) | |
| return render_template('chat.html', query_text=query_text, answer=data,token_count=approximate_bpe_token_counter(data), history=session['history'],old_db=CHROMA_PATH) | |
| except Exception as e: | |
| flash(f"Error in Creating DB: {e}","error") | |
| return redirect(url_for('list_dbs')) | |
| return render_template('chat.html', history=session['history'], old_db=CHROMA_PATH) | |
| ######################################################################################################################################################## | |
| ####---------------------------------------------------------------- routes for create-db ---------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def create_db(): | |
| try: | |
| if request.method == 'POST': | |
| db_name = request.form.get('db_name', '').strip() | |
| if not db_name: | |
| return "Database name is required", 400 | |
| # Get uploaded files | |
| files = request.files.getlist('folder') # Folder uploads (multiple files) | |
| single_files = request.files.getlist('file') # Single file uploads | |
| print("==================folder==>", files) | |
| print("==================single_files==>", single_files) | |
| # Ensure at least one valid file is uploaded | |
| if not any(file.filename.strip() for file in files) and not any(file.filename.strip() for file in single_files): | |
| return "No files uploaded", 400 | |
| # Create upload directory | |
| upload_base_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(db_name)) | |
| print(f"Base Upload Path: {upload_base_path}") | |
| os.makedirs(upload_base_path, exist_ok=True) | |
| # Process single file uploads first (if any exist) | |
| if any(file.filename.strip() for file in single_files): | |
| for file in single_files: | |
| if file.filename.strip(): # Ensure the file is valid | |
| file_name = secure_filename(file.filename) | |
| file_path = os.path.join(upload_base_path, file_name) | |
| print(f"Saving single file to: {file_path}") | |
| file.save(file_path) | |
| # If single file is uploaded, skip folder processing | |
| print("Single file uploaded, skipping folder processing.") | |
| asyncio.run(generate_data_store(upload_base_path, db_name)) | |
| return redirect(url_for('list_dbs')) | |
| # Process folder files only if valid files exist | |
| if any(file.filename.strip() for file in files): | |
| for file in files: | |
| if file.filename.strip(): # Ensure it's a valid file | |
| file_name = secure_filename(file.filename) | |
| file_path = os.path.join(upload_base_path, file_name) | |
| print(f"Saving folder file to: {file_path}") | |
| file.save(file_path) | |
| # Generate datastore | |
| #flash("Warning: storing data in DB may take time","warning") | |
| asyncio.run(generate_data_store(upload_base_path, db_name)) | |
| flash(f"{db_name} created sucessfully!","success") | |
| return redirect(url_for('list_dbs')) | |
| except Exception as e: | |
| flash(f"Error in Creating DB: {e}","error") | |
| return redirect(url_for('list_dbs')) | |
| return render_template('create_db.html') | |
| ######################################################################################################################################################## | |
| ####------------------------------------------------------- routes for list-dbs and documents -----------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def list_dbs(): | |
| vector_dbs = [name for name in os.listdir(VECTOR_DB_FOLDER) if os.path.isdir(os.path.join(VECTOR_DB_FOLDER, name))] | |
| if vector_dbs==[]: | |
| flash("NO available DBs! Let create new db","error") | |
| return redirect(url_for('create_db')) | |
| return render_template('list_dbs.html', vector_dbs=vector_dbs) | |
| def select_db(db_name): | |
| flash(f"{db_name} Database has been selected", "success") | |
| #Selecting the Documnet Vector DB | |
| global CHROMA_PATH | |
| global TABLE_PATH | |
| print(f"Selected DB: {CHROMA_PATH}") | |
| print("---------------------------------------------------------") | |
| CHROMA_PATH = os.path.join(VECTOR_DB_FOLDER, db_name) | |
| CHROMA_PATH = CHROMA_PATH.replace("\\", "/") | |
| print(f"Selected DB: {CHROMA_PATH}") | |
| print("---------------------------------------------------------") | |
| # Selecting the Table Vector DB | |
| table_db_path = os.path.join(TABLE_DB_FOLDER, db_name) | |
| table_db_path = table_db_path.replace("\\", "/") | |
| TABLE_PATH = table_db_path if os.path.exists(table_db_path) else None | |
| print(f"Selected Table DB: {TABLE_PATH}") | |
| return redirect(url_for('chat')) | |
| ######################################################################################################################################################## | |
| ####---------------------------------------------------------- routes for modification of dbs -----------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def modify_db(db_name): | |
| flash(f"{db_name} Database is selected","success") | |
| print(db_name) | |
| return render_template('modify_dbs.html', db_name=db_name) | |
| ######################################################################################################################################################## | |
| ####--------------------------------------------------------- routes for update exisiting of dbs --------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def update_db(db_name): | |
| try: | |
| if db_name and request.method == 'POST': | |
| print(db_name) | |
| #vector DB name is db_name | |
| # Get all files from the uploaded folder | |
| files = request.files.getlist('folder') # Folder uploads (multiple files) | |
| single_files = request.files.getlist('file') # Single file uploads | |
| print("============from_update======folder==>", files) | |
| print("============from_update======single_files==>", single_files) | |
| # Ensure at least one valid file is uploaded | |
| if not any(file.filename.strip() for file in files) and not any(file.filename.strip() for file in single_files): | |
| return "No files uploaded", 400 | |
| # Create upload directory | |
| upload_base_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(db_name)) | |
| print(f"Base Upload Path: {upload_base_path}") | |
| os.makedirs(upload_base_path, exist_ok=True) | |
| # Process single file uploads first (if any exist) | |
| if any(file.filename.strip() for file in single_files): | |
| for file in single_files: | |
| if file.filename.strip(): # Ensure the file is valid | |
| file_name = secure_filename(file.filename) | |
| file_path = os.path.join(upload_base_path, file_name) | |
| print(f"Saving single file to: {file_path}") | |
| file.save(file_path) | |
| # If single file is uploaded, skip folder processing | |
| print("Single file uploaded, skipping folder processing.") | |
| flash(f"{db_name} updated successfully!","success") | |
| asyncio.run(update_data_store(upload_base_path, db_name)) | |
| return redirect(url_for('modify_db', db_name=db_name)) | |
| # Process folder files only if valid files exist | |
| if any(file.filename.strip() for file in files): | |
| for file in files: | |
| if file.filename.strip(): # Ensure it's a valid file | |
| file_name = secure_filename(file.filename) | |
| file_path = os.path.join(upload_base_path, file_name) | |
| print(f"Saving folder file to: {file_path}") | |
| file.save(file_path) | |
| # Generate datastore | |
| asyncio.run(update_data_store(upload_base_path, db_name)) | |
| flash(f"{db_name} updated successfully!","success") | |
| return redirect(url_for('modify_db', db_name=db_name)) | |
| except Exception as e: | |
| print("No Database selected for updating") | |
| print(f"got unexpected error {e}") | |
| flash("got unexpected error while updating","error") | |
| return render_template('update_db.html',db_name=db_name) | |
| ######################################################################################################################################################## | |
| ####--------------------------------------------------------- routes for removing the of dbs ------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def remove_db(db_name): | |
| if db_name: | |
| print(db_name) | |
| CHROMA_PATH = f"./VectorDB/{db_name}" | |
| TABLE_PATH = f"./TableDB/{db_name}" | |
| try: | |
| if os.path.exists(CHROMA_PATH): | |
| shutil.rmtree(CHROMA_PATH) | |
| if os.path.exists(TABLE_PATH): | |
| shutil.rmtree(TABLE_PATH) | |
| flash(f"{db_name} Database Removed successfully","success") | |
| return redirect(url_for('list_dbs')) | |
| except Exception as e: | |
| print(f"Error in getting table: {e}") | |
| flash(f"Error in getting table: {e}","error") | |
| return redirect(url_for('list_dbs')) | |
| ######################################################################################################################################################## | |
| ####--------------------------------------------------------- routes for removing specific dbs ----------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| def delete_doc(db_name): | |
| try: | |
| DB_PATH = f"./VectorDB/{db_name}" | |
| TAB_PATH = f"./TableDB/{db_name}" | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| # Select your collection | |
| collection = client.get_collection("langchain") | |
| # Fetch all documents (including metadata) | |
| results = collection.get(include=["metadatas"]) | |
| # Extract unique file names from metadata | |
| file_list = set(item["filename"] for item in results["metadatas"] if "filename" in item) | |
| print("file_list", file_list) | |
| if request.method == 'POST': | |
| list_doc = request.form.get('list_doc') | |
| print("list_doc", list_doc) | |
| # Delete from the VectorDB collection | |
| collection.delete(where={"filename": f"{list_doc}"}) | |
| flash(f"The document '{list_doc}' has been removed from VectorDB.", "success") | |
| # Check if TAB_PATH exists and delete the document from TableDB if present | |
| if os.path.exists(TAB_PATH): | |
| client_tab = chromadb.PersistentClient(path=TAB_PATH) # Create a new client for TableDB | |
| collect_tab = client_tab.get_collection("langchain") | |
| # Fetch documents in TableDB | |
| result_tab = collect_tab.get(include=["metadatas"]) | |
| # Extract unique file names from TableDB metadata | |
| file_list_tab = set(item["filename"] for item in result_tab["metadatas"] if "filename" in item) | |
| print("TableDB file_list:", file_list_tab) | |
| if list_doc in file_list_tab: | |
| collect_tab.delete(where={"filename": f"{list_doc}"}) # Delete the document from TableDB | |
| flash(f"The document '{list_doc}' has also been removed from TableDB.", "success") | |
| else: | |
| flash(f"The document '{list_doc}' was not found in TableDB.", "warning") | |
| else: | |
| print("Note: TableDB does not exist.") | |
| flash(f"TableDB path '{TAB_PATH}' does not exist.", "warning") | |
| return redirect(url_for('modify_db', db_name=db_name)) | |
| return render_template('delete_doc.html', db_name=db_name, file_list=file_list) | |
| except Exception as e: | |
| flash(f"Error while deleting documents: {e}", "error") | |
| return redirect(url_for('modify_db', db_name=db_name)) | |
| ######################################################################################################################################################## | |
| ####---------------------------------------------------------------------- App MAIN ---------------------------------------------------------------#### | |
| ######################################################################################################################################################## | |
| if __name__ == "__main__": | |
| app.run(debug=False, use_reloader=False) | |