Spaces:
Running
Running
File size: 5,429 Bytes
8b017a0 76c37ac 8b017a0 76c37ac 1650cc6 8b017a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""Flask App script for RAG chatbot"""
import gc
import os
import re
import tempfile
from flask import Flask, request, jsonify, render_template
# Disable CUDA and excessive parallel threads to save memory
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
# Flask app initialization
app = Flask(__name__, template_folder="templates", static_folder="static")
# Global states
retriever = None
LLM_model = None
api_key = None # API key will come from frontend
SYSTEM_MESSAGE = """
You are a RAG Assistant for the uploaded document.
Your role is to help users understand its contents clearly and accurately.
Rules:
1. Prioritize the document context first.
2. If the answer isn’t in the document, say you don’t know.
3. Be friendly, direct, and concise.
4. Avoid adding extra information unless asked.
"""
# routes
@app.route("/")
def home():
return render_template("chat_page.html")
@app.route("/upload", methods=["POST"])
def upload_file():
"""Route handling document upload, splitting, chunking, and vectorization."""
global retriever, LLM_model, api_key
# Import heavy dependencies only when needed
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
api_key = request.form.get("apiKey")
if not api_key:
return "API key missing!", 400
uploaded = request.files.get("file")
if not uploaded or uploaded.filename.strip() == "":
return "No file uploaded", 400
ext = uploaded.filename.rsplit(".", 1)[-1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as tmp_file:
uploaded.save(tmp_file.name)
path = tmp_file.name
# load document
try:
loader = PyPDFLoader(path) if ext == "pdf" else TextLoader(path)
documents = loader.load()
except Exception as e:
os.unlink(path)
return f"Failed to read document: {e}", 400
if not documents:
os.unlink(path)
return "No readable content found in the document.", 400
# split document into chunks
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) # reduce chunk_size for low memory
chunks = splitter.split_documents(documents)
# Light embedding model (fast + low memory)
try:
# embeds = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-MiniLM-L3-v2")
embeds = HuggingFaceEmbeddings(model_name="./models/paraphrase-MiniLM-L3-v2")
vector_store = FAISS.from_documents(chunks, embeds)
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
except Exception as e:
os.unlink(path)
return f"Embedding model failed: {e}", 500
# Initialize chat model
try:
LLM_model = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=api_key)
except Exception as e:
return f"Failed to initialize chat model: {e}", 500
# Cleanup temp file
os.unlink(path)
del documents, chunks, vector_store
gc.collect()
return "Document processed successfully! You can now ask questions."
@app.route("/chat", methods=["POST"])
def chat():
"""Q&A route on uploaded document."""
global retriever, LLM_model
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
if retriever is None or LLM_model is None:
return jsonify({"error": "Please upload a document first."}), 400
question = request.form.get("question") or (request.json and request.json.get("question"))
if not question:
return jsonify({"error": "No question provided."}), 400
# Retrieve documents with retriever
try:
docs = retriever.invoke(question)
context = "\n\n".join(d.page_content for d in docs)
except Exception as e:
return jsonify({"error": f"Retriever failed: {e}"}), 500
# prompt template
prompt_template = PromptTemplate(
template=(
"You are answering strictly based on this document.\n\n"
"{context}\n\n"
"Question: {question}\n\n"
"Answer:"
),
input_variables=["context", "question"],
)
# Combine into a pipeline
chain = (
RunnableParallel({
"context": retriever | RunnableLambda(lambda docs: "\n\n".join(d.page_content for d in docs)),
"question": RunnablePassthrough(),
})
| prompt_template
| LLM_model
| StrOutputParser()
)
try:
response = chain.invoke(question).strip()
except Exception as e:
response = f"Error generating response: {str(e)}"
# Clean markdown artifacts
cleaned = re.sub(r"\*\*(.*?)\*\*", r"\1", response)
cleaned = re.sub(r"\*(.*?)\*", r"\1", cleaned)
gc.collect()
return jsonify({"answer": cleaned})
# run app
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False) |