import os import gradio as gr from PyPDF2 import PdfReader from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import faiss import numpy as np import math import time # ---------- CONFIG ---------- EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" GEN_MODEL_NAME = "google/flan-t5-base" # fast & capable CHUNK_SIZE = 500 # characters per chunk (approx 250-350 tokens) CHUNK_OVERLAP = 100 # overlap between chunks to preserve context TOP_K = 4 # number of chunks retrieved MAX_NEW_TOKENS = 150 # generation length (keep small for speed) GEN_TEMPERATURE = 0.0 # deterministic, faster NORMALIZE_EMB = True # ---------------------------- # Global state embedder = SentenceTransformer(EMBED_MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) # Use the pipeline for convenience (it wraps tokenizer+model) qa_pipeline = pipeline( "text2text-generation", model=gen_model, tokenizer=tokenizer, device=-1, # CPU (Spaces default). If GPU available, change to 0. ) faiss_index = None pdf_chunks = [] # list[str] pdf_embeddings = None # numpy array (N, dim) last_loaded_filename = None last_loaded_at = None # ---------- utilities ---------- def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP): if not text: return [] chunks = [] start = 0 length = len(text) while start < length: end = start + chunk_size chunk = text[start:end].strip() if chunk: chunks.append(chunk) start = end - overlap # move with overlap if start < 0: start = 0 return chunks def build_faiss_index(embeddings: np.ndarray): dim = embeddings.shape[1] # IndexFlatIP with normalized vectors -> cosine similarity index = faiss.IndexFlatIP(dim) faiss.normalize_L2(embeddings) index.add(embeddings) return index def embed_texts(texts): # sentence-transformers returns numpy arrays embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False) if NORMALIZE_EMB: faiss.normalize_L2(embeddings) return embeddings # ---------- Gradio functions ---------- def process_pdf(pdf_file): """ Upload and process PDF. Builds FAISS index and stores chunks & embeddings in memory. Returns status message and basic metadata. """ global faiss_index, pdf_chunks, pdf_embeddings, last_loaded_filename, last_loaded_at if pdf_file is None: return "⚠️ No file uploaded." try: # Extract text reader = PdfReader(pdf_file.name) full_text = [] for p in reader.pages: text = p.extract_text() if text: full_text.append(text) text = "\n".join(full_text).strip() if not text: return "⚠️ No readable text found in PDF." # Chunk text pdf_chunks = chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) # Embed chunks (batch) pdf_embeddings = embed_texts(pdf_chunks) # Build FAISS index faiss_index = build_faiss_index(np.copy(pdf_embeddings)) last_loaded_filename = os.path.basename(pdf_file.name) last_loaded_at = time.time() return f"✅ PDF processed. {len(pdf_chunks)} chunks indexed. Ready for Q&A." except Exception as e: return f"❌ Error processing PDF: {e}" def chat_with_pdf(query): """ Retrieve relevant chunks and generate an answer using the generator model. Designed for low-latency responses. """ global faiss_index, pdf_chunks, pdf_embeddings if faiss_index is None or pdf_chunks is None or len(pdf_chunks) == 0: return "⚠️ Please upload and process a PDF first." if not query or not query.strip(): return "⚠️ Please enter a question." query = query.strip() # Embed query q_emb = embedder.encode([query], convert_to_numpy=True) if NORMALIZE_EMB: faiss.normalize_L2(q_emb) # Search top-k top_k = min(TOP_K, len(pdf_chunks)) distances, indices = faiss_index.search(q_emb, top_k) indices = indices[0].tolist() # Compose context from retrieved chunks (concatenate, truncate if too long) retrieved = [pdf_chunks[i] for i in indices] context = "\n\n".join(retrieved) # Build prompt - be concise and reference context system_prompt = ( "You are a helpful assistant that answers questions using only the provided context. " "If the answer is not contained in the context, say 'I don't know based on the document.' " "Be concise and factual." ) prompt = ( f"{system_prompt}\n\n" f"Context:\n{context}\n\n" f"Question: {query}\n\n" f"Answer:" ) # Limit prompt size by truncating context from the left if it's too long # Keep the question + system prompt + rightmost part of context max_prompt_chars = 3000 # heuristic to keep generation fast if len(prompt) > max_prompt_chars: # keep the question and system prompt, then rightmost slice of context right_context = context[-2000:] prompt = f"{system_prompt}\n\nContext:\n{right_context}\n\nQuestion: {query}\n\nAnswer:" # Generate try: out = qa_pipeline( prompt, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, temperature=GEN_TEMPERATURE, num_return_sequences=1, ) answer = out[0]["generated_text"].strip() # Safety: if model hallucinates beyond context, keep it short return answer except Exception as e: return f"❌ Generation error: {e}" # ---------- Gradio UI ---------- with gr.Blocks(title="PDF Chat (fast, retrieval-augmented)") as demo: gr.Markdown("# 📚 Chat with your PDF — optimized for speed") gr.Markdown( "Upload a PDF, click **Process PDF**, then ask questions. " "This app uses semantic search (FAISS) + a lightweight generator for quick responses." ) with gr.Row(): file_in = gr.File(label="Upload PDF (PDF only)") process_btn = gr.Button("Process PDF") status = gr.Textbox(label="Status", interactive=False) process_btn.click(fn=process_pdf, inputs=[file_in], outputs=[status]) gr.Markdown("---") query = gr.Textbox(label="Ask a question about the PDF", placeholder="e.g. What is the main conclusion?") ask_btn = gr.Button("Ask") answer = gr.Textbox(label="Answer", lines=6) ask_btn.click(fn=chat_with_pdf, inputs=[query], outputs=[answer]) gr.Markdown( "Notes:\n" "- The app keeps the processed PDF in memory for the session (no DB).\n" "- Designed for low latency; tune CHUNK_SIZE/TOP_K/MAX_NEW_TOKENS for speed/quality tradeoffs." ) if __name__ == "__main__": demo.launch()