Srikesh commited on
Commit
c53d978
·
verified ·
1 Parent(s): 2ceabc6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from PyPDF2 import PdfReader
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ import faiss
7
+ import numpy as np
8
+ import math
9
+ import time
10
+
11
+ # ---------- CONFIG ----------
12
+ EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
+ GEN_MODEL_NAME = "google/flan-t5-base" # fast & capable
14
+ CHUNK_SIZE = 500 # characters per chunk (approx 250-350 tokens)
15
+ CHUNK_OVERLAP = 100 # overlap between chunks to preserve context
16
+ TOP_K = 4 # number of chunks retrieved
17
+ MAX_NEW_TOKENS = 150 # generation length (keep small for speed)
18
+ GEN_TEMPERATURE = 0.0 # deterministic, faster
19
+ NORMALIZE_EMB = True
20
+ # ----------------------------
21
+
22
+ # Global state
23
+ embedder = SentenceTransformer(EMBED_MODEL_NAME)
24
+ tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
25
+ gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
26
+ # Use the pipeline for convenience (it wraps tokenizer+model)
27
+ qa_pipeline = pipeline(
28
+ "text2text-generation",
29
+ model=gen_model,
30
+ tokenizer=tokenizer,
31
+ device=-1, # CPU (Spaces default). If GPU available, change to 0.
32
+ )
33
+
34
+ faiss_index = None
35
+ pdf_chunks = [] # list[str]
36
+ pdf_embeddings = None # numpy array (N, dim)
37
+ last_loaded_filename = None
38
+ last_loaded_at = None
39
+
40
+ # ---------- utilities ----------
41
+ def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
42
+ if not text:
43
+ return []
44
+ chunks = []
45
+ start = 0
46
+ length = len(text)
47
+ while start < length:
48
+ end = start + chunk_size
49
+ chunk = text[start:end].strip()
50
+ if chunk:
51
+ chunks.append(chunk)
52
+ start = end - overlap # move with overlap
53
+ if start < 0:
54
+ start = 0
55
+ return chunks
56
+
57
+ def build_faiss_index(embeddings: np.ndarray):
58
+ dim = embeddings.shape[1]
59
+ # IndexFlatIP with normalized vectors -> cosine similarity
60
+ index = faiss.IndexFlatIP(dim)
61
+ faiss.normalize_L2(embeddings)
62
+ index.add(embeddings)
63
+ return index
64
+
65
+ def embed_texts(texts):
66
+ # sentence-transformers returns numpy arrays
67
+ embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
68
+ if NORMALIZE_EMB:
69
+ faiss.normalize_L2(embeddings)
70
+ return embeddings
71
+
72
+ # ---------- Gradio functions ----------
73
+ def process_pdf(pdf_file):
74
+ """
75
+ Upload and process PDF. Builds FAISS index and stores chunks & embeddings in memory.
76
+ Returns status message and basic metadata.
77
+ """
78
+ global faiss_index, pdf_chunks, pdf_embeddings, last_loaded_filename, last_loaded_at
79
+
80
+ if pdf_file is None:
81
+ return "⚠️ No file uploaded."
82
+
83
+ try:
84
+ # Extract text
85
+ reader = PdfReader(pdf_file.name)
86
+ full_text = []
87
+ for p in reader.pages:
88
+ text = p.extract_text()
89
+ if text:
90
+ full_text.append(text)
91
+ text = "\n".join(full_text).strip()
92
+ if not text:
93
+ return "⚠️ No readable text found in PDF."
94
+
95
+ # Chunk text
96
+ pdf_chunks = chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP)
97
+
98
+ # Embed chunks (batch)
99
+ pdf_embeddings = embed_texts(pdf_chunks)
100
+
101
+ # Build FAISS index
102
+ faiss_index = build_faiss_index(np.copy(pdf_embeddings))
103
+
104
+ last_loaded_filename = os.path.basename(pdf_file.name)
105
+ last_loaded_at = time.time()
106
+
107
+ return f"✅ PDF processed. {len(pdf_chunks)} chunks indexed. Ready for Q&A."
108
+ except Exception as e:
109
+ return f"❌ Error processing PDF: {e}"
110
+
111
+ def chat_with_pdf(query):
112
+ """
113
+ Retrieve relevant chunks and generate an answer using the generator model.
114
+ Designed for low-latency responses.
115
+ """
116
+ global faiss_index, pdf_chunks, pdf_embeddings
117
+
118
+ if faiss_index is None or pdf_chunks is None or len(pdf_chunks) == 0:
119
+ return "⚠️ Please upload and process a PDF first."
120
+
121
+ if not query or not query.strip():
122
+ return "⚠️ Please enter a question."
123
+
124
+ query = query.strip()
125
+
126
+ # Embed query
127
+ q_emb = embedder.encode([query], convert_to_numpy=True)
128
+ if NORMALIZE_EMB:
129
+ faiss.normalize_L2(q_emb)
130
+
131
+ # Search top-k
132
+ top_k = min(TOP_K, len(pdf_chunks))
133
+ distances, indices = faiss_index.search(q_emb, top_k)
134
+ indices = indices[0].tolist()
135
+
136
+ # Compose context from retrieved chunks (concatenate, truncate if too long)
137
+ retrieved = [pdf_chunks[i] for i in indices]
138
+ context = "\n\n".join(retrieved)
139
+
140
+ # Build prompt - be concise and reference context
141
+ system_prompt = (
142
+ "You are a helpful assistant that answers questions using only the provided context. "
143
+ "If the answer is not contained in the context, say 'I don't know based on the document.' "
144
+ "Be concise and factual."
145
+ )
146
+ prompt = (
147
+ f"{system_prompt}\n\n"
148
+ f"Context:\n{context}\n\n"
149
+ f"Question: {query}\n\n"
150
+ f"Answer:"
151
+ )
152
+
153
+ # Limit prompt size by truncating context from the left if it's too long
154
+ # Keep the question + system prompt + rightmost part of context
155
+ max_prompt_chars = 3000 # heuristic to keep generation fast
156
+ if len(prompt) > max_prompt_chars:
157
+ # keep the question and system prompt, then rightmost slice of context
158
+ right_context = context[-2000:]
159
+ prompt = f"{system_prompt}\n\nContext:\n{right_context}\n\nQuestion: {query}\n\nAnswer:"
160
+
161
+ # Generate
162
+ try:
163
+ out = qa_pipeline(
164
+ prompt,
165
+ max_new_tokens=MAX_NEW_TOKENS,
166
+ do_sample=False,
167
+ temperature=GEN_TEMPERATURE,
168
+ num_return_sequences=1,
169
+ )
170
+ answer = out[0]["generated_text"].strip()
171
+
172
+ # Safety: if model hallucinates beyond context, keep it short
173
+ return answer
174
+ except Exception as e:
175
+ return f"❌ Generation error: {e}"
176
+
177
+ # ---------- Gradio UI ----------
178
+ with gr.Blocks(title="PDF Chat (fast, retrieval-augmented)") as demo:
179
+ gr.Markdown("# 📚 Chat with your PDF — optimized for speed")
180
+ gr.Markdown(
181
+ "Upload a PDF, click **Process PDF**, then ask questions. "
182
+ "This app uses semantic search (FAISS) + a lightweight generator for quick responses."
183
+ )
184
+
185
+ with gr.Row():
186
+ file_in = gr.File(label="Upload PDF (PDF only)")
187
+ process_btn = gr.Button("Process PDF")
188
+ status = gr.Textbox(label="Status", interactive=False)
189
+
190
+ process_btn.click(fn=process_pdf, inputs=[file_in], outputs=[status])
191
+
192
+ gr.Markdown("---")
193
+ query = gr.Textbox(label="Ask a question about the PDF", placeholder="e.g. What is the main conclusion?")
194
+ ask_btn = gr.Button("Ask")
195
+ answer = gr.Textbox(label="Answer", lines=6)
196
+
197
+ ask_btn.click(fn=chat_with_pdf, inputs=[query], outputs=[answer])
198
+
199
+ gr.Markdown(
200
+ "Notes:\n"
201
+ "- The app keeps the processed PDF in memory for the session (no DB).\n"
202
+ "- Designed for low latency; tune CHUNK_SIZE/TOP_K/MAX_NEW_TOKENS for speed/quality tradeoffs."
203
+ )
204
+
205
+ if __name__ == "__main__":
206
+ demo.launch()