pdf_chat / app.py
Srikesh's picture
Create app.py
c53d978 verified
raw
history blame
7.04 kB
import os
import gradio as gr
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import faiss
import numpy as np
import math
import time
# ---------- CONFIG ----------
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
GEN_MODEL_NAME = "google/flan-t5-base" # fast & capable
CHUNK_SIZE = 500 # characters per chunk (approx 250-350 tokens)
CHUNK_OVERLAP = 100 # overlap between chunks to preserve context
TOP_K = 4 # number of chunks retrieved
MAX_NEW_TOKENS = 150 # generation length (keep small for speed)
GEN_TEMPERATURE = 0.0 # deterministic, faster
NORMALIZE_EMB = True
# ----------------------------
# Global state
embedder = SentenceTransformer(EMBED_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
# Use the pipeline for convenience (it wraps tokenizer+model)
qa_pipeline = pipeline(
"text2text-generation",
model=gen_model,
tokenizer=tokenizer,
device=-1, # CPU (Spaces default). If GPU available, change to 0.
)
faiss_index = None
pdf_chunks = [] # list[str]
pdf_embeddings = None # numpy array (N, dim)
last_loaded_filename = None
last_loaded_at = None
# ---------- utilities ----------
def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
if not text:
return []
chunks = []
start = 0
length = len(text)
while start < length:
end = start + chunk_size
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap # move with overlap
if start < 0:
start = 0
return chunks
def build_faiss_index(embeddings: np.ndarray):
dim = embeddings.shape[1]
# IndexFlatIP with normalized vectors -> cosine similarity
index = faiss.IndexFlatIP(dim)
faiss.normalize_L2(embeddings)
index.add(embeddings)
return index
def embed_texts(texts):
# sentence-transformers returns numpy arrays
embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
if NORMALIZE_EMB:
faiss.normalize_L2(embeddings)
return embeddings
# ---------- Gradio functions ----------
def process_pdf(pdf_file):
"""
Upload and process PDF. Builds FAISS index and stores chunks & embeddings in memory.
Returns status message and basic metadata.
"""
global faiss_index, pdf_chunks, pdf_embeddings, last_loaded_filename, last_loaded_at
if pdf_file is None:
return "⚠️ No file uploaded."
try:
# Extract text
reader = PdfReader(pdf_file.name)
full_text = []
for p in reader.pages:
text = p.extract_text()
if text:
full_text.append(text)
text = "\n".join(full_text).strip()
if not text:
return "⚠️ No readable text found in PDF."
# Chunk text
pdf_chunks = chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP)
# Embed chunks (batch)
pdf_embeddings = embed_texts(pdf_chunks)
# Build FAISS index
faiss_index = build_faiss_index(np.copy(pdf_embeddings))
last_loaded_filename = os.path.basename(pdf_file.name)
last_loaded_at = time.time()
return f"✅ PDF processed. {len(pdf_chunks)} chunks indexed. Ready for Q&A."
except Exception as e:
return f"❌ Error processing PDF: {e}"
def chat_with_pdf(query):
"""
Retrieve relevant chunks and generate an answer using the generator model.
Designed for low-latency responses.
"""
global faiss_index, pdf_chunks, pdf_embeddings
if faiss_index is None or pdf_chunks is None or len(pdf_chunks) == 0:
return "⚠️ Please upload and process a PDF first."
if not query or not query.strip():
return "⚠️ Please enter a question."
query = query.strip()
# Embed query
q_emb = embedder.encode([query], convert_to_numpy=True)
if NORMALIZE_EMB:
faiss.normalize_L2(q_emb)
# Search top-k
top_k = min(TOP_K, len(pdf_chunks))
distances, indices = faiss_index.search(q_emb, top_k)
indices = indices[0].tolist()
# Compose context from retrieved chunks (concatenate, truncate if too long)
retrieved = [pdf_chunks[i] for i in indices]
context = "\n\n".join(retrieved)
# Build prompt - be concise and reference context
system_prompt = (
"You are a helpful assistant that answers questions using only the provided context. "
"If the answer is not contained in the context, say 'I don't know based on the document.' "
"Be concise and factual."
)
prompt = (
f"{system_prompt}\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
f"Answer:"
)
# Limit prompt size by truncating context from the left if it's too long
# Keep the question + system prompt + rightmost part of context
max_prompt_chars = 3000 # heuristic to keep generation fast
if len(prompt) > max_prompt_chars:
# keep the question and system prompt, then rightmost slice of context
right_context = context[-2000:]
prompt = f"{system_prompt}\n\nContext:\n{right_context}\n\nQuestion: {query}\n\nAnswer:"
# Generate
try:
out = qa_pipeline(
prompt,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
temperature=GEN_TEMPERATURE,
num_return_sequences=1,
)
answer = out[0]["generated_text"].strip()
# Safety: if model hallucinates beyond context, keep it short
return answer
except Exception as e:
return f"❌ Generation error: {e}"
# ---------- Gradio UI ----------
with gr.Blocks(title="PDF Chat (fast, retrieval-augmented)") as demo:
gr.Markdown("# 📚 Chat with your PDF — optimized for speed")
gr.Markdown(
"Upload a PDF, click **Process PDF**, then ask questions. "
"This app uses semantic search (FAISS) + a lightweight generator for quick responses."
)
with gr.Row():
file_in = gr.File(label="Upload PDF (PDF only)")
process_btn = gr.Button("Process PDF")
status = gr.Textbox(label="Status", interactive=False)
process_btn.click(fn=process_pdf, inputs=[file_in], outputs=[status])
gr.Markdown("---")
query = gr.Textbox(label="Ask a question about the PDF", placeholder="e.g. What is the main conclusion?")
ask_btn = gr.Button("Ask")
answer = gr.Textbox(label="Answer", lines=6)
ask_btn.click(fn=chat_with_pdf, inputs=[query], outputs=[answer])
gr.Markdown(
"Notes:\n"
"- The app keeps the processed PDF in memory for the session (no DB).\n"
"- Designed for low latency; tune CHUNK_SIZE/TOP_K/MAX_NEW_TOKENS for speed/quality tradeoffs."
)
if __name__ == "__main__":
demo.launch()