Srikesh commited on
Commit
e4efe09
Β·
verified Β·
1 Parent(s): c53d978

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -193
app.py CHANGED
@@ -1,206 +1,78 @@
1
- import os
2
  import gradio as gr
3
- from PyPDF2 import PdfReader
4
- from sentence_transformers import SentenceTransformer
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  import faiss
7
  import numpy as np
8
- import math
9
- import time
10
-
11
- # ---------- CONFIG ----------
12
- EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
- GEN_MODEL_NAME = "google/flan-t5-base" # fast & capable
14
- CHUNK_SIZE = 500 # characters per chunk (approx 250-350 tokens)
15
- CHUNK_OVERLAP = 100 # overlap between chunks to preserve context
16
- TOP_K = 4 # number of chunks retrieved
17
- MAX_NEW_TOKENS = 150 # generation length (keep small for speed)
18
- GEN_TEMPERATURE = 0.0 # deterministic, faster
19
- NORMALIZE_EMB = True
20
- # ----------------------------
21
-
22
- # Global state
23
- embedder = SentenceTransformer(EMBED_MODEL_NAME)
24
- tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
25
- gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
26
- # Use the pipeline for convenience (it wraps tokenizer+model)
27
- qa_pipeline = pipeline(
28
- "text2text-generation",
29
- model=gen_model,
30
- tokenizer=tokenizer,
31
- device=-1, # CPU (Spaces default). If GPU available, change to 0.
32
- )
33
-
34
- faiss_index = None
35
- pdf_chunks = [] # list[str]
36
- pdf_embeddings = None # numpy array (N, dim)
37
- last_loaded_filename = None
38
- last_loaded_at = None
39
-
40
- # ---------- utilities ----------
41
- def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
42
- if not text:
43
- return []
44
- chunks = []
45
- start = 0
46
- length = len(text)
47
- while start < length:
48
- end = start + chunk_size
49
- chunk = text[start:end].strip()
50
- if chunk:
51
- chunks.append(chunk)
52
- start = end - overlap # move with overlap
53
- if start < 0:
54
- start = 0
55
- return chunks
56
-
57
- def build_faiss_index(embeddings: np.ndarray):
58
- dim = embeddings.shape[1]
59
- # IndexFlatIP with normalized vectors -> cosine similarity
60
- index = faiss.IndexFlatIP(dim)
61
- faiss.normalize_L2(embeddings)
62
- index.add(embeddings)
63
- return index
64
-
65
- def embed_texts(texts):
66
- # sentence-transformers returns numpy arrays
67
- embeddings = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=False)
68
- if NORMALIZE_EMB:
69
- faiss.normalize_L2(embeddings)
70
- return embeddings
71
-
72
- # ---------- Gradio functions ----------
73
- def process_pdf(pdf_file):
74
- """
75
- Upload and process PDF. Builds FAISS index and stores chunks & embeddings in memory.
76
- Returns status message and basic metadata.
77
- """
78
- global faiss_index, pdf_chunks, pdf_embeddings, last_loaded_filename, last_loaded_at
79
-
80
- if pdf_file is None:
81
- return "⚠️ No file uploaded."
82
 
83
- try:
84
- # Extract text
85
- reader = PdfReader(pdf_file.name)
86
- full_text = []
87
- for p in reader.pages:
88
- text = p.extract_text()
89
- if text:
90
- full_text.append(text)
91
- text = "\n".join(full_text).strip()
92
- if not text:
93
- return "⚠️ No readable text found in PDF."
94
 
95
- # Chunk text
96
- pdf_chunks = chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP)
 
97
 
98
- # Embed chunks (batch)
99
- pdf_embeddings = embed_texts(pdf_chunks)
100
 
101
- # Build FAISS index
102
- faiss_index = build_faiss_index(np.copy(pdf_embeddings))
 
 
 
 
 
 
 
 
103
 
104
- last_loaded_filename = os.path.basename(pdf_file.name)
105
- last_loaded_at = time.time()
106
 
107
- return f"βœ… PDF processed. {len(pdf_chunks)} chunks indexed. Ready for Q&A."
108
- except Exception as e:
109
- return f"❌ Error processing PDF: {e}"
 
 
 
 
 
 
110
 
 
 
111
  def chat_with_pdf(query):
112
- """
113
- Retrieve relevant chunks and generate an answer using the generator model.
114
- Designed for low-latency responses.
115
- """
116
- global faiss_index, pdf_chunks, pdf_embeddings
117
-
118
- if faiss_index is None or pdf_chunks is None or len(pdf_chunks) == 0:
119
- return "⚠️ Please upload and process a PDF first."
120
-
121
- if not query or not query.strip():
122
- return "⚠️ Please enter a question."
123
-
124
- query = query.strip()
125
-
126
- # Embed query
127
- q_emb = embedder.encode([query], convert_to_numpy=True)
128
- if NORMALIZE_EMB:
129
- faiss.normalize_L2(q_emb)
130
-
131
- # Search top-k
132
- top_k = min(TOP_K, len(pdf_chunks))
133
- distances, indices = faiss_index.search(q_emb, top_k)
134
- indices = indices[0].tolist()
135
-
136
- # Compose context from retrieved chunks (concatenate, truncate if too long)
137
- retrieved = [pdf_chunks[i] for i in indices]
138
- context = "\n\n".join(retrieved)
139
-
140
- # Build prompt - be concise and reference context
141
- system_prompt = (
142
- "You are a helpful assistant that answers questions using only the provided context. "
143
- "If the answer is not contained in the context, say 'I don't know based on the document.' "
144
- "Be concise and factual."
145
- )
146
- prompt = (
147
- f"{system_prompt}\n\n"
148
- f"Context:\n{context}\n\n"
149
- f"Question: {query}\n\n"
150
- f"Answer:"
151
- )
152
-
153
- # Limit prompt size by truncating context from the left if it's too long
154
- # Keep the question + system prompt + rightmost part of context
155
- max_prompt_chars = 3000 # heuristic to keep generation fast
156
- if len(prompt) > max_prompt_chars:
157
- # keep the question and system prompt, then rightmost slice of context
158
- right_context = context[-2000:]
159
- prompt = f"{system_prompt}\n\nContext:\n{right_context}\n\nQuestion: {query}\n\nAnswer:"
160
-
161
- # Generate
162
  try:
163
- out = qa_pipeline(
164
- prompt,
165
- max_new_tokens=MAX_NEW_TOKENS,
166
- do_sample=False,
167
- temperature=GEN_TEMPERATURE,
168
- num_return_sequences=1,
169
- )
170
- answer = out[0]["generated_text"].strip()
171
-
172
- # Safety: if model hallucinates beyond context, keep it short
173
- return answer
174
  except Exception as e:
175
- return f"❌ Generation error: {e}"
176
-
177
- # ---------- Gradio UI ----------
178
- with gr.Blocks(title="PDF Chat (fast, retrieval-augmented)") as demo:
179
- gr.Markdown("# πŸ“š Chat with your PDF β€” optimized for speed")
180
- gr.Markdown(
181
- "Upload a PDF, click **Process PDF**, then ask questions. "
182
- "This app uses semantic search (FAISS) + a lightweight generator for quick responses."
183
- )
184
-
185
- with gr.Row():
186
- file_in = gr.File(label="Upload PDF (PDF only)")
187
- process_btn = gr.Button("Process PDF")
188
- status = gr.Textbox(label="Status", interactive=False)
189
-
190
- process_btn.click(fn=process_pdf, inputs=[file_in], outputs=[status])
191
-
192
- gr.Markdown("---")
193
- query = gr.Textbox(label="Ask a question about the PDF", placeholder="e.g. What is the main conclusion?")
194
- ask_btn = gr.Button("Ask")
195
- answer = gr.Textbox(label="Answer", lines=6)
196
-
197
- ask_btn.click(fn=chat_with_pdf, inputs=[query], outputs=[answer])
198
-
199
- gr.Markdown(
200
- "Notes:\n"
201
- "- The app keeps the processed PDF in memory for the session (no DB).\n"
202
- "- Designed for low latency; tune CHUNK_SIZE/TOP_K/MAX_NEW_TOKENS for speed/quality tradeoffs."
203
- )
204
-
205
- if __name__ == "__main__":
206
- demo.launch()
 
 
1
  import gradio as gr
2
+ import fitz # PyMuPDF
 
 
3
  import faiss
4
  import numpy as np
5
+ from sentence_transformers import SentenceTransformer
6
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # ⚑ Load models once for efficiency
9
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
10
+ qa_model = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
 
 
 
 
 
11
 
12
+ # Store embeddings globally
13
+ index = None
14
+ chunks = []
15
 
 
 
16
 
17
+ # 🧠 Extract text safely from PDF
18
+ def extract_text_from_pdf(pdf_file):
19
+ try:
20
+ with fitz.open(stream=pdf_file.read(), filetype="pdf") as doc:
21
+ text = ""
22
+ for page in doc:
23
+ text += page.get_text("text") + "\n"
24
+ return text.strip()
25
+ except Exception as e:
26
+ raise RuntimeError(f"PDF extraction error: {str(e)}")
27
 
 
 
28
 
29
+ # 🧱 Create FAISS index from PDF text
30
+ def create_index(text):
31
+ global index, chunks
32
+ # Split text into chunks for context
33
+ chunk_size = 800
34
+ chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
35
+ vectors = embedding_model.encode(chunks, convert_to_numpy=True)
36
+ index = faiss.IndexFlatL2(vectors.shape[1])
37
+ index.add(vectors)
38
 
39
+
40
+ # πŸ’¬ Chat function
41
  def chat_with_pdf(query):
42
+ if index is None:
43
+ return "❗ Please upload a PDF first."
44
+
45
+ # Get top 3 relevant chunks
46
+ q_vector = embedding_model.encode([query])
47
+ D, I = index.search(np.array(q_vector).astype("float32"), k=3)
48
+ context = " ".join([chunks[i] for i in I[0]])
49
+
50
+ # Generate answer
51
+ prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
52
+ response = qa_model(prompt, max_new_tokens=200, temperature=0.3)[0]["generated_text"]
53
+ return response.split("Answer:")[-1].strip()
54
+
55
+
56
+ # πŸ“„ Handle PDF uploads
57
+ def handle_pdf_upload(pdf_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
+ text = extract_text_from_pdf(pdf_file)
60
+ if not text:
61
+ return "❌ No readable text found in the PDF. It may be scanned."
62
+ create_index(text)
63
+ return "βœ… PDF processed successfully. You can now ask questions!"
 
 
 
 
 
 
64
  except Exception as e:
65
+ return f"❌ Error processing PDF: {str(e)}"
66
+
67
+
68
+ # 🎨 Gradio Interface
69
+ with gr.Blocks() as app:
70
+ gr.Markdown("## πŸ€– Chat with Your PDF β€” Fast & Reliable AI Assistant")
71
+ pdf_input = gr.File(label="Upload a PDF", file_types=[".pdf"])
72
+ status_box = gr.Textbox(label="Status", interactive=False)
73
+
74
+ pdf_input.change(fn=handle_pdf_upload, inputs=pdf_input, outputs=status_box)
75
+
76
+ gr.ChatInterface(fn=chat_with_pdf, title="Ask Questions about your PDF")
77
+
78
+ app.launch()