|
|
import os |
|
|
import gradio as gr |
|
|
from PyPDF2 import PdfReader |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
import faiss |
|
|
import numpy as np |
|
|
import math |
|
|
import time |
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
GEN_MODEL_NAME = "google/flan-t5-base" |
|
|
CHUNK_SIZE = 500 |
|
|
CHUNK_OVERLAP = 100 |
|
|
TOP_K = 4 |
|
|
MAX_NEW_TOKENS = 150 |
|
|
GEN_TEMPERATURE = 0.0 |
|
|
NORMALIZE_EMB = True |
|
|
|
|
|
|
|
|
|
|
|
embedder = SentenceTransformer(EMBED_MODEL_NAME) |
|
|
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) |
|
|
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) |
|
|
|
|
|
qa_pipeline = pipeline( |
|
|
"text2text-generation", |
|
|
model=gen_model, |
|
|
tokenizer=tokenizer, |
|
|
device=-1, |
|
|
) |
|
|
|
|
|
faiss_index = None |
|
|
pdf_chunks = [] |
|
|
pdf_embeddings = None |
|
|
last_loaded_filename = None |
|
|
last_loaded_at = None |
|
|
|
|
|
|
|
|
def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP): |
|
|
if not text: |
|
|
return [] |
|
|
chunks = [] |
|
|
start = 0 |
|
|
length = len(text) |
|
|
while start < length: |
|
|
end = start + chunk_size |
|
|
chunk = text[start:end].strip() |
|
|
if chunk: |
|
|
chunks.append(chunk) |
|
|
start = end - overlap |
|
|
if start < 0: |
|
|
start = 0 |
|
|
return chunks |
|
|
|
|
|
def build_faiss_index(embeddings: np.ndarray): |
|
|
dim = embeddings.shape[1] |
|
|
|
|
|
index = faiss.IndexFlatIP(dim) |
|
|
faiss.normalize_L2(embeddings) |
|
|
index.add(embeddings) |
|
|
return index |
|
|
|
|
|
def embed_texts(texts): |
|
|
|
|
|
embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False) |
|
|
if NORMALIZE_EMB: |
|
|
faiss.normalize_L2(embeddings) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
def process_pdf(pdf_file): |
|
|
""" |
|
|
Upload and process PDF. Builds FAISS index and stores chunks & embeddings in memory. |
|
|
Returns status message and basic metadata. |
|
|
""" |
|
|
global faiss_index, pdf_chunks, pdf_embeddings, last_loaded_filename, last_loaded_at |
|
|
|
|
|
if pdf_file is None: |
|
|
return "⚠️ No file uploaded." |
|
|
|
|
|
try: |
|
|
|
|
|
reader = PdfReader(pdf_file.name) |
|
|
full_text = [] |
|
|
for p in reader.pages: |
|
|
text = p.extract_text() |
|
|
if text: |
|
|
full_text.append(text) |
|
|
text = "\n".join(full_text).strip() |
|
|
if not text: |
|
|
return "⚠️ No readable text found in PDF." |
|
|
|
|
|
|
|
|
pdf_chunks = chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) |
|
|
|
|
|
|
|
|
pdf_embeddings = embed_texts(pdf_chunks) |
|
|
|
|
|
|
|
|
faiss_index = build_faiss_index(np.copy(pdf_embeddings)) |
|
|
|
|
|
last_loaded_filename = os.path.basename(pdf_file.name) |
|
|
last_loaded_at = time.time() |
|
|
|
|
|
return f"✅ PDF processed. {len(pdf_chunks)} chunks indexed. Ready for Q&A." |
|
|
except Exception as e: |
|
|
return f"❌ Error processing PDF: {e}" |
|
|
|
|
|
def chat_with_pdf(query): |
|
|
""" |
|
|
Retrieve relevant chunks and generate an answer using the generator model. |
|
|
Designed for low-latency responses. |
|
|
""" |
|
|
global faiss_index, pdf_chunks, pdf_embeddings |
|
|
|
|
|
if faiss_index is None or pdf_chunks is None or len(pdf_chunks) == 0: |
|
|
return "⚠️ Please upload and process a PDF first." |
|
|
|
|
|
if not query or not query.strip(): |
|
|
return "⚠️ Please enter a question." |
|
|
|
|
|
query = query.strip() |
|
|
|
|
|
|
|
|
q_emb = embedder.encode([query], convert_to_numpy=True) |
|
|
if NORMALIZE_EMB: |
|
|
faiss.normalize_L2(q_emb) |
|
|
|
|
|
|
|
|
top_k = min(TOP_K, len(pdf_chunks)) |
|
|
distances, indices = faiss_index.search(q_emb, top_k) |
|
|
indices = indices[0].tolist() |
|
|
|
|
|
|
|
|
retrieved = [pdf_chunks[i] for i in indices] |
|
|
context = "\n\n".join(retrieved) |
|
|
|
|
|
|
|
|
system_prompt = ( |
|
|
"You are a helpful assistant that answers questions using only the provided context. " |
|
|
"If the answer is not contained in the context, say 'I don't know based on the document.' " |
|
|
"Be concise and factual." |
|
|
) |
|
|
prompt = ( |
|
|
f"{system_prompt}\n\n" |
|
|
f"Context:\n{context}\n\n" |
|
|
f"Question: {query}\n\n" |
|
|
f"Answer:" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
max_prompt_chars = 3000 |
|
|
if len(prompt) > max_prompt_chars: |
|
|
|
|
|
right_context = context[-2000:] |
|
|
prompt = f"{system_prompt}\n\nContext:\n{right_context}\n\nQuestion: {query}\n\nAnswer:" |
|
|
|
|
|
|
|
|
try: |
|
|
out = qa_pipeline( |
|
|
prompt, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=False, |
|
|
temperature=GEN_TEMPERATURE, |
|
|
num_return_sequences=1, |
|
|
) |
|
|
answer = out[0]["generated_text"].strip() |
|
|
|
|
|
|
|
|
return answer |
|
|
except Exception as e: |
|
|
return f"❌ Generation error: {e}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="PDF Chat (fast, retrieval-augmented)") as demo: |
|
|
gr.Markdown("# 📚 Chat with your PDF — optimized for speed") |
|
|
gr.Markdown( |
|
|
"Upload a PDF, click **Process PDF**, then ask questions. " |
|
|
"This app uses semantic search (FAISS) + a lightweight generator for quick responses." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
file_in = gr.File(label="Upload PDF (PDF only)") |
|
|
process_btn = gr.Button("Process PDF") |
|
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
process_btn.click(fn=process_pdf, inputs=[file_in], outputs=[status]) |
|
|
|
|
|
gr.Markdown("---") |
|
|
query = gr.Textbox(label="Ask a question about the PDF", placeholder="e.g. What is the main conclusion?") |
|
|
ask_btn = gr.Button("Ask") |
|
|
answer = gr.Textbox(label="Answer", lines=6) |
|
|
|
|
|
ask_btn.click(fn=chat_with_pdf, inputs=[query], outputs=[answer]) |
|
|
|
|
|
gr.Markdown( |
|
|
"Notes:\n" |
|
|
"- The app keeps the processed PDF in memory for the session (no DB).\n" |
|
|
"- Designed for low latency; tune CHUNK_SIZE/TOP_K/MAX_NEW_TOKENS for speed/quality tradeoffs." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|